To run the code you need pytorch and torchvision installed on your device. 

In [None]:

import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
%pylab
%matplotlib inline
import IPython

import numpy as np

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
import sys
import torch.nn.functional as func
import pickle
dtype = torch.FloatTensor
dtype_labels = torch.LongTensor

no_of_hl= 30   #second arg is the number of hidden layers

HUs=128 # number of hidden units
step_size =0.01 # stepsize
min_batch_size = 32 # minibatch for SGD
batch_norm_size = 10
hidden_layers=np.ones(no_of_hl,dtype=int)*HUs  

#### DATA

In [None]:
import torchvision
import torchvision.transforms as transforms

In [None]:
# extracting Fashion-minist using torchvision
transform = transforms.Compose(
    [transforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(root='Fashion_MNIST_data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=min_batch_size,
                                          shuffle=True, num_workers=1)

In [None]:
# making layer sizes
# loading a mini-batch out of the dataset
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.view(-1,1*28*28)
D_in=images[0].shape[0]
D_out=10 # for mnist is 10 
_layers=np.append(hidden_layers,[D_out])
layers=np.append([D_in], _layers)  #this variable contains the Network arcitecture in terms of units in each layer,e.g. 5,10,10,1 (D_in,hidden 1, hidden 2, D_out)
print('Network architecture (no of units in layer): ',layers)


In [None]:
#networks 
def normalize_center(A):
    d=A.shape[1]
    means=torch.mean(A,dim=0)
    stds=torch.std(A,dim=0)
    stds=stds+0.001
    A_scaled=(A-means.reshape(1,d))/stds.reshape(1,d) ## normalized along COLS (feature lives in row for me)
    return means.detach(), stds.detach(), A_scaled



# MLP without batch normalization 
class MlpPlane(torch.nn.Module): 
    def __init__(self,h_sizes): 
        super(MlpPlane,self).__init__()
        self.h_sizes = h_sizes
        self.layers = nn.ModuleList()
        for k in range(len(h_sizes)-1):
            linear_module = nn.Linear(h_sizes[k].item(), h_sizes[k+1].item())
            variance = np.sqrt(2.0/(h_sizes[k].item() + h_sizes[k+1].item()))
            linear_module.weight.data.normal_(0.0, variance)
            self.layers.append(linear_module)
    def forward(self,x):
        for k in range(len(self.h_sizes)-2): 
            x = torch.relu(self.layers[k](x))
        return self.layers[len(self.h_sizes)-2](x)
    def get_weights(self): 
        ws = [None]*(len(self.h_sizes)-1)
        for k in range(len(self.h_sizes)-1): 
            ws[k] = self.layers[k].weight
        return ws
    def getlayerloss(self,x,layer_num): # approximate 
        for k in range(layer_num): 
            x = torch.relu(self.layers[k](x))
       # for k in range(layer_num+1): 
        #    x = torch.relu(self.layers[k](x))
        x=self.layers[layer_num](x)
        
        M = x.t().mm(x)/x.size(0)
        return torch.trace(M.mm(M))/torch.trace(M)**2 #+ troch.norm(M)
    def getblanceloss(self,x):
        lo = 0 
        for k in range(len(self.h_sizes)-1): 
            x = torch.relu(self.layers[k](x))
            M = x.mm(x.t())/float(min_batch_size)
#             print(M.size())
            lo = lo + torch.trace(M.mm(M))/torch.trace(M)**2 #+ torch.norm(M)
        return lo
##### BATCH Normalization 


class BNN(nn.Module): #note that the actual number of hidden laers is no_of_hidden_layers+1
    def __init__(self, input_dim=784, hidden_dim=128, output_dim=10,no_of_hidden_layers=no_of_hl,seed=None, act=torch.tanh):
        super(BNN, self).__init__()
        if seed is not None:
            torch.manual_seed(seed)
        self.Win = nn.Linear(input_dim, hidden_dim,bias=True)        
        self.layers = torch.nn.ModuleList([nn.Linear(hidden_dim, hidden_dim,bias=True) for _ in range(no_of_hidden_layers)])
        self.BNlayers = torch.nn.ModuleList([nn.BatchNorm1d(hidden_dim,momentum=0.0) for _ in range(no_of_hidden_layers)])
        self.Wout = nn.Linear(hidden_dim, output_dim,bias=True)
        self.act=act
    def forward(self, input):
        means_list=[]
        stds_list=[]
        _x = self.Win(input)
        x=self.act(_x)
        
        for layer,BN in zip(self.layers,self.BNlayers):
            _x=BN(layer(x))
            x=self.act(_x)
        y_pred = self.Wout(x)
        return y_pred

    def getlayerloss(self,x,layer_num): # approximate 
        
        _x = self.Win(x)
        x=self.act(_x)
        counter=0

        for layer,BN in zip(self.layers,self.BNlayers):
            if counter<=layer_num:
                _x=BN(layer(x))
                x=self.act(_x)
                counter=counter+1
        
        M = x.t().mm(x)/x.size(0)
        return torch.trace(M.mm(M))/torch.trace(M)**2 #+ troch.norm(M)

class MlpBatch(MlpPlane): 
    def __init__(self,h_sizes): 
        super(MlpBatch,self).__init__(h_sizes)
        self.batches = nn.ModuleList()
        for k in range(len(h_sizes)-2): 
            self.batches.append(torch.nn.BatchNorm1d(num_features=h_sizes[k+1].item(),momentum=0.0))
    def forward(self,x):
        for k in range(len(self.h_sizes)-2): 
            x = torch.relu(self.batches[k](self.layers[k](x)))
        return self.layers[len(self.h_sizes)-2](x)
    
    def getlayerloss(self,x,layer_num): # approximate 

        counter=0
        for k in range(layer_num): 
            x = torch.relu(self.batches[k](self.layers[k](x)))

        
        M = x.t().mm(x)/x.size(0)
        return torch.trace(M.mm(M))/torch.trace(M)**2 #+ troch.norm(M)
            


In [None]:

import torch.nn.functional as f
def run_training(mlp, epochs = 15,ss = step_size): # this function runs SGD training for the given network mlp using stepsize ss for #epoches 
    errors = []
    h_ranks=[]
    criterion = torch.nn.CrossEntropyLoss(size_average=True)
    opt2= torch.optim.SGD(mlp.parameters(),lr =ss )
    loss_epoch = 0 
    data_counter = 0 
    N = 50000
    for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs= inputs.view(-1,1*28*28)
            inputs = Variable(inputs).type(dtype)
            labels = Variable(labels).type(dtype_labels)
            outputs = mlp.forward(inputs)
            loss = criterion(outputs, labels)
            loss_epoch += loss.detach().numpy()*inputs.shape[0]/float(N)
            data_counter += inputs.shape[0]
    print(loss_epoch)
    with torch.no_grad():

        avg_rank=(mlp.getlayerloss(inputs,1)+mlp.getlayerloss(inputs,5)+mlp.getlayerloss(inputs,10)+mlp.getlayerloss(inputs,15)\
                +mlp.getlayerloss(inputs,20)+mlp.getlayerloss(inputs,25)+mlp.getlayerloss(inputs,29))/7
        h_ranks.append(avg_rank.item())

    errors.append(loss_epoch)
    for epoch in range(epochs):  # loop over the dataset multiple times

        print('new epoch--------')
        loss_epoch = 0 
        data_counter = 0 
        for i, data in enumerate(train_loader, 0):
            opt2.zero_grad()
            inputs, labels = data
            inputs= inputs.view(-1,28*28)
    #         inputs = f.normalize(inputs, p=2, dim=1)
            inputs = Variable(inputs).type(dtype)
            labels = Variable(labels).type(dtype_labels)
            outputs = mlp.forward(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            opt2.step()
            loss_epoch += loss.detach().numpy()*inputs.shape[0]/float(N)
    #         data_counter += inputs.shape[0]
        print('loss:',loss_epoch)
        with torch.no_grad():
            avg_rank=(mlp.getlayerloss(inputs,1)+mlp.getlayerloss(inputs,5)+mlp.getlayerloss(inputs,10)+mlp.getlayerloss(inputs,15)\
                    +mlp.getlayerloss(inputs,20)+mlp.getlayerloss(inputs,25)+mlp.getlayerloss(inputs,29))/7
            h_ranks.append(avg_rank.item())
        print('hrank:',1./avg_rank)


        errors.append(loss_epoch)
    return errors,h_ranks

### Linear Nets

In [None]:
# loading a mini-batch out of the dataset
dataiter = iter(train_loader)
xb, labels = dataiter.next()
xb = xb.view(-1,1*28*28)

In [None]:
models = []
xb = images
soft_ranks = []


for i in range(4): 
    models.append(MlpPlane(layers))

loss = models[0].getlayerloss(xb,9)
print(loss.data)
soft_ranks.append(loss.data)

#### Pre train

In [None]:

bsize = xb.shape[0]
mlp = MlpPlane(layers)
ss = 0.1
outnet = mlp.forward(xb)
outnet = outnet.detach().numpy()
_,s,_ = np.linalg.svd(outnet)
rank_track = []

print('===============')
print('Before optimization: normalized eigenvalues of input-output mapping (normalized by the trace)')
print(s/sum(s))
print('===============')
ranges = [5,75]  


for mm in range(len(models)-1):
    print('============================')
    print('new model!!')
    avg_loss=[]
    for kk in range(no_of_hl+1): # we layerwise optimize the established lower-bound on the rank function 

        if kk>0 and kk%4==0:
            for i in range(ranges[mm]):
                if i==0:
                    loss = mlp.getlayerloss(xb,kk) #compute the approx rank of layer kk
                    print("layer "+str(kk)+' --before--')
                    print(loss.data)
                    print('-------')
            
                for j in np.arange(0,kk+1):
                    mlp.layers[j].weight.data = mlp.layers[j].weight.data- ss*torch.autograd.grad(loss, mlp.layers[j].weight, create_graph=True)[0].data


                loss = mlp.getlayerloss(xb,kk)
                avg_loss.append(loss.item())

            print("layer "+str(kk)+" --after " + str(ranges[mm]) + " iterations--")
            print(loss.data)
            print('-------')
    models[mm+1].load_state_dict(mlp.state_dict()) ## Copies parameters and buffers from state_dict into this module and its descendants
    soft_ranks.append(np.mean(avg_loss))
    
## save an extra model that will be regularized
extra_model=MlpPlane(layers)
extra_model.load_state_dict(mlp.state_dict())
    
outnet = mlp.forward(xb)
outnet = outnet.detach().numpy()
_,s,_ = np.linalg.svd(outnet)
print('===============')
print('After optimization: normalized eigenvalues of input-output mapping (normalized by the trace)')
print(s/sum(s))
print('===============')

In [None]:
models_train = []
xb = images
for i in range(len(models)): 
    mod = MlpPlane(layers)
    mod.load_state_dict(models[i].state_dict())
    models_train.append(mod)

for model in models_train: 
    loss = model.getlayerloss(xb,no_of_hl)
    print(loss.data)
    print(model.layers[no_of_hl].weight.size())

In [None]:
errors = []
rankz = []
for i in range(len(models)): 
  terror,_rank = run_training(models[i],epochs = 15,ss = 0.01)
  errors.append(terror)
  rankz.append(_rank)

In [None]:
for i in range(len(rankz)):
    rankz[i]=1./np.array(rankz[i])

### BN nets

In [None]:
mlp_batch = BNN(no_of_hidden_layers=no_of_hl)

errors_batch, rankz_batch = run_training(mlp_batch,epochs = 15,ss = 0.01)

#### Screw BN up

In [None]:
def bad_initialization(model,hl,input_dim=784, hidden_dim=128, output_dim=10):
    for p in model.parameters():
        C=torch.FloatTensor(p.data.shape).uniform_(0, 0.1).type(dtype)
        p.data=C.data
        

In [None]:
mlp_batch = BNN(no_of_hidden_layers=no_of_hl,seed=torch.LongTensor(1).random_(0, 100))

bad_initialization(mlp_batch,no_of_hl)
errors_batch_uni, rankz_batch_uni = run_training(mlp_batch,epochs = 15,ss = 0.01)

### PLOT

In [None]:
import pandas as pd

In [None]:
errors_batch_1=errors_batch
errors_batch_uni_1=errors_batch_uni
rankz_1=rankz
errors_1=errors
rankz_batch_1=rankz_batch
rankz_batch_uni_1=rankz_batch_uni

In [None]:
dict_run1 = {'run_id': np.ones(len(errors_batch),dtype=np.int8)*1,'errors_mlp_1': errors[0],'errors_mlp_2': errors[1],
             'errors_mlp_3': errors[2],'errors_mlp_4': errors[3],
             'ranks_mlp_1': rankz[0],'ranks_mlp_2': rankz[1],
             'ranks_mlp_3': rankz[2],'ranks_mlp_4': rankz[3],
             'errors_batch': errors_batch, 
             'errors_batch_uni': errors_batch_uni,'rankz_batch': rankz_batch,
             'rankz_batch_uni':rankz_batch_uni}  
    
pd_run1 = pd.DataFrame(dict_run1) 

In [None]:
#data=pd.concat([pd_run1,pd_run2,pd_run3])
data=pd_run1

In [None]:
data["rankz_batch_uni"]=1./data["rankz_batch_uni"]
data["rankz_batch"]=1./data["rankz_batch"]

In [None]:
import seaborn as sns

In [None]:
fig=plt.figure()


sns.lineplot(data=data,x=data.index,y="errors_mlp_1",label="SGD no pre-training", ci=95)
sns.lineplot(data=data,x=data.index,y="errors_mlp_2",label="SGD low pre-training", ci=95)
sns.lineplot(data=data,x=data.index,y="errors_mlp_3",label="SGD high pre-training", ci=95)
#sns.lineplot(data=data,x=data.index,y="errors_mlp_4",label="MLP 4", marker="o", ci=95)

sns.lineplot(data=data,x=data.index,y="errors_batch",label="BN $W\sim U[-a,a]$", marker="X", ci=95)

ax=sns.lineplot(data=data,x=data.index,y="errors_batch_uni",label="BN $W\sim U[0,2a]$",  marker="X", ci=95)

plt.legend(fontsize=14)
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5),fontsize=14)

ax.xaxis.set_tick_params(labelsize=12.5)
ax.yaxis.set_tick_params(labelsize=12.5)
plt.ylabel('Training loss',fontsize=14)
plt.xlabel('Epochs',fontsize=14)
#plt.title("10 hidden layers", fontsize=15)
plt.savefig("fig_pretrain_a.pdf")

In [None]:
%matplotlib inline

In [None]:
fig=plt.figure()


sns.lineplot(data=data,x=data.index,y="ranks_mlp_1",label="SGD no pre-training",  ci=95)
sns.lineplot(data=data,x=data.index,y="ranks_mlp_2",label="SGD low pre-training", ci=95)
sns.lineplot(data=data,x=data.index,y="ranks_mlp_3",label="SGD high pre-training",  ci=95)
#sns.lineplot(data=data,x=data.index,y="ranks_mlp_4",label="MLP 4", marker="o", ci=95)

sns.lineplot(data=data,x=data.index,y="rankz_batch",label="BN W\sim$ U[-a,a]$", marker="X", ci=95)

ax=sns.lineplot(data=data,x=data.index,y="rankz_batch_uni",label="BN $W\sim U[0,2a]$", marker="X", ci=95)


ax.xaxis.set_tick_params(labelsize=12.5)
ax.yaxis.set_tick_params(labelsize=12.5)
plt.ylabel('Lower bound on rank',fontsize=14)
plt.xlabel('Epochs',fontsize=14)
plt.yscale("log")

ax.legend(loc='center left', bbox_to_anchor=(1, 0.5),fontsize=14)
