<a href="https://colab.research.google.com/github/james-simon/varnets/blob/master/varnets_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import sys
import time
import math
import inspect
from datetime import datetime

import numpy as np

import torch
import torchvision
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# from torch.nn.modules.utils import _single, _pair, _triple, _reverse_repeat_tuple

import matplotlib.pyplot as plt

In [2]:
RANDOM_SEED = 42
BATCH_SIZE = 128
N_EPOCHS = 15
N_WORKERS = 4

IMG_SIZE = 32
N_CLASSES = 10

dataset_name = "CIFAR10"
data_path = ""

architecture = "VGG"

optimizer_class = torch.optim.SGD
LEARNING_RATE = 0.05
MOMENTUM = .9

# check device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("using " + DEVICE)

try:
    latch
except NameError:
  old_linear_def = F.linear
  old_conv_def = F.conv2d
  old_relu_def = F.relu
  print('latched')
latch = True

using cuda
latched


In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def all_parameters(model):
  return torch.cat([p for p in model.parameters() if p.requires_grad])

def sample_loader(loader, n_samples=1):
  for x,_ in loader:
    return(x[0:min(n_samples, x.size()[0])].to(DEVICE))

## Dataset setup

In [4]:
transform_set_train = None
transform_set_test = None

if dataset_name == "MNIST":
  transform_set_train = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
                                 transforms.ToTensor()])
  transform_set_test = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
                                 transforms.ToTensor()])
if dataset_name == "CIFAR10":
  if architecture == "ALEXNET":
    transform_set_train = transforms.Compose([transforms.Scale((224, 224)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                                          (0.2023, 0.1994, 0.2010))])
    transform_set_test = transforms.Compose([transforms.Scale((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                                        (0.2023, 0.1994, 0.2010))])
  else:
    transform_set_train = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, padding=4),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
    transform_set_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])

ds = getattr(torchvision.datasets, dataset_name)
dataset_path = os.path.join(data_path, dataset_name.lower())

train_set = ds(dataset_path, train=True, download=True, transform=transform_set_train)
test_set = ds(dataset_path, train=False, download=True, transform=transform_set_test)

train_loader = DataLoader(dataset=train_set, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True,
                          num_workers=N_WORKERS)
test_loader = DataLoader(dataset=test_set, 
                          batch_size=BATCH_SIZE, 
                          shuffle=False,
                          num_workers=N_WORKERS)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting cifar10/cifar-10-python.tar.gz to cifar10
Files already downloaded and verified


## Train and test methods

In [5]:
def run_epoch(model, optimizer, criterion, device, loader, mode):
  print("starting " + mode + "\t\t cuda using " + str(torch.cuda.memory_allocated()/10**9) + " GB")

  if mode == "train":
    model.train()
  elif mode == "test" or mode == "eval":
    model.eval()
  else:
    print("INVALID MODE")
    assert(False)
  
  epoch_loss = 0
  epoch_acc = 0

  for batch_i, (X, y) in enumerate(loader):

    if mode == "train":
      optimizer.zero_grad()
    
    X = X.to(device)
    y = y.to(device)

    y_hat = model(X)
    # y_hat = None
    # if mode == "train":
    #   y_hat = model(X)
    # else:
    #   with torch.no_grad():
    #     y_hat = model(X)

    loss = criterion(y_hat, y)

    if mode == "train":
      loss.backward()
      optimizer.step()
    
    batch_loss = loss.item()*X.size(0)
    epoch_loss += batch_loss

    _, predicted_labels = torch.max(y_hat, 1)
    batch_acc = (predicted_labels == y).sum().item()
    epoch_acc += batch_acc
  
  epoch_loss /= len(loader.dataset)
  epoch_acc /= len(loader.dataset)
  return (epoch_loss, epoch_acc)

def train(model, optimizer, criterion, device, train_loader, test_loader, n_epochs, print_every_n_epochs=1):
    # set up optimization metrics
    train_losses = []
    train_accs = []
    test_losses = []
    test_accs = []

    for epoch in range(n_epochs):
      e_start_t = time.time()

      # trainin'
      tr_loss, tr_acc = run_epoch(model, optimizer, criterion, device, train_loader, "train")
      train_losses.append(tr_loss)
      train_accs.append(tr_acc)

      # testin'
      te_loss, te_acc = run_epoch(model, optimizer, criterion, device, test_loader, "test")
      test_losses.append(te_loss)
      test_accs.append(te_acc)

      if epoch % print_every_n_epochs == (print_every_n_epochs - 1):
          print(f'Epoch: {epoch}\t'
                f'Epoch time: {time.time() - e_start_t} --- '
                f'Train loss: {tr_loss:.4f}\t'
                f'test loss: {te_loss:.4f}\t'
                f'Train accuracy: {100 * tr_acc:.2f}\t'
                f'test accuracy: {100 * te_acc:.2f}')
          
    return (train_accs, train_losses, test_accs, test_losses)

## Define model initialization with w ~ N(0, 1/n) and b = 0

In [6]:
def initialize(model):
  for m in model.modules():
    if isinstance(m, nn.Conv2d):
      # w ~ N(0, 1/n)
      # b = 0
      n = m.weight.size()[1]*m.weight.size()[2]*m.weight.size()[3]
      std = 1/n**(1/2)
      nn.init.normal_(m.weight, std=std)
      if m.bias is not None:
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      nn.init.constant_(m.weight, 1)
      nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
      # w ~ N(0, 1/n)
      # b = 0
      n = m.weight.size()[1]
      std = 1/n**(1/2)
      nn.init.normal_(m.weight, std=std)
      nn.init.constant_(m.bias, 0)

## Define special varnet functions

In [7]:
# E[|x|^a] where x ~ N(0, 1)
def normal_distribution_moment(a):
  return (2**(a/2))*math.gamma((a + 1)/2)/math.pi**(1/2)

def power_series_var(terms):
  total = 0
  for (a0, b0, c0) in terms:
    for (a1, b1, c1) in terms:
      total += c0*c1*normal_distribution_moment(a0+a1)*normal_distribution_moment(b0+b1)
  return (total)

def normalize_terms(terms):
  var = power_series_var(terms)
  std = var**(1/2)
  return ([(a, b, c/std) for (a, b, c) in terms])

In [37]:
#a, b, c_ab
terms = [(1, 1, 1), (3, 3, 1)]
terms = normalize_terms(terms)

def var_linear(input, weights, bias=None):
  n = weights.size()[1]
  sqrt_n = n**.5
  
  output = old_linear_def(torch.relu(input)**2/(torch.relu(input)+1), weights, bias)
  output += (1/sqrt_n)*old_linear_def(torch.tanh(torch.relu(input)), torch.tanh(sqrt_n*weights), bias)

  return output*1.5

def var_conv2d(input, weights, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1):
  n = weights.size()[1]*weights.size()[2]*weights.size()[3]
  sqrt_n = n**.5

  # if this is the first layer
  if input.size()[1:] == torch.Size([3,32,32]):
    output = old_conv_def(input, weights, bias, stride, padding, dilation, groups)
    return output*1

  output = old_conv_def(torch.relu(input)**2/(torch.relu(input)+1), weights, bias, stride, padding, dilation, groups)
  output += (1/sqrt_n)*old_conv_def(torch.tanh(torch.relu(input)), torch.tanh(sqrt_n*weights), bias, stride, padding, dilation, groups)

  return output*1.5

# replace F methods with mult methods
# F.linear = old_linear_def
# F.conv2d = old_conv_def
F.linear = var_linear
F.conv2d = var_conv2d

In [38]:
def relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor:
    result = torch.relu(input)
    return result

def identity(input: torch.Tensor, inplace: bool = False) -> torch.Tensor:
    result = input
    return result

# redefine F.relu so that when torch.models.vgg19 calls nn.ReLU which calls F.relu, it finds this function instead
F.relu = identity

In [39]:
model = torchvision.models.vgg19(num_classes=10, init_weights=True).to(DEVICE)
initialize(model)
output = model(sample_loader(train_loader, n_samples=1))
print(output)
# output.sum().backward()

mem_used = torch.cuda.memory_allocated()
print(mem_used/10**9)

del output
torch.cuda.empty_cache()
mem_used -= torch.cuda.memory_allocated()
print("one forward pass uses " + str(mem_used/10**9) + " GB; that's a lot")

tensor([[ 0.7108, -1.6850,  0.2018,  1.3550, -0.4648,  2.0402,  0.4250, -1.4768,
         -0.2197,  0.8804]], device='cuda:0', grad_fn=<MulBackward0>)
4.486951936
one forward pass uses 0.566934016 GB; that's a lot


## Set up and run

In [None]:
def run_tests(n_trials=3, n_epochs=15, print_every=10**7):
  global model
  trajectories = []
  for i in range(n_trials):
    model = torchvision.models.vgg19(num_classes=10, init_weights=True).to(DEVICE)
    initialize(model)
    optimizer = optimizer_class(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
    criterion = nn.CrossEntropyLoss()

    traj = train(model, optimizer, criterion, DEVICE, train_loader, test_loader, n_epochs, print_every_n_epochs=1)
    trajectories.append(traj)
  return torch.tensor(trajectories)

# STANDARD
# print("STANDARD")
# F.linear = old_linear_def
# F.conv2d = old_conv_def
# F.relu = old_relu_def
# LEARNING_RATE = .005
# additive_results = run_tests(n_trials=1, n_epochs=50, print_every=1)
# print(additive_results)

# # WX + TANH(W)TANH(X)
# print("WX + TANH(W)TANH(X)")

# WX + TANH(W)RELU(X)
print("(1/sqrtN)*TANH(sqrtN*W)*TANH(X) + W*RELU(X)^2/(RELU(X) + 1)")
F.linear = var_linear
F.conv2d = var_conv2d
F.relu = identity
LEARNING_RATE = .005
var_results = run_tests(n_trials=1, n_epochs=100, print_every=1)
print(var_results)

(1/sqrtN)*TANH(sqrtN*W)*TANH(X) + W*RELU(X)^2/(RELU(X) + 1)
starting train		 cuda using 3.92001792 GB
starting test		 cuda using 5.041366528 GB
Epoch: 0	Epoch time: 110.54276156425476 --- Train loss: 2.3031	test loss: 2.3026	Train accuracy: 10.24	test accuracy: 10.00
starting train		 cuda using 5.041366528 GB
starting test		 cuda using 5.041366528 GB
Epoch: 1	Epoch time: 113.32418417930603 --- Train loss: 2.3027	test loss: 2.3012	Train accuracy: 10.17	test accuracy: 10.24
starting train		 cuda using 5.041366528 GB
starting test		 cuda using 5.041366528 GB
Epoch: 2	Epoch time: 114.49362063407898 --- Train loss: 2.0778	test loss: 1.8686	Train accuracy: 18.09	test accuracy: 25.44
starting train		 cuda using 5.041366528 GB
starting test		 cuda using 5.041366528 GB
Epoch: 3	Epoch time: 114.83872723579407 --- Train loss: 1.8010	test loss: 1.6589	Train accuracy: 27.44	test accuracy: 34.24
starting train		 cuda using 5.041366528 GB
starting test		 cuda using 5.041366528 GB
Epoch: 4	Epoch time:

In [None]:
W*RELU(X) with linear start | .005 | 100e | 1
99.01/88.91

W*RELU(X) with linear start | .001 | 100e | 1.4
97.88/87.58

W*RELU(X) with ELU(X) start | .005 | 100e | 1

W*ELU(X) with linear start | .001 | 100e | 1.1
99.10/85.67

(1/sqrtN)*TANH(sqrtN*W)*TANH(X) + W*RELU(X) with linear start | .001 | 100e | .9
99.16/86.92

W*ELU(X) - (1/sqrtN)*TANH(sqrtN*W)RELU(X) with linear start | .001 | 100e | 2
98.87/85.65

W*TANH(X) with linear start | .001 | 100e | 1.1
98.80/85.04

W*MIN(X, SIN(X)) with linear start | .0001 | 100e | 1.1
88.83/80.38

W*TANH(RELU(X)) with linear start | .001 | 100e | 1.5
98.23/88.94

WX + TANH(W)RELU(X) with linear start | .0005 | 100e | .7
98.13/84.53

W*RELU(X)^2/(RELU(X) + 1) + (1/sqrtN)TANH(sqrtN*W)*TANH(RELU(X)) with linear start | .005 | 100e | 1.1
running #1

W*LEAKYRELU(X, .2) with linear start | .005 | 100e | 1.3
running #3