In [1]:
import torch
import torchvision
import torch.nn as nn
import time
import json
import datetime

In [2]:
# File based on https://github.com/huyvnphan/PyTorch_CIFAR10/
class VGG(nn.Module):
    def __init__(self, features, num_classes=10, avgpool_size=(1,1)):
        super(VGG, self).__init__()
        self.features = features

        self.avgpool = nn.AdaptiveAvgPool2d(avgpool_size)

        self.classifier = nn.Sequential(
            nn.Linear(512 * avgpool_size[0] * avgpool_size[1], 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

 # Create convolution part of VGG11_bn archhitecture
def make_vgg11_bn_layers(cfg = None):
    if cfg == None:
        cfg = [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"]
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

 # Create VGG11_bn model
def vgg11_bn(device="cpu", num_classes=10):
    model = VGG(make_vgg11_bn_layers(), num_classes=num_classes)
    return model

vgg_cfg = {
    '8':  [64, 'M', 128, 'M', 256, 'M', 512, 'M', 512, 'M'],
    '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 
          512, 512, 512, 512, 'M'],
}



def eval_accuracy(model, dataloader, training_device='cpu'):
    model.eval()
    with torch.no_grad():
        model.to(training_device)
        correct = 0
        all_so_far = 0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(training_device), labels.to(training_device)
            pred = torch.argmax(model(inputs), dim=1)

            all_so_far += labels.size().numel()
            correct += torch.sum(pred.eq(labels)).item()
    model.train()
    return correct/all_so_far

In [3]:
def backup_to_ram(model):
    from copy import deepcopy
    return deepcopy(model).cpu()

class EarlyStopper:
    def __init__(self, patience = 3, backup_method=backup_to_ram):
        self.patience = patience
        self.current = 0
        
        self.backup_method = backup_method
        
        self.best_backup = None
        self.best_accuracy = 0.

    def should_continue(self, accuracy, model = None):
        if self.best_accuracy < accuracy:
            self.current = 0
            self.best_accuracy = accuracy
            if model is not None:
                self.best_backup = self.backup_method(model)
            return True
        
        self.current += 1
        
        if self.current >= self.patience:
            return False
        return True

In [4]:
def train_one_epoch(model, optimizer, criterion, dataloader_train, training_device):
    for inputs, labels in dataloader_train:
        inputs, labels = inputs.to(training_device), labels.to(training_device)
        optimizer.zero_grad()
        loss = criterion(model(inputs), labels)
        loss.backward()
        optimizer.step()

def train_one_run(model, optimizer, criterion,
                  dataloader_train, dataloader_val,
                  max_epochs, early_stopper, 
                  trajectory, 
                  training_device='cuda', mem_bg_allocated=0, mem_bg_reserved=0,
                  *_args, **_kwargs):
    model.train()
    model.to(training_device)

    for epoch in range(max_epochs):
        start_time = time.time()
        
        train_one_epoch(model, optimizer, criterion, dataloader_train, training_device)
        
        end_time = time.time()
        
        training_accuracy = eval_accuracy(model, dataloader_train, training_device)
        validation_accuracy = eval_accuracy(model, dataloader_val, training_device)
        print("Epoch: {}, Accuracy on validation set: {}".format(epoch, validation_accuracy))
        
        trajectory.append({
            "epoch": epoch,
            "train": training_accuracy,
            "validation": validation_accuracy,
            "start_time": start_time,
            "duration": end_time - start_time,
            "memory_allocated_mb": (torch.cuda.memory_allocated() - mem_bg_allocated)/1024/1024,
            "memory_reserved_mb": (torch.cuda.memory_reserved() - mem_bg_reserved)/1024/1024,
        })
        
        
        if not early_stopper.should_continue(validation_accuracy, model):
            print("Early stop")
            return early_stopper.best_backup
    
    return model

In [5]:
def run_train_experiment(arch_name, model_factory, aug_name, aug_factory, train_func, train_name, run):
    path = f"experiments/train_{train_name}_aug_{aug_name}_arch_{arch_name}_{run}_"
    
    # 
    import os
    try:
        if os.stat(path + "report.json").st_size != 0:
            print("Report exists already for " + path[:-1] + ". Skipping...")
            return
    except OSError:
        pass
    
    model, trajectory, validation_accuracy = train_name(aug_factory, model_factory)
    
    with open(path + "report.json", "w") as f:
        json.dump(
            {
                "name": arch_name,
                "run": run,
                "augment": aug_name,
                "train": train_name, 
                "best_accuracy_validation": validation_accuracy,
                "time_generated": datetime.datetime.now().isoformat(),
                "trajectory": trajectory
            },
            f
        )
    torch.save(model, path + "model.pt")

In [6]:
def fasterlearn(aug_factory, model_factory):
    torch.cuda.empty_cache()
    mem_bg_allocated = torch.cuda.memory_allocated()
    mem_bg_reserved = torch.cuda.memory_reserved()
    
    train, test, val = aug_factory()
    model = model_factory()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    early_stopper = EarlyStopper(patience = 5)
    trajectory = []
    model = train_one_run(model, optimizer, criterion,
                          train, val,
                          200, early_stopper,
                          trajectory, 
                          mem_bg_allocated=mem_bg_allocated,
                          mem_bg_reserved=mem_bg_reserved)
    validation_accuracy = eval_accuracy(model, val, "cuda")
    return model, trajectory, validation_accuracy

def morepatient(aug_factory, model_factory):
    torch.cuda.empty_cache()
    mem_bg_allocated = torch.cuda.memory_allocated()
    mem_bg_reserved = torch.cuda.memory_reserved()
    
    train, test, val = aug_factory()
    model = model_factory()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    early_stopper = EarlyStopper(patience = 10)
    trajectory = []
    model = train_one_run(model, optimizer, criterion,
                          train, val,
                          200, early_stopper,
                          trajectory, 
                          mem_bg_allocated=mem_bg_allocated,
                          mem_bg_reserved=mem_bg_reserved)
    validation_accuracy = eval_accuracy(model, val, "cuda")    
    return model, trajectory, validation_accuracy

def smallbatch(aug_factory, model_factory):
    torch.cuda.empty_cache()
    mem_bg_allocated = torch.cuda.memory_allocated()
    mem_bg_reserved = torch.cuda.memory_reserved()
    
    train, test, val = aug_factory(bs=16)
    model = model_factory()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    early_stopper = EarlyStopper(patience = 5)
    trajectory = []
    model = train_one_run(model, optimizer, criterion,
                          train, val,
                          200, early_stopper,
                          trajectory, 
                          mem_bg_allocated=mem_bg_allocated,
                          mem_bg_reserved=mem_bg_reserved)
    validation_accuracy = eval_accuracy(model, val, "cuda")    
    return model, trajectory, validation_accuracy


def bigbatch(aug_factory, model_factory):
    torch.cuda.empty_cache()
    mem_bg_allocated = torch.cuda.memory_allocated()
    mem_bg_reserved = torch.cuda.memory_reserved()
    
    train, test, val = aug_factory(bs=1024)
    model = model_factory()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    early_stopper = EarlyStopper(patience = 5)
    trajectory = []
    model = train_one_run(model, optimizer, criterion,
                          train, val,
                          200, early_stopper,
                          trajectory, 
                          mem_bg_allocated=mem_bg_allocated,
                          mem_bg_reserved=mem_bg_reserved)
    validation_accuracy = eval_accuracy(model, val, "cuda")    
    return model, trajectory, validation_accuracy

In [7]:
def flip(bs=128):
    return augmented_cifar10_dataset_randomflip(bs=bs)

def smallrotate(bs=128):
    return augmented_cifar10_dataset_rotate_randomapply(5, bs=bs)

def fliprotate(bs=128):
    return augmented_cifar10_dataset_randomflip_rotate_randomapply(5, bs=bs)


In [8]:
def vgg_from_cfg(cfg_key):
    return VGG(
        make_vgg11_bn_layers(cfg=vgg_cfg[cfg_key]),
        num_classes=10,
        avgpool_size=(1,1)
    ).to("cuda")

def vgg13():
    return vgg_from_cfg("13")

def vgg16():
    return vgg_from_cfg("16")


In [9]:
experiment_list = [
    (
        archfactory.__name__, 
        archfactory, 
        augfactory.__name__, 
        augfactory, 
        trainfunc.__name__,
        trainfunc,
        str(run)
    )
    for run in range(1, 10) 
    for archfactory in [vgg13, vgg16]
    for augfactory in [flip, smallrotate,  fliprotate]
    for trainfunc in [fasterlearn, morepatient]
] + [
    (
        archfactory.__name__, 
        archfactory, 
        "none", # memory is the only interesting part for batch size
        load_cifar10_dataloaders_validation, 
        trainfunc.__name__,
        trainfunc,
        str(1) # memory is the only interesting part for batch size
    )
    for archfactory in [vgg13, vgg16]
    for trainfunc in [smallbatch, bigbatch]    
]

NameError: name 'load_cifar10_dataloaders_validation' is not defined

In [20]:
len(experiment_list)

64

In [11]:
for experiment in experiment_list:
    print(
        "Time:", datetime.datetime.now().isoformat(),
        *experiment[::2]
    )
    try:
        run_arch_experiment(*experiment)
    except Exception as e:
        print("Error occured, skipping...", repr(e))

Time: 2023-03-16T06:48:24.402368 Arch:  vgg8 Run:  1
Report exists already for experiments/arch_vgg8_1. Skipping...
Time: 2023-03-16T06:48:24.402548 Arch:  vgg8 Run:  2
Report exists already for experiments/arch_vgg8_2. Skipping...
Time: 2023-03-16T06:48:24.404434 Arch:  vgg8 Run:  3
Report exists already for experiments/arch_vgg8_3. Skipping...
Time: 2023-03-16T06:48:24.404495 Arch:  vgg8 Run:  4
Report exists already for experiments/arch_vgg8_4. Skipping...
Time: 2023-03-16T06:48:24.404513 Arch:  vgg8 Run:  5
Report exists already for experiments/arch_vgg8_5. Skipping...
Time: 2023-03-16T06:48:24.404527 Arch:  vgg8 Run:  6
Report exists already for experiments/arch_vgg8_6. Skipping...
Time: 2023-03-16T06:48:24.404557 Arch:  vgg8 Run:  7
Report exists already for experiments/arch_vgg8_7. Skipping...
Time: 2023-03-16T06:48:24.404587 Arch:  vgg8 Run:  8
Report exists already for experiments/arch_vgg8_8. Skipping...
Time: 2023-03-16T06:48:24.404605 Arch:  vgg8 Run:  9
Report exists alrea

Epoch: 10, Accuracy on validation set: 0.7724
Epoch: 11, Accuracy on validation set: 0.7932
Epoch: 12, Accuracy on validation set: 0.7936
Epoch: 13, Accuracy on validation set: 0.794
Epoch: 14, Accuracy on validation set: 0.7992
Epoch: 15, Accuracy on validation set: 0.8052
Epoch: 16, Accuracy on validation set: 0.802
Epoch: 17, Accuracy on validation set: 0.797
Epoch: 18, Accuracy on validation set: 0.8034
Epoch: 19, Accuracy on validation set: 0.7986
Epoch: 20, Accuracy on validation set: 0.8134
Epoch: 21, Accuracy on validation set: 0.8144
Epoch: 22, Accuracy on validation set: 0.819
Epoch: 23, Accuracy on validation set: 0.8086
Epoch: 24, Accuracy on validation set: 0.8122
Epoch: 25, Accuracy on validation set: 0.8046
Epoch: 26, Accuracy on validation set: 0.811
Epoch: 27, Accuracy on validation set: 0.816
Early stop
Time: 2023-03-16T07:41:45.285752 Arch:  vgg11_fat_classifier Run:  6
Files already downloaded and verified
Files already downloaded and verified
Epoch: 0, Accuracy on 

Epoch: 20, Accuracy on validation set: 0.8124
Epoch: 21, Accuracy on validation set: 0.8066
Epoch: 22, Accuracy on validation set: 0.8234
Epoch: 23, Accuracy on validation set: 0.8064
Epoch: 24, Accuracy on validation set: 0.8142
Epoch: 25, Accuracy on validation set: 0.8142
Epoch: 26, Accuracy on validation set: 0.8058
Epoch: 27, Accuracy on validation set: 0.8166
Early stop
Time: 2023-03-16T08:35:48.191569 Arch:  vgg13 Run:  1
Files already downloaded and verified
Files already downloaded and verified
Epoch: 0, Accuracy on validation set: 0.3634
Epoch: 1, Accuracy on validation set: 0.5174
Epoch: 2, Accuracy on validation set: 0.6494
Epoch: 3, Accuracy on validation set: 0.7
Epoch: 4, Accuracy on validation set: 0.7426
Epoch: 5, Accuracy on validation set: 0.7704
Epoch: 6, Accuracy on validation set: 0.7756
Epoch: 7, Accuracy on validation set: 0.7952
Epoch: 8, Accuracy on validation set: 0.8006
Epoch: 9, Accuracy on validation set: 0.8068
Epoch: 10, Accuracy on validation set: 0.807

Epoch: 25, Accuracy on validation set: 0.841
Epoch: 26, Accuracy on validation set: 0.8418
Epoch: 27, Accuracy on validation set: 0.8368
Early stop
Time: 2023-03-16T09:35:50.093759 Arch:  vgg13 Run:  6
Files already downloaded and verified
Files already downloaded and verified
Epoch: 0, Accuracy on validation set: 0.269
Epoch: 1, Accuracy on validation set: 0.4862
Epoch: 2, Accuracy on validation set: 0.612
Epoch: 3, Accuracy on validation set: 0.7022
Epoch: 4, Accuracy on validation set: 0.7406
Epoch: 5, Accuracy on validation set: 0.7736
Epoch: 6, Accuracy on validation set: 0.7894
Epoch: 7, Accuracy on validation set: 0.8048
Epoch: 8, Accuracy on validation set: 0.8134
Epoch: 9, Accuracy on validation set: 0.8116
Epoch: 10, Accuracy on validation set: 0.8146
Epoch: 11, Accuracy on validation set: 0.8102
Epoch: 12, Accuracy on validation set: 0.8208
Epoch: 13, Accuracy on validation set: 0.8206
Epoch: 14, Accuracy on validation set: 0.8126
Epoch: 15, Accuracy on validation set: 0.825

Epoch: 3, Accuracy on validation set: 0.6376
Epoch: 4, Accuracy on validation set: 0.6772
Epoch: 5, Accuracy on validation set: 0.7362
Epoch: 6, Accuracy on validation set: 0.746
Epoch: 7, Accuracy on validation set: 0.7748
Epoch: 8, Accuracy on validation set: 0.7756
Epoch: 9, Accuracy on validation set: 0.795
Epoch: 10, Accuracy on validation set: 0.7982
Epoch: 11, Accuracy on validation set: 0.804
Epoch: 12, Accuracy on validation set: 0.8084
Epoch: 13, Accuracy on validation set: 0.814
Epoch: 14, Accuracy on validation set: 0.8196
Epoch: 15, Accuracy on validation set: 0.823
Epoch: 16, Accuracy on validation set: 0.8258
Epoch: 17, Accuracy on validation set: 0.8202
Epoch: 18, Accuracy on validation set: 0.8244
Epoch: 19, Accuracy on validation set: 0.8318
Epoch: 20, Accuracy on validation set: 0.837
Epoch: 21, Accuracy on validation set: 0.8298
Epoch: 22, Accuracy on validation set: 0.8362
Epoch: 23, Accuracy on validation set: 0.843
Epoch: 24, Accuracy on validation set: 0.8402
Ep

Epoch: 22, Accuracy on validation set: 0.8292
Epoch: 23, Accuracy on validation set: 0.8324
Epoch: 24, Accuracy on validation set: 0.8304
Epoch: 25, Accuracy on validation set: 0.8376
Epoch: 26, Accuracy on validation set: 0.8308
Epoch: 27, Accuracy on validation set: 0.832
Epoch: 28, Accuracy on validation set: 0.8292
Epoch: 29, Accuracy on validation set: 0.8294
Epoch: 30, Accuracy on validation set: 0.837
Early stop
Time: 2023-03-16T11:50:17.919590 Arch:  vgg16 Run:  7
Files already downloaded and verified
Files already downloaded and verified
Epoch: 0, Accuracy on validation set: 0.3076
Epoch: 1, Accuracy on validation set: 0.412
Epoch: 2, Accuracy on validation set: 0.5498
Epoch: 3, Accuracy on validation set: 0.633
Epoch: 4, Accuracy on validation set: 0.7058
Epoch: 5, Accuracy on validation set: 0.7298
Epoch: 6, Accuracy on validation set: 0.7612
Epoch: 7, Accuracy on validation set: 0.77
Epoch: 8, Accuracy on validation set: 0.7802
Epoch: 9, Accuracy on validation set: 0.796
Ep

Epoch: 26, Accuracy on validation set: 0.8282
Epoch: 27, Accuracy on validation set: 0.8408
Epoch: 28, Accuracy on validation set: 0.838
Epoch: 29, Accuracy on validation set: 0.8426
Epoch: 30, Accuracy on validation set: 0.836
Epoch: 31, Accuracy on validation set: 0.8326
Epoch: 32, Accuracy on validation set: 0.8422
Epoch: 33, Accuracy on validation set: 0.8436
Epoch: 34, Accuracy on validation set: 0.838
Epoch: 35, Accuracy on validation set: 0.8362
Epoch: 36, Accuracy on validation set: 0.8492
Epoch: 37, Accuracy on validation set: 0.8376
Epoch: 38, Accuracy on validation set: 0.8454
Epoch: 39, Accuracy on validation set: 0.841
Epoch: 40, Accuracy on validation set: 0.8428
Epoch: 41, Accuracy on validation set: 0.8468
Early stop
Time: 2023-03-16T13:04:47.938005 Arch:  vgg19 Run:  2
Files already downloaded and verified
Files already downloaded and verified
Epoch: 0, Accuracy on validation set: 0.2542
Epoch: 1, Accuracy on validation set: 0.3534
Epoch: 2, Accuracy on validation set:

Epoch: 27, Accuracy on validation set: 0.8422
Epoch: 28, Accuracy on validation set: 0.8478
Epoch: 29, Accuracy on validation set: 0.8374
Epoch: 30, Accuracy on validation set: 0.8504
Epoch: 31, Accuracy on validation set: 0.8384
Epoch: 32, Accuracy on validation set: 0.8482
Epoch: 33, Accuracy on validation set: 0.8446
Epoch: 34, Accuracy on validation set: 0.8354
Epoch: 35, Accuracy on validation set: 0.8476
Early stop
Time: 2023-03-16T14:19:15.307152 Arch:  vgg19 Run:  7
Files already downloaded and verified
Files already downloaded and verified
Epoch: 0, Accuracy on validation set: 0.2034
Epoch: 1, Accuracy on validation set: 0.32
Epoch: 2, Accuracy on validation set: 0.3764
Epoch: 3, Accuracy on validation set: 0.4598
Epoch: 4, Accuracy on validation set: 0.5402
Epoch: 5, Accuracy on validation set: 0.6424
Epoch: 6, Accuracy on validation set: 0.6654
Epoch: 7, Accuracy on validation set: 0.7016
Epoch: 8, Accuracy on validation set: 0.7324
Epoch: 9, Accuracy on validation set: 0.74