In [1]:
import torch
import torchvision
import torch.nn as nn
import time
import json
import datetime


# Data loading and augmentation

def load_cifar10_dataloaders():
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset_train = torchvision.datasets.CIFAR10(".data", download=True, transform=transform)
    dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128)
    dataset_test = torchvision.datasets.CIFAR10(".data", download=True, train=False, transform=transform)
    dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=128)
    return dataloader_train, dataloader_test

def load_cifar10_dataloaders_validation():
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = torchvision.datasets.CIFAR10(".data", download=True, transform=transform)
    size_train = 0.9*len(dataset)
    size_val = len(dataset) - size_train
    dataset_train, dataset_val = torch.utils.data.random_split(dataset, [int(size_train), int(size_val)])
    dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128)
    dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=128)
    dataset_test = torchvision.datasets.CIFAR10(".data", download=True, train=False, transform=transform)
    dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=128)
    return dataloader_train, dataloader_test, dataloader_val

In [2]:
# File based on https://github.com/huyvnphan/PyTorch_CIFAR10/
class VGG(nn.Module):
    def __init__(self, features, num_classes=10, avgpool_size=(1,1)):
        super(VGG, self).__init__()
        self.features = features

        self.avgpool = nn.AdaptiveAvgPool2d(avgpool_size)

        self.classifier = nn.Sequential(
            nn.Linear(512 * avgpool_size[0] * avgpool_size[1], 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

 # Create convolution part of VGG11_bn archhitecture
def make_vgg11_bn_layers(cfg = None):
    if cfg == None:
        cfg = [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"]
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

 # Create VGG11_bn model
def vgg11_bn(device="cpu", num_classes=10):
    model = VGG(make_vgg11_bn_layers(), num_classes=num_classes)
    return model


def eval_accuracy(model, dataloader, training_device='cpu'):
    with torch.no_grad():
        model.to(training_device)
        correct = 0
        all_so_far = 0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(training_device), labels.to(training_device)
            pred = torch.argmax(model(inputs), dim=1)

            all_so_far += labels.size().numel()
            correct += torch.sum(pred.eq(labels)).item()
    return correct/all_so_far

In [3]:
def backup_to_ram(model):
    from copy import deepcopy
    return deepcopy(model).cpu()

class EarlyStopper:
    def __init__(self, patience = 3, backup_method=backup_to_ram):
        self.patience = patience
        self.current = 0
        
        self.backup_method = backup_method
        
        self.best_backup = None
        self.best_accuracy = 0.

    def should_continue(self, accuracy, model = None):
        if self.best_accuracy < accuracy:
            self.current = 0
            self.best_accuracy = accuracy
            if model is not None:
                self.best_backup = self.backup_method(model)
            return True
        
        self.current += 1
        
        if self.current >= self.patience:
            return False
        return True

In [4]:
def train_one_epoch(model, optimizer, criterion, dataloader_train, training_device):
    for inputs, labels in dataloader_train:
        inputs, labels = inputs.to(training_device), labels.to(training_device)
        optimizer.zero_grad()
        loss = criterion(model(inputs), labels)
        loss.backward()
        optimizer.step()

def train_one_run(model, optimizer, criterion, dataloader_train, dataloader_val, max_epochs, early_stopper, trajectory, training_device='cuda', *_args, **_kwargs):
    model.train()
    model.to(training_device)

    for epoch in range(max_epochs):
        start_time = time.time()
        
        train_one_epoch(model, optimizer, criterion, dataloader_train, training_device)
        
        end_time = time.time()
        
        training_accuracy = eval_accuracy(model, dataloader_train, training_device)
        validation_accuracy = eval_accuracy(model, dataloader_val, training_device)
        print("Epoch: {}, Accuracy on validation set: {}".format(epoch, validation_accuracy))
        
        trajectory.append({
            "epoch": epoch,
            "train": training_accuracy,
            "validation": validation_accuracy,
            "start_time": start_time,
            "duration": end_time - start_time,
            "memory_allocated_mb": torch.cuda.memory_allocated()/1024/1024,
            "memory_reserved_mb": torch.cuda.memory_reserved()/1024/1024,
        })
        
        
        if not early_stopper.should_continue(validation_accuracy, model):
            print("Early stop")
            return early_stopper.best_backup
    
    return model

In [5]:
def run_arch_experiment(arch_name, model_factory, run):
    path = "experiments/arch/" + arch_name + "/" + run + "/"
    
    import os
    try:
        os.makedirs(path)
    except FileExistsError:
        print("Directory exists, skipping...")
        return
    
    train, test, val = load_cifar10_dataloaders_validation()
    model = model_factory()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    early_stopper = EarlyStopper(patience = 5)
    trajectory = []
    model = train_one_run(model, optimizer, criterion, train, val, 200, early_stopper, trajectory)
    validation_accuracy = eval_accuracy(model, val, "cuda")
    
    with open(path + "report.json", "w") as f:
        json.dump(
            {
                "name": arch_name,
                "run": run,
                "best_accuracy_validation": validation_accuracy,
                "time_generated": datetime.datetime.now().isoformat(),
                "trajectory": trajectory
            },
            f
        )
    torch.save(model, path + "model.pt")

In [6]:
def small_vgg11():
    return VGG(make_vgg11_bn_layers(), num_classes=10, avgpool_size=(1,1)).to("cuda")

def medium_vgg11():
    return VGG(make_vgg11_bn_layers(), num_classes=10, avgpool_size=(2,2)).to("cuda")

def full_vgg11():
    return VGG(make_vgg11_bn_layers(), num_classes=10, avgpool_size=(7,7)).to("cuda")

In [7]:
experiment_list = [
    (factory.__name__, factory, str(run))
    for factory in [small_vgg11, medium_vgg11, full_vgg11]
    for run in range(1, 11)
]

In [8]:
for experiment in experiment_list:
    print(
        "Time:", datetime.datetime.now().isoformat(), 
        "Arch: ", experiment[0], 
        "Run: ", experiment[2]
    )
    run_arch_experiment(*experiment)

Time: 2023-03-13T22:41:04.364512 Arch:  small_vgg11 Run:  1
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365353 Arch:  small_vgg11 Run:  2
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365409 Arch:  small_vgg11 Run:  3
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365440 Arch:  small_vgg11 Run:  4
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365491 Arch:  small_vgg11 Run:  5
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365522 Arch:  small_vgg11 Run:  6
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365552 Arch:  small_vgg11 Run:  7
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365720 Arch:  small_vgg11 Run:  8
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365759 Arch:  small_vgg11 Run:  9
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365789 Arch:  small_vgg11 Run:  10
Directory exists, skipping...
Time: 2023-03-13T22:41:04.365837 Arch:  medium_vgg11 Run:  1
Files already downloaded and verified


Epoch: 3, Accuracy on validation set: 0.5394
Epoch: 4, Accuracy on validation set: 0.622
Epoch: 5, Accuracy on validation set: 0.7082
Epoch: 6, Accuracy on validation set: 0.7272
Epoch: 7, Accuracy on validation set: 0.7522
Epoch: 8, Accuracy on validation set: 0.7562
Epoch: 9, Accuracy on validation set: 0.7766
Epoch: 10, Accuracy on validation set: 0.7708
Epoch: 11, Accuracy on validation set: 0.7848
Epoch: 12, Accuracy on validation set: 0.7934
Epoch: 13, Accuracy on validation set: 0.7784
Epoch: 14, Accuracy on validation set: 0.7892
Epoch: 15, Accuracy on validation set: 0.7856
Epoch: 16, Accuracy on validation set: 0.794
Epoch: 17, Accuracy on validation set: 0.8022
Epoch: 18, Accuracy on validation set: 0.8034
Epoch: 19, Accuracy on validation set: 0.8018
Epoch: 20, Accuracy on validation set: 0.81
Epoch: 21, Accuracy on validation set: 0.7996
Epoch: 22, Accuracy on validation set: 0.8104
Epoch: 23, Accuracy on validation set: 0.8056
Epoch: 24, Accuracy on validation set: 0.811


Epoch: 3, Accuracy on validation set: 0.1786
Epoch: 4, Accuracy on validation set: 0.2088
Epoch: 5, Accuracy on validation set: 0.2092
Epoch: 6, Accuracy on validation set: 0.223
Epoch: 7, Accuracy on validation set: 0.2208
Epoch: 8, Accuracy on validation set: 0.241
Epoch: 9, Accuracy on validation set: 0.235
Epoch: 10, Accuracy on validation set: 0.2376
Epoch: 11, Accuracy on validation set: 0.2462
Epoch: 12, Accuracy on validation set: 0.2376
Epoch: 13, Accuracy on validation set: 0.2568
Epoch: 14, Accuracy on validation set: 0.2594
Epoch: 15, Accuracy on validation set: 0.2706
Epoch: 16, Accuracy on validation set: 0.279
Epoch: 17, Accuracy on validation set: 0.3022
Epoch: 18, Accuracy on validation set: 0.3548
Epoch: 19, Accuracy on validation set: 0.3734
Epoch: 20, Accuracy on validation set: 0.4148
Epoch: 21, Accuracy on validation set: 0.5806
Epoch: 22, Accuracy on validation set: 0.6588
Epoch: 23, Accuracy on validation set: 0.7008
Epoch: 24, Accuracy on validation set: 0.7158

Epoch: 34, Accuracy on validation set: 0.7668
Epoch: 35, Accuracy on validation set: 0.7624
Epoch: 36, Accuracy on validation set: 0.7562
Epoch: 37, Accuracy on validation set: 0.7642
Early stop
Time: 2023-03-14T01:48:49.697168 Arch:  full_vgg11 Run:  5
Files already downloaded and verified
Files already downloaded and verified
Epoch: 0, Accuracy on validation set: 0.1624
Epoch: 1, Accuracy on validation set: 0.186
Epoch: 2, Accuracy on validation set: 0.1868
Epoch: 3, Accuracy on validation set: 0.1978
Epoch: 4, Accuracy on validation set: 0.2694
Epoch: 5, Accuracy on validation set: 0.2586
Epoch: 6, Accuracy on validation set: 0.3452
Epoch: 7, Accuracy on validation set: 0.3762
Epoch: 8, Accuracy on validation set: 0.3282
Epoch: 9, Accuracy on validation set: 0.5018
Epoch: 10, Accuracy on validation set: 0.519
Epoch: 11, Accuracy on validation set: 0.6308
Epoch: 12, Accuracy on validation set: 0.6706
Epoch: 13, Accuracy on validation set: 0.6332
Epoch: 14, Accuracy on validation set:

Epoch: 29, Accuracy on validation set: 0.7754
Epoch: 30, Accuracy on validation set: 0.7372
Epoch: 31, Accuracy on validation set: 0.771
Epoch: 32, Accuracy on validation set: 0.7658
Epoch: 33, Accuracy on validation set: 0.7604
Epoch: 34, Accuracy on validation set: 0.7618
Early stop
Time: 2023-03-14T03:11:34.822278 Arch:  full_vgg11 Run:  10
Files already downloaded and verified
Files already downloaded and verified
Epoch: 0, Accuracy on validation set: 0.1416
Epoch: 1, Accuracy on validation set: 0.135
Epoch: 2, Accuracy on validation set: 0.148
Epoch: 3, Accuracy on validation set: 0.1344
Epoch: 4, Accuracy on validation set: 0.1338
Epoch: 5, Accuracy on validation set: 0.1428
Epoch: 6, Accuracy on validation set: 0.1564
Epoch: 7, Accuracy on validation set: 0.1554
Epoch: 8, Accuracy on validation set: 0.1628
Epoch: 9, Accuracy on validation set: 0.155
Epoch: 10, Accuracy on validation set: 0.1902
Epoch: 11, Accuracy on validation set: 0.1662
Epoch: 12, Accuracy on validation set: 