In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import copy
import pickle
import yaml
import os
import plotly.express as px
import pandas as pd
import random
import time
from tqdm.notebook import trange, tqdm
from scipy.integrate import solve_ivp
from scipy import signal


In [None]:
def sigmoid(x,w = 0.1):
    return 1/(1+ np.exp(-w * x))

def a_to_eps(a):
    eps = (a - 9 * 1e-5)/0.5
    print("Corresponding value for epsilon:",eps)

def eps_to_a(eps):
    a = 0.5 * eps + 9 * 1e-5
    return("Corresponding value for a:",a)

In [None]:
def rossler(t, xyz, a):
    b = 0.2
    c = 5.7
    k = 0.001
    Iext = 0
    d = 8
    x = np.array(xyz[0:100])
    y = np.array(xyz[100:200])
    z = np.array(xyz[200:300])
    x_mean = np.mean(x)
    
    dxdt = -y - z + k * (d - x_mean) + Iext
    dydt = x + a * y
    dzdt = b + z * (x - c)
    diff = []
    diff.extend(dxdt)
    diff.extend(dydt)
    diff.extend(dzdt)
    return np.array(diff).flatten()

def STN_GPe(eps):
    a = 0.5 * eps + 9 * 1e-5

    # Rossler solver
    t_span = (0, 200) #start to end
    t_eval = np.linspace(*t_span, 100000) # dt=0.002 s or 5000 samples/sec
    N = 100
    np.random.seed(70)
    x_intial = np.random.rand(N)
    y_initial = np.random.rand(N)
    z_initial = np.random.rand(N)
    initial_vals = []
    initial_vals.extend(x_intial)
    initial_vals.extend(y_initial)
    initial_vals.extend(z_initial)
    stn_units = np.random.choice(np.arange(100), 4) # choose 4 random clusters in STN
    sol = solve_ivp(rossler, t_span, initial_vals, args=(a,), t_eval=t_eval,vectorized=True,method='LSODA')
    #stn_output = {}
    stn_output = np.zeros((4,50000))
    for i in range(len(stn_units)):
        stn_output[i] = sigmoid(sol.y[stn_units[i],50000:100000])

    # Slicing STN output
    stn_avg =  np.mean(stn_output.reshape(4, 100, 500), axis = 2)
    return stn_avg

    

In [None]:
# eps= 1 #0.001
# print(eps_to_a(eps))
# stn_out= STN_GPe(eps)


In [None]:
# stn = torch.tensor(stn_out)

In [None]:
# plt.plot(stn[0])
# plt.plot(stn[1])
# plt.plot(stn[2])

In [None]:
# corr = np.corrcoef((stn_all))
# plt.imshow(np.abs(corr))
# plt.colorbar()
# plt.title('Correlation matrix: Oscillatory phase for ' + str(eps_to_a(eps)), fontsize = 10)
# plt.xticks(fontsize = 8)
# plt.yticks(fontsize = 8)
# plt.show()

In [None]:
class IGT():
    def __init__(self,
                 iter = 50, 
                 gpi_threshold = 0.3, 
                 lr = 0.1, 
                 d1_amp = 1, 
                 d2_amp = 0.5,
                 ip_size = 4,
                 PD = False):
    

        # Hyper parameters
        self.iter = iter # gpi iters
        self.gpi_threshold = gpi_threshold # Race model threshold
        self.lr = lr # learning rate
        self.d1_amp = d1_amp # d1 weightage at GPi
        self.d2_amp = d2_amp # d2 weightage at GPi
        self.ip_size = ip_size # Number of input cards
        self.PD = PD # PD or Normal
        
        # Input vector
        self.input = torch.ones((1,self.ip_size)) 

        # Initialising weights
        # wt for input to straitum
        self.w_strd1 = torch.ones((1, self.ip_size)) 
        self.w_strd2 = torch.ones((1, self.ip_size))

        # d2 pathwaty stn to gpe weight
        self.w_d2_gpe = torch.ones((1,self.ip_size))
        
        self.W_DP = torch.rand((1,self.ip_size)) 
        self.W_IP =  torch.ones((1, self.ip_size)) 
        self.W_snc = torch.ones((1,self.ip_size)) 

        # Intialising deltaV, delta, VF and Q value
        self.deltav = torch.ones((1,self.ip_size))
        self.delta = torch.zeros((1,self.ip_size))
        self.prevv = torch.zeros((1,self.ip_size))
        self.q_val = torch.zeros((1,self.ip_size))
        self.value = torch.zeros((1, self.ip_size))
        self.delta_W_DP = torch.zeros((1, self.ip_size))  
        self.delta_W_IP = torch.zeros((1, self.ip_size))  

        # gpi voltage
        self.v_gpi = torch.zeros((1, self.ip_size))

        # Counts
        self.count = torch.zeros(self.ip_size, dtype= int)

        self.iter_mon = []
        self.gpi_mon = []

        # Rewards
        self.rew = np.array([100, 100, 50, 50])
        self.rew_a = 100
        self.rew_b = 100
        self.rew_c = 50
        self.rew_d = 50

        # Loss
        self.loss = np.array([-250,-1250,-50, -250])
        self.loss_a = -250
        self.loss_b = -1250
        self.loss_c = -50
        self.loss_d = -250

        # loss probs
        self.loss_probs = np.array([0.5, 0.1, 0.5,0.1])

        # epsilon
        self.eps = torch.tensor(0.75) #changed

        # epsilon monitor
        self.eps_mon = []

    def reward(self, card_chosen):

        reward_ = self.rew[card_chosen]
        loss_ = np.random.choice([self.loss[card_chosen],0], p = [self.loss_probs[card_chosen],1-self.loss_probs[card_chosen]])
        tot_reward = reward_+ loss_
        return tot_reward/1150
    
    def train(self, 
              del_clamp = -0.8, 
              delv_clamp= -0.3,
              IP_train = False):
        
        self.eps_mon.append(self.eps)

        # Parameters
                
        del_clamp = torch.tensor(del_clamp)
        del_clamp = torch.unsqueeze(del_clamp, dim = 0)
        tau_gpi = 0.1 # 100ms
        dt = 0.01 # 10 ms

        
        # Direct pathway
        d1_out = self.input * self.w_strd1 * (1-self.eps)
        v_d1_gpi = d1_out * self.W_DP

        # Indirect pathway
        d2_out = self.input * self.w_strd2 * self.eps
        
        # STN solver
        # print('***SOLVING***')
        # print('eps = ' , self.eps)
        stn_output= torch.tensor(STN_GPe(self.eps.item()))
        
        t = 0
        while t < self.iter:
            v_d2_gpi = stn_output[:,t] * self.W_IP
            #print(stn_output[:,t], self.W_IP, v_d2_gpi)
            self.v_gpi = self.v_gpi + (dt/tau_gpi) * (-self.v_gpi - self.d1_amp * v_d1_gpi + self.d2_amp * v_d2_gpi)
            self.v_gpi_out = -self.v_gpi
            self.gpi_mon.append(self.v_gpi_out)
            max, card_chosen = torch.max(self.v_gpi_out,1)
            t += 1
            
            if max > self.gpi_threshold:
                break

        self.iter_mon.append(t-1) # Number of iters in each trail
        reward_ = self.reward(card_chosen.item())

        # ******************* UPDATING value functions **************************
        self.q_val = v_d1_gpi
        #print(self.q_val)
        self.delta = reward_ - self.q_val[:,card_chosen]
        #print(f'R = {reward_}, q = {self.q_val[:,card_chosen]}, ep = {self.eps}, delta = {self.delta}')
        self.eps = self.eps + (1-torch.exp(-(self.delta**2)/1000)-self.eps)*0.05 #*0.25 # epsilon update
        print(f'Epsilon = {self.eps}')
        print(f'Loss = ', self.loss)
        print('card chosen=', card_chosen)

        # ****************** Clamping values **********************************
        if self.PD:
            # Delta clamp
            if self.delta > del_clamp:
                self.delta = del_clamp
        
        # W_DP training
        self.delta_W_DP[:,card_chosen] = self.lr * self.delta
        self.W_DP = self.W_DP + self.delta_W_DP

        if IP_train:

            self.delta_W_IP[:,card_chosen] = self.lr * self.delta
            self.W_IP = self.W_IP - self.delta_W_IP

        return card_chosen.item(), reward_, self.gpi_mon, self.iter_mon
    

In [None]:
def load_yaml(yaml_file):
    with open(yaml_file,'r') as file:
        data = yaml.safe_load(file)
    return data

def experiment(exp_yaml_path):

    arguments = load_yaml(exp_yaml_path) # dict

    # run = wandb.init(project=arguments.get('project_name'),
    #                  name=arguments.get('exp_name'),
    #                  config= arguments)

    epochs = arguments.get('epochs')
    bins = arguments.get('bins')
    binsize = arguments.get('binsize')

    del_clamp = arguments.get('del_clamp')
    delv_clamp = arguments.get('delv_clamp')

    igt_dict = {}
    igt_se_dict = {}
    
    for del_c in del_clamp: # loop for delta clamp
        for del_v in delv_clamp: # %%%% Change to eps%%%
            print(f'Clamping values: delv = {del_v}, del = {del_c}')
            name = 'del '+str(np.abs(del_c))+' delv ' + str(np.abs(del_v))
            igt_monitor = np.zeros((epochs, bins))
            rew_monitor = np.zeros((epochs,bins,binsize))
            wt_DP_monA = []
            wt_DP_monB = []
            wt_DP_monC = []
            wt_DP_monD = []
            wt_IP_monA = []
            wt_IP_monB = []
            wt_IP_monC = []
            wt_IP_monD = []
            
            for i in tqdm(range(epochs)):

                model = IGT(iter = arguments.get('iter'), 
                            gpi_threshold = arguments.get('gpi_threshold'), 
                            lr = arguments.get('lr'),  
                            d1_amp = arguments.get('d1_amp'), 
                            d2_amp = arguments.get('d2_amp'),
                            ip_size = arguments.get('ip_size'),
                            PD = arguments.get('PD'),)
                
                reward_acc = 0
                for iter in range(bins): # bins
                    num_cards_chosen = [0 for _ in range(4)]
                            
                    for trial in range(binsize): # binsize
                        card_chosen, reward_, _,_ = model.train(del_clamp = del_c, 
                                                                delv_clamp= del_v, 
                                                                IP_train = arguments.get('IP_train'))
                        
                        
                        reward_acc += reward_
                        rew_monitor[i,iter, trial] = reward_acc
                        num_cards_chosen[card_chosen] += 1
                    
                    igt_score = np.sum(num_cards_chosen[2:4]) - np.sum(num_cards_chosen[0:2]) # computing IGT score for each bin
                    igt_monitor[i,iter] = igt_score
                
                wt_DP_monA.append(model.W_DP[:,0].item())
                wt_DP_monB.append(model.W_DP[:,1].item())
                wt_DP_monC.append(model.W_DP[:,2].item())
                wt_DP_monD.append(model.W_DP[:,3].item())
                wt_IP_monA.append(model.W_IP[:,0].item())
                wt_IP_monB.append(model.W_IP[:,1].item())
                wt_IP_monC.append(model.W_IP[:,2].item())
                wt_IP_monD.append(model.W_IP[:,3].item())
                print(f'epochs = {i} : IGT score = {igt_monitor}')
                

            avg_igt_score = np.mean(igt_monitor, axis = 0)
            igt_se = np.std(igt_monitor, axis = 0)/np.sqrt(epochs)

            igt_dict[name] = avg_igt_score
            igt_se_dict[name] = igt_se
        
        return igt_dict,igt_se_dict,wt_DP_monA,wt_DP_monB, wt_DP_monC, wt_DP_monD, wt_IP_monA, wt_IP_monB, wt_IP_monC, wt_IP_monD


In [None]:
igt_dict,igt_se_dict,wt_DP_monA,wt_DP_monB, wt_DP_monC, wt_DP_monD, wt_IP_monA, wt_IP_monB, wt_IP_monC, wt_IP_monD = experiment(r'D:\CNS IITM\Chaotic Network\experiment_args_IGT.yaml')

In [None]:
igt_score = igt_dict['del 0 delv 0']
print(f'IGT score: {igt_score}')

In [None]:
print(f'DP => A: {np.mean(wt_DP_monA)}, B: {np.mean(wt_DP_monB)}, C: {np.mean(wt_DP_monC)},D: {np.mean(wt_DP_monD)}')
print(f'IP => A: {np.mean(wt_IP_monA)}, B: {np.mean(wt_IP_monB)}, C: {np.mean(wt_IP_monC)},D: {np.mean(wt_IP_monD)}')

In [None]:
bins = np.arange(5)+1
plt.bar(bins, igt_score)
plt.errorbar(bins, igt_score, yerr =igt_se_dict['del 0 delv 0'], fmt = 'o', color = 'black')
plt.xlabel('BINS')
plt.ylabel('IGT Score')
plt.title('IGT Score')
plt.show()