<a href="https://colab.research.google.com/github/james-simon/varnets/blob/master/varnets_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports and constants

In [None]:
import os
import sys
import time
import math
import inspect
from datetime import datetime

import numpy as np

import torch
import torchvision
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# from torch.nn.modules.utils import _single, _pair, _triple, _reverse_repeat_tuple

import matplotlib.pyplot as plt

In [None]:
RANDOM_SEED = 42
BATCH_SIZE = 128
N_EPOCHS = 15
N_WORKERS = 4

IMG_SIZE = 32
N_CLASSES = 10

dataset_name = "CIFAR10"
data_path = ""

architecture = "VGG"

optimizer_class = torch.optim.SGD
LEARNING_RATE = 0.05
MOMENTUM = .9

# check device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("using " + DEVICE)

try:
    latch
except NameError:
  old_linear_def = F.linear
  old_conv_def = F.conv2d
  old_relu_def = F.relu
  print('latched')
latch = True

using cuda
latched


In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def all_parameters(model):
  return torch.cat([p for p in model.parameters() if p.requires_grad])

def sample_loader(loader, n_samples=1):
  for x,_ in loader:
    return(x[0:min(n_samples, x.size()[0])].to(DEVICE))

## Dataset setup

In [None]:
transform_set_train = None
transform_set_test = None

if dataset_name == "MNIST":
  transform_set_train = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
                                 transforms.ToTensor()])
  transform_set_test = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
                                 transforms.ToTensor()])
if dataset_name == "CIFAR10":
  if architecture == "ALEXNET":
    transform_set_train = transforms.Compose([transforms.Scale((224, 224)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                                          (0.2023, 0.1994, 0.2010))])
    transform_set_test = transforms.Compose([transforms.Scale((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                                        (0.2023, 0.1994, 0.2010))])
  else:
    transform_set_train = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, padding=4),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
    transform_set_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])

ds = getattr(torchvision.datasets, dataset_name)
dataset_path = os.path.join(data_path, dataset_name.lower())

train_set = ds(dataset_path, train=True, download=True, transform=transform_set_train)
test_set = ds(dataset_path, train=False, download=True, transform=transform_set_test)

train_loader = DataLoader(dataset=train_set, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True,
                          num_workers=N_WORKERS)
test_loader = DataLoader(dataset=test_set, 
                          batch_size=BATCH_SIZE, 
                          shuffle=False,
                          num_workers=N_WORKERS)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting cifar10/cifar-10-python.tar.gz to cifar10
Files already downloaded and verified


## Train and test methods

In [None]:
def run_epoch(model, optimizer, criterion, device, loader, mode):
  print("starting " + mode + "\t\t cuda using " + str(torch.cuda.memory_allocated()/10**9) + " GB")

  if mode == "train":
    model.train()
  elif mode == "test" or mode == "eval":
    model.eval()
  else:
    print("INVALID MODE")
    assert(False)
  
  epoch_loss = 0
  epoch_acc = 0

  for batch_i, (X, y) in enumerate(loader):

    if mode == "train":
      optimizer.zero_grad()
    
    X = X.to(device)
    y = y.to(device)

    y_hat = model(X)
    # y_hat = None
    # if mode == "train":
    #   y_hat = model(X)
    # else:
    #   with torch.no_grad():
    #     y_hat = model(X)

    loss = criterion(y_hat, y)

    if mode == "train":
      loss.backward()
      optimizer.step()
    
    batch_loss = loss.item()*X.size(0)
    epoch_loss += batch_loss

    _, predicted_labels = torch.max(y_hat, 1)
    batch_acc = (predicted_labels == y).sum().item()
    epoch_acc += batch_acc
  
  epoch_loss /= len(loader.dataset)
  epoch_acc /= len(loader.dataset)
  return (epoch_loss, epoch_acc)

def train(model, optimizer, criterion, device, train_loader, test_loader, n_epochs, print_every_n_epochs=1):
    # set up optimization metrics
    train_losses = []
    train_accs = []
    test_losses = []
    test_accs = []

    for epoch in range(n_epochs):
      e_start_t = time.time()

      # trainin'
      tr_loss, tr_acc = run_epoch(model, optimizer, criterion, device, train_loader, "train")
      train_losses.append(tr_loss)
      train_accs.append(tr_acc)

      # testin'
      te_loss, te_acc = run_epoch(model, optimizer, criterion, device, test_loader, "test")
      test_losses.append(te_loss)
      test_accs.append(te_acc)

      if epoch % print_every_n_epochs == (print_every_n_epochs - 1):
          print(f'Epoch: {epoch}\t'
                f'Epoch time: {time.time() - e_start_t} --- '
                f'Train loss: {tr_loss:.4f}\t'
                f'test loss: {te_loss:.4f}\t'
                f'Train accuracy: {100 * tr_acc:.2f}\t'
                f'test accuracy: {100 * te_acc:.2f}')
          
    return (train_accs, train_losses, test_accs, test_losses)

## Define model initialization with w ~ N(0, 1/n) and b = 0

In [None]:
def initialize(model):
  for m in model.modules():
    if isinstance(m, nn.Conv2d):
      # w ~ N(0, 1/n)
      # b = 0
      n = m.weight.size()[1]*m.weight.size()[2]*m.weight.size()[3]
      std = 1/n**(1/2)
      nn.init.normal_(m.weight, std=std)
      if m.bias is not None:
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      nn.init.constant_(m.weight, 1)
      nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
      # w ~ N(0, 1/n)
      # b = 0
      n = m.weight.size()[1]
      std = 1/n**(1/2)
      nn.init.normal_(m.weight, std=std)
      nn.init.constant_(m.bias, 0)

## Define special varnet functions

In [None]:
# E[|x|^a] where x ~ N(0, 1)
def normal_distribution_moment(a):
  return (2**(a/2))*math.gamma((a + 1)/2)/math.pi**(1/2)

def power_series_var(terms):
  total = 0
  for (a0, b0, c0) in terms:
    for (a1, b1, c1) in terms:
      total += c0*c1*normal_distribution_moment(a0+a1)*normal_distribution_moment(b0+b1)
  return (total)

def normalize_terms(terms):
  var = power_series_var(terms)
  std = var**(1/2)
  return ([(a, b, c/std) for (a, b, c) in terms])

In [None]:
#a, b, c_ab
terms = [(1, 1, 1), (3, 3, 1)]
terms = normalize_terms(terms)

def var_linear(input, weights, bias=None):
  n = weights.size()[1]

  # input_pow = torch.sign(input)*torch.abs(input)
  # weights_pow = torch.sign(weights)*torch.abs(weights)
  # output = old_linear_def(input_pow, weights_pow, bias)
  # # output = old_linear_def(torch.sign(input)*torch.abs(input), torch.sign(weights)*torch.abs(weights), bias)
  # return output
  
  # output = None
  # for (a, b, c) in terms:
  #   if output is None:
  #     output = (n**((a-1)/2))*c*old_linear_def(torch.sign(input)*torch.abs(input)**a, torch.sign(weights)*torch.abs(weights)**b, bias)
  #   else:
  #     output += (n**((a-1)/2))*c*old_linear_def(torch.sign(input)*torch.abs(input)**a, torch.sign(weights)*torch.abs(weights)**b, bias)
  # return output
  
  output = old_linear_def(input, weights, bias)
  output += old_linear_def(torch.relu(input), torch.tanh(weights), bias)
  # output += old_linear_def(torch.tanh(input), torch.tanh(weights), bias)
  return output*.7

def var_conv2d(input, weights, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1):
  n = weights.size()[1]*weights.size()[2]*weights.size()[3]

  # output = None
  # for (a, b, c) in terms:
  #   if output is None:
  #     output = (n**((a-1)/2))*c*old_conv_def(torch.sign(input)*torch.abs(input)**a, torch.sign(weights)*torch.abs(weights)**b, bias, stride, padding, dilation, groups)
  #   else:
  #     output += (n**((a-1)/2))*c*old_conv_def(torch.sign(input)*torch.abs(input)**a, torch.sign(weights)*torch.abs(weights)**b, bias, stride, padding, dilation, groups)
  # return output

  output = old_conv_def(input, weights, bias, stride, padding, dilation, groups)
  output += old_conv_def(torch.relu(input), torch.tanh(weights), bias, stride, padding, dilation, groups)
  # output += old_conv_def(torch.tanh(input), torch.tanh(weights), bias, stride, padding, dilation, groups)
  return output*.7

# replace F methods with mult methods
# F.linear = old_linear_def
# F.conv2d = old_conv_def
F.linear = var_linear
F.conv2d = var_conv2d

In [None]:
def relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor:
    result = torch.relu(input)
    return result

def identity(input: torch.Tensor, inplace: bool = False) -> torch.Tensor:
    result = input
    return result

# redefine F.relu so that when torch.models.vgg19 calls nn.ReLU which calls F.relu, it finds this function instead
F.relu = identity

In [None]:
model = torchvision.models.vgg19(num_classes=10, init_weights=True).to(DEVICE)
initialize(model)
output = model(sample_loader(train_loader, n_samples=1))
print(output)
# output.sum().backward()

mem_used = torch.cuda.memory_allocated()
print(mem_used/10**9)

del output
torch.cuda.empty_cache()
mem_used -= torch.cuda.memory_allocated()
print("one forward pass uses " + str(mem_used/10**9) + " GB; that's a lot")

tensor([[-1.5291, -1.4181, -4.1588,  0.1082,  4.1351, -3.6158,  0.6514, -2.7504,
          0.3193,  0.0853]], device='cuda:0', grad_fn=<MulBackward0>)
1.121393664
one forward pass uses 0.562947584 GB; that's a lot


## Set up and run

In [None]:
def run_tests(n_trials=3, n_epochs=15, print_every=10**7):
  global model
  trajectories = []
  for i in range(n_trials):
    model = torchvision.models.vgg19(num_classes=10, init_weights=True).to(DEVICE)
    initialize(model)
    optimizer = optimizer_class(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
    criterion = nn.CrossEntropyLoss()

    traj = train(model, optimizer, criterion, DEVICE, train_loader, test_loader, 100, print_every_n_epochs=1)
    trajectories.append(traj)
  return torch.tensor(trajectories)

# STANDARD
# print("STANDARD")
# F.linear = old_linear_def
# F.conv2d = old_conv_def
# F.relu = old_relu_def
# LEARNING_RATE = .005
# additive_results = run_tests(n_trials=1, n_epochs=50, print_every=1)
# print(additive_results)

# # WX + TANH(W)TANH(X)
# print("WX + TANH(W)TANH(X)")

# WX + TANH(W)RELU(X)
print("WX + TANH(W)RELU(X)")
F.linear = var_linear
F.conv2d = var_conv2d
F.relu = identity
LEARNING_RATE = .0005
var_results = run_tests(n_trials=1, n_epochs=50, print_every=1)
print(var_results)

WX + TANH(W)RELU(X)
starting train		 cuda using 0.560281088 GB
starting test		 cuda using 1.68189184 GB
Epoch: 0	Epoch time: 60.63836908340454 --- Train loss: 1.8423	test loss: 1.4027	Train accuracy: 35.13	test accuracy: 49.19
starting train		 cuda using 1.68189184 GB
starting test		 cuda using 1.68189184 GB
Epoch: 1	Epoch time: 60.66513013839722 --- Train loss: 1.3714	test loss: 1.1924	Train accuracy: 50.58	test accuracy: 56.45
starting train		 cuda using 1.68189184 GB
starting test		 cuda using 1.68189184 GB
Epoch: 2	Epoch time: 60.65664982795715 --- Train loss: 1.1990	test loss: 1.0599	Train accuracy: 57.27	test accuracy: 62.70
starting train		 cuda using 1.68189184 GB
starting test		 cuda using 1.68189184 GB
Epoch: 3	Epoch time: 60.720757484436035 --- Train loss: 1.0768	test loss: 0.9426	Train accuracy: 62.02	test accuracy: 67.16
starting train		 cuda using 1.68189184 GB
starting test		 cuda using 1.68189184 GB
Epoch: 4	Epoch time: 60.598742723464966 --- Train loss: 0.9828	test los