In [1]:
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
import os
import copy
import math

print("PyTorch Version:", torch.__version__)
print("Torchvision Version:", torchvision.__version__)
print("GPU is available?", torch.cuda.is_available())

PyTorch Version: 1.11.0+cu113
Torchvision Version: 0.12.0+cu113
GPU is available? True


In [2]:
dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Imported datasets
For the testing and comparison of our algorithms we will use the following datasets:

1. MNIST

In [3]:
ts = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0,), (1,))])
mnist_trainset = datasets.MNIST('../data', train=True, download=True, transform=ts)
mnist_testset = datasets.MNIST(root='../data', train=False, download=True, transform=ts)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



# Train - test split

Code taken from https://github.com/timlautk/BCD-for-DNNs-PyTorch/blob/master/bcd_dnn_mlp_mnist.ipynb

In [4]:
#train-set initialization
x_d0 = mnist_trainset[0][0].size()[0]
x_d1 = mnist_trainset[0][0].size()[1]
x_d2 = mnist_trainset[0][0].size()[2]
N = x_d3 = len(mnist_trainset)
K = 10
x_train = torch.empty((N,x_d0*x_d1*x_d2), device=device)
y_train = torch.empty(N, dtype=torch.long)
for i in range(N): 
    x_train[i,:] = torch.reshape(mnist_trainset[i][0], (1, x_d0*x_d1*x_d2))
    y_train[i] = mnist_trainset[i][1]
x_train = torch.t(x_train)
#y_one_hot = torch.zeros(N, K).scatter_(1, torch.reshape(y_train, (N, 1)), 1)
#y_one_hot = torch.t(y_one_hot).to(device=device)
y_train = y_train.to(device=device)

#test-set initialization
N_test = x_d3_test = len(mnist_testset)
x_test = torch.empty((N_test,x_d0*x_d1*x_d2), device=device)
y_test = torch.empty(N_test, dtype=torch.long)
for i in range(N_test): 
    x_test[i,:] = torch.reshape(mnist_testset[i][0], (1, x_d0*x_d1*x_d2))
    y_test[i] = mnist_testset[i][1]
x_test = torch.t(x_test)
#y_test_one_hot = torch.zeros(N_test, K).scatter_(1, torch.reshape(y_test, (N_test, 1)), 1)
#y_test_one_hot = torch.t(y_test_one_hot).to(device=device)
y_test = y_test.to(device=device)

# Architecture initialization

For the MultiLayerPerceptron we have the parameters **input_size** , **hidden_size**,**output_size** corresponding to the size of the input layer, the hidden layer and the output layer, respectively.

The MLP only has 3 layers like https://github.com/timlautk/BCD-for-DNNs-PyTorch/blob/master/bcd_dnn_mlp_mnist.ipynb as a starting point.

Also we use ReLU currently for the same reason.

In [5]:
input_size = 28*28
hidden_size = 1500
output_size = 10

In [6]:
torch.manual_seed(32)
d0 = input_size
d1 = d2 = hidden_size
d3 = output_size 

# The layers are: input + 2 hidden + output

# we represent the weigths of each layer as matrices with d_{i-1} columns and d_{i} rows

# Weight initialization (we replicate pytorch initialization)
std_1 = math.sqrt(1/d0)
W1 = torch.FloatTensor(d1, d0).uniform_(-std_1, std_1)
b1 = torch.FloatTensor(d1, 1).uniform_(-std_1, std_1)

# we move them to GPU
b1 = b1.to('cuda:0')
W1 = W1.to('cuda:0')


std_2 = math.sqrt(1/d1)
W2 = torch.FloatTensor(d2, d1).uniform_(-std_2, std_2)
b2 = torch.FloatTensor(d2, 1).uniform_(-std_2, std_2)

# we move them to GPU
b2 = b2.to('cuda:0')
W2 = W2.to('cuda:0')

std_3 = math.sqrt(1/d2)
W3 = torch.FloatTensor(d3, d2).uniform_(-std_3, std_3)
b3 = torch.FloatTensor(d3, 1).uniform_(-std_3, std_3)

# we move them to GPU
b3 = b3.to('cuda:0')
W3 = W3.to('cuda:0')


U1 = torch.addmm(b1.repeat(1, N), W1, x_train) # equivalent to W1@x_train+b1.repeat(1,N)
V1 = nn.ReLU()(U1)
U2 = torch.addmm(b2.repeat(1, N), W2, V1)
V2 = nn.ReLU()(U2)
U3 = torch.addmm(b3.repeat(1, N), W3, V2)
V3 = U3

# constant initializations
gamma = 1
gamma1 = gamma2 = gamma3 = gamma4 = gamma

rho = gamma
rho1 = rho2 = rho3 = rho4 = rho


alpha = 5
alpha1 = alpha2 = alpha3 = alpha4 = alpha5 = alpha6 = alpha7 \
= alpha8 = alpha9 = alpha10 = alpha

# initialization of the vectors of losses and accuracies
niter = 500
loss1 = np.empty(niter)
loss2 = np.empty(niter)
accuracy_train = np.empty(niter)
accuracy_test = np.empty(niter)
time1 = np.empty(niter)

In [7]:
def updateV_js(U1,U2,W,b,rho,gamma): 
    _, d = W.size()
    I = torch.eye(d, device=device)
    U1 = nn.ReLU()(U1)
    _, col_U2 = U2.size()
    Vstar = torch.mm(torch.inverse(rho*(torch.mm(torch.t(W),W))+gamma*I), rho*torch.mm(torch.t(W),U2-b.repeat(1,col_U2))+gamma*U1)
    return Vstar

In [8]:
def updateWb_js(U, V, W, b, alpha, rho): 
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse(alpha*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = (alpha*b+rho*torch.sum(U-torch.mm(W,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

In [9]:
def relu_prox(a, b, gamma, d, N):
    val = torch.empty(d,N, device=device)
    x = (a+gamma*b)/(1+gamma)
    y = torch.min(b,torch.zeros(d,N, device=device))
    # torch.zeros(d,N, device=device)
    val = torch.where(a+gamma*b < 0, y, torch.zeros(d,N, device=device))
    val = torch.where(((a+gamma*b >= 0) & (b >=0)) | ((a*(gamma-np.sqrt(gamma*(gamma+1))) <= gamma*b) & (b < 0)), x, val)
    val = torch.where((-a <= gamma*b) & (gamma*b <= a*(gamma-np.sqrt(gamma*(gamma+1)))), b, val)
    return val

In [10]:
def flatten(t):
    return [item for sublist in t for item in sublist]

def random_select(min,max,nodes,dims):
  """
  This function generates a set of indexes to chose which nodes of the layer we will update
  :param min: the minimum index from which the nodes will be chosen
  :param max: the maximum index from which the nodes will be chosen
  :param nodes: the number of nodes that should be chosen
  :param dims: the second dimension of the layer 
  :return layers,res: two lists that index together randomly selected elements in the layers
  """
  res = []
  layers = []
  for i in range(dims):
    res.append(random.sample(range(min,max),nodes))
    layers.append([i]*nodes)
  return flatten(layers),flatten(res)

def rand(min,max,nodes):
   return random.sample(range(min,max),nodes)

In [11]:
nodes = 1000
layers = 10

In [12]:
for i in range(10,1,-1):
  print(i)

10
9
8
7
6
5
4
3
2


In [13]:
def index_lists(l1,l2):
  ret1 = flatten(list(map(lambda x : np.repeat(x,len(l2)).tolist(),l1)))
  ret2 = np.tile(l2,len(l1)).tolist()
  return ret1,ret2

In [14]:
def make_pred(Ws,bs,input,N):
  a1_train = input
  for i in range(0,len(Ws)-1):
    a1_train = nn.ReLU()(torch.addmm(bs[i].repeat(1, N), Ws[i], a1_train))
  pred = torch.argmax(torch.addmm(bs[len(Ws)-1].repeat(1, N), Ws[len(Ws)-1], a1_train), dim=0)
  return pred

In [23]:
#The function requires at least 1 hidden layer otherwise it need some rewrriting
def execute_function(layers,input_size,hidden_size,output_size,train_set,val_set,train_labels,val_labels,niter = 100, gamma = 1, alpha = 5):
  """
  The function takes the following arguements and produces a list of weights and biases with which 
  you can use the make_pred function to get a list of predictions
  :param layers: The total number of layers of the network
  :param input_size: The total size of the input layer
  :param hidden_size: The size of the hidden layer
  :param output_size: The size of the output layer (usefull for multiclass classification)
  :param train_set: The training set
  :param val_set: The validation set
  :param train_labels: The training labels
  :param test labels: The validation labels
  :param niter: The default number of epochs to train the network
  :param gamma: The gamma parameter of the algorithm
  :param alpha: The alpha parameter of the algorithm
  :return Ws,bs: Returns two lists that go in order from the input to the output layer of the weights and the biases of each layer
  """

  N = len(train_labels)
  N_test = len(val_labels)

  std = math.sqrt(1/input_size)
  W = torch.FloatTensor(hidden_size, input_size).uniform_(-std_1, std_1)
  b = torch.FloatTensor(hidden_size, 1).uniform_(-std_1, std_1)

  b = b.to('cuda:0')
  W = W.to('cuda:0')

  U = torch.addmm(b.repeat(1, N), W, x_train) # equivalent to W1@x_train+b1.repeat(1,N)
  V = nn.ReLU()(U)

  Ws = [W]
  bs = [b]
  Us = [U]
  Vs = [V]

  for i in range(1,layers-1):
    std = math.sqrt(1/hidden_size)
    W = torch.FloatTensor(hidden_size, hidden_size).uniform_(-std, std)
    b = torch.FloatTensor(hidden_size, 1).uniform_(-std, std)
    b = b.to('cuda:0')
    W = W.to('cuda:0')
    U = torch.addmm(b.repeat(1, N), W, Vs[-1])
    V = nn.ReLU()(U)
    Ws.append(W)
    bs.append(b)
    Us.append(U)
    Vs.append(V)
  
  std = math.sqrt(1/hidden_size)
  W = torch.FloatTensor(output_size, hidden_size).uniform_(-std, std)
  b = torch.FloatTensor(output_size, 1).uniform_(-std, std)

  # we move them to GPU
  b = b.to('cuda:0')
  W = W.to('cuda:0')
  U = torch.addmm(b.repeat(1, N), W, Vs[-1])
  V = U
  Ws.append(W)
  bs.append(b)
  Us.append(U)
  Vs.append(V)
  
  gamma1 = gamma2 = gamma3 = gamma4 = gamma

  rho = gamma
  rho1 = rho2 = rho3 = rho4 = rho

  alpha1 = alpha2 = alpha3 = alpha4 = alpha5 = alpha6 = alpha7 \
  = alpha8 = alpha9 = alpha10 = alpha

  loss1 = np.empty(niter)
  loss2 = np.empty(niter)
  accuracy_train = np.empty(niter)
  accuracy_test = np.empty(niter)
  time1 = np.empty(niter)

  y_one_hot = torch.zeros(N, output_size).to(device = device).scatter_(1, torch.reshape(train_labels,(N,1)), 1)
  y_one_hot = torch.t(y_one_hot).to(device=device)

  y_test_one_hot = torch.zeros(N_test, output_size).to(device = device).scatter_(1, torch.reshape(val_labels,(N_test,1)), 1)
  y_test_one_hot = torch.t(y_test_one_hot).to(device=device)

  print('Train on', N, 'samples, validate on', N_test, 'samples')
  for k in range(niter):

    start = time.time()

    # update V3
    Vs[-1] = (y_one_hot + gamma3*Us[-1] + alpha1*Vs[-1])/(1+ gamma3 + alpha1)
    
    # update U3 
    Us[-1] = (gamma3*Vs[-1] + rho3*(torch.mm(Ws[-1],Vs[-2]) + bs[-1].repeat(1,N)))/(gamma3 + rho3)

    # update W3 and b3
    W, b = updateWb_js(Us[-1],Vs[-2],Ws[-1],bs[-1],alpha1,rho3)
    Ws[-1] = W
    bs[-1] = b

    for i in range(len(Vs)-2,0,-1):
      Vs[i] = updateV_js(Us[i],Us[i+1],Ws[i+1],bs[i+1],rho3,gamma2)
      Us[i] = relu_prox(Vs[i],(rho2*torch.addmm(bs[i].repeat(1,N), Ws[i], Vs[i-1]) + alpha2*Us[i])/(rho2 + alpha2),(rho2 + alpha2)/gamma2,1500,N)
      W,b = updateWb_js(Us[i],Vs[i-1],Ws[i],bs[i],alpha3,rho2)
      Ws[i] = W
      bs[i]= b
    
    # update V1
    Vs[0] = updateV_js(Us[0],Us[1],Ws[1],bs[1],rho2,gamma1)
        
    # update U1
    Us[0] = relu_prox(Vs[0],(rho1*torch.addmm(bs[0].repeat(1,N), Ws[0], x_train) + alpha7*Us[0])/(rho1 + alpha7),(rho1 + alpha7)/gamma1,1500,N)
    
    # update W1 and b1
    W, b = updateWb_js(Us[0],x_train,Ws[0],bs[0],alpha8,rho1)
    Ws[0] = W
    bs[0] = b

    #a1_train = nn.ReLU()(torch.addmm(b1.repeat(1, N), W1, x_train))
    #a1_train = x_train
    #for i in range(len(Vs)-1,0,-1):
    #  a1_train = nn.ReLU()(torch.addmm(bs[i].repeat(1, N), Ws[i], a1_train))
    #pred = torch.argmax(torch.addmm(bs[0].repeat(1, N), Ws[0], a1_train), dim=0)
    pred = make_pred(Ws,bs,x_train,N)

    #a1_test = x_test
    #a1_test = nn.ReLU()(torch.addmm(b1.repeat(1, N_test), W1, x_test))
    #for i in range(len(Vs)-1,0,-1):
    #  a1_test = nn.ReLU()(torch.addmm(bs[i].repeat(1, N_test), Ws[i], a1_test))
    #pred_test = torch.argmax(torch.addmm(bs[0].repeat(1, N_test), Ws[0], a1_test), dim=0)
    pred_test = make_pred(Ws,bs,x_test,N_test) 

    loss1[k] = gamma3/2*torch.pow(torch.dist(V3,y_one_hot,2),2).cpu().numpy()
    loss2[k] = loss1[k] + rho1/2*torch.pow(torch.dist(torch.addmm(b1.repeat(1,N), W1, x_train),U1,2),2).cpu().numpy() \
    +rho2/2*torch.pow(torch.dist(torch.addmm(b2.repeat(1,N), W2, V1),U2,2),2).cpu().numpy() \
    +rho3/2*torch.pow(torch.dist(torch.addmm(b3.repeat(1,N), W3, V2),U3,2),2).cpu().numpy()
        
    # compute training accuracy
    correct_train = pred == train_labels
    accuracy_train[k] = np.mean(correct_train.cpu().numpy())
        
    # compute validation accuracy
    correct_test = pred_test == val_labels
    accuracy_test[k] = np.mean(correct_test.cpu().numpy())
        
    # compute training time
    stop = time.time()
    duration = stop - start
    time1[k] = duration
        
    # print results
    print('Epoch', k + 1, '/', niter, '\n', 
          '-', 'time:', time1[k], '-', 'sq_loss:', loss1[k], '-', 'tot_loss:', loss2[k], 
          '-', 'acc:', accuracy_train[k], '-', 'val_acc:', accuracy_test[k])
  return Ws,bs

In [25]:
execute_function(3,input_size,hidden_size,output_size,x_train,x_test,y_train,y_test,niter = 100, gamma = 1, alpha = 5)

Train on 60000 samples, validate on 10000 samples
Epoch 1 / 100 
 - time: 1.276848554611206 - sq_loss: 9066.9150390625 - tot_loss: 475391.3671875 - acc: 0.48468333333333335 - val_acc: 0.4893
Epoch 2 / 100 
 - time: 1.312650442123413 - sq_loss: 9066.9150390625 - tot_loss: 475391.3671875 - acc: 0.89395 - val_acc: 0.8921
Epoch 3 / 100 
 - time: 1.334385633468628 - sq_loss: 9066.9150390625 - tot_loss: 475391.3671875 - acc: 0.9387166666666666 - val_acc: 0.9379
Epoch 4 / 100 
 - time: 1.315220832824707 - sq_loss: 9066.9150390625 - tot_loss: 475391.3671875 - acc: 0.9483833333333334 - val_acc: 0.945
Epoch 5 / 100 
 - time: 1.317286729812622 - sq_loss: 9066.9150390625 - tot_loss: 475391.3671875 - acc: 0.9522333333333334 - val_acc: 0.9476
Epoch 6 / 100 
 - time: 1.3165884017944336 - sq_loss: 9066.9150390625 - tot_loss: 475391.3671875 - acc: 0.95385 - val_acc: 0.9497
Epoch 7 / 100 
 - time: 1.3227264881134033 - sq_loss: 9066.9150390625 - tot_loss: 475391.3671875 - acc: 0.9555333333333333 - val_ac

([tensor([[ 0.0174, -0.0334,  0.0232,  ...,  0.0034, -0.0328,  0.0047],
          [-0.0149, -0.0180,  0.0315,  ..., -0.0259,  0.0087, -0.0141],
          [-0.0162, -0.0152, -0.0008,  ..., -0.0079,  0.0037,  0.0256],
          ...,
          [ 0.0340, -0.0022, -0.0194,  ...,  0.0034, -0.0239, -0.0309],
          [ 0.0070,  0.0139,  0.0051,  ...,  0.0115, -0.0077,  0.0283],
          [ 0.0143, -0.0072, -0.0281,  ..., -0.0046,  0.0289,  0.0094]],
         device='cuda:0'),
  tensor([[-0.0168,  0.0198, -0.0152,  ...,  0.0159, -0.0082,  0.0115],
          [ 0.0030,  0.0082,  0.0224,  ...,  0.0068,  0.0042,  0.0172],
          [ 0.0198, -0.0201, -0.0165,  ...,  0.0123,  0.0048,  0.0037],
          ...,
          [-0.0036, -0.0142,  0.0082,  ..., -0.0151, -0.0089, -0.0229],
          [ 0.0225, -0.0239,  0.0111,  ..., -0.0077,  0.0232,  0.0042],
          [-0.0026, -0.0185,  0.0199,  ..., -0.0191, -0.0016, -0.0111]],
         device='cuda:0'),
  tensor([[ 3.2860e-01, -8.1814e-01,  8.8845e-02, 

In [21]:
# we perform one-hot encoding on the y-> y_one_hot will be a tensor of dimension 10*training_samples, for each sample there
# will be a one in the row corresponding to the class.

y_one_hot = torch.zeros(N, K).to(device = device).scatter_(1, torch.reshape(y_train,(N,1)), 1)
y_one_hot = torch.t(y_one_hot).to(device=device)

y_test_one_hot = torch.zeros(N_test, K).to(device = device).scatter_(1, torch.reshape(y_test,(N_test,1)), 1)
y_test_one_hot = torch.t(y_test_one_hot).to(device=device)

# Iterations
print('Train on', N, 'samples, validate on', N_test, 'samples')
for k in range(niter):

  start = time.time()

  #select_the nodes for the third layer
  nodes_up3 = rand(0,10,5)
  nodes_up2 = rand(0,1500,750)
  nodes_up1 = rand(0,1500,750)
  V3_sample = V3[nodes_up3]
  U3_sample = U3[nodes_up3]
  W3_sample = torch.transpose(torch.transpose(W3[nodes_up3],0,1)[nodes_up2],0,1)
  b3_sample = b3[nodes_up3]
  V2_sample = V2[nodes_up2]
  U2_sample = U2[nodes_up2]
  W2_sample = torch.transpose(torch.transpose(W2[nodes_up2],0,1)[nodes_up1],0,1)
  b2_sample = b2[nodes_up2]
  V1_sample = V1[nodes_up1]
  U1_sample = U1[nodes_up1]
  W1_sample = W1[nodes_up1]
  b1_sample = b1[nodes_up1]
  # update V3
  V3_sample = (y_one_hot[nodes_up3] + gamma3*U3_sample + alpha1*V3_sample)/(1+ gamma3 + alpha1)
      
  # update U3 
  U3_sample = (gamma3*V3_sample + rho3*(torch.mm(W3_sample,V2_sample) + b3_sample.repeat(1,N)))/(gamma3 + rho3)

  # update W3 and b3
  W3_sample, b3_sample = updateWb_js(U3_sample,V2_sample,W3_sample,b3_sample,alpha1,rho3)
  
  # update V2
  V2_sample = updateV_js(U2_sample,U3_sample,W3_sample,b3_sample,rho3,gamma2)
      
  # update U2
  U2_sample = relu_prox(V2_sample,(rho2*torch.addmm(b2_sample.repeat(1,N), W2_sample, V1_sample) + alpha2*U2_sample)/(rho2 + alpha2),(rho2 + alpha2)/gamma2,750,N)
      
  # update W2 and b2
  W2_sample, b2_sample = updateWb_js(U2_sample,V1_sample,W2_sample,b2_sample,alpha3,rho2)
  
  # update V1
  V1_sample = updateV_js(U1_sample,U2_sample,W2_sample,b2_sample,rho2,gamma1)
      
  # update U1
  U1_sample = relu_prox(V1_sample,(rho1*torch.addmm(b1_sample.repeat(1,N), W1_sample, x_train) + alpha7*U1_sample)/(rho1 + alpha7),(rho1 + alpha7)/gamma1,750,N)
  
  # update W1 and b1
  W1_sample, b1_sample = updateWb_js(U1_sample,x_train,W1_sample,b1_sample,alpha8,rho1)

  print(list(zip(nodes_up3,nodes_up2)))
  #print(W3[].shape)
  V3[nodes_up3] = V3_sample
  U3[nodes_up3] = U3_sample
  W3[index_lists(nodes_up3,nodes_up2)] = torch.reshape(W3_sample,(750*5,))
  b3[nodes_up3] = b3_sample
  V2[nodes_up2] = V2_sample
  U2[nodes_up2] = U2_sample
  W2[index_lists(nodes_up2,nodes_up1)] = torch.reshape(W2_sample,(750*750,))
  b2[nodes_up2] = b2_sample
  V1[nodes_up1] = V1_sample
  U1[nodes_up1] = V1_sample
  W1[nodes_up1] = W1_sample
  b1[nodes_up1] = b1_sample

  a1_train = nn.ReLU()(torch.addmm(b1_sample.repeat(1, N), W1_sample, x_train))
  a2_train = nn.ReLU()(torch.addmm(b2_sample.repeat(1, N), W2_sample, a1_train))
  pred = torch.argmax(torch.addmm(b3_sample.repeat(1, N), W3_sample, a2_train), dim=0)

  a1_test = nn.ReLU()(torch.addmm(b1_sample.repeat(1, N_test), W1_sample, x_test))
  a2_test = nn.ReLU()(torch.addmm(b2_sample.repeat(1, N_test), W2_sample, a1_test))
  pred_test = torch.argmax(torch.addmm(b3_sample.repeat(1, N_test), W3_sample, a2_test), dim=0)
      
  loss1[k] = gamma3/2*torch.pow(torch.dist(V3,y_one_hot,2),2).cpu().numpy()
  loss2[k] = loss1[k] + rho1/2*torch.pow(torch.dist(torch.addmm(b1.repeat(1,N), W1, x_train),U1,2),2).cpu().numpy() \
  +rho2/2*torch.pow(torch.dist(torch.addmm(b2.repeat(1,N), W2, V1),U2,2),2).cpu().numpy() \
  +rho3/2*torch.pow(torch.dist(torch.addmm(b3.repeat(1,N), W3, V2),U3,2),2).cpu().numpy()
      
  # compute training accuracy
  correct_train = pred == y_train
  accuracy_train[k] = np.mean(correct_train.cpu().numpy())
      
  # compute validation accuracy
  correct_test = pred_test == y_test
  accuracy_test[k] = np.mean(correct_test.cpu().numpy())
      
  # compute training time
  stop = time.time()
  duration = stop - start
  time1[k] = duration
      
  # print results
  print('Epoch', k + 1, '/', niter, '\n', 
        '-', 'time:', time1[k], '-', 'sq_loss:', loss1[k], '-', 'tot_loss:', loss2[k], 
        '-', 'acc:', accuracy_train[k], '-', 'val_acc:', accuracy_test[k])

Train on 60000 samples, validate on 10000 samples
[(2, 1495), (6, 1368), (8, 209), (5, 408), (3, 1431)]
Epoch 1 / 500 
 - time: 0.7344505786895752 - sq_loss: 29407.458984375 - tot_loss: 588855.4241790771 - acc: 0.042083333333333334 - val_acc: 0.0427
[(7, 1288), (6, 478), (2, 1138), (4, 273), (8, 772)]
Epoch 2 / 500 
 - time: 0.656104326248169 - sq_loss: 27555.2109375 - tot_loss: 685162.0559082031 - acc: 0.16803333333333334 - val_acc: 0.1727
[(1, 317), (0, 287), (5, 937), (7, 1338), (6, 599)]
Epoch 3 / 500 
 - time: 0.6585822105407715 - sq_loss: 25683.33203125 - tot_loss: 663158.2809906006 - acc: 0.08121666666666667 - val_acc: 0.0773
[(0, 426), (1, 1361), (6, 551), (3, 626), (7, 698)]
Epoch 4 / 500 
 - time: 0.6542847156524658 - sq_loss: 24441.55078125 - tot_loss: 618961.4534759521 - acc: 0.17298333333333332 - val_acc: 0.1796
[(4, 1427), (5, 1362), (1, 1088), (6, 1453), (2, 910)]
Epoch 5 / 500 
 - time: 0.6665594577789307 - sq_loss: 23358.212890625 - tot_loss: 593552.0503387451 - acc: 0

KeyboardInterrupt: ignored

# Training

Note: Fix it so that it moves everything to device in the following function and that it does the label sample split here

In [None]:
## We plot the train losses

plt.xlabel('epochs')
plt.ylabel('train losses')
plt.plot(np.arange(0,loss1.shape[0]), loss1)
plt.show()

In [None]:
## We plot the test accuracy

plt.xlabel('epochs')
plt.ylabel('test accuracy')
plt.plot(np.arange(0, accuracy_test.shape[0]), accuracy_test)
plt.show()