In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from sklearn.metrics.pairwise import euclidean_distances
from torchsummary import summary
from tensorboardX import SummaryWriter
import argparse
import torchvision.utils as vutils
import matplotlib.pyplot as plt

import datetime
import os
class arg_struct():
    def __init__(self):
        self.batch_size = 1024
        self.latent_shape = [40,40]
        self.n_epochs = 600
        self.sampling = 'poisson'
        self.n_samples = 20
        self.layers = [20,40]
        self.cuda = True        
        self.sigma = 2.0
        self.eta = 0.00001
        self.lateral = 'mexican'
        self.lambda_l = 20
        self.default_rate = 6.0 # the expected number of spikes in each bin
        self.save_path = '/home/jts3256/projects/stimModel/models'
        self.dropout=93 # as a percentage

args = arg_struct()



In [None]:
if args.cuda:
    print("Using CUDA")


EPS = 1e-6

def locmap():
    '''
    :return: location of each neuron
    '''
    x = np.arange(0, args.latent_shape[0], dtype=np.float32)
    y = np.arange(0, args.latent_shape[1], dtype=np.float32)
    xv, yv = np.meshgrid(x, y)
    xv = np.reshape(xv, (xv.size, 1))
    yv = np.reshape(yv, (yv.size, 1))
    return np.hstack((xv, yv))


def lateral_effect():
    '''
    :return: functions of lateral effect
    '''
    locations = locmap()
    weighted_distance_matrix = euclidean_distances(locations, locations)/args.sigma

    if args.lateral is 'mexican':
        S = (1.0-0.5*np.square(weighted_distance_matrix))*np.exp(-0.5*np.square(weighted_distance_matrix))
        return S-np.eye(len(locations))

    if args.lateral is 'rbf':
        S = np.exp(-0.5*np.square(weighted_distance_matrix))
        return S-np.eye(len(locations))
    print('no lateral effect is chosen')
    return np.zeros(weighted_distance_matrix.shape, dtype=np.float32)


class Encoder(nn.Module):
    def __init__(self, input_size):
        super(Encoder, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(input_size, args.layers[0], bias=False),
            #nn.Tanh()
            
        )

        self.layer2 = nn.Sequential(
            nn.Linear(args.layers[0], args.layers[1], bias=False),
            #nn.Tanh()
            
        )

        self.layer3 = nn.Sequential(
            nn.Linear(args.layers[1], latent_size, bias=False),
            # nn.Softplus()
            nn.ReLU()
        )

    def forward(self, x):
        if args.cuda:
            self.cuda()
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x


class Decoder(nn.Module):
    def __init__(self, input_size):
        super(Decoder, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(latent_size, input_size, bias=False),
        )

    def forward(self, x):
        output = self.layer1(x)
        return output


class VAE(nn.Module):
    def __init__(self, encoder, decoder, lateral):
        super(VAE, self).__init__()
        if args.cuda:
            self.cuda()
        self.encoder = encoder
        self.decoder = decoder
        self.lateral = torch.from_numpy(lateral).type(torch.FloatTensor) # not positive definite
        self.dropout = nn.Dropout(args.dropout/100) # convert from percentage
        
    def forward(self, inputs):
        if args.cuda:
            self.cuda()
            inputs = inputs.cuda()
        #inputs = inputs/40.0
        rates = self.encoder(inputs)

        # dropout layer
        rates = self.dropout(rates)+0.0001
        
        if args.sampling is 'bernoulli':
            self.posterior = torch.distributions.Bernoulli(probs=rates)
            samples = self.posterior.sample([args.n_samples])
            samples = torch.transpose(samples, 0, 1)
            samples.clamp(max = args.n_samples)
            return torch.mean(self.decoder(samples), 1)

        if args.sampling is 'poisson':
            self.posterior = torch.distributions.Poisson(rates*args.n_samples)
            samples = self.posterior.sample()
            return self.decoder(samples/args.n_samples)

        if args.sampling is 'none':
            self.posterior = rates
            return self.decoder(rates)


    def kl_divergence(self):
        if args.sampling is 'bernoulli':
            prior = torch.distributions.Bernoulli(probs = torch.ones_like(self.posterior.probs)*args.default_rate)
            kl = torch.distributions.kl_divergence(self.posterior, prior)
            return torch.mean(kl)

        if args.sampling is 'poisson':
            prior = torch.distributions.Poisson(torch.ones_like(self.posterior.mean) * \
                                                args.default_rate * args.n_samples)
            kl = torch.distributions.kl_divergence(self.posterior, prior)
            return torch.mean(kl)

        if args.sampling is 'none':
            return 0.0

    def lateral_loss(self):
        if args.sampling is 'bernoulli':
            rates = torch.squeeze(self.posterior.probs)
        if args.sampling is 'poisson':
            rates = torch.squeeze(self.posterior.mean)
        if args.sampling is 'none':
            rates = torch.squeeze(self.posterior)

        n = rates.norm(2, 1).view(-1, 1).repeat(1, latent_size)
        rates = rates/n
        if args.cuda:
            A = rates.mm(self.lateral.cuda()).mm(rates.t())/latent_size
        else:
            A = rates.mm(self.lateral).mm(rates.t())/latent_size # self.lateral is a lower triangular matrix
        loss = torch.diag(A)
        return -torch.mean(loss)

    def normalise_weight(self):
        weight = self.decoder.layer[0].weight.data
        tmp = torch.norm(weight, dim=0)
        self.decoder.layer[0].weight.data = weight/tmp.repeat([input_size, 1])

    def save(self):
        torch.save(self.state_dict(), args.save_path)

class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self,*datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

def vaf(x,xhat):
    x = x - x.mean(axis=0)
    xhat = xhat - xhat.mean(axis=0)
    return (1-(np.sum(np.square(x-xhat))/np.sum(np.square(x))))*100

In [None]:
#%%
training_fname = '/home/jts3256/projects/stimModel/training_data/Han_20160315_RW_SmoothNormalizedJointVel_uniformAngDist_50ms.txt'
all_data_fname = '/home/jts3256/projects/stimModel/training_data/Han_20160315_RW_SmoothNormalizedJointVel_50ms.txt'

my_data = np.genfromtxt(training_fname, delimiter=',')[:,:]
my_data_test = np.genfromtxt(all_data_fname,delimiter=',')[:,:]
my_data_test = torch.from_numpy(my_data_test).type(torch.FloatTensor)

train = my_data[:21000]
test = my_data[21000:]

x_tr = torch.from_numpy(train[:,:]).type(torch.FloatTensor)
y_tr = torch.from_numpy(train[:,:]).type(torch.FloatTensor)

x_te = torch.from_numpy(test[:,:]).type(torch.FloatTensor)
y_te = test[:,:]

dataloader = DataLoader(ConcatDataset(x_tr,y_tr), batch_size=args.batch_size,
                                              shuffle=True)

test_data = (x_te,y_te)
x=datetime.datetime.now()
dname = '_' + x.strftime("%G")+'-'+x.strftime("%m")+'-'+x.strftime("%d")+'-'+x.strftime("%H")+x.strftime("%M")+x.strftime("%S")

split_fname = all_data_fname.split(os.path.sep)[-1]
underscore_fname = split_fname.split('_')

monkey = underscore_fname[0]
date_task = underscore_fname[1]
task = underscore_fname[2]



for i_run in range(15):

    latent_size = args.latent_shape[0]*args.latent_shape[1]

    writer = SummaryWriter()

    input_size = len(x_tr[0])
    output_size = len(y_tr[0])
    encoder = Encoder(input_size=input_size)
    decoder = Decoder(input_size=output_size)
    lateral = lateral_effect()

    vae = VAE(encoder, decoder, lateral)

    if args.cuda:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        vae.cuda()
        vae.to(device)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(vae.parameters(), lr=args.eta)

    pathname = args.save_path + os.path.sep + monkey + '_' + date_task + '_' + task + dname
    os.mkdir(pathname)
    
    for epoch in range(args.n_epochs):
        print(epoch)
        for i_batch, (x_batch, y_batch) in enumerate(dataloader):
            if args.cuda:
                x_batch = x_batch.cuda()
                y_batch = y_batch.cuda()
                vae.cuda()

            yhat = vae(x_batch)
            recon_error = criterion(yhat,y_batch)
            kl = vae.kl_divergence()
            lateral_loss = vae.lateral_loss()
            loss = 10.0*recon_error +args.lambda_l*lateral_loss + kl*0.005 # usually 0.005
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        test_result = vae(x_te)
        y_te_hat = test_result.cpu().detach().numpy()

        weight = vae.decoder.layer1[0].weight.data
        w = weight.reshape([-1, 1, args.latent_shape[0], args.latent_shape[1]])
        imgs = vutils.make_grid(w, normalize=True, scale_each=False)
        writer.add_image('Model/Weight', imgs, epoch)

        writer.add_scalar('loss/total_loss', loss, epoch)
        writer.add_scalar('loss/kl', kl, epoch)
        writer.add_scalar('loss/lateral', lateral_loss, epoch)
        writer.add_scalar('loss/recon', recon_error, epoch)
        writer.add_scalar('loss/VAF', vaf(y_te,y_te_hat), epoch)


        if epoch % 2 == 0:
            if args.cuda:
                vae.cuda()

            # test_result = torch.mean(test_result, 1)
            fig, ax = plt.subplots(10,1)
            for i in range(len(my_data[1,:])):
                if args.cuda:
                    ax[i].plot(y_te[:,i])
                    ax[i].plot(y_te_hat[:,i])
                else:
                    ax[i].plot(y_te[:,i])
                    ax[i].plot(y_te_hat[:,i])

            if args.sampling is 'bernoulli':
                rates = torch.squeeze(vae.posterior.probs)
            if args.sampling is 'poisson':
                rates = torch.squeeze(vae.posterior.mean)
            if args.sampling is 'none':
                rates = torch.squeeze(vae.posterior)
            rates = rates.reshape([-1, 1, args.latent_shape[0], args.latent_shape[1]])
           # print(rates.shape)
            response = vutils.make_grid(rates[0:1000:50], normalize=True, scale_each=False)
            writer.add_image('Model/Response', response, epoch)
            writer.add_figure('Model/test', fig, epoch)

        writer.flush()

        
    writer.export_scalars_to_json(pathname + '/all_scalers.json')
    writer.close()

    #%% run test data set through model and save firing rates

    my_data_test=my_data_test.to(device)
    rates = vae.encoder(my_data_test)
    rates = rates.cpu().detach().numpy()

    
    fname = 'rates_' + monkey + '_' + date_task + '_' + task + '_sigma'
    fname = fname + str(args.sigma) + '_drop' + str(args.dropout) + '_lambda' + str(args.lambda_l) + '_learning' + str(args.eta)
    fname = fname + '_n-epochs' + str(args.n_epochs) + '_n-neurons' + str(args.latent_shape[0]*args.latent_shape[1]) + '_rate' + str(args.default_rate)
    
    np.savetxt(pathname + os.path.sep + fname + dname + '.csv', rates,delimiter=",")
    torch.save(vae.state_dict(), pathname + os.path.sep + monkey + '_' + date_task + '_' + task + dname + '_model_params')