In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
import math
import time
from collections import defaultdict
import numpy as np
import pandas as pd
from tqdm import tqdm
import os

from project import DATASETS_DIR, get_weights_path

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device

In [None]:
batch_size = 256

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # is this right? TODO
])

In [None]:
cifar_train = datasets.CIFAR10(root=DATASETS_DIR, train=True, download=True, transform=transform)
cifar_test = datasets.CIFAR10(root=DATASETS_DIR, train=False, download=True, transform=transform)

In [None]:
train = torch.utils.data.DataLoader(cifar_train, batch_size=batch_size, shuffle=True)
test = torch.utils.data.DataLoader(cifar_test, batch_size=batch_size, shuffle=False)

In [None]:
def replace_vgg_classifier_(model, classes=10):
    model.classifier = nn.Sequential(
        nn.Linear(512 * 7 * 7, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, 4096),
        nn.ReLU(True),
        nn.Dropout(),
        nn.Linear(4096, classes),
    )
    return model
    
def _kaiming_normal_scaled_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', scale=1):
    """
    He init with std scaled by `scale`
    """
    fan = nn.init._calculate_correct_fan(tensor, mode)
    gain = nn.init.calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    std *= scale
    print(std)
    with torch.no_grad():
        return tensor.normal_(0, std)

def reinitialize_vgg_(model, scale=1):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            _kaiming_normal_scaled_(m.weight, mode='fan_out', nonlinearity='relu', scale=scale)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)

In [None]:
def train_epoch_(model, data_train, data_val, criterion, opt, scheduler=None, history=None, verbose=1):
    """
    Train epoch and in-place update model and history
    Returns
        model, history
    """
    if history is None:
        history = defaultdict(list)
        
    norms = []
    
    for val, data in [(False, data_train), (True, data_val)]:
        if val:
            model.eval()
        else:
            model.train()
        
        total = 0
        cum_loss = 0
        cum_correct = 0
    
        for inputs, labels in tqdm(data) if verbose else data:
            inputs, labels = inputs.to(device), labels.to(device)
            opt.zero_grad()
            with torch.set_grad_enabled(not val):
                outputs = model(inputs)
                _, preds = outputs.max(axis=1)
                loss = criterion(outputs, labels)
                
                if not val:
                    loss.backward()
                    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1000)
                    norms.append(norm.item())
                    opt.step()
            
            if not val and scheduler:
                scheduler.step()
            
            total += inputs.size(0)
            cum_loss += loss.item() * inputs.size(0)
            cum_correct += torch.sum(labels == preds).item()
        
        if val:
            history["val_acc"].append(cum_correct / total)
            history["val_loss"].append(cum_loss / total)
        else:
            history["train_acc"].append(cum_correct / total)
            history["train_loss"].append(cum_loss / total)
    
    print(" GRAD NORMS ".center(80, "#"))
    print(pd.Series(norms).describe())
    
    return model, history

In [None]:
scale = 5
print("Rescaling VGG conv weights by {}".format(scale))
model = models.vgg11()
replace_vgg_classifier_(model, classes=10)
reinitialize_vgg_(model, scale=scale)
model = model.to(device)

In [None]:
opt = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [None]:
start = time.time()
history = None
epochs = 500
save_interval = 20
weights_dir = "weights"
for e in range(epochs):
    elapsed = time.time() - start
    s = elapsed % 60
    m = elapsed // 60 % 60
    h = elapsed // 3600
    print("Epoch {:04d}".format(e + 1), end="")
    print(" | {:02.0f}:{:02.0f}:{:02.0f}".format(h, m, s), end="")
    print(" | ", end="")
    model, history = train_epoch_(model, train, test, opt=opt, criterion=criterion, history=history, verbose=False)
    print({k[:7]: "{:.4f}".format(v[-1]) for k, v in history.items()})
    
    if (e + 1) % save_interval == 0:
        path = os.path.join(weights_dir, "vgg_weights_e{:04d}.pth".format(e))
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(model.state_dict(), path)