In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
directory='.'
# directory='/content/drive/MyDrive/esn2sparse'; 
# !cp "/content/drive/MyDrive/esn2sparse/params.py" "."

In [None]:
import torch
import numpy as np
import sympy as sp
from scipy import stats
import matplotlib as mpl
import matplotlib.pyplot as pl
import torchvision
import torchvision.datasets as datasets
import time
import torch.jit as jit
from torch import nn
from torch import optim
from os.path import exists
import gc
import importlib
import params_feedback as par
import time
import os
# device = 'cpu'
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
mnist_trainset = datasets.MNIST(root=directory+'/data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root=directory+'/data', train=False, download=True, transform=None)

In [None]:
X_te=mnist_testset.data               ## Test set images
y_te=mnist_testset.test_labels        ## Test set labels

N_o=10                                ## Number of output nodes/classes
N_te=y_te.size()[0]                   ## Number of test samples
Y_te=torch.zeros([N_te,N_o])          ## Initialisation of the one-hot encoded labels for the test set
Y_te[np.arange(0,N_te),y_te]=1        ## From labels to one-hot encoded labels for the test set

X_tr=mnist_trainset.data              ## Train set images
y_tr=mnist_trainset.train_labels      ## Train labels 
N_tr=y_tr.size()[0]                   ## Number of training samples
N_i = X_tr.size()[1]                  ## Number of inputs to ESN

Y_tr=torch.zeros([N_tr,N_o])          ## Initialisation of one-hot encoded labels for training
Y_tr[np.arange(0,N_tr),y_tr]=1        ## From labels to one-hot encoded labels for the training set

N_val=10000                           ## Here I take out N_val samples from the training set and use them for validation
i_val=np.random.permutation(np.arange(0,N_tr))[0:N_val]

X_val=X_tr[i_val,:,:]
Y_val=Y_tr[i_val,:]

i_tr=np.delete(np.arange(0,N_tr),i_val)
N_tr=N_tr-N_val

X_tr=X_tr[i_tr,:,:]
Y_tr=Y_tr[i_tr,:]

T=X_tr.size()[2]
N_in=X_tr.size()[1]

## Normalisation and conversion to float
X_M=255
# X_tr=torch.reshape( (X_tr.float()/X_M),[-1,784]) 
# X_val=torch.reshape((X_val.float()/X_M),[-1,784])
# X_te=torch.reshape((X_te.float()/X_M),[-1,784])
X_tr=torch.reshape( (X_tr.float()),[-1,784]).to(device)
X_val=torch.reshape((X_val.float()),[-1,784]).to(device)
X_te=torch.reshape((X_te.float()),[-1,784]).to(device)

for j in range(X_tr.shape[0]):
    X_tr[j,:] -= torch.mean(X_tr[j,:])
    X_tr[j,:] /= torch.std(X_tr[j,:])

for j in range(X_val.shape[0]):
    X_val[j,:] -= torch.mean(X_val[j,:])
    X_val[j,:] /= torch.std(X_val[j,:])

for j in range(X_te.shape[0]):
    X_te[j,:] -= torch.mean(X_te[j,:])
    X_te[j,:] /= torch.std(X_te[j,:])

X_tr=torch.reshape( (X_tr),[-1,28,28]) 
X_val=torch.reshape((X_val),[-1,28,28])
X_te=torch.reshape((X_te),[-1,28,28])

Y_tr=Y_tr.float().to(device)
Y_val=Y_val.float().to(device)
Y_te=Y_te.float().to(device)

print(f'X_tr shape = {X_tr.shape},    Y_tr shape = {Y_tr.shape}')


In [None]:
# Add random contrast transformations to input data (OLD VERSION - results in outlines)
for j in range(X_tr.shape[0]):
    X_tr[j,X_tr[j,:]==X_tr[j,:].min()] = X_tr[j,X_tr[j,:]==X_tr[j,:].min()] + torch.rand(1).to(device) * torch.tensor(2.).to(device) * (X_tr[j,:].max() - X_tr[j,:].min())
for j in range(X_val.shape[0]):
    X_val[j,X_val[j,:]==X_val[j,:].min()] = X_val[j,X_val[j,:]==X_val[j,:].min()] + torch.rand(1).to(device) * torch.tensor(2.).to(device) * (X_val[j,:].max() - X_val[j,:].min())
for j in range(X_te.shape[0]):
    X_te[j,X_te[j,:]==X_te[j,:].min()] = X_te[j,X_te[j,:]==X_te[j,:].min()] + torch.rand(1).to(device) * torch.tensor(2.).to(device) * (X_te[j,:].max() - X_te[j,:].min())

# Z-score inputs
for j in range(X_tr.shape[0]):
    X_tr[j,:] -= torch.mean(X_tr[j,:])
    X_tr[j,:] /= torch.std(X_tr[j,:])

for j in range(X_val.shape[0]):
    X_val[j,:] -= torch.mean(X_val[j,:])
    X_val[j,:] /= torch.std(X_val[j,:])

for j in range(X_te.shape[0]):
    X_te[j,:] -= torch.mean(X_te[j,:])
    X_te[j,:] /= torch.std(X_te[j,:])

In [None]:
# Add random contrast transformations to input data
contrast_mean = torch.tensor(0.5).to(device)
contrast_range = torch.tensor(0.99).to(device)
# Training data
for j in range(X_tr.shape[0]):
    # Rescale values to range 0-1
    X_tr[j,:] -= torch.min(X_tr[j,:])
    X_tr[j,:] /= torch.max(X_tr[j,:])
    # Reverse contrast with prob. = 0.5
    if torch.rand(1)>0.5:
        X_tr[j,:] = torch.tensor(1.0).to(device) - X_tr[j,:]
    # Z-score
    X_tr[j,:] -= torch.mean(X_tr[j,:])
    X_tr[j,:] /= (torch.std(X_tr[j,:]) / (contrast_mean + contrast_range*(torch.rand(1).to(device) - torch.tensor(0.5).to(device))))

# Validation data
for j in range(X_val.shape[0]):
    # Rescale values to range 0-1
    X_val[j,:] -= torch.min(X_val[j,:])
    X_val[j,:] /= torch.max(X_val[j,:])
    # Reverse contrast with prob. = 0.5
    if torch.rand(1)>0.5:
        X_val[j,:] = torch.tensor(1.0).to(device) - X_val[j,:]
    # Z-score
    X_val[j,:] -= torch.mean(X_val[j,:])
    X_val[j,:] /= (torch.std(X_val[j,:]) / (contrast_mean + contrast_range*(torch.rand(1).to(device) - torch.tensor(0.5).to(device))))

# Test data
for j in range(X_te.shape[0]):
    # Rescale values to range 0-1
    X_te[j,:] -= torch.min(X_te[j,:])
    X_te[j,:] /= torch.max(X_te[j,:])
    # Reverse contrast with prob. = 0.5
    if torch.rand(1)>0.5:
        X_te[j,:] = torch.tensor(1.0).to(device) - X_te[j,:]
    # Z-score
    X_te[j,:] -= torch.mean(X_te[j,:])
    X_te[j,:] /= (torch.std(X_te[j,:]) / (contrast_mean + contrast_range*(torch.rand(1).to(device) - torch.tensor(0.5).to(device))))

In [None]:
# IF NOT USING ESN
# Make inputs 2800-dim and z-score
from scipy import stats
with torch.no_grad():
    w = torch.randn([784, par.N_esn*28]).to(device)/torch.sqrt(torch.tensor(par.N_esn*28+784)).to(device)
    batchsize = torch.tensor(100).to(device)
    # Operate on batches of data to save GPU-memory
    # Training data
    nbatch = torch.ceil(X_tr.shape[0]/batchsize).int()
    temp = torch.zeros(batchsize, par.N_esn*28).to(device)
    perm = torch.zeros(X_tr.shape[0], par.N_esn*28).cpu()
    for j in range(nbatch):    
        temp = torch.matmul(X_tr[j*batchsize:min((j+1)*batchsize,X_tr.shape[0]),:,:].reshape(batchsize, X_tr.shape[1]*X_tr.shape[2]), w)
        temp -= torch.matmul(temp.mean(dim=1, keepdim=True), torch.ones(1, par.N_esn * 28).to(device))
        temp /= temp.std()
        perm[j*batchsize:min((j+1)*batchsize,X_tr.shape[0]), :] = torch.clone(temp).to(device)
    X_tr = torch.clone(perm).cpu()
    # Validation data
    nbatch = torch.ceil(X_val.shape[0]/batchsize).int()
    temp = torch.zeros(batchsize, par.N_esn*28).to(device)
    perm = torch.zeros(X_val.shape[0], par.N_esn*28).cpu()
    for j in range(nbatch):    
        temp = torch.matmul(X_val[j*batchsize:min((j+1)*batchsize,X_val.shape[0]),:,:].reshape(batchsize, X_val.shape[1]*X_val.shape[2]), w)
        temp -= torch.matmul(temp.mean(dim=1, keepdim=True), torch.ones(1, par.N_esn * 28).to(device))
        temp /= temp.std()
        perm[j*batchsize:min((j+1)*batchsize,X_val.shape[0]), :] = torch.clone(temp).to(device)
    X_val = torch.clone(perm).cpu()
    # Test data
    nbatch = torch.ceil(X_te.shape[0]/batchsize).int()
    temp = torch.zeros(batchsize, par.N_esn*28).to(device)
    perm = torch.zeros(X_te.shape[0], par.N_esn*28).cpu()
    for j in range(nbatch):    
        temp = torch.matmul(X_te[j*batchsize:min((j+1)*batchsize,X_te.shape[0]),:,:].reshape(batchsize, X_te.shape[1]*X_te.shape[2]), w)
        temp -= torch.matmul(temp.mean(dim=1, keepdim=True), torch.ones(1, par.N_esn * 28).to(device))
        temp /= temp.std()
        perm[j*batchsize:min((j+1)*batchsize,X_te.shape[0]), :] = torch.clone(temp).to(device)
    X_te = torch.clone(perm).cpu()
    
    # X_tr -= torch.matmul(X_tr.mean(dim=1, keepdim=True), torch.ones(1, par.N_esn * 28).to(device))
    # X_tr /= X_tr.mean()
    # X_tr -= torch.matmul(X_tr.mean(dim=1, keepdim=True), torch.ones(1, par.N_esn * 28).to(device))
    # X_tr /= X_tr.mean()

    # X_tr = torch.tensor(stats.zscore(torch.matmul(X_tr.reshape(X_tr.shape[0], X_tr.shape[1]*X_tr.shape[2]), w).numpy(), axis=1))
    # X_val = torch.tensor(stats.zscore(torch.matmul(X_val.reshape(X_val.shape[0], X_val.shape[1]*X_val.shape[2]), w).numpy(), axis=1))
    # X_te = torch.tensor(stats.zscore(torch.matmul(X_te.reshape(X_te.shape[0], X_te.shape[1]*X_te.shape[2]), w).numpy(), axis=1))



In [None]:
def Data2Classes(X,Y):
    
    ind=torch.where(Y==1)[1]

    N_class=torch.max(ind)+1
    
    X1=[]
    Y1=[]
    
    for n in range(N_class):
    
        ind1=torch.where(ind==n)[0].type(torch.long)

        X1.append(X[ind1,:].to(device))
        Y1.append(Y[ind1,:].to(device))
        
    return X1, Y1
        
# X_tr/X_val/X_te are lists of length 10 (1 entry per class)
X_tr, Y_tr=Data2Classes(X_tr,Y_tr)

X_val, Y_val=Data2Classes(X_val,Y_val)

X_te, Y_te=Data2Classes(X_te,Y_te)

In [None]:
# Plot activity and correlation between classes
N_esn = par.N_esn
a = np.zeros((1000,N_esn*28))
for j in range(len(X_tr)):
    a[j*100:(j+1)*100,:] = X_tr[j][0:100,:].numpy()
print(np.max(a), np.min(a))
c = np.matmul(stats.zscore(a,axis=1), np.transpose(stats.zscore(a,axis=1))) / a.shape[1]
fig = pl.figure(figsize=tuple(np.array((50.,20.))/2.54)); ax = pl.axes()
imdata = ax.imshow(X_tr[0][0:1000,:],vmin=-2.0, vmax=2.0)
fig = pl.figure(figsize=tuple(np.array((50.,20.))/2.54)); ax = pl.axes()
imdata = ax.imshow(X_tr[2][0:1000,:],vmin=-2.0, vmax=2.0)
fig = pl.figure(figsize=tuple(np.array((5.,5.))/2.54)); ax = pl.axes()
imdata = ax.imshow(c,vmin=-1.0, vmax=1.0)
# cb = fig.colorbar(imdata, ticks=[0, 1.0])

xent: 
 - train from layer 0 to layer -1
 - train forward pass to layer -1
 - train loss from layer -1
 - transfer from layer tranFrom to layer -1
 - transfer forward pass to layer -1
 - transfer loss from layer -1

metric:
 - train from layer 0 to layer tranFrom-1
 - train forward pass to layer tranFrom-1
 - train loss from layer tranForm-1
 - transfer frmo layer tranFrom to layer -1
 - transfer forward pass to layer -1
 - transfer loss from layer -1

In [None]:
class MLPclassic(nn.Module):
    
    def __init__(self,par):
        super().__init__()

        self.N_class=par.nClass
        self.batch_size = par.batch_size
        self.nSampPerClassPerBatch = int(par.batch_size/par.nClass) # No. input samples per class, per batch
        self.lossfn = nn.BCEWithLogitsLoss()
        # Setup target vector to compute Loss and Accuracy
        self.target = torch.zeros(self.batch_size).long().to(device)
        for j in range(1,self.N_class):
            self.target[j*self.nSampPerClassPerBatch:(j+1)*self.nSampPerClassPerBatch] = j

        self.N = par.N_esn
        self.alpha = par.alpha
        self.rho = par.rho
        self.N_av = par.N_av
        self.N_i = par.nInputs
        self.gamma = par.gamma
        self.fbLayer = par.fbLayer
        self.lossLayer = len(par.Ns)-1
        self.etaInitial = par.etaInitial
        self.etaTransfer = par.etaTransfer
        
        self.Ns = par.Ns

        dilution = 1-self.N_av/self.N
        W = np.random.uniform(-1, 1, [self.N, self.N])
        W = W*(np.random.uniform(0, 1, [self.N, self.N]) > dilution)
        eig = np.linalg.eigvals(W)
        self.W = torch.from_numpy(
            self.rho*W/(np.max(np.absolute(eig)))).float().to(device)

        self.x = []

        if self.N_i == 1:

            self.W_in = 2*np.random.randint(0, 2, [self.N_i, self.N])-1
            self.W_in = torch.from_numpy(self.W_in*self.gamma).float().to(device)

        else:

            self.W_in = np.random.randn(self.N_i, self.N)
            self.W_in = torch.from_numpy(self.gamma*self.W_in).float().to(device)

        self.Ws=[]
        self.bs=[]
        
        
        for n in range(1,np.shape(self.Ns)[0]):
        
            self.Ws.append(nn.Parameter((torch.randn([self.Ns[n-1],self.Ns[n]])/torch.sqrt(torch.tensor(self.Ns[n-1]+self.Ns[n]))).to(device)))
            self.bs.append(nn.Parameter(torch.zeros([self.Ns[n]]).to(device)))
        
        if par.fbLayer:
            self.W_fb = nn.Parameter((torch.randn([self.Ns[self.fbLayer],self.N])/torch.sqrt(torch.tensor(self.Ns[self.fbLayer] + self.N))).to(device))

    def initialOptimiser(self):
        if self.fbLayer:
            self.iniOpt=optim.Adam([{ 'params': self.Ws+self.bs+[self.W_fb], 'lr':self.etaInitial }])
        else:
            self.iniOpt=optim.Adam([{ 'params': self.Ws+self.bs, 'lr':self.etaInitial }])
        
    def Forward(self, input):
        ### ESN
        if self.fbLayer:
            self.x[0] = (1-self.alpha)*self.x[0]+self.alpha * \
                torch.tanh(torch.matmul(input, self.W_in)+torch.matmul(self.x[0], self.W)+torch.matmul(self.xFB, self.W_fb))
        else:
            self.x[0] = (1-self.alpha)*self.x[0]+self.alpha * \
                torch.tanh(torch.matmul(input, self.W_in)+torch.matmul(self.x[0], self.W))
        
        ### Hidden layers
        for n in range(1,len(self.Ns)-1): # For each layer
            # Build up training data so that outputs learn from all previous layers
            self.x[n] = torch.relu( torch.add(torch.matmul(self.x[n-1],self.Ws[n-1]),self.bs[n-1]) )
            if n==self.fbLayer:
                self.xFB = torch.clone(self.x[n])
        
        ### Output
        self.x[-1] = torch.add(torch.matmul(self.x[-2],self.Ws[-1]),self.bs[-1])

    def Reset(self, nSamples):

        self.x = []
        for n in range(0,len(self.Ns)): # For each layer
            self.x.append(torch.zeros(nSamples, self.Ns[n], requires_grad=True).to(device))
            if n==self.fbLayer:
                self.xFB = torch.clone(self.x[n])
    
    def lossCrossEntropy(self, nSamples):
        # Compute softmax probabilities for responses
        p = torch.div( torch.exp(self.x[-1]), torch.sum(torch.exp(self.x[-1]), 1, keepdim=True).tile((1,self.x[-1].shape[1])) )
        L = torch.mean(- torch.log(p[range(nSamples),self.target]))
        return L

    def accuracyTarget(self):
        acc = torch.mean(torch.eq( torch.argmax(self.x[-1],1), self.target ).type(torch.float))
        return acc
    
    def accuracyClassCentroid(self):
        nSamples = self.x[-1].shape[0]
        # Compute class centroids
        centroids = torch.zeros(self.N_class, self.Ns[self.lossLayer]).to(device)
        for j in range(self.N_class):
            centroids[j,:] = self.x[-1][j*self.nSampPerClassPerBatch:(j+1)*self.nSampPerClassPerBatch,:].mean(0)
        # Compute distances between samples and centroids. Is argmin(dist)==true class?
        o = torch.ones(self.N_class,1)
        Acc = 0.0
        
        for j in range(nSamples):     
            rr = torch.tile(torch.clone(self.x[-1][j,:]), [self.N_class, 1])
            dist = (rr - centroids).pow(2).sum(1)
            arg = torch.argmin(dist)       
            true_class = torch.floor(torch.tensor(j/self.nSampPerClassPerBatch)).long()
            Acc += torch.eq(arg, true_class).float()
        Acc /= nSamples
        
        return Acc
    
    def getLossAccuracy(self, backwardFlag=False, opt=[]):
        # Compute Loss and accuracy
        L = self.lossCrossEntropy(self.x[-1].shape[0])
        accTa = self.accuracyTarget()
        accClCe = self.accuracyClassCentroid()

        if backwardFlag:
            L.backward()
            opt.step()
            opt.zero_grad()
            
        return L,accTa,accClCe
    
    def response(self, Input):

        N_samples = Input.shape[0]
        T = Input.shape[2]

        self.Reset(N_samples)
        
        for t in range(T):
            self.Forward(Input[:, :, t])

    # def generateBatch(self, X):
       
    #     batch = torch.zeros(self.batch_size, X[0].shape[1], X[0].shape[2], dtype=torch.float32, requires_grad=False).to(device)
    #     self.target = torch.zeros(self.batch_size, self.nClass, dtype=torch.float32).to(device)
    #     # for j in range(self.batch_size):
    #     #     ind1 = torch.randint(0,self.nClass,[1])
    #     #     ind2 = torch.randint(0,X[ind1].shape[0],[1])
    #     #     batch[j,:,:] = torch.clone(X[ind1][ind2,:,:])
    #     #     self.target[j,ind1] = 1.

    #     for k in range(self.nClass):
    #         rand_ind=np.random.randint(0,X[k].shape[0],(self.nSampPerClassPerBatch,))
    #         batch[k*self.nSampPerClassPerBatch:(k+1)*self.nSampPerClassPerBatch,:,:] = torch.clone(X[k][rand_ind,:,:])
    #         self.target[k*self.nSampPerClassPerBatch:(k+1)*self.nSampPerClassPerBatch, k] = 1

    #     return batch
    
    def responseSave(self, Input,saveLayers=[]):

        N_samples = Input.shape[0]
        T = Input.shape[2]
        
        sav = []
        for l in saveLayers:
            sav.append(torch.zeros(N_samples, self.Ns[l], T))

        self.Reset(N_samples)
        for t in range(T):
            self.Forward(Input[:, :, t])
            
            for li, l in enumerate(saveLayers):
                sav[li][:,:,t] = torch.clone(self.x[l].detach())
    
        return sav

In [None]:
###
### WITHOUT metric learning
###
def xent_esn_fb(expName,rngSeed):

    ### Initiliase RNGs
    torch.manual_seed(rngSeed)
    np.random.default_rng(rngSeed)

    ### Setup directory names
    experiment = expName
    expDir = directory+'/data/'+experiment
    if not os.path.exists(expDir):
        os.mkdir(expDir)
    outputDir = expDir    # Storage directory for input/label data
    if not os.path.exists(outputDir):
        os.mkdir(outputDir)

    ### Other parameters
    save_every = int(np.floor(par.nSaveMaxT / par.nWeightSave)) # Save weights every <> epochs, up to epoch nSaveMaxT
    tMax = X_tr[0].shape[2]

    ###############################
    #### First phase of training
    ###############################

    # Init memory to save responses
    resp_saveind = 0
    if par.saveFlag_RESP:
        RESP = []
        for layer in par.saveLayers:
            RESP.append(np.zeros((len(par.saveRespAtN)+1, par.nClass * par.nSaveSamples, par.Ns[layer], tMax))) 
    # Init memory to save feedback weights
    if par.saveFlag_FBWeights:
        savWeights = np.zeros((par.Ns[par.fbLayer], par.Ns[0], par.nWeightSave+1))
    # Init memory to save effective learning rate
    dw_saveind = 0
    dwEdges = torch.logspace(-10,-3,51)
    savDW = [] # List of arrays to store DW histograms over time
    for k in range(len(par.Ns)-1):
        savDW.append(np.zeros((dwEdges.shape[0]-1, par.nWeightSave))) # Histograms for feedforward layers
    if par.fbLayer:
        savDW.append(np.zeros((dwEdges.shape[0]-1, par.nWeightSave))) # Histogram for feedback layer

    ### Initialise model
    MOD = MLPclassic(par)

    MOD.initialOptimiser()
    L_tr = []; L_tr.append(np.zeros([par.nEpisodes]))
    A_trTa = []; A_trTa.append(np.zeros([par.nEpisodes]))
    A_trClCe = []; A_trClCe.append(np.zeros([par.nEpisodes]))
    L_val = []; L_val.append(np.zeros([par.nEpisodes]))
    A_valTa = []; A_valTa.append(np.zeros([par.nEpisodes]))
    A_valClCe = []; A_valClCe.append(np.zeros([par.nEpisodes]))

    print('**********************START TRAINING')
    t=time.time()
    for n in range(par.nEpisodes):
        
        ### Save met responses before updates
        if par.saveFlag_RESP and (np.any(np.isin(n, par.saveRespAtN)) or n==(par.nEpisodes-1)):
            print(f'Saving response on iteration {n} of {par.nEpisodes}')
            with torch.no_grad():
                s = X_tr[0][0:par.nSaveSamples,:]
                for k in range(1,par.nClass):
                    s = torch.concat([s, X_tr[k][0:par.nSaveSamples,:]],0)
                resp = MOD.responseSave(s, par.saveLayers)
                for li, l in enumerate(par.saveLayers):
                    RESP[li][resp_saveind,:,:,:] = resp[li].numpy()
            resp_saveind += 1

        ### Save feedback weights
        if par.saveFlag_FBWeights and (n%save_every)<1:
            savWeights[:,:,int(np.floor(n/save_every))] = MOD.W_fb.data.cpu()

        ### Training data
        # Prepare batch
        rand_ind=np.random.randint(0,X_tr[0].size()[0],(int(par.batch_size/par.nClass),))
        Im=X_tr[0][rand_ind,:]
        for k in range(1,par.nClass):
            rand_ind=np.random.randint(0,X_tr[k].size()[0],(MOD.nSampPerClassPerBatch,))
            Im=torch.concat([Im,X_tr[k][rand_ind,:]],0)
        
        # Get response to input
        MOD.response(Im)
        # Compute loss and accuracy
        loss, accTrTa, accTrClCe = MOD.getLossAccuracy(backwardFlag=True, opt=MOD.iniOpt)

        #####
        ##### CHECKING FOR NANS
        if torch.any(torch.isnan(MOD.x[-1])):
            print(f'Quitting at batch {n} of {par.nEpisodes}')
            return -1
        #####
        #####
        
        # Store loss and accuracy
        L_tr[0][n]=np.copy(np.array(loss.to('cpu').detach()))
        A_trTa[0][n]=np.copy(np.array(accTrTa.to('cpu').detach()))
        A_trClCe[0][n]=np.copy(np.array(accTrClCe.to('cpu').detach()))

        # Store weight changes for effective learning rate
        if par.saveFlag_DW and (n%save_every)<1:
            with torch.no_grad():
                # Save before weights
                w1 = []
                for k in range(len(par.Ns)-1):
                    w1.append(torch.clone(MOD.Ws[k].data))
                if par.fbLayer:
                    w1.append(torch.clone(MOD.W_fb.data))
        if par.saveFlag_DW and ((n-1)%save_every)<1:
            with torch.no_grad():
                # Save after weights
                w2 = []
                for k in range(len(par.Ns)-1):
                    w2.append(torch.clone(MOD.Ws[k].data))
                if par.fbLayer:
                    w2.append(torch.clone(MOD.W_fb.data))
                # Compute weight change
                for k in range(len(w1)):
                    dw = torch.abs(torch.subtract(w2[k],w1[k])).cpu()
                    # savDW[k][:,dw_saveind] = torch.divide(torch.histogram(dw[dw>0], dwEdges).hist, torch.numel(dw)).cpu().numpy()
                    savDW[k][:,dw_saveind] = torch.histogram(dw[dw>0], dwEdges).hist.numpy()
                dw_saveind += 1
        
        ### Validation data
        with torch.no_grad():
            # Prepare batch
            rand_ind=np.random.randint(0,X_val[0].size()[0],(int(par.batch_size/par.nClass),))
            Im=X_val[0][rand_ind,:]
            Y=Y_val[0][rand_ind,:]
            for k in range(1,par.nClass):
                rand_ind=np.random.randint(0,X_val[k].size()[0],(int(par.batch_size/par.nClass),))
                Im=torch.concat([Im,X_val[k][rand_ind,:]],0)
            
            # Get response to input
            MOD.response(Im)
            # Compute loss and accuracy
            loss, accValTa, accValClCe = MOD.getLossAccuracy()
            
            # Store loss and accuracy
            L_val[0][n]=np.copy(np.array(loss.to('cpu').detach()))
            A_valTa[0][n]=np.copy(np.array(accValTa.to('cpu').detach()))
            A_valClCe[0][n]=np.copy(np.array(accValClCe.to('cpu').detach()))
        
        # Update learning rate
        MOD.iniOpt.param_groups[0]['lr'] = par.etaInitial * np.exp(-float(n)/par.eta_tau)

        if ((n+1)%par.reportTime==0) and (n>0):
            print(f'Time per stage: {time.time()-t}')
            mseTr_mean=np.mean(L_tr[0][(n+1)-par.reportTime:(n+1)])
            accTrTa_mean=np.mean(A_trTa[0][(n+1)-par.reportTime:(n+1)])
            accTrClCe_mean=np.mean(A_trClCe[0][(n+1)-par.reportTime:(n+1)])
            mseVal_mean=np.mean(L_val[0][(n+1)-par.reportTime:(n+1)])
            accValTa_mean=np.mean(A_valTa[0][(n+1)-par.reportTime:(n+1)])
            accValClCe_mean=np.mean(A_valClCe[0][(n+1)-par.reportTime:(n+1)])
            t=time.time()
            
            print(f'Progress: {np.float32(n+1)/np.float32(par.nEpisodes)*100.:.3}%   Mean Tr Er: {mseTr_mean}, Mean Val Er: {mseVal_mean};   Mean Tr AccTa: {accTrTa_mean}, Mean Val AccTa: {accValTa_mean}')
            print(f'Progress:                                                                                           Mean Tr AccClCe: {accTrClCe_mean}, Mean Val AccClCe: {accValClCe_mean}')

    ### Save hidden layer representations for transfer learning
    H_tr = []
    H_val = []
    with torch.no_grad():
        for j in range(par.nClass):
            MOD.response(X_tr[j])
            H_tr.append(MOD.x[par.saveRepLayer])
            MOD.response(X_val[j])
            H_val.append(MOD.x[par.saveRepLayer])
    


    ### Save outputs
    torch.save(L_tr, outputDir + '/' + 'lossTr'+str(rngSeed)+'.pt')
    torch.save(A_trTa, outputDir + '/' + 'accTrTa'+str(rngSeed)+'.pt')
    torch.save(A_trClCe, outputDir + '/' + 'accTrClCe'+str(rngSeed)+'.pt')
    torch.save(L_val, outputDir + '/' + 'lossVal'+str(rngSeed)+'.pt')
    torch.save(A_valTa, outputDir + '/' + 'accValTa'+str(rngSeed)+'.pt')
    torch.save(A_valClCe, outputDir + '/' + 'accValClCe'+str(rngSeed)+'.pt')
    torch.save(RESP, outputDir + '/' + 'respSave'+str(rngSeed)+'.pt')
    torch.save(MOD, outputDir + '/' + 'model'+str(rngSeed)+'.pt')
    torch.save(savDW, outputDir + '/' + 'dw'+str(rngSeed)+'.pt')
    torch.save([H_tr, H_val], outputDir + '/' + 'hResp'+str(rngSeed)+'.pt')
    if par.saveFlag_FBWeights:
        torch.save(savWeights, outputDir + '/' + 'weightSave'+str(rngSeed)+'.pt')

    return 1


In [None]:
class MLPmetric(nn.Module):
    
    def __init__(self,par):
        super().__init__()

        self.N_class=par.nClass
        self.batch_size = par.batch_size
        self.nSampPerClassPerBatch = int(par.batch_size/par.nClass) # No. input samples per class, per batch
        # Setup target vector to compute Loss and Accuracy
        self.target = torch.zeros(self.batch_size).long().to(device)    
        for j in range(1,self.N_class):
            self.target[j*self.nSampPerClassPerBatch:(j+1)*self.nSampPerClassPerBatch] = j

        self.nInputs = par.nInputs
        self.N = par.N_esn
        self.alpha = par.alpha
        self.rho = par.rho
        self.N_av = par.N_av
        self.N_i = par.nInputs
        self.gamma = par.gamma
        self.tMax = par.tMax
        self.fbLayer = par.fbLayer
        self.metricLossType = par.metricLossType
        self.margin = par.margin
        self.etaInitial = par.etaInitial
        self.etaTransfer = par.etaTransfer
        self.batch_size = par.batch_size
        if par.metricLossType=='prototypicalLoss':
            self.nP = par.nSampProto
            self.nQ = int(par.batch_size / par.nClass)
        self.tri = 1 if par.metricLossType=='tripletLoss' else 0
        self.wPerf = torch.exp(-torch.arange(self.tMax).flip(0)/par.tauPerf)

        dilution = 1-self.N_av/self.N
        W = np.random.uniform(-1, 1, [self.N, self.N])
        W = W*(np.random.uniform(0, 1, [self.N, self.N]) > dilution)
        eig = np.linalg.eigvals(W)
        self.W = torch.from_numpy(
            self.rho*W/(np.max(np.absolute(eig)))).float().to(device)

        self.x = []

        if self.N_i == 1:

            self.W_in = 2*np.random.randint(0, 2, [self.N_i, self.N])-1
            self.W_in = torch.from_numpy(self.W_in*self.gamma).float().to(device)

        else:

            self.W_in = np.random.randn(self.N_i, self.N)
            self.W_in = torch.from_numpy(self.gamma*self.W_in).float().to(device)

        self.Ws=[]
        self.bs=[]
        
        self.Ns=par.Ns
        
        for n in range(1,np.shape(self.Ns)[0]):
        
            self.Ws.append(nn.Parameter((torch.randn([self.Ns[n-1],self.Ns[n]])/torch.sqrt(torch.tensor(self.Ns[n-1]+self.Ns[n]))).to(device)))
            self.bs.append(nn.Parameter(torch.zeros([self.Ns[n]]).to(device)))
        
        if par.fbLayer:
            self.W_fb = nn.Parameter((torch.randn([self.Ns[self.fbLayer],self.N])/10**4).to(device))

    def metricOptimiser(self):

        if self.fbLayer:
            self.metOpt=optim.Adam([{ 'params': self.Ws+self.bs+[self.W_fb], 'lr':self.etaInitial }])
        else:
            self.metOpt=optim.Adam([{ 'params': self.Ws+self.bs, 'lr':self.etaInitial }])

    def Forward(self, input):

        ### ESN
        if self.fbLayer:
            self.x[0] = (1-self.alpha)*self.x[0]+self.alpha * \
                torch.tanh(torch.matmul(input, self.W_in)+torch.matmul(self.x[0], self.W)+torch.matmul(self.xFB, self.W_fb))
        else:
            self.x[0] = (1-self.alpha)*self.x[0]+self.alpha * \
                torch.tanh(torch.matmul(input, self.W_in)+torch.matmul(self.x[0], self.W))
        
        ### Hidden layers
        for n in range(1,len(self.Ns)): # For each layer
            # Build up training data so that outputs learn from all previous layers
            self.x[n] = torch.relu( torch.add(torch.matmul(self.x[n-1],self.Ws[n-1]),self.bs[n-1]) )
            if n==self.fbLayer:
                self.xFB = torch.clone(self.x[n])

    def Reset(self, nSamples):

        self.x = []
        for n in range(0,len(self.Ns)): # For each layer
            self.x.append(torch.zeros(nSamples, self.Ns[n], requires_grad=True).to(device))
            if n==self.fbLayer:
                self.xFB = torch.clone(self.x[n])

    def lossCrossEntropy(self, nSamples, r):
        # Compute softmax probabilities for responses
        p = torch.div( torch.exp(r), torch.sum(torch.exp(r), 1, keepdim=True).tile((1,r.shape[1])) )
        L = torch.mean(- torch.log(p[range(nSamples),self.target]))
        return L
    
    def tripletLoss(self, resp):
        # Implement hard negative mining

        dAP = (resp[0] - resp[1]).pow(2).sum(1).sqrt() # anchor-positive
        dAN = (resp[0] - resp[2]).pow(2).sum(1).sqrt() # anchor-negative
        
        dPN = (resp[1] - resp[2]).pow(2).sum(1).sqrt() # positive-negative
        ind = torch.le(dPN, dAN)
        ind1 = torch.nonzero(ind)
        ind2 = torch.nonzero(torch.logical_not(ind))
        
        L = torch.concat((torch.maximum(torch.zeros(ind2.shape).to(device), dAP[ind2] - dAN[ind2] + self.margin), 
                         torch.maximum(torch.zeros(ind1.shape).to(device), dAP[ind1] - dPN[ind1] + self.margin)), dim=0).mean()
        # if t%50==0:
        #     print(f'DP = {dAP.mean()}                            DN = {dAN.mean()}                           maxRESP = {resp[0].max()}')

        return L.mean()
    
    def prototypicalLoss(self, r):
        proto = torch.zeros(self.N_class, self.Ns[-1]).to(device) # init mem for prototypes
        dist = torch.zeros(self.N_class*self.nQ, self.N_class).to(device) # init mem for distances
        prob = torch.zeros(self.N_class*self.nQ).to(device) # Init mem for probabilities
        # Compute prototypes
        for j in range(self.N_class):
            proto[j,:] = torch.clone(r[j*self.nP:(j+1)*self.nP,:]).mean(0)
        for j in range(self.N_class): # for each class
            for k in range(self.nQ): # for each query in class j
                # Compute distances between queries and prototypes
                fx = torch.clone(r[self.nP*self.N_class + j*self.nQ + k,:]).unsqueeze(0).tile(self.N_class,1)
                dist[j*self.nQ+k,:] = (fx - proto).pow(2).sum(1).sqrt()
            # Compute probabilities using softmax
            prob[j*self.nQ:(j+1)*self.nQ] = torch.divide(torch.exp(-dist[j*self.nQ:(j+1)*self.nQ,j]), 
                                               torch.sum(torch.exp(-dist[j*self.nQ:(j+1)*self.nQ,:]), 1))
        
        L = - torch.log(prob)

        return L.mean(), proto
    
    def accuracyTarget(self, r):
        acc = torch.mean(torch.eq( torch.argmax(r,1), self.target ).type(torch.float))
        return acc
    
    def accuracyClassCentroid(self, r, prototypes=[]):
        nSamples = r.shape[0]
        # Compute class centroids
        if len(prototypes)>0:
            centroids = prototypes
        else:
            centroids = torch.zeros(self.N_class, r.shape[1]).to(device)
            for j in range(self.N_class):
                centroids[j,:] = r[j*self.nSampPerClassPerBatch:(j+1)*self.nSampPerClassPerBatch,:].mean(0)
        # Compute distances between samples and centroids. Is argmin(dist)==true class?
        o = torch.ones(self.N_class,1)
        Acc = 0.0
        
        for j in range(nSamples):     
            rr = torch.tile(torch.clone(r[j,:]), [self.N_class, 1])
            dist = (rr - centroids).pow(2).sum(1)
            arg = torch.argmin(dist)       
            true_class = torch.floor(torch.tensor(j/self.nSampPerClassPerBatch)).long()
            Acc += torch.eq(arg, true_class).float()
        Acc /= nSamples
        
        return Acc

    def getLossAccuracy(self, learning, r, backwardFlag=False, opt=[]):
        # Compute Loss and accuracy
        if learning=='metric':
            if self.metricLossType=='tripletLoss':
                L = self.tripletLoss(r)
                accTa = self.accuracyTarget(r[0])
                accClCe = self.accuracyClassCentroid(r[0])
            elif self.metricLossType=='prototypicalLoss':
                L, proto = self.prototypicalLoss(r[0])
                accTa = self.accuracyTarget(r[0][self.N_class*self.nP:,:])
                accClCe = self.accuracyClassCentroid(r[0][self.N_class*self.nP:,:], proto)
                
        elif learning=='transfer':
            L = self.lossCrossEntropy(r.shape[0], r)
            accTa = self.accuracyTarget(r)
            accClCe = self.accuracyClassCentroid(r)

        if backwardFlag:          
            L.backward()
            opt.step()
            opt.zero_grad()

        return L,accTa,accClCe
      
    def response(self, Input, tripletFlag=False):

        if tripletFlag:
            N_samples = int(Input.shape[0] / 3)
        else:
            N_samples = Input.shape[0]
        T = self.tMax

        # Forward pass for anchor, positive, and negative
        if tripletFlag:
            self.Reset(N_samples * 3)
        else:
            self.Reset(N_samples)
        for t in range(T):
            self.Forward(Input[:,:,t])

        # Compile list of responses from lossLayer. When using Triplet Loss, 
        # list is 3x1 for [anchor, positive, negative]
        r = [] 
        for j in range(1+2*int(tripletFlag)):
            r.append(torch.clone(self.x[-1][j*N_samples:(j+1)*N_samples,:]))
        
        return r
    
    def generateBatch(self, method, X):
       
        ### For normal batches, e.g. for transfer learning
        if method=='simple':
            batch = torch.zeros(self.batch_size, X[0].shape[1], X[0].shape[2], dtype=torch.float32, requires_grad=False).to(device)
            self.target = torch.zeros(self.batch_size, self.nClass, dtype=torch.float32).to(device)
            for k in range(self.nClass):
                rand_ind=np.random.randint(0,X[k].shape[0],(self.nSampPerClassPerBatch,))
                batch[k*self.nSampPerClassPerBatch:(k+1)*self.nSampPerClassPerBatch,:,:] = torch.clone(X[k][rand_ind,:,:])
                self.target[k*self.nSampPerClassPerBatch:(k+1)*self.nSampPerClassPerBatch, k] = 1
        
        ### For triplet loss
        if method=='tripletLoss':
            
            # Populate batch
            batch = torch.zeros([3*self.batch_size, self.nInputs,self.tMax]).to(device)
            for k in range(self.N_class):
                # Random indeces to select samples (2 for anchor and positive, then 1 for negative)
                ind_ap = np.random.choice(X[k].shape[0],(self.nSampPerClassPerBatch,2), replace=False)
                # Populate Anchor and Positive samples
                batch[k*self.nSampPerClassPerBatch:(k+1)*self.nSampPerClassPerBatch,:] = torch.clone(X[k][ind_ap[:,0],:])
                batch[self.batch_size+k*self.nSampPerClassPerBatch:self.batch_size+(k+1)*self.nSampPerClassPerBatch,:] = torch.clone(X[k][ind_ap[:,1],:])
                # Populate negative samples
                randClass = np.random.choice((np.arange(self.N_class)!=k).nonzero()[0], self.nSampPerClassPerBatch)
                for m, cl in enumerate(randClass):
                    batch[2*self.batch_size + k*self.nSampPerClassPerBatch+m,:] = torch.clone(X[cl][np.random.randint(X[cl].shape[0]),:])
            
        ### For Protoypical Loss
        if method=='prototypicalLoss':
            batch = torch.zeros((self.nP+self.nQ)*self.N_class, X[0].shape[1], X[0].shape[2]).to(device)
            for j, x in enumerate(X): # For each class (X is a list)
                ind = np.random.choice(x.shape[0], self.nP+self.nQ, replace=False)
                batch[j*self.nP:(j+1)*self.nP,:] = torch.clone(x[ind[:self.nP],:])
                batch[self.nP*self.N_class+j*self.nQ:self.nP*self.N_class+(j+1)*self.nQ] = torch.clone(x[ind[self.nP:],:])

        return batch
    
    def responseSave(self, Input,saveLayers=[]):

        N_samples = Input.shape[0]
        T = self.tMax
        
        sav = []
        for l in saveLayers:
            sav.append(torch.zeros(N_samples, self.Ns[l], T))

        self.Reset(N_samples)
        for t in range(T):
            self.Forward(Input[:,:,t])
            
            for li, l in enumerate(saveLayers):
                sav[li][:,:,t] = torch.clone(self.x[l].detach())
    
        return sav

In [None]:
# swLR = 0.00001#0.00046 # Sweep over these learning rates
# rngSeed = 11117 # No. runs per hyperparameter

# # Update hyperparameter
# importlib.reload(par)
# par.eta = swLR
# par.fbLayer = 2
# par.saveFlag_FBWeights = True
# par.nEpisodes = 50
# par.nSaveMaxT = par.nEpisodes
# expName = 'test'
###
### WITH metric learning
###
def metric_esn_fb(expName, rngSeed):

    ### Initialise RNGs
    torch.manual_seed(rngSeed)
    np.random.default_rng(rngSeed)

    ### Setup directory names
    experiment = expName
    expDir = directory+'/data/'+experiment
    if not os.path.exists(expDir):
        os.mkdir(expDir)
    outputDir = expDir    # Storage directory for input/label data
    if not os.path.exists(outputDir):
        os.mkdir(outputDir)

    ### Other parameters
    save_every = int(np.floor(par.nSaveMaxT / par.nWeightSave)) # Save data every <> epochs, up to epoch nSaveMaxT
    tMax = X_tr[0].shape[2] # No. time steps per input sequence
    
    ###############################
    #### First phase of training
    ###############################
    
    # Init memory to save responses
    resp_saveind = 0
    if par.saveFlag_RESP:
        RESP = []
        for layer in par.saveLayers:
            RESP.append(np.zeros((len(par.saveRespAtN)+1, par.nClass * par.nSaveSamples, par.Ns[layer], tMax))) 
    # Init memory to save feedback weights
    if par.saveFlag_FBWeights:
        print('Saving weights')
        savWeights = np.zeros((par.Ns[par.fbLayer], par.Ns[0], par.nWeightSave+1))
    # Init memory to save effective learning rate
    dw_saveind = 0
    dwEdges = torch.logspace(-10,-3,51)
    savDW = [] # List of arrays to store DW histograms over time
    for k in range(len(par.Ns)-1):
        savDW.append(np.zeros((dwEdges.shape[0]-1, par.nWeightSave))) # Histograms for feedforward layers
    if par.fbLayer:
        savDW.append(np.zeros((dwEdges.shape[0]-1, par.nWeightSave))) # Histogram for feedback layer

    ### Initialise model
    MOD = MLPmetric(par)
    MOD.metricOptimiser()

    ### Init memory for saving data
    L_tr = []; L_tr.append(np.zeros([par.nEpisodes]))
    A_trTa = []; A_trTa.append(np.zeros([par.nEpisodes]))
    A_trClCe = []; A_trClCe.append(np.zeros([par.nEpisodes]))
    L_val = []; L_val.append(np.zeros([par.nEpisodes]))
    A_valTa = []; A_valTa.append(np.zeros([par.nEpisodes]))
    A_valClCe = []; A_valClCe.append(np.zeros([par.nEpisodes]))
    
    print('**********************START TRAINING')
    t=time.time()

    for n in range(par.nEpisodes):
        
        ### Save met responses before updates
        if par.saveFlag_RESP and (np.any(np.isin(n, par.saveRespAtN)) or n==(par.nEpisodes-1)):
            print(f'Saving response on iteration {n} of {par.nEpisodes}')
            with torch.no_grad():
                s = X_tr[0][0:par.nSaveSamples,:]
                for k in range(1,par.nClass):
                    s = torch.concat([s, X_tr[k][0:par.nSaveSamples,:]],0)
                resp = MOD.responseSave(s, par.saveLayers)
                for li, l in enumerate(par.saveLayers):
                    RESP[li][resp_saveind,:,:,:] = resp[li].numpy()
            resp_saveind += 1

        ### Save feedback weights
        if par.saveFlag_FBWeights and (n%save_every)<1:
            savWeights[:,:,int(np.floor(n/save_every))] = MOD.W_fb.data.cpu()

        ### Training data
        # Prepare batch
        Im = MOD.generateBatch(par.metricLossType, X_tr)

        # Get response to input
        if MOD.metricLossType=='tripletLoss':
            r = MOD.response(Im, tripletFlag=True)
        else:
            r = MOD.response(Im)

        # Compute loss and accuracy
        loss, accTrTa, accTrClCe = MOD.getLossAccuracy('metric',r,backwardFlag=True, opt=MOD.metOpt)

        ####
        #### CHECKING FOR NANS
        if torch.any(torch.isnan(MOD.x[-1])):
            print(f'Quitting at batch {n} of {par.nEpisodes}')
            return -1
        ####
        ####

        # Store training loss and accuracy
        L_tr[0][n]=np.copy(np.array(loss.to('cpu').detach()))
        A_trTa[0][n]=np.copy(np.array(accTrTa.to('cpu').detach()))
        A_trClCe[0][n]=np.copy(np.array(accTrClCe.to('cpu').detach()))

        # Store weight changes for effective learning rate
        with torch.no_grad():
            if par.saveFlag_DW and (n%save_every)<1:
                # Save before weights
                w1 = []
                for k in range(len(MOD.Ns)-1):
                    w1.append(torch.clone(MOD.Ws[k].data))
                if par.fbLayer:
                    w1.append(torch.clone(MOD.W_fb.data))
            if par.saveFlag_DW and ((n-1)%save_every)<1:
                # Save after weights
                w2 = []
                for k in range(len(MOD.Ns)-1):
                    w2.append(torch.clone(MOD.Ws[k].data))
                if par.fbLayer:
                    w2.append(torch.clone(MOD.W_fb.data))
                # Compute weight change
                for k in range(len(w1)):
                    dw = torch.abs(torch.subtract(w2[k],w1[k])).cpu()
                    # savDW[k][:,dw_saveind] = torch.divide(torch.histogram(dw[dw>0], dwEdges).hist, torch.numel(dw)).cpu().numpy()
                    savDW[k][:,dw_saveind] = torch.histogram(dw[dw>0], dwEdges).hist.numpy()
                dw_saveind += 1

        ### Validation data
        with torch.no_grad():
            Im = MOD.generateBatch(par.metricLossType, X_val)

            # Get response to input
            if MOD.metricLossType=='tripletLoss':
                r = MOD.response(Im, tripletFlag=True)
            else:
                r = MOD.response(Im)
                
            # Compute loss and accuracy
            loss, accValTa, accValClCe = MOD.getLossAccuracy('metric',r)

            # Store validation loss and accuracy
            L_val[0][n]=np.copy(np.array(loss.to('cpu').detach()))
            A_valTa[0][n]=np.copy(np.array(accValTa.to('cpu').detach()))
            A_valClCe[0][n]=np.copy(np.array(accValClCe.to('cpu').detach()))

        # Update learning rate
        MOD.metOpt.param_groups[0]['lr'] = par.etaInitial * np.exp(-float(n)/par.eta_tau)

        if ((n+1)%par.reportTime==0) and (n>0):
            print(f'Time per stage: {time.time()-t}')
            mseTr_mean=np.mean(L_tr[0][(n+1)-par.reportTime:(n+1)])
            accTrTa_mean=np.mean(A_trTa[0][(n+1)-par.reportTime:(n+1)])
            accTrClCe_mean=np.mean(A_trClCe[0][(n+1)-par.reportTime:(n+1)])
            mseVal_mean=np.mean(L_val[0][(n+1)-par.reportTime:(n+1)])
            accValTa_mean=np.mean(A_valTa[0][(n+1)-par.reportTime:(n+1)])
            accValClCe_mean=np.mean(A_valClCe[0][(n+1)-par.reportTime:(n+1)])
            t=time.time()
            
            print(f'Progress: {np.float32(n+1)/np.float32(par.nEpisodes)*100.:.3}%   Tr-Loss: {mseTr_mean}, Val-Loss: {mseVal_mean};  Tr-AccTa: {accTrTa_mean}, Val-AccTa: {accValTa_mean}')
            print(f'          Tr-AccClCe: {accTrClCe_mean}, Val-AccClCe: {accValClCe_mean}')

    ### Save hidden layer representations for transfer learning
    H_tr = []
    H_val = []
    with torch.no_grad():
        for j in range(par.nClass):
            MOD.response(X_tr[j])
            H_tr.append(MOD.x[-1])
            MOD.response(X_val[j])
            H_val.append(MOD.x[-1])

    ### Save outputs
    torch.save(L_tr, outputDir + '/' + 'lossTr'+str(rngSeed)+'.pt')
    torch.save(A_trTa, outputDir + '/' + 'accTrTa'+str(rngSeed)+'.pt')
    torch.save(A_trClCe, outputDir + '/' + 'accTrClCe'+str(rngSeed)+'.pt')
    torch.save(L_val, outputDir + '/' + 'lossVal'+str(rngSeed)+'.pt')
    torch.save(A_valTa, outputDir + '/' + 'accValTa'+str(rngSeed)+'.pt')
    torch.save(A_valClCe, outputDir + '/' + 'accValClCe'+str(rngSeed)+'.pt')
    torch.save(RESP, outputDir + '/' + 'respSave'+str(rngSeed)+'.pt')
    torch.save(MOD, outputDir + '/' + 'model'+str(rngSeed)+'.pt')
    torch.save(savDW, outputDir + '/' + 'dw'+str(rngSeed)+'.pt')
    torch.save([H_tr, H_val], outputDir + '/' + 'hResp'+str(rngSeed)+'.pt')
    if par.saveFlag_FBWeights:
        torch.save(savWeights, outputDir + '/' + 'weightSave'+str(rngSeed)+'.pt')

    return 1


In [None]:
class MLPtransfer(nn.Module):
    
    def __init__(self,par):
        super().__init__()

        self.N_class=par.nClass
        self.batch_size = par.batch_size
        self.nSampPerClassPerBatch = int(par.batch_size/par.nClass) # No. input samples per class, per batch
        self.lossfn = nn.BCEWithLogitsLoss()
        # Setup target vector to compute Loss and Accuracy
        self.target = torch.zeros(self.batch_size).long().to(device)
        for j in range(1,self.N_class):
            self.target[j*self.nSampPerClassPerBatch:(j+1)*self.nSampPerClassPerBatch] = j

        self.N_i = par.nInputs
        self.lossLayer = len(par.Ns)-1
        self.etaInitial = par.etaInitial
        self.etaTransfer = par.etaTransfer
        
        self.Ns = par.Ns

        self.x = []

        self.Ws=[]
        self.bs=[]
        
        
        for n in range(1,np.shape(self.Ns)[0]):
        
            self.Ws.append(nn.Parameter((torch.randn([self.Ns[n-1],self.Ns[n]])/torch.sqrt(torch.tensor(self.Ns[n-1]+self.Ns[n]))).to(device)))
            self.bs.append(nn.Parameter(torch.zeros([self.Ns[n]]).to(device)))
        
    def transferOptimiser(self):
        
        self.tranOpt=optim.Adam([{ 'params': self.Ws+self.bs, 'lr':self.etaTransfer }])
        
    def Forward(self, input):
        
        self.x[0] = torch.clone(input)
        ### Hidden layers
        for n in range(1,len(self.Ns)-1): # For each layer
            # Build up training data so that outputs learn from all previous layers
            self.x[n] = torch.relu( torch.add(torch.matmul(self.x[n-1],self.Ws[n-1]),self.bs[n-1]) )
        
        ### Output
        self.x[-1] = torch.add(torch.matmul(self.x[-2],self.Ws[-1]),self.bs[-1])

    def Reset(self, nSamples):

        self.x = []
        for n in range(0,len(self.Ns)): # For each layer
            self.x.append(torch.zeros(nSamples, self.Ns[n], requires_grad=True).to(device))
    
    def lossCrossEntropy(self, nSamples):
        # Compute softmax probabilities for responses
        p = torch.div( torch.exp(self.x[-1]), torch.sum(torch.exp(self.x[-1]), 1, keepdim=True).tile((1,self.x[-1].shape[1])) )
        L = torch.mean(- torch.log(p[range(nSamples),self.target]))
        return L

    def accuracyTarget(self):
        acc = torch.mean(torch.eq( torch.argmax(self.x[-1],1), self.target ).type(torch.float))
        return acc
    
    def accuracyClassCentroid(self):
        nSamples = self.x[-1].shape[0]
        # Compute class centroids
        centroids = torch.zeros(self.N_class, self.Ns[self.lossLayer]).to(device)
        for j in range(self.N_class):
            centroids[j,:] = self.x[-1][j*self.nSampPerClassPerBatch:(j+1)*self.nSampPerClassPerBatch,:].mean(0)
        # Compute distances between samples and centroids. Is argmin(dist)==true class?
        o = torch.ones(self.N_class,1)
        Acc = 0.0
        
        for j in range(nSamples):     
            rr = torch.tile(torch.clone(self.x[-1][j,:]), [self.N_class, 1])
            dist = (rr - centroids).pow(2).sum(1)
            arg = torch.argmin(dist)       
            true_class = torch.floor(torch.tensor(j/self.nSampPerClassPerBatch)).long()
            Acc += torch.eq(arg, true_class).float()
        Acc /= nSamples
        
        return Acc
    
    def getLossAccuracy(self, backwardFlag=False, opt=[]):
        # Compute Loss and accuracy
        L = self.lossCrossEntropy(self.x[-1].shape[0])
        accTa = self.accuracyTarget()
        accClCe = self.accuracyClassCentroid()

        if backwardFlag:
            L.backward()
            opt.step()
            opt.zero_grad()
            
        return L,accTa,accClCe
    
    def response(self, Input):

        N_samples = Input.shape[0]
        T = Input.shape[2]

        self.Reset(N_samples)
        
        for t in range(T):
            self.Forward(Input[:, :, t])

    # def generateBatch(self, X):
       
    #     batch = torch.zeros(self.batch_size, X[0].shape[1], X[0].shape[2], dtype=torch.float32, requires_grad=False).to(device)
    #     self.target = torch.zeros(self.batch_size, self.nClass, dtype=torch.float32).to(device)
    #     # for j in range(self.batch_size):
    #     #     ind1 = torch.randint(0,self.nClass,[1])
    #     #     ind2 = torch.randint(0,X[ind1].shape[0],[1])
    #     #     batch[j,:,:] = torch.clone(X[ind1][ind2,:,:])
    #     #     self.target[j,ind1] = 1.

    #     for k in range(self.nClass):
    #         rand_ind=np.random.randint(0,X[k].shape[0],(self.nSampPerClassPerBatch,))
    #         batch[k*self.nSampPerClassPerBatch:(k+1)*self.nSampPerClassPerBatch,:,:] = torch.clone(X[k][rand_ind,:,:])
    #         self.target[k*self.nSampPerClassPerBatch:(k+1)*self.nSampPerClassPerBatch, k] = 1

    #     return batch
    
    def responseSave(self, Input,saveLayers=[]):

        N_samples = Input.shape[0]
        T = Input.shape[2]
        
        sav = []
        for l in saveLayers:
            sav.append(torch.zeros(N_samples, self.Ns[l], T))

        self.Reset(N_samples)
        for t in range(T):
            self.Forward(Input[:, :, t], t)
            
            for li, l in enumerate(saveLayers):
                sav[li][:,:,t] = torch.clone(self.x[l].detach())
    
        return sav

In [None]:
###############################
#### Transfer Learning
###############################
def tran_esn_fb(expName, rngSeed):
    
    ### Initiliase RNGs
    torch.manual_seed(rngSeed)
    np.random.default_rng(rngSeed)

    ### Setup directory names
    experiment = expName
    expDir = directory+'/data/'+experiment # Storage directory for inputs
    outputDir = directory+'/data/'+'tran_'+experiment # Storage directory for outputs
    if not os.path.exists(expDir):
        os.mkdir(expDir) 
    if not os.path.exists(outputDir):
        os.mkdir(outputDir)

    ### Load inputs
    print(f'Loading inputs from: {expDir}/hResp{str(rngSeed)}.pt')
    inputs = torch.load(expDir + '/' + 'hResp'+str(rngSeed)+'.pt')
    X_tr = inputs[0]
    X_val = inputs[1]
    print(f'Shape of X_tr is {X_tr[0].shape}')

    ### Other parameters
    save_every = int(np.floor(par.nSaveMaxT / par.nWeightSave)) # Save weights every <> epochs, up to epoch nSaveMaxT
    if len(X_tr[0].shape) > 2:
        tMax = X_tr[0].shape[2]
    else:
        tMax = 0

    ### Initialise model
    MOD = MLPtransfer(par)

    ### Init lists for saving data
    outL_tr = np.zeros([par.nEpisodesTran])
    outA_trTa = np.zeros([par.nEpisodesTran])
    outA_trClCe = np.zeros([par.nEpisodesTran])
    outL_val = np.zeros([par.nEpisodesTran])
    outA_valTa = np.zeros([par.nEpisodesTran])
    outA_valClCe = np.zeros([par.nEpisodesTran])
    outsavDW = []

    print(f'**********************START TRANSFER TRAINING')
    
    # Init memory to save responses
    resp_saveind = 0
    if par.saveFlag_RESP:
        outRESP = []
        for layer in range(len(par.Ns)):
            if len(X_tr[0].shape) > 2:
                outRESP.append(np.zeros((len(par.saveRespAtN)+1, par.nClass * par.nSaveSamples, par.Ns[layer], tMax))) 
            else:
                outRESP.append(np.zeros((len(par.saveRespAtN)+1, par.nClass * par.nSaveSamples, par.Ns[layer]))) 

    # Init memory to save dw histograms
    dw_saveind = 0
    dwEdges = torch.logspace(-10,-3,51)
    for k in range(len(par.Ns)-1):
        outsavDW.append(np.zeros((dwEdges.shape[0]-1, par.nWeightSave))) # Histograms for feedforward layers
    if par.fbLayer:
        outsavDW.append(np.zeros((dwEdges.shape[0]-1, par.nWeightSave))) # Histogram for feedback layer

    ### Initialise optimiser
    MOD.transferOptimiser()
    
    t=time.time()

    for n in range(par.nEpisodesTran):
        # ### Save output responses before updates
        # if par.saveFlag_RESP and (np.any(np.isin(n, par.saveRespAtN)) or n==(par.nEpisodesTran-1)):
        #     print(f'Saving response on iteration {n} of {par.nEpisodesTran}')
        #     with torch.no_grad():
        #         s = X_tr[0][0:par.nSaveSamples,:]
        #         for k in range(1,par.nClass):
        #             s = torch.concat([s, X_tr[k][0:par.nSaveSamples,:]],0)
        #         resp = MOD.responseSave(s, range(len(par.Ns)))
        #         for li, l in enumerate(range(len(par.Ns))):
        #             outRESP[li][resp_saveind,:,:,:] = resp[li].numpy()
        #     resp_saveind += 1

        ### Training data
        # Prepare batch
        rand_ind=np.random.randint(0,X_tr[0].size()[0],(MOD.nSampPerClassPerBatch,))
        Im=X_tr[0][rand_ind,:]
        for k in range(1,par.nClass):
            rand_ind=np.random.randint(0,X_tr[k].size()[0],(MOD.nSampPerClassPerBatch,))
            Im=torch.concat([Im,X_tr[k][rand_ind,:]],0)

        # Get response to input
        MOD.Reset(Im.shape[0])
        with torch.no_grad():
            MOD.x[0] = torch.clone(Im)
        try:
            MOD.Forward(Im)
        except:
            print(f'Size of layer 0 activity: {MOD.x[0].shape}')
            for k in range(len(par.Ns)-1):
                print(f'Size of layer {k} activity: {MOD.x[k+1].shape}')
                print(f'Size of layer {k} weights: {MOD.Ws[k].shape}')
        
        # Compute training loss and accuracy
        loss, accTrTa, accTrClCe = MOD.getLossAccuracy(backwardFlag=True, opt=MOD.tranOpt)
        
        ####
        #### CHECKING FOR NANS
        if torch.any(torch.isnan(MOD.x[-1])):
            print(f'Quitting at batch {n} of {par.nEpisodesTran}')
            return -1
        ####
        ####

        # Store loss and accuracy
        outL_tr[n]=np.copy(np.array(loss.to('cpu').detach()))
        outA_trTa[n]=np.copy(np.array(accTrTa.to('cpu').detach()))
        outA_trClCe[n]=np.copy(np.array(accTrClCe.to('cpu').detach()))
        
        # Store weight changes for effective learning rate
        with torch.no_grad():
            if par.saveFlag_DW and (n%save_every)<1:
                # Save before weights
                w1 = []
                for k in range(len(par.Ns)-1):
                    w1.append(torch.clone(MOD.Ws[k].data))
                if par.fbLayer:
                    w1.append(torch.clone(MOD.W_fb.data))
            if par.saveFlag_DW and ((n-1)%save_every)<1:
                # Save after weights
                w2 = []
                for k in range(len(par.Ns)-1):
                    w2.append(torch.clone(MOD.Ws[k].data))
                if par.fbLayer:
                    w2.append(torch.clone(MOD.W_fb.data))
                # Compute weight change
                for k in range(len(w1)):
                    dw = torch.abs(torch.subtract(w2[k],w1[k])).cpu()
                    # savDW[k][:,dw_saveind] = torch.divide(torch.histogram(dw[dw>0], dwEdges).hist, torch.numel(dw)).cpu().numpy()
                    outsavDW[k][:,dw_saveind] = torch.histogram(dw[dw>0], dwEdges).hist.numpy()
                dw_saveind += 1
            
        ### Validation data
        with torch.no_grad():
            # Prepare batch
            rand_ind=np.random.randint(0,X_val[0].size()[0],(int(par.batch_size/par.nClass),))
            Im=X_val[0][rand_ind,:]
            Y=Y_val[0][rand_ind,:]
            for k in range(1,par.nClass):
                rand_ind=np.random.randint(0,X_val[k].size()[0],(int(par.batch_size/par.nClass),))
                Im=torch.concat([Im,X_val[k][rand_ind,:]],0)
            
            # Get response to input
            MOD.Reset(Im.shape[0])
            MOD.x[0] = Im
            MOD.Forward(Im)

            # Compute validation loss and accuracy
            loss,accValTa,accValClCe = MOD.getLossAccuracy()

            # Store loss and accuracy
            outL_val[n] = np.copy(np.array(loss.to('cpu').detach()))
            outA_valTa[n] = np.copy(np.array(accValTa.to('cpu').detach()))
            outA_valClCe[n] = np.copy(np.array(accValClCe.to('cpu').detach()))

        # Update learning rate
        MOD.tranOpt.param_groups[0]['lr'] = par.etaTransfer * np.exp(-float(n)/par.eta_tau)

        if ((n+1)%par.reportTime==0) and (n>0):
            print(f'Time per stage: {time.time()-t}')
            mseTr_mean=np.mean(outL_tr[(n+1)-par.reportTime:(n+1)])
            accTrTa_mean=np.mean(outA_trTa[(n+1)-par.reportTime:(n+1)])
            accTrClCe_mean=np.mean(outA_trClCe[(n+1)-par.reportTime:(n+1)])
            mseVal_mean=np.mean(outL_val[(n+1)-par.reportTime:(n+1)])
            accValTa_mean=np.mean(outA_valTa[(n+1)-par.reportTime:(n+1)])
            accValClCe_mean=np.mean(outA_valClCe[(n+1)-par.reportTime:(n+1)])
            t=time.time()
            
            print(f'Progress: {np.float32(n+1)/np.float32(par.nEpisodesTran)*100.:.3}%   Mean Tr Er: {mseTr_mean}, Mean Val Er: {mseVal_mean};   Mean Tr AccTa: {accTrTa_mean}, Mean Val AccTa: {accValTa_mean}')
            print(f'Progress:                                                                               Mean Tr AccClCe: {accTrClCe_mean}, Mean Val AccClCe: {accValClCe_mean}')
    
    wTran = [torch.clone(MOD.Ws[-1]).detach().cpu().numpy(), torch.clone(MOD.bs[-1]).detach().cpu().numpy()]

    torch.save(outL_tr, outputDir + '/' + 'lossTr'+str(rngSeed)+'.pt')
    torch.save(outA_trTa, outputDir + '/' + 'accTrTa'+str(rngSeed)+'.pt')
    torch.save(outA_trClCe, outputDir + '/' + 'accTrClCe'+str(rngSeed)+'.pt')
    torch.save(outL_val, outputDir + '/' + 'lossVal'+str(rngSeed)+'.pt')
    torch.save(outA_valTa, outputDir + '/' + 'accValTa'+str(rngSeed)+'.pt')
    torch.save(outA_valClCe, outputDir + '/' + 'accValClCe'+str(rngSeed)+'.pt')
    torch.save(outRESP, outputDir + '/' + 'respSave'+str(rngSeed)+'.pt')
    torch.save(outsavDW, outputDir + '/' + 'dw'+str(rngSeed)+'.pt')
    torch.save(wTran, outputDir + '/' + 'wTran'+str(rngSeed)+'.pt')

    return 1

# Test run

In [None]:
swLR = 0.001#0.00046 # Sweep over these learning rates
seed = 11117 # No. runs per hyperparameter
### Initial training
# Update hyperparameter
importlib.reload(par)
par.etaInitial = swLR
par.etaTransfer = 0.001
par.fbLayer = 2
par.saveRepLayer = 2
par.saveFlag_FBWeights = False if not par.fbLayer else True # Save feedback weights
par.nEpisodes = 5000
par.reportTime = 250
# par.Ns = [par.N_esn, 100, par.nClass] # No. neurons in each layer - FOR XENT
par.Ns = [par.N_esn, 100, 100] # No. neurons in each layer - FOR METRIC
par.nSaveMaxT = par.nEpisodes 
par.metricLossType = 'prototypicalLoss'
expName = 'test'
# xent_esn_fb(expName, seed)
metric_esn_fb(expName, seed)

In [None]:
### Test Transfer learning
par.nEpisodesTran = 1000
par.etaTransfer = 0.001
par.Ns = [100,par.nClass]
par.nInputs = 100
par.fbLayer = []
tran_esn_fb(expName, seed)

# Run parameter sweep (CrossEntropy)

In [None]:
### Define parameters
swLR = np.logspace(-15/3,-7/3,5) # Sweep over these learning rates
nSeeds = 5 # No. runs per hyperparameter
seeds = list(sp.primerange(11111,33333))[0:nSeeds] # RNG seeds
modName = 'xenFB1'
experiment = 'swLR'
expNameSuffix = modName+'_'+experiment+'_'

complete = torch.zeros((len(swLR), len(seeds)))
for j, lr in enumerate(swLR):
    for k, sd in enumerate(seeds):
        t = time.time()
        # Update hyperparameter
        importlib.reload(par)
        par.saveFlag_FBWeights = True
        par.nSaveMaxT = par.nEpisodes 
        par.etaInitial = swLR[j]
        par.fbLayer = 1
        par.saveFlag_FBWeights = False if not par.fbLayer else True 
        expName = expNameSuffix+str(j)
        complete[j, k] = xent_esn_fb(expName, sd)
        print(f'j={j}/{len(swLR)};     k={k}/{len(seeds)}')
        print(f'*************************FB1 Run time: {time.time()-t}')
completeName = directory+'/data/'+expNameSuffix+'0/complete.pt'
torch.save(complete, completeName)

print('*******************************FINISHED FB1********************************')

modName = 'xenFB0'
expNameSuffix = modName+'_'+experiment+'_'
complete = torch.zeros((len(swLR), len(seeds)))
for j, lr in enumerate(swLR):
    for k, sd in enumerate(seeds):
        t = time.time()
        # Update hyperparameter
        importlib.reload(par)
        par.saveFlag_FBWeights = True
        par.nSaveMaxT = par.nEpisodes 
        par.etaInitial = swLR[j]
        par.fbLayer = []
        par.saveFlag_FBWeights = False if not par.fbLayer else True 
        expName = expNameSuffix+str(j)
        complete[j, k] = xent_esn_fb(expName, sd)
        print(f'j={j}/{len(swLR)};     k={k}/{len(seeds)}')
        print(f'*************************FB0 Run time: {time.time()-t}')
completeName = directory+'/data/'+expNameSuffix+'0/complete.pt'
torch.save(complete, completeName)

print('*******************************FINISHED FB0********************************')

# Run parameter sweep (MetricLearning)

In [None]:
### Define parameters
swLR = np.logspace(-15/3,-7/3,5) # Sweep over these learning rates
nSeeds = 2 # No. runs per hyperparameter
seeds = list(sp.primerange(11111,33333))[0:nSeeds] # RNG seeds
experiment = 'swLR'
losstype = 'prototypicalLoss'
if losstype=='tripletLoss':
    lossSuffix = 'tri'
elif losstype=='prototypicalLoss':
    lossSuffix = 'pro'

modName = lossSuffix+'FB1'
expNameSuffix = modName+'_'+experiment+'_'
complete = torch.zeros((len(swLR), len(seeds)))
for j, lr in enumerate(swLR):
    for k, sd in enumerate(seeds):
        t = time.time()
        # Update hyperparameter
        importlib.reload(par)
        par.Ns = [par.N_esn, 100]
        par.metricLossType = 'prototypicalLoss'
        par.etaInitial = swLR[j]
        par.fbLayer = 1
        par.saveFlag_FBWeights = False if not par.fbLayer else True 
        expName = expNameSuffix+str(j)
        complete[j, k] = metric_esn_fb(expName, sd)
        print(f'j={j}/{len(swLR)};     k={k}/{len(seeds)}')
        print(f'*************************metFB1 Run time: {time.time()-t}')
completeName = directory+'/data/'+expNameSuffix+'0/complete.pt'
torch.save(complete, completeName)

print('*******************************FINISHED FB1********************************')

modName = lossSuffix+'FB0'
expNameSuffix = modName+'_'+experiment+'_'
complete = torch.zeros((len(swLR), len(seeds)))
for j, lr in enumerate(swLR):
    for k, sd in enumerate(seeds):
        t = time.time()
        # Update hyperparameter
        importlib.reload(par)
        par.Ns = [par.N_esn, 100]
        par.metricLossType = 'prototypicalLoss'
        par.etaInitial = swLR[j]
        par.fbLayer = []
        par.saveFlag_FBWeights = False if not par.fbLayer else True 
        expName = expNameSuffix+str(j)
        complete[j, k] = metric_esn_fb(expName, sd)
        print(f'j={j}/{len(swLR)};     k={k}/{len(seeds)}')
        print(f'*************************metFB0 Run time: {time.time()-t}')
completeName = directory+'/data/'+expNameSuffix+'0/complete.pt'
torch.save(complete, completeName)

print('*******************************FINISHED FB0********************************')

# Run transfer learning, using learned representations as inputs

In [None]:
### Transfer learning
swLR = np.logspace(-15/3,-7/3,5) # Sweep over these learning rates
nSeeds = 5 # No. runs per hyperparameter
seeds = list(sp.primerange(11111,33333))[0:nSeeds] # RNG seeds
experiment = 'swLR'
losstype = 'xentropy'
if losstype=='tripletLoss':
    lossSuffix = 'tri'
elif losstype=='prototypicalLoss':
    lossSuffix = 'pro'
elif losstype=='xentropy':
    lossSuffix = 'xen'

###
### Learned representations using CROSS-ENTROPY Loss with FB1
###
modName = lossSuffix+'FB1'
expNameSuffix = modName+'_'+experiment+'_'
complete = torch.zeros((len(swLR), len(seeds)))
c = torch.load('/its/home/jb739/esn_feedback/data/xenFB1_swLR_0/complete.pt')
for j, lr in enumerate(swLR):
    for k, sd in enumerate(seeds):
        if c[j,k]>0:
            t = time.time()
            # Update hyperparameters
            importlib.reload(par)
            par.Ns = [100, par.nClass]
            par.nInputs = 100
            par.fbLayer = []
            expName = expNameSuffix+str(j)
            complete[j, k] = tran_esn_fb(expName, sd)
        else:
            complete[j, k] = -1
        print(f'j={j}/{len(swLR)};     k={k}/{len(seeds)}')
        print(f'*************************tran_xenFB1 Run time: {time.time()-t}')
completeName = directory+'/data/tran_'+expNameSuffix+'0/complete.pt'
torch.save(complete, completeName)

###
### Learned representations using CROSS-ENTROPY Loss with FB0
###
modName = lossSuffix+'FB0'
expNameSuffix = modName+'_'+experiment+'_'
complete = torch.zeros((len(swLR), len(seeds)))
c = torch.load('/its/home/jb739/esn_feedback/data/xenFB0_swLR_0/complete.pt')
for j, lr in enumerate(swLR):
    for k, sd in enumerate(seeds):
        if c[j,k]>0:
            t = time.time()
            # Update hyperparameters
            importlib.reload(par)
            par.Ns = [100, par.nClass]
            par.nInputs = 100
            par.fbLayer = []
            expName = expNameSuffix+str(j)
            complete[j, k] = tran_esn_fb(expName, sd)
        else:
            complete[j, k] = -1
        print(f'j={j}/{len(swLR)};     k={k}/{len(seeds)}')
        print(f'*************************tran_xenFB0 Run time: {time.time()-t}')
completeName = directory+'/data/tran_'+expNameSuffix+'0/complete.pt'
torch.save(complete, completeName)

losstype = 'prototypicalLoss'
if losstype=='tripletLoss':
    lossSuffix = 'tri'
elif losstype=='prototypicalLoss':
    lossSuffix = 'pro'
elif losstype=='xentropy':
    lossSuffix = 'xen'

nSeeds = 2 # No. runs per hyperparameter
seeds = list(sp.primerange(11111,33333))[0:nSeeds] # RNG seeds
###
### Learned representations using PROTOTYPICAL Loss with FB1
###
modName = lossSuffix+'FB1'
expNameSuffix = modName+'_'+experiment+'_'
complete = torch.zeros((len(swLR), len(seeds)))
c = torch.load('/its/home/jb739/esn_feedback/data/proFB1_swLR_0/complete.pt')
for j, lr in enumerate(swLR):
    for k, sd in enumerate(seeds):
        if c[j,k]>0:
            t = time.time()
            # Update hyperparameters
            importlib.reload(par)
            par.Ns = [100, par.nClass]
            par.nInputs = 100
            par.fbLayer = []
            expName = expNameSuffix+str(j)
            complete[j, k] = tran_esn_fb(expName, sd)
        else:
            complete[j, k] = -1
        print(f'j={j}/{len(swLR)};     k={k}/{len(seeds)}')
        print(f'*************************tran_proFB1 Run time: {time.time()-t}')
completeName = directory+'/data/tran_'+expNameSuffix+'0/complete.pt'
torch.save(complete, completeName)

###
### Learned representations using PROTOTYPICAL Loss with FB0
###
modName = lossSuffix+'FB0'
expNameSuffix = modName+'_'+experiment+'_'
complete = torch.zeros((len(swLR), len(seeds)))
c = torch.load('/its/home/jb739/esn_feedback/data/proFB0_swLR_0/complete.pt')
for j, lr in enumerate(swLR):
    for k, sd in enumerate(seeds):
        if c[j,k]>0:
            t = time.time()
            # Update hyperparameters
            importlib.reload(par)
            par.Ns = [100, par.nClass]
            par.nInputs = 100
            par.fbLayer = []
            expName = expNameSuffix+str(j)
            complete[j, k] = tran_esn_fb(expName, sd)
        else:
            complete[j, k] = -1
        print(f'j={j}/{len(swLR)};     k={k}/{len(seeds)}')
        print(f'*************************tran_proFB0 Run time: {time.time()-t}')
completeName = directory+'/data/tran_'+expNameSuffix+'0/complete.pt'
torch.save(complete, completeName)

### Run multiple cases of transfer learning per trained network, to investigate whether performance is hindered by getting stuck in local minima.

In [None]:
### Define parameters
swLR = np.logspace(-5,-7/3,7) # LRs used for sweep above
swLR = [swLR[4]]
nSeeds = 1 # No. runs per hyperparameter

seeds = list(sp.primerange(11111,33333))[0:nSeeds] # RNG seeds

complete = torch.zeros((len(swLR), len(seeds)))
for k, sd in enumerate(seeds):
    t = time.time()
    # Update hyperparameter
    importlib.reload(par)
    par.saveFlag_FBWeights = True
    par.maxLayer = 3
    par.lossLayer = 3
    par.nSaveMaxT = par.nEpisodes 
    par.etaInitial = swLR[0]
    par.fbLayer = 2
    par.saveFlag_FBWeights = False if not par.fbLayer else True 
    par.nTransferRuns = 10
    expName = 'FB2_LocMin'
    complete = xent_esn_fb(expName, sd)
    print(f'k={k}/{len(seeds)}')
    print(f'*************************FB2_LocMin Run time: {time.time()-t}')

print('*******************************FINISHED FB2********************************')


# Plot data

In [None]:
# ### Load Triplet data
# inputDir = directory+'/data/'+experiment    # Storage directory for input/label data
# met = torch.load(inputDir + 'met_save.pt')
# l_tri = torch.load(inputDir + 'loss_triplet.pt')
# l_out = torch.load(inputDir + 'loss_out.pt')
# a_tri = torch.load(inputDir + 'acc_triplet.pt')
# a_out = torch.load(inputDir + 'acc_out.pt')
# dist = torch.load(inputDir + 'dist_triplet.pt')
# nt_dist = dist.shape[0] 
# nt_out = l_out.shape[0]
# nt_met = met.shape[0]
# N_triplet=45000
# N_out=5000
# batch_size=64
# eta_t=0.0002
# eta_o = 0.001
# eta_t_tau = 40000.0
# eta_o_tau = 4000.0
# N_class=10
# margin=2
# save_N = 100 #100 # # of saved epochs
# save_every = np.floor(N_triplet / save_N) # Save data every <> epohsChoice 3
# save_Nsamples = 100 # # of inputs from each class for which to save resonses

### Load Classic data
experiment = 'met_FB0_swLR_0'
expDir = directory+'/data/'+experiment # To read in saved data
sd_name='11113'
inputDir = expDir#+'/'+sd_name  
figDir = directory+'/figs/met/'+experiment # To export figures
print(figDir)
if not os.path.exists(figDir):
    os.mkdir(figDir)

RESP = torch.load(inputDir + '/respSave'+sd_name+'.pt')
print(f'Num layers = {len(RESP)}')
lossTr = torch.load(inputDir + '/lossTr'+sd_name+'.pt')
accTr = torch.load(inputDir + '/accTr'+sd_name+'.pt')
accVal = torch.load(inputDir + '/accVal'+sd_name+'.pt')
weights = torch.load(inputDir + '/weightSave'+sd_name+'.pt')
nt = RESP[0].shape[-1]
kernel = np.ones(50)/50

### Set figure properties
import matplotlib
matplotlib.rcParams['savefig.dpi'] = 300
matplotlib.rcParams.update({'font.size': 6})
matplotlib.rcParams['svg.fonttype'] = 'none'
matplotlib.rcParams['savefig.format'] = 'svg'
matplotlib.rcParams['font.family'] = 'sans-serif'

In [None]:
weights.shape

In [None]:
# saveflag = True
saveflag = False

# fig = pl.figure(figsize=tuple(np.array((6.,4.))/2.54)); ax = pl.axes()
# ax.spines[['top','right']].set_visible(False)
# layer = 2
# c=0
# times=[0, 4]

# ### Plot responses
# # t=0
# # for i in range(c*par.nClass,(c+1)*par.nClass):
# #             pl.plot(RESP[layer][t,i,::10,:].transpose(), linewidth=0.5)
# for c in [0, 1, 9]:
#     for j, t in enumerate(times):
#         for i in range(c*par.nClass,(c+1)*par.nClass):
#             pl.plot(RESP[layer][t,i,:,:].transpose(), linewidth=0.5)
#         if j==0:
#             ax.xaxis.set_ticks((0,nt)); ax.yaxis.set_ticks((0,0.2))
#         elif j==1:
#             ax.xaxis.set_ticks((0,nt)); ax.yaxis.set_ticks((0,10,20))
#         if saveflag:
#             pl.savefig(figDir+'/resp_c'+str(c)+'t'+str(j)+'.svg', format="svg")
#         ax.cla() 

# ### Plot accuracy
# fig = pl.figure(figsize=tuple(np.array((6.,4.))/2.54)); ax = pl.axes()
# ax.spines[['top','right']].set_visible(False)
# pl.plot(np.linspace(1,par.nEpisodes,par.nEpisodes),np.convolve(np.pad(accTr, (50,50), mode='edge'),kernel,mode='same')[50:-50], linewidth=1.0)
# pl.plot(np.linspace(1,par.nEpisodes,par.nEpisodes),np.convolve(np.pad(accVal, (50,50), mode='edge'),kernel,mode='same')[50:-50], linewidth=1.0)
# # pl.plot(accTr, linewidth=0.5); pl.plot(accVal, linewidth=0.5)
# ax.xaxis.set_ticks((0,par.nEpisodes)); ax.yaxis.set_ticks((0,1))
# if saveflag:
#     pl.savefig(figDir+'/accTrVal.svg', format="svg")

### Plot Feedback Weight Evolution
if par.saveFlag_FBWeights:
    fig = pl.figure(figsize=tuple(np.array((20.,5.))/2.54)); ax = pl.axes()
    ax.spines[['top','right']].set_visible(False)
    for j in range(0,1000,100):
        pl.plot(np.linspace(1,par.nEpisodes,par.nSave+1),weights[::2,j,:].transpose(), linewidth=0.5)
ax.xaxis.set_ticks((0,par.nEpisodes)); ax.yaxis.set_ticks((-.05,0,.05))
if saveflag:
    pl.savefig(figDir+'/fbWeights.svg', format="svg")

# ### Correlation between met responses
# layer = 1
# fig = pl.figure(figsize=tuple(np.array((4.,4.))/2.54)); ax = pl.axes()
# # imdata = ax.imshow(stats.zscore(np.reshape(RESP[1][0,:,:,:],[RESP[1].shape[1], RESP[1].shape[2]*RESP[1].shape[3]]), 1).transpose(),vmin=-1.0, vmax=1.0)
# c0 = np.matmul(stats.zscore(np.reshape(RESP[layer][0,:,:,-1],[RESP[layer].shape[1], RESP[layer].shape[2]]), 1), stats.zscore(np.reshape(RESP[layer][0,:,:,-1],[RESP[layer].shape[1], RESP[layer].shape[2]]), 1).transpose()) / (RESP[layer].shape[2])
# imdata = ax.imshow(c0,vmin=-1.0, vmax=1.0)
# cb = fig.colorbar(imdata, ticks=[-1.0, 0.0, 1.0])
# if saveflag:
#     pl.savefig(directory+'/figs/'+prefix+'met0_corr.svg', format="svg")
# fig = pl.figure(figsize=tuple(np.array((4.,4.))/2.54)); ax = pl.axes()
# c1 = np.matmul(stats.zscore(np.reshape(RESP[layer][19,:,:,-1],[RESP[layer].shape[1], RESP[layer].shape[2]]), 1), stats.zscore(np.reshape(RESP[layer][19,:,:,-1],[RESP[layer].shape[1], RESP[layer].shape[2]]), 1).transpose()) / (RESP[layer].shape[2])
# imdata = ax.imshow(c1,vmin=-1.0, vmax=1.0)
# cb = fig.colorbar(imdata, ticks=[-1.0, 0.0, 1.0])
# if saveflag:
#     pl.savefig(directory+'/figs/'+prefix+'met1_corr.svg', format="svg") 


### Plot triplet distances and loss

In [None]:
import matplotlib
from matplotlib import pyplot as pl
# saveflag = True
saveflag = False
tripletflag = True
# tripletflag = False

kernel = np.ones(50)/50
prefix = experiment#'classic_original_'
matplotlib.rcParams['savefig.dpi'] = 300
matplotlib.rcParams['axes.labelsize'] = 6
matplotlib.rcParams['svg.fonttype'] = 'none'
matplotlib.rcParams['savefig.format'] = 'svg'
matplotlib.rcParams['font.family'] = 'sans-serif'

###Plot distances
if tripletflag:
  fig = pl.figure(figsize=tuple(np.array((6.,4.))/2.54)); ax = pl.axes()
  ax.spines[['top','right']].set_visible(False)
  # pl.plot(np.log10(np.linspace(1,nt_dist,nt_dist)),dist[:,0]); pl.plot(np.log10(np.linspace(1,nt_dist,nt_dist)),dist[:,1], linestyle='dashed')
  pl.plot(np.linspace(1,nt_dist,nt_dist),dist[:,0], linewidth=0.5)
  pl.plot(np.linspace(1,nt_dist,nt_dist),dist[:,1], linewidth=0.5)
  ax.set_xlim(xmin=0, xmax=np.log10(nt_dist)); ax.set_ylim(ymin=0.0, ymax=20)
  # ax.xaxis.set_ticks((np.log10((1,10,100,1000,10000)))); ax.yaxis.set_ticks((0,15,30))
  ax.xaxis.set_ticks((0,nt_dist)); ax.yaxis.set_ticks((0,20))
  # ax.tick_params(axis='both', which='major', labelsize=6.0)
  if saveflag:
      pl.savefig(directory+'/figs/'+prefix+'distances.svg', format="svg")

###Plot triplett loss
if tripletflag:
  fig = pl.figure(figsize=tuple(np.array((6.,4.))/2.54)); ax = pl.axes()
  ax.spines[['top','right']].set_visible(False)
  # pl.plot(np.log10(np.linspace(1,nt_dist,nt_dist)),l_tri); pl.plot(np.log10(np.linspace(1,nt_dist,nt_distX_tr.shape)),np.convolve(l_tri,kernel,mode='same')); 
  pl.plot(np.linspace(1,nt_dist,nt_dist),l_tri, linewidth=0.5)
  pl.plot(np.linspace(1,nt_dist,nt_dist),np.convolve(np.pad(l_tri, (50,50), mode='edge'),kernel,mode='same')[50:-50], linewidth=1.0); 
  # ax.set_xlim(xmin=0, xmax=np.log10(nt_dist)); ax.set_ylim(ymin=0.0, ymax=0.5*1.05*np.max(l_tri))
  ax.set_xlim(xmin=0, xmax=np.log10(nt_dist)); ax.set_ylim(ymin=0.0, ymax=1.0)
  # ax.xaxis.set_ticks((np.log10((1,10,100,1000,10000)))); ax.yaxis.set_ticks((0,5))
  ax.xaxis.set_ticks((0,nt_dist)); ax.yaxis.set_ticks((0,1))
  ax.tick_params(axis='both', which='major', labelsize=6.0)
  if saveflag:
      pl.savefig(directory+'/figs/'+prefix+'loss_triplet.svg', format="svg")

###Plot triplet accuracy
if tripletflag:
  fig = pl.figure(figsize=tuple(np.array((6.,4.))/2.54)); ax = pl.axes()
  ax.spines[['top','right']].set_visible(False)
  pl.plot(np.linspace(1,nt_dist,nt_dist),a_tri*100, linewidth=0.5)
  pl.plot(np.linspace(1,nt_dist,nt_dist),np.convolve(np.pad(a_tri*100, (50,50), mode='edge'),kernel,mode='same')[50:-50], linewidth=1.0); 
  ax.set_xlim(xmin=0, xmax=np.log10(nt_dist)); ax.set_ylim(ymin=0.0, ymax=100.)
  # ax.xaxis.set_ticks((np.log10((1,10,100,1000,10000)))); ax.yaxis.set_ticks((0,5))
  ax.xaxis.set_ticks((0,nt_dist)); ax.yaxis.set_ticks((0,100))
  ax.tick_params(axis='both', which='major', labelsize=6.0)
  if saveflag:
      pl.savefig(directory+'/figs/'+prefix+'acc_triplet.svg', format="svg")

###Plot output loss
fig = pl.figure(figsize=tuple(np.array((6.,4.))/2.54)); ax = pl.axes()
ax.spines[['top','right']].set_visible(False)
# pl.plot(np.log10(np.linspace(1,nt_dist,nt_dist)),l_tri); pl.plot(np.log10(np.linspace(1,nt_dist,nt_dist)),np.convolve(l_tri,kernel,mode='same')); 
pl.plot(np.linspace(1,nt_out,nt_out),l_out, linewidth=0.5)
pl.plot(np.linspace(1,nt_out,nt_out),np.convolve(np.pad(l_out, (50,50), mode='edge'),kernel,mode='same')[50:-50], linewidth=1.0); 
ax.set_xlim(xmin=0, xmax=np.log10(nt_out)); ax.set_ylim(ymin=0.0, ymax=0.8)
# ax.xaxis.set_ticks((np.log10((1,10,100,1000,10000)))); ax.yaxis.set_ticks((0,5))
ax.xaxis.set_ticks((0,nt_out)); ax.yaxis.set_ticks((0.0,0.8))
ax.tick_params(axis='both', which='major', labelsize=6.0)
ax.set_xticklabels((0, nt_out), fontdict={'family': 'sans-serif', 'size':5})
ax.set_yticklabels((0.5,0.8), fontdict={'family': 'sans-serif', 'size':5})
if saveflag:
    pl.savefig(directory+'/figs/'+prefix+'loss_out.svg', format="svg")

###Plot output accuracy
fig = pl.figure(figsize=tuple(np.array((6.,4.))/2.54)); ax = pl.axes()
ax.spines[['top','right']].set_visible(False)
pl.plot(np.linspace(1,nt_out,nt_out),a_out*100, linewidth=0.5)
pl.plot(np.linspace(1,nt_out,nt_out),np.convolve(np.pad(a_out*100, (50,50), mode='edge'),kernel,mode='same')[50:-50], linewidth=1.0); 
ax.set_xlim(xmin=0, xmax=np.log10(nt_out)); ax.set_ylim(ymin=0.0, ymax=100.)
# ax.xaxis.set_ticks((np.log10((1,10,100,1000,10000)))); ax.yaxis.set_ticks((0,5))
ax.xaxis.set_ticks((0,nt_out)); ax.yaxis.set_ticks((0,100))
ax.tick_params(axis='both', which='major', labelsize=6.0)
if saveflag:
    pl.savefig(directory+'/figs/'+prefix+'acc_out.svg', format="svg")

### Plot met responses over training

In [None]:
from matplotlib import pyplot as pl
from scipy import stats
pl.rcParams['savefig.dpi'] = 400

saveflag = True
# saveflag = False

### Raw met reponses
fig = pl.figure(figsize=tuple(np.array((6.,5.))/2.54)); ax = pl.axes()
ax.spines[['top','right']].set_visible(False)
plmet = np.squeeze(met[0,:,199])
pl.plot(np.linspace(1,nt_met,nt_met)*save_every,plmet)
ax.set_xlim(xmin=0, xmax=nt_met*save_every); ax.set_ylim(ymin=0.0, ymax=1.05*np.max(plmet))
ax.xaxis.set_ticks((0,nt_met*save_every)); ax.yaxis.set_ticks((0,))
pl.show()

fig = pl.figure(figsize=tuple(np.array((20.,5.))/2.54)); ax = pl.axes()
c0 = stats.zscore(np.squeeze(met[0,:,:]),0)
imdata = ax.imshow(c0,vmin=0.0, vmax=2.0)
cb = fig.colorbar(imdata, ticks=[0., 2.0])
if saveflag:
    pl.savefig(directory+'/figs/'+prefix+'met0.svg', format="svg")
fig = pl.figure(figsize=tuple(np.array((20.,5.))/2.54)); ax = pl.axes()
c0 = stats.zscore(np.squeeze(met[-1,:,:]),0)
imdata = ax.imshow(c0,vmin=0.0, vmax=2.0)
cb = fig.colorbar(imdata, ticks=[0., 2.0])
if saveflag:
    pl.savefig(directory+'/figs/'+prefix+'met1.svg', format="svg")

### Correlation between met responses
fig = pl.figure(figsize=tuple(np.array((5.,5.))/2.54)); ax = pl.axes()
c0 = np.matmul(stats.zscore(np.squeeze(met[0,:,:]),0).transpose(), stats.zscore(np.squeeze(met[0,:,:]),0)) / met.shape[1]
imdata = ax.imshow(c0,vmin=-1.0, vmax=1.0)
cb = fig.colorbar(imdata, ticks=[-1.0, 0.0, 1.0])
if saveflag:
    pl.savefig(directory+'/figs/'+prefix+'met0_corr.svg', format="svg")
fig = pl.figure(figsize=tuple(np.array((5.,5.))/2.54)); ax = pl.axes()
c1 = np.matmul(stats.zscore(np.squeeze(met[-1,:,:]),0).transpose(), stats.zscore(np.squeeze(met[-1,:,:]),0)) / met.shape[1]
imdata = ax.imshow(c1,vmin=-1.0, vmax=1.0)
cb = fig.colorbar(imdata, ticks=[-1.0, 0.0, 1.0])
if saveflag:
    pl.savefig(directory+'/figs/'+prefix+'met1_corr.svg', format="svg") 


# Plot weights