In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.nn import functional as F

from torchvision.datasets import CIFAR10, CelebA, ImageNet
from torchvision import transforms as T

import matplotlib.pyplot as plt

from models import *
from utils import *
from train import train
from groups import *

##### Data

In [2]:
batch_size = 2

In [3]:
# traindata = CelebA('./data', split='train', transform = T.Compose([T.ToTensor(), T.Normalize(.5, .5), T.Resize((218, 218), antialias=True)]), download=True)
# testdata = CelebA('./data', split='test', transform = T.Compose([T.ToTensor(), T.Normalize(.5, .5), T.Resize((218, 218), antialias=True)]))#, download=True)
valdata = CelebA('./data', split='valid', transform = T.Compose([T.ToTensor(), T.Normalize(.5, .5), T.Resize((218, 218), antialias=True)]))#, download=True)

In [4]:
# trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=True, drop_last=True)
# testloader = DataLoader(testdata, batch_size=batch_size, shuffle=True, drop_last=True)
valloader = DataLoader(valdata, batch_size=batch_size, shuffle=True, drop_last=True)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

##### CNN

In [None]:
cnnmodel = CNN(
    [3, 16, 32, 64, 128, 256], 
    [9, 9, 9, 9, 5],
    [0, 0, 0, 0, 0], 
    (256, 40),
    0,
    True,
    device
)
print(cnnmodel.n_params)

In [None]:
train_loss, reg_loss, train_acc, val_loss, val_acc = train(
    cnnmodel,
    optim.Adam,
    nn.BCELoss(),
    1,
    1e-3,
    trainloader, 
    valloader,
    reg_str = 1e-1,
    reg_ord = 2,
    celeba = True
)

In [None]:
_, _ = plot_losses(train_loss, reg_loss, val_loss, train_acc, val_acc)

In [None]:
calc_accuracy(cnnmodel, testloader)

##### p4CNN

In [None]:
scale = 2
p4model = GroupCNN(
    group=C4,
    channels=[3, int(16//scale), int(32//scale), int(64//scale), int(128//scale), int(256//scale)], 
    kernel_sizes=[9, 9, 9, 9, 5],
    paddings=0,
    pooling_kernels=[(1,2,2), (1,2,2), (1,2,2), (1,2,2), (4,2,2)],
    pooling_strides=[(1,2,2), (1,2,2), (1,2,2), (1,2,2), (1,2,2)],
    pooling_paddings=[(0,0,0), (0,0,0), (0,0,0), (0,0,0), (0,0,0)],
    output_dims=(256, 40),
    sigmoid_out=True,
    device=device
)
print(p4model.n_params)

In [None]:
train_loss, reg_loss, train_acc, val_loss, val_acc = train(
    p4model,
    optim.Adam,
    nn.CrossEntropyLoss(),
    1,
    1e-3,
    trainloader, 
    valloader,
    reg_str = 1e-1,
    reg_ord = 2,
    celeba=True
)

In [None]:
_, _ = plot_losses(train_loss, reg_loss, val_loss, train_acc, val_acc)

In [None]:
calc_accuracy(p4model, test_loader)

##### p4mCNN

In [8]:
scale =  2.8 # ~sqrt(8)
p4mmodel = GroupCNN(
    group=D4,
    channels=[3, int(16//scale), int(32//scale), int(64//scale), int(128//scale), int(256//scale)], 
    kernel_sizes=[9, 9, 9, 9, 5],
    paddings=0,
    pooling_kernels=[(1,2,2), (1,2,2), (1,2,2), (1,2,2), (8,2,2)],
    pooling_strides=[(1,2,2), (1,2,2), (1,2,2), (1,2,2), (1,2,2)],
    pooling_paddings=[(0,0,0), (0,0,0), (0,0,0), (0,0,0), (0,0,0)],
    output_dims=(int(256//scale), 40),
    sigmoid_out=True,
    # device=device
)
print(p4mmodel.n_params)

1658045


In [None]:
train_loss, reg_loss, train_acc, val_loss, val_acc = train(
    p4mmodel,
    optim.Adam,
    nn.BCELoss(),
    1,
    1e-3,
    trainloader, 
    valloader,
    reg_str = 1e-1,
    reg_ord = 2,
    celeba=True,
)

In [None]:
_, _ = plot_losses(train_loss, reg_loss, val_loss, train_acc, val_acc)

In [None]:
calc_accuracy(p4mmodel, test_loader)