In [1]:
!pip install ipywidgets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [10]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose, functional, RandomResizedCrop, Resize


In [11]:
# Training data
data_transform = Compose([
                                        RandomResizedCrop(500),
                                        ToTensor(),
                                        Resize(256),
                                        ])

training_data = datasets.Flowers102(
                        root="data",
                        split="train",
                        download=True,
                        transform=data_transform,
                        )


In [12]:
test_data = datasets.Flowers102(
                        root="data",
                        split="test",
                        download=True,
                        transform=data_transform,
                        )


In [13]:
training_data[1][0].size()

torch.Size([3, 256, 256])

In [14]:
batch_sz = 32
train_dataloader = DataLoader(training_data, batch_size=batch_sz)
test_dataloader = DataLoader(test_data, batch_size=batch_sz)


In [15]:
for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")


Shape of X [N, C, H, W]: torch.Size([32, 3, 256, 256])
Shape of y: torch.Size([32]) torch.int64
Using cuda device


In [None]:
class FlowerNet(nn.Module):
    def __init__(self):
        super(FlowerNet, self).__init__()
        self.flatten = nn.Flatten()
        self.conv_stack = nn.Sequential(
            nn.Conv2d(3, 20, 5),
            nn.ReLU(),
            nn.MaxPool2d(5),
            nn.Conv2d(20, 50, 5),
            nn.ReLU(),
            nn.MaxPool2d(5)
        )
        
        self.linear_stack = nn.Sequential(
            nn.Linear(404*404*50, 512),
            nn.ReLU(),
            nn.Linear(512, 102)
        )

    def forward(self, x):
        logits = self.conv_stack(x)
        logits = torch.flatten(logits)
        logits = self.linear_stack(logits)

        return logits

model = FlowerNet().to(device)
print(model)