In [10]:
!pip install ipywidgets



In [11]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch import optim
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose, functional, RandomResizedCrop, Resize


In [12]:
# Training data
data_transform = Compose([
                                        RandomResizedCrop(500),
                                        ToTensor(),
                                        Resize(256),
                                        ])

training_data = datasets.Flowers102(
                        root="data",
                        split="train",
                        download=True,
                        transform=data_transform,
                        )

In [13]:
test_data = datasets.Flowers102(
                        root="data",
                        split="test",
                        download=True,
                        transform=data_transform,
                        )


In [14]:
training_data[1][0].size()

torch.Size([3, 256, 256])

In [12]:
batch_sz = 32
train_dataloader = DataLoader(training_data, batch_size=batch_sz)
test_dataloader = DataLoader(test_data, batch_size=batch_sz)


In [13]:
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")


Shape of X [N, C, H, W]: torch.Size([32, 3, 256, 256])
Shape of y: torch.Size([32]) torch.int64
Using cuda device


In [14]:
class FlowerNet(nn.Module):
    def __init__(self):
        super(FlowerNet, self).__init__()
        self.flatten = nn.Flatten()
        self.conv_stack = nn.Sequential(
            nn.Conv2d(3, 10, 3, stride=2 ),
            nn.ReLU(),
            nn.Conv2d(10, 20, 5, stride=2 ),
            nn.ReLU(),
            nn.MaxPool2d(5,stride=1),
            nn.Conv2d(20, 30, 7, stride=1),
            nn.ReLU(),
            nn.Conv2d(30, 40, 5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(5,stride=1)
        )
        
        self.linear_stack = nn.Sequential(
            nn.Linear(77440, 512),
            nn.ReLU(),
            nn.Linear(512, 102),
            nn.ReLU()
        )

    def forward(self, x):
        logits = self.conv_stack(x)
        logits = torch.flatten(logits,1,-1)
        # print(f"flatted output{logits.size}")
        logits = self.linear_stack(logits)

        return logits

model = FlowerNet().to(device)
print(model)

FlowerNet(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (conv_stack): Sequential(
    (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(10, 20, kernel_size=(5, 5), stride=(2, 2))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=5, stride=1, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(20, 30, kernel_size=(7, 7), stride=(1, 1))
    (6): ReLU()
    (7): Conv2d(30, 40, kernel_size=(5, 5), stride=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=5, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (linear_stack): Sequential(
    (0): Linear(in_features=77440, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=102, bias=True)
    (3): ReLU()
  )
)


In [19]:
# Define optimizer
learning_rate = 0.001
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
loss_criterion = nn.CrossEntropyLoss()
num_epochs = 500
# in your training loop:
for i in range(num_epochs):
  print(f"Running Epoch number: {i}")
  batch = 0;
  running_loss = 0
  for x,y in train_dataloader:
    if(y.shape[0] != batch_sz):
      continue
    optimizer.zero_grad()   # zero the gradient buffers
    # print(x.shape, x.dtype)
    x = x.to(device)
    y = y.to(device)
    output = model(x)
    # print(output.shape)
    # print(y.shape)
    loss = loss_criterion(output, y)
    running_loss += loss.item()
    loss.backward()
    optimizer.step()    # Does the updatefrom torch.utils.data import DataLoader
    batch+=1
  print(f"Loss value is: {running_loss/batch}")


Running Epoch number: 0
Loss value is: 4.36894472952812
Running Epoch number: 1
Loss value is: 4.346402914293351
Running Epoch number: 2
Loss value is: 4.3473436909337195
Running Epoch number: 3
Loss value is: 4.340477320455736
Running Epoch number: 4
Loss value is: 4.325738122386317
Running Epoch number: 5
Loss value is: 4.327804188574514
Running Epoch number: 6
Loss value is: 4.309423131327475
Running Epoch number: 7
Loss value is: 4.323642238493888
Running Epoch number: 8
Loss value is: 4.301605839883128
Running Epoch number: 9
Loss value is: 4.3150522631983605
Running Epoch number: 10
Loss value is: 4.32473994839576
Running Epoch number: 11
Loss value is: 4.291014771307668
Running Epoch number: 12
Loss value is: 4.297827959060669
Running Epoch number: 13
Loss value is: 4.280280843857796
Running Epoch number: 14
Loss value is: 4.2752006207743
Running Epoch number: 15
Loss value is: 4.247879658975909
Running Epoch number: 16
Loss value is: 4.2641380063949095
Running Epoch number: 17


In [20]:
PATH = './flowernet.pth'
torch.save(model.state_dict(), PATH)

# New Section