In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split

from torchvision.datasets import CIFAR10
from torchvision import transforms as T

import matplotlib.pyplot as plt

from models import *
from utils import *
from train import train
from groups import *

##### Data

In [3]:
batch_size = 128

In [4]:
train_subset, val_subset = random_split(
    CIFAR10('./data/', train=True, transform = T.Compose([T.ToTensor(), T.Normalize(.5, .5)]), download=True), 
    [45000, 5000]
)
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=True, drop_last=True)

Files already downloaded and verified


In [5]:
test_set = CIFAR10('./data/', train=False, transform = T.Compose([T.ToTensor(), T.Normalize(.5, .5)]), download=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last=True)

Files already downloaded and verified


##### CNN

In [6]:
cnnmodel = CNN([1, 8, 64, 128], 3, 0, (128,10))
print(cnnmodel.n_params)

79898


In [None]:
train_loss, reg_loss, train_acc, val_loss, val_acc = train(
    cnnmodel,
    optim.Adam,
    nn.CrossEntropyLoss(),
    1,
    1e-3,
    train_loader, 
    val_loader,
    reg_str = 1e-1,
    reg_ord = 1
)

In [None]:
_, _ = plot_losses(train_loss, reg_loss, val_loss, train_acc, val_acc)

In [None]:
calc_accuracy(model, test_loader)

##### p4CNN

In [None]:
p4model = GroupCNN(
    group=C4,
    channels=[1, 8, 32, 64],
    kernel_sizes=3,
    paddings=0,
    pooling_kernels=[(1,2,2), (1,2,2), (4,2,2)],
    pooling_strides=[(1,2,2), (1,2,2), (1,2,2)],
    pooling_paddings=[(0,0,0), (0,0,0), (0,0,0)],
    output_dims=(64,10)
)
print(p4model.n_params)

In [None]:
train_loss, reg_loss, train_acc, val_loss, val_acc = train(
    p4model,
    optim.Adam,
    nn.CrossEntropyLoss(),
    10,
    1e-3
    train_loader, 
    val_loader,
    reg_str = 1e-3,
    reg_ord = 1
)

In [None]:
_, _ = plot_losses(train_loss, reg_loss, val_loss, train_acc, val_acc)

In [None]:
calc_accuracy(p4model, test_loader)

##### p4mCNN

In [None]:
p4mmodel = GroupCNN(
    group=D4,
    channels=[1, 8, 16, 64],
    kernel_sizes=3,
    paddings=0,
    pooling_kernels=[(1,2,2), (1,2,2), (8,2,2)],
    pooling_strides=[(1,2,2), (1,2,2), (1,2,2)],
    pooling_paddings=[(0,0,0), (0,0,0), (0,0,0)],
    output_dims=(64,10)
)
print(p4mmodel.n_params)

In [None]:
train_loss, reg_loss, train_acc, val_loss, val_acc = train(
    p4mmodel,
    optim.Adam,
    nn.CrossEntropyLoss(),
    10,
    1e-3
    train_loader, 
    val_loader,
    reg_str = 1e-3,
    reg_ord = 1
)

In [None]:
_, _ = plot_losses(train_loss, reg_loss, val_loss, train_acc, val_acc)

In [None]:
calc_accuracy(mp4model, test_loader)