<a href="https://colab.research.google.com/github/james-simon/varnets/blob/master/multnets_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports and constants

In [None]:
import os
import sys
import time
import math
import inspect
from datetime import datetime

import numpy as np

import torch
import torchvision
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.modules.utils import _single, _pair, _triple, _reverse_repeat_tuple

import matplotlib.pyplot as plt

# check device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

try:
    latch
except NameError:
  old_linear_def = F.linear
  old_conv_def = F.conv2d
  print('latched')
latch = True

latched


In [None]:
RANDOM_SEED = 42
LEARNING_RATE = 0.00001
BATCH_SIZE = 128
N_EPOCHS = 15
N_WORKERS = 4

IMG_SIZE = 32
N_CLASSES = 10

dataset_name = "CIFAR10"
data_path = ""

architecture = "VGG"

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def all_parameters(model):
  return torch.cat([p for p in model.parameters() if p.requires_grad])
  
def identity(x):
  return(x)

include_scale_factor = False

pre_weight_nonlinearity = identity
post_weight_nonlinearity = identity
combination_function = lambda x,y: x*y
post_sum_nonlinearity = identity
neuron_nonlinearity = nn.Identity

net_start_nonlinearity = identity
net_end_nonlinearity = identity

## Dataset setup

In [None]:
transform_set_train = None
transform_set_test = None

if dataset_name == "MNIST":
  transform_set_train = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
                                 transforms.ToTensor()])
  transform_set_test = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
                                 transforms.ToTensor()])
if dataset_name == "CIFAR10":
  if architecture == "ALEXNET":
    transform_set_train = transforms.Compose([transforms.Scale((224, 224)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                                          (0.2023, 0.1994, 0.2010))])
    transform_set_test = transforms.Compose([transforms.Scale((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                                        (0.2023, 0.1994, 0.2010))])
  else:
    transform_set_train = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, padding=4),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
    transform_set_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])

ds = getattr(torchvision.datasets, dataset_name)
dataset_path = os.path.join(data_path, dataset_name.lower())

train_set = ds(dataset_path, train=True, download=True, transform=transform_set_train)
test_set = ds(dataset_path, train=False, download=True, transform=transform_set_test)

train_loader = DataLoader(dataset=train_set, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True,
                          num_workers=N_WORKERS)
test_loader = DataLoader(dataset=test_set, 
                          batch_size=BATCH_SIZE, 
                          shuffle=False,
                          num_workers=N_WORKERS)

def sample_loader(loader, n_samples=1):
  for x,_ in loader:
    return(x[0:min(n_samples, x.size()[0])].to(DEVICE))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting cifar10/cifar-10-python.tar.gz to cifar10
Files already downloaded and verified


## Train and test methods

In [None]:
def train_epoch(train_loader, model, criterion, optimizer, device):

    model.train()
    running_loss = 0
    running_acc = 0
    
    for batch_i, (X, y_true) in enumerate(train_loader):
      optimizer.zero_grad()

      X = X.to(device)
      y_true = y_true.to(device)

      y_hat = model(X)

      loss = criterion(y_hat, y_true)
      loss.backward()
      optimizer.step()

      batch_loss = loss.item()*X.size(0)
      running_loss += batch_loss

      _, predicted_labels = torch.max(y_hat, 1)
      batch_acc = (predicted_labels == y_true).sum().item()
      running_acc += batch_acc
      
      batch_i += 1
    
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_acc / len(train_loader.dataset)
    return model, optimizer, epoch_loss, epoch_acc

def test(test_loader, model, criterion, device):

  model.eval()
  running_loss = 0
  running_acc = 0
  
  for X, y_true in test_loader:
  
      X = X.to(device)
      y_true = y_true.to(device)

      y_hat = model(X) 
      loss = criterion(y_hat, y_true) 
      running_loss += loss.item() * X.size(0)

      _, predicted_labels = torch.max(y_hat, 1)
      batch_acc = (predicted_labels == y_true).sum().item()
      running_acc += batch_acc

  epoch_loss = running_loss / len(test_loader.dataset)
  epoch_acc = running_acc / len(test_loader.dataset)
      
  return model, epoch_loss, epoch_acc

def training_loop(model, criterion, optimizer, train_loader, test_loader, epochs, device, print_every=1):

    # set objects for storing metrics
    best_loss = 1e10
    train_losses = []
    test_losses = []
    train_accs = []
    test_accs = []
 
    # Train model
    for epoch in range(0, epochs):
      start_t = time.time()

      # trainin'
      model, optimizer, train_loss, train_acc = train_epoch(train_loader, model, criterion, optimizer, device)
      train_losses.append(train_loss)
      train_accs.append(train_acc)

      # testin'
      with torch.no_grad():
          model, test_loss, test_acc = test(test_loader, model, criterion, device)
          test_losses.append(test_loss)
          test_accs.append(test_acc)

      if epoch % print_every == (print_every - 1):

          print(f'{time.time()-start_t} --- '
                f'Epoch: {epoch}\t'
                f'Train loss: {train_loss:.4f}\t'
                f'test loss: {test_loss:.4f}\t'
                f'Train accuracy: {100 * train_acc:.2f}\t'
                f'test accuracy: {100 * test_acc:.2f}')
    
    return model, optimizer, (train_accs, train_losses, test_accs, test_losses)

## Model definition

In [None]:
Q = 10**9
cache_lin = [None, None, None]
cache_mult = [None, None, None]

def mult_linear(input, weights, bias=None):
  global cache_lin
  cache_lin[0] = input
  cache_lin[1] = (weights, bias)

  if torch.sum(input == 0) > 0:
    print('zero found in mult_linear input')
    q = 1 + '1'

  output = old_linear_def(Q*torch.log(input), weights, bias)
  # print('ln (lin)')
  if output.size()[1:] != torch.Size([N_CLASSES]):
    output = torch.exp((1/Q)*output) + 0
  #   print('exp (lin)')
  # else:
  #   print('END')

  if torch.sum(output.isnan()) > 0:
    print("nan found in mult_linear")
    print(input)
    print(output)
    q = 1 + '1'

  if torch.sum(output.isinf()) > 0:
    print("inf found in mult_linear")
    # print(input)
    # print(output)
    q = 1 + '1'

  if torch.sum(output == 0) > 0:
    print("zero found in mult_linear")
    # print(input)
    # print(output)
    q = 1 + '1'
  
  cache_lin[2] = output
  
  return output

def mult_conv2d(input, weights, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1):
  global cache_mult
  cache_mult[0] = input
  cache_mult[1] = (weights, bias, stride, padding, dilation, groups)

  if input.size()[1:] != torch.Size([3,32,32]):
    input = Q*torch.log(input)
  #   print('ln (conv)')
  # else:
  #   print('START')
  output = torch.exp((1/Q)*old_conv_def(input, weights, bias, stride, padding, dilation, groups)) + 0
  # print('exp (conv)')

  if torch.sum(output.isnan()) > 0:
    print("nan found in mult_conv2d")
    # print(input)
    # print(output)
    q = 1 + '1'
  
  if torch.sum(output == 0) > 0:
    print("zero found in mult_conv2d")
    # print(input)
    # print(output)
    q = 1 + '1'
  
  cache_mult[2] = output

  return output

# replace F methods with mult methods
# F.linear = old_linear_def
# F.conv2d = old_conv_def
F.linear = mult_linear
F.conv2d = mult_conv2d

In [None]:
## test plan
#  for different values of Q, check what the typical highest intermediate value is in mult_conv2d

In [None]:
# input = torch.log(cache_mult[0])
# print(input.size())
# print(cache_mult[1][0].size())

# # print(torch.max(torch.abs(input)))
# output = old_conv_def(input, cache_mult[1][0], None, cache_mult[1][2], cache_mult[1][3], cache_mult[1][4], cache_mult[1][5])
# print(output.size())

# print(torch.max(torch.abs(input)))
# print(torch.max(torch.abs(output)))
# print(output[0][0][0])
# print(torch.std(cache_mult[1][0]))
# print(torch.std(input))
# print(torch.std(output))

In [None]:
class Nonlinearity(nn.Module):
    __constants__ = ['inplace']
    inplace: bool

    def __init__(self, inplace: bool = False):
        super(Nonlinearity, self).__init__()
        self.inplace = inplace

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # return input
        # return F.tanh(input)
        return 0 + F.relu(input - 0, inplace=self.inplace)

    def extra_repr(self) -> str:
        inplace_str = 'inplace=True' if self.inplace else ''
        return inplace_str


DROPOUT_PROB = 0

config = {
    16: [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]],
    19: [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]],
}


def make_layers(config, batch_norm=False, fix_points=None):
    layer_blocks = nn.ModuleList()
    activation_blocks = nn.ModuleList()
    poolings = nn.ModuleList()

    kwargs = dict()
    conv = nn.Conv2d
    bn = nn.BatchNorm2d
    if fix_points is not None:
        kwargs['fix_points'] = fix_points
        conv = curves.Conv2d
        bn = curves.BatchNorm2d

    in_channels = 3
    for sizes in config:
        layer_blocks.append(nn.ModuleList())
        activation_blocks.append(nn.ModuleList())
        for channels in sizes:
            layer_blocks[-1].append(conv(in_channels, channels, kernel_size=3, padding=1, **kwargs))
            if batch_norm:
                layer_blocks[-1].append(bn(channels, **kwargs))
            activation_blocks[-1].append(Nonlinearity(inplace=True))
            in_channels = channels
        poolings.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return layer_blocks, activation_blocks, poolings

class VGGBase(nn.Module):
    def __init__(self, num_classes, depth=16, batch_norm=False):
        super(VGGBase, self).__init__()
        layer_blocks, activation_blocks, poolings = make_layers(config[depth], batch_norm)
        self.layer_blocks = layer_blocks
        self.activation_blocks = activation_blocks
        self.poolings = poolings

        self.classifier = nn.Sequential(
            nn.Dropout(p=DROPOUT_PROB),
            nn.Linear(512, 513),
            Nonlinearity(inplace=True),
            nn.Dropout(p=DROPOUT_PROB),
            nn.Linear(513, 514),
            Nonlinearity(inplace=True),
            nn.Linear(514, num_classes),
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        for layers, activations, pooling in zip(self.layer_blocks, self.activation_blocks,
                                                self.poolings):
            for layer, activation in zip(layers, activations):
                x = layer(x)
                x = activation(x)
            x = pooling(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

VGGBase(num_classes=10, depth=16, batch_norm=False);

## Timing tests

In [None]:
def new_model():
  model = None
  if architecture == "ALEXNET":
    model = torchvision.models.alexnet(num_classes=10).to(DEVICE)
    return model
  else:
    # model = torchvision.models.vgg19(num_classes=10, init_weights=True).to(DEVICE)
    model = VGGBase(num_classes=10, depth=16, batch_norm=False).to(DEVICE)
    return model

model = new_model()
x = sample_loader(train_loader, 5)
print(model(x))

torch.autograd.set_detect_anomaly(True)

t0 = time.time()
model.train()
model(x)
t1 = time.time()
model.train()
model(x).sum().backward()
t2 = time.time()
print(t1-t0, t2-t1)

tensor([[-0.6970, -0.8308, -0.1197,  0.6774,  0.9251, -0.5817, -0.4519, -0.0798,
         -0.1254,  0.0082],
        [-0.5738, -1.3904, -0.0795,  0.7063,  1.4924, -0.6672, -0.4227, -0.4048,
          0.0498, -0.0287],
        [-0.8445, -1.1013,  0.0061,  0.6581,  1.4424, -0.7201, -0.3603, -0.1992,
         -0.0254,  0.0642],
        [-0.3861, -0.5445,  0.0092,  0.5160,  0.7854, -0.3774, -0.4700, -0.1882,
         -0.0198,  0.2068],
        [-0.8695, -1.3404, -0.1985,  1.0008,  1.4824, -0.6239, -0.7236, -0.3206,
         -0.1788,  0.1319]], device='cuda:0', grad_fn=<AddmmBackward>)
0.04085278511047363 0.06316518783569336


## Set up and run

In [None]:
def leaky_relu(x):
  return(torch.relu(x) - .5*torch.relu(-x))

def run_tests(n_trials=3, n_epochs=15, print_every=10**7):
  global model
  trajectories = []
  for i in range(n_trials):
    # torch.manual_seed(RANDOM_SEED + i)
    model = new_model()
    optimizer = torch.optim.SGD(model.parameters(), lr=.05, momentum=.9)
    # optimizer = torch.optim.Adam(model.parameters(), lr=10*LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    model, optimizer, traj = training_loop(model, criterion, optimizer, train_loader, test_loader, n_epochs, DEVICE, print_every=print_every)
    # final_train_acc = losses[0][-1]
    # final_test_acc = losses[2][-1]
    # final_accs.append([final_train_acc, final_test_acc])
    trajectories.append(traj)
  return torch.tensor(trajectories)

# # ADDITIVE
# print("ADDITIVE")
# F.linear = old_linear_def
# F.conv2d = old_conv_def
# additive_results = run_tests(n_trials=1, n_epochs=150, print_every=1)
# print(additive_results)

# MULTIPLICATIVE
print("\n\nMULTIPLICATIVE")
F.linear = mult_linear
F.conv2d = mult_conv2d
multiplicative_results = run_tests(n_trials=1, n_epochs=50, print_every=1)
print(multiplicative_results)



MULTIPLICATIVE
zero found in mult_conv2d


TypeError: ignored

In [None]:
output = old_linear_def(cache_lin[0], cache_lin[1][0], cache_lin[1][1])
# (cache_lin[0].isinf()).sum()
cache_mult[0].isinf().sum()

# optimizer = torch.optim.Adam(model.parameters(), lr=10*LEARNING_RATE)
# criterion = nn.CrossEntropyLoss()
# model, optimizer, traj = training_loop(model, criterion, optimizer, train_loader, test_loader, 3, DEVICE, print_every=1)

In [None]:
line_a_tr,= plt.plot(additive_results[0][0], color=(0,0,1))
line_a_te,= plt.plot(additive_results[0][2], color=(.5,.5,1))

line_m_tr,= plt.plot(multiplicative_results[0][0], color=(1,0,0))
line_m_te,= plt.plot(multiplicative_results[0][2], color=(1,.5,.5))
# line_a_te.set_dashes([1,1])

In [None]:
print(cache[0].size())
print(cache[1].size())
print(cache[2].size())
print((cache[0] <= 0).sum())
print((cache[2] <= 0).sum())
print(torch.log(cache[0]).isinf().sum())
old_linear_def(torch.log(cache[0]), cache[1]).isnan().sum()

In [None]:
q = torch.log(cache[0])
for i in range(len(q)):
  for j in range(len(q[i])):
    if q[i][j].isinf():
      print(i,j)

In [None]:
xs = sample_loader(train_loader, n_samples=128)
model(xs).isnan().sum()

In [None]:
cache[0][cache[0].isinf()]

In [None]:
x = None
for sample in train_loader:
  if model(sample[0].cuda()).isinf().sum() > 0:
    x = sample
    print('found a nan')
    break

In [None]:
input = torch.log(cache_mult[0])
output = torch.exp(old_conv_def(cache_mult[0], *cache_mult[1])) + 0
output.isinf().sum()

In [None]:
cache_mult[0].max