# Basic net

## 4-conv, 3-dense net on CIFAR-10

In [None]:
import torch.nn as nn

class Model(nn.Module):
    
    def __init__(self, in_resolution=32*32, in_channels=3, n_classes=10):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.conv1_bn = nn.BatchNorm2d(num_features=8)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2_bn = nn.BatchNorm2d(num_features=16)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3_bn = nn.BatchNorm2d(num_features=32)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv4_bn = nn.BatchNorm2d(num_features=64)
    
        self.dense1 = nn.Linear(in_features=64*in_resolution//(4*4*4*4), out_features=40)
        self.dense1_bn = nn.BatchNorm1d(num_features=40)
        self.dense2 = nn.Linear(in_features=40, out_features=10)
        self.dense2_bn = nn.BatchNorm1d(num_features=10)
        self.dense3 = nn.Linear(in_features=10, out_features=n_classes)
        self.dense3_bn = nn.BatchNorm1d(num_features=n_classes)
        
    def forward(self, x):
        x = self.conv1_bn(nn.ReLU()(self.conv1(x)))
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = self.conv2_bn(nn.ReLU()(self.conv2(x)))
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = self.conv3_bn(nn.ReLU()(self.conv3(x)))
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = self.conv4_bn(nn.ReLU()(self.conv4(x)))
        x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        x = nn.Flatten()(x)
        x = self.dense1_bn(nn.ReLU()(self.dense1(x)))
        x = self.dense2_bn(nn.ReLU()(self.dense2(x)))
        x = self.dense3_bn(self.dense3(x))
        return x

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
labels = 'airplane automobile bird cat deer dog frog horse ship truck'.split()

dataset_train = datasets.CIFAR10('data/CIFAR-10', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]))
dataset_val = datasets.CIFAR10('data/CIFAR-10', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]))

dataloader_train = DataLoader(dataset_train, batch_size=64, shuffle=True)
dataloader_val = DataLoader(dataset_train, batch_size=64, shuffle=True)
dataloaders = {'train': dataloader_train, 'val': dataloader_val}

In [None]:
from torch.optim import Adam, lr_scheduler

net = Model()
optimizer = Adam(net.parameters(), lr=1e-2, weight_decay=1e-5)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5)

trainer = Trainer(net, optimizer, scheduler)
trainer.train(dataloaders, epochs=30, early_stopping=5)
trainer.plot_training()

## Linear net on fake data

In [None]:
from torchvision.datasets import FakeData
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.optim import Adam
import torch.nn as nn
from torch.optim import lr_scheduler

dataset = FakeData(1000, (3, 224, 224), 10, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64)
data = {'train': dataloader, 'val': dataloader}

net = nn.Sequential(nn.Flatten(),nn.Linear(3*224*224, 10))
optimizer = Adam(net.parameters(), lr=1e-3)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5)

trainer = Trainer(net, optimizer, scheduler)
trainer.train(data, epochs=10, early_stopping=5)
trainer.plot_training()