In [3]:
import torch
import torchvision
import torchvision.transforms as transforms

In [5]:
"""
    A convolutional classifier will need:
    Input
        torch.Size([3, 32, 32])
        torch.float32
    Output
        tensor(5) # class no.6
        torch.int64
    """
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(
    root="../data", transform=transform, train=True
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True)
valset = torchvision.datasets.CIFAR10(
    root="../data", train=False, transform=transform
)
valloader = torch.utils.data.DataLoader(valset, batch_size=8, shuffle=True)

In [66]:
## NETWORK
import torch.nn as nn
import torch.nn.functional as F

class MyConv(nn.Module):

    def __init__(self,in_channels=3, out_channels=16, kernel_size=2, padding=2) -> None:
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU()
        )
    
    def forward(self, x: torch.Tensor):
        return self.net(x)

class VanillaNet(nn.Module):
    def __init__(self, n_classes) -> None:
        super().__init__()
        #
        self.conv1 = MyConv(in_channels=3, out_channels=16, kernel_size=2, padding=0)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = MyConv(16, 32, 2, padding=0)
        self.fc1 = nn.Linear(in_features=(32 * 7 * 7), out_features=256) # Calculated manually 
        self.fc2 = nn.Linear(in_features=256, out_features=32) 
        self.fc3 = nn.Linear(in_features=32, out_features=n_classes) 
        
    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        print(x.shape)
        x = self.pool(x)
        x = self.pool(self.conv2(x))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), dim=1)
        return x

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = VanillaNet(n_classes=len(classes))

In [67]:
def get_class_idx(out: torch.tensor):
    return torch.argmax(out).item()
for x, y in trainloader:
    out = model(x)
    print(classes[get_class_idx(out[0])])
    break

torch.Size([8, 16, 31, 31])
deer
