In [None]:
#v4 without Re
#v4_1 revised forward function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import os
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = hiddens[0][0]
    HPC_input = torch.cat([input,Re],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, data_length=100, freq=60., noise=0.01):
   
    x = np.loadtxt("primal_long131test_r.csv",delimiter=',')
    y = np.loadtxt("primal_long131test_l.csv",delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)


class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        Re_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0
        var = 1
        HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.05
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
class MyLSTM_3lay(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_3lay, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        sparse = 0.1
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input,hiddens[2])
        PFC_input = hiddens[2][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        vHPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        vHPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
class MyLSTM_RNN_uniPFC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=1):
        super(MyLSTM_RNN_uniPFC, self).__init__()

        self.hidden_size_PFC = hidden_size+0
        self.hidden_size_HPC = hidden_size+0
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)
        # nn.init.normal_(self.PFC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.Re.weight_ih.data,0,0.1/10)
        # nn.init.normal_(self.PFC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.Re.weight_hh.data,0,1/10)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = hiddens[1][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]
    

def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq


def main(num):
    training_size = 200
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 2
    sparse = 1
#     model_path = 'model/ReModeltest_30reset3addlong_estref_L2_interRNNrand_Reinh_AddRe_OUT5_131_s'+str(num)+'_100_4.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_1212121_s'+str(num)+'_100_1_2.pth'
#     model_path = 'model/ReModeltest_30reset3addlong_L2_interRNNrand_Reinh_AddRe_OUT5_131_s10_100_1.pth'
#     model_path = 'model/ReModeltest_30addlong_interRNNrand_Reinh_AddRe_OUT5_131_s3_100_2.pth'
#     model_path = 'model/ReModel_interRNNrand_AddRe_OUT5_long_s7_200_2_1_'+str(num)+'.pth'
#     model_path = 'model/ReModel_interRNNrand_AddRe_OUT5_181_s'+str(num)+'_100_1.pth'
#     model_path = 'model/ReModel_LSTMrand_noise_long131test_s'+str(num)+'_100_1.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_Reinh_AddRe_OUT5_11311_s6_100_1.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_H121_s'+str(num)+'_100_2.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_H121_s3_100_2.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_H2v121_s'+str(num)+'_100_1.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_H2v121_s9_100_1.pth'
#     model_path = 'model/R20_H/ReModel_L2_interRNNrand_OUT1_121H_s'+str(num)+'_100_1_epoch155.pth'
    model_path = 'model/R20_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s'+str(num)+'_100_1_epoch160.pth'
#     model_path = 'model/R20_uniPFC_H/ReModel_L2_interRNNrand_OUT1_uniPFC_121H_s'+str(num)+'_100_1_epoch95.pth'
    filename = "primal_long131test"
    
    if os.path.exists(model_path):
        print(model_path)
    else:
        print("Not exist")
        return
    
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)
    test_x = mkOwnDataSet(test_size,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size, sparse)
#     rnn = MyLSTM_RNN_uniPFC(inputsize, hidden_size, outputsize, batch_size, sparse)


    rnn.load_state_dict(torch.load(model_path))
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()                
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
        
    
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 120
    for k in range(data_limit):
#             print(data[k].shape)
#             hidden = rnn.noiseHidden_rand(hidden)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(torch.rand(10,2),hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*6):
#             output += torch.randn(10,2)*0.01
#             hidden = rnn.noiseHidden_rand(hidden)
            output,hidden = rnn(output,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    fig = plt.figure()
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:100,0,0],traj[:100,0,1])
    plt.plot(traj[100:,0,0],traj[100:,0,1])
    plt.show()
    
    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
#     MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
            if n == "linear.weight":
                Output_w = np.array(p.data)

    fig2 = plt.figure(figsize=(10,5))
    ax1 = fig2.add_subplot(141)
    ax2 = fig2.add_subplot(142)
    axre = fig2.add_subplot(143)
    axout = fig2.add_subplot(144)
    
    ax1.imshow(PFC_w,cmap="coolwarm")
    ax2.imshow(HPC_w,cmap="coolwarm")
    axre.imshow(Re_w,cmap="coolwarm")
    axout.imshow(Output_w,cmap="coolwarm")
    ax1.set_title("max = {:.2f},min = {:.2f}".format(np.max(PFC_w),np.min(PFC_w)))
    ax2.set_title("max = {:.2f},min = {:.2f}".format(np.max(HPC_w),np.min(HPC_w)))
    axre.set_title("max = {:.2f},min = {:.2f}".format(np.max(Re_w),np.min(Re_w)))
    axout.set_title("max = {:.2f},min = {:.2f}".format(np.max(Output_w),np.min(Output_w)))

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_hh":
                PFC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_w = np.array(p.data)
                
    
    fig2 = plt.figure()
    ax1 = fig2.add_subplot(131)
    ax2 = fig2.add_subplot(132)
    axre = fig2.add_subplot(133)
    
    ax1.imshow(PFC_w,cmap="coolwarm")
    ax2.imshow(HPC_w,cmap="coolwarm")
    axre.imshow(Re_w,cmap="coolwarm")
    ax1.set_title("max = {:.2f},min = {:.2f}".format(np.max(PFC_w),np.min(PFC_w)))
    ax2.set_title("max = {:.2f},min = {:.2f}".format(np.max(HPC_w),np.min(HPC_w)))
    axre.set_title("max = {:.2f},min = {:.2f}".format(np.max(Re_w),np.min(Re_w)))                
    
    
    pca = PCA()
    dfs = np.array(HPCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    HPC_feature = np.copy(feature)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
#     fig3d = plt.figure()
#     ax3d = Axes3D(fig3d)
# #     ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
# #     ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
#     ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.8)
#     ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
#     plt.show()

    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
#     ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
    ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.8)
    colors = ['C1','C2','C3','C4']
    for i in range(4):
        start = i*120
        end = (i+1)*120
        mid = int((start+end)/2)
        ax3d.plot(feature[start:end, 0], feature[start:end, 1], feature[start:end, 2], color=colors[i], alpha=0.8)
        ax3d.plot(feature[start+40:start+41, 0], feature[start+40:start+41, 1], feature[start+40:start+41, 2],  "o", color=colors[i], alpha=0.8)
        ax3d.plot(feature[mid:mid+1, 0], feature[mid:mid+1, 1], feature[mid:mid+1, 2], "o", color=colors[i], alpha=0.5)  
        ax3d.plot(feature[end:end+1, 0], feature[end:end+1, 1], feature[end:end+1, 2], "o", color=colors[i], alpha=0.5)        
#     ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()

    
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(HPCstate)[:,0,i])
#         plt.show()

#     moving_feature = np.concatenate(([moving_average(feature[:, 0])], [moving_average(feature[:, 1])], [moving_average(feature[:, 2])]),axis=0).T
#     print(moving_feature.shape)
#     plot_distance(feature,moving_feature)
#     plt.figure()
#     plt.plot(overlap_coefficient(moving_feature,feature))
#     plt.plot(moving_average(overlap_coefficient(moving_feature,feature)))
    
    pca = PCA()
    dfs = np.array(PFCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    PFC_feature = np.copy(feature)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
#     ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
    ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.8)
#     colors = ['C1','C2','C3','C4']
#     for i in range(4):
#         start = i*120
#         end = (i+1)*120
#         mid = int((start+end)/2)
#         ax3d.plot(feature[start:end, 0], feature[start:end, 1], feature[start:end, 2], color=colors[i], alpha=0.8)
#         ax3d.plot(feature[start+40:start+41, 0], feature[start+40:start+41, 1], feature[start+40:start+41, 2],  "o", color=colors[i], alpha=0.8)
#         ax3d.plot(feature[mid:mid+1, 0], feature[mid:mid+1, 1], feature[mid:mid+1, 2], "o", color=colors[i], alpha=0.5)  
#         ax3d.plot(feature[end:end+1, 0], feature[end:end+1, 1], feature[end:end+1, 2], "o", color=colors[i], alpha=0.5)        
# #     ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(PFCstate)[:,0,i])
#         plt.show()

#     moving_feature = np.concatenate(([moving_average(feature[:, 0])], [moving_average(feature[:, 1])], [moving_average(feature[:, 2])]),axis=0).T
#     print(moving_feature.shape)
#     plot_distance(feature,moving_feature)
#     plt.figure()
#     plt.plot(overlap_coefficient(moving_feature,feature))
#     plt.plot(moving_average(overlap_coefficient(moving_feature,feature)))

    pca = PCA()
    dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    Re_feature = np.copy(feature)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
#     ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
    ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.8)
    ax3d.plot(moving_average(feature[:, 0]), moving_average(feature[:, 1]), moving_average(feature[:, 2]), alpha=0.8)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(Restate)[:,i])
#         plt.show()
    
#     linelist = search_delay(traj[:,0])
#     moving_feature = np.concatenate(([moving_average(feature[:, 0])], [moving_average(feature[:, 1])], [moving_average(feature[:, 2])]),axis=0).T
#     print(moving_feature.shape)
#     plot_distance(feature,moving_feature)
    
#     plt.figure()
#     plt.plot(overlap_coefficient(moving_feature,feature))
#     plt.plot(moving_average(overlap_coefficient(moving_feature,feature)))
#     plt.vlines(linelist,0,1)

#     fig2d = plt.figure()
#     plt.plot(feature[:200, 0], feature[:200, 1], alpha=0.8)
#     plt.plot(feature[0:1, 0], feature[0:1, 1], "o", alpha=1)
#     plt.show()
    
    return PFC_feature
    
    
if __name__ == '__main__':
    features = []
    for i in range(1):
        features.append(main(i+1))

#     fig2d = plt.figure()
#     for feature in features:
#         plt.plot(feature[:200, 0], feature[:200, 1], alpha=0.2)
#         plt.plot(feature[0:1, 0], feature[0:1, 1], "o", alpha=1)
#     plt.show()

In [None]:
#v4 without Re
#v4_1 revised forward function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import os
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = hiddens[0][0]
    HPC_input = torch.cat([input,Re],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.01):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)


class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        Re_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.RNNCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        sparse = 10
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size)*var, torch.rand(self.batch_size, self.hidden_size)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0
        var = 1
        HPC_hidden = [torch.ones(self.batch_size, self.hidden_size)*const, torch.ones(self.batch_size, self.hidden_size)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size)*var, torch.rand(self.batch_size, self.hidden_size)*var]
        Re_hidden = torch.ones(self.batch_size, self.hidden_size)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def hidden_randinsert(self,hidden):
        var = 1
        insert = [torch.rand(self.batch_size, self.hidden_size)*var, torch.rand(self.batch_size, self.hidden_size)*var]
#         insert = torch.rand(self.batch_size, self.hidden_size)*var
        return [hidden[0],insert,hidden[2]]
    
class MyLSTM_3lay(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_3lay, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        sparse = 0.1
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input,hiddens[2])
        PFC_input = hiddens[2][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        vHPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        vHPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    

def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq


    

def main(num):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
#     model_path = 'model/ReModel_interRNN_long131test_s'+str(num)+'_100_3.pth'
#     model_path = 'model/ReModel_interRNN_long131test_s1_100_3_plus_'+str(num)+'.pth'
    model_path = 'model/ReModel_interRNN_long131test_s1_100_1.pth'
    filename = "primal_long131test"
    
    if os.path.exists(model_path):
        print(model_path)
    else:
        print("Not exist")
        return
    
    train_x,train_y = mkOwnDataSet(training_size,filename,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    data = mkOwnRandomBatch(train_x, batch_size)
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
        


    
    
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 20
    for k in range(data_limit):
#             print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(torch.rand(10,2),hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    hidden = rnn.hidden_randinsert(hidden)
    for k in range(data.shape[0]*4):
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    fig = plt.figure()
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:100,0,0],traj[:100,0,1])
    plt.plot(traj[100:,0,0],traj[100:,0,1])
    plt.show()
    
    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")

#     for n, p in rnn.named_parameters():
#             if n == "PFC.weight_ih":
#                 PFC_w = np.array(p.data)
#             if n == "HPC.weight_ih":
#                 HPC_w = np.array(p.data)
#             if n == "Re.weight_ih":
#                 Re_w = np.array(p.data)

#     fig2 = plt.figure()
#     ax1 = fig2.add_subplot(131)
#     ax2 = fig2.add_subplot(132)
#     axre = fig2.add_subplot(133)
    
#     ax1.imshow(PFC_w,cmap="coolwarm")
#     ax2.imshow(HPC_w,cmap="coolwarm")
#     axre.imshow(Re_w,cmap="coolwarm")
#     ax1.set_title("max = {:.2f},min = {:.2f}".format(np.max(PFC_w),np.min(PFC_w)))
#     ax2.set_title("max = {:.2f},min = {:.2f}".format(np.max(HPC_w),np.min(HPC_w)))
#     axre.set_title("max = {:.2f},min = {:.2f}".format(np.max(Re_w),np.min(Re_w)))

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_hh":
                PFC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_w = np.array(p.data)
    
    fig2 = plt.figure()
    ax1 = fig2.add_subplot(131)
    ax2 = fig2.add_subplot(132)
    axre = fig2.add_subplot(133)
    
    ax1.imshow(PFC_w,cmap="coolwarm")
    ax2.imshow(HPC_w,cmap="coolwarm")
    axre.imshow(Re_w,cmap="coolwarm")
    ax1.set_title("max = {:.2f},min = {:.2f}".format(np.max(PFC_w),np.min(PFC_w)))
    ax2.set_title("max = {:.2f},min = {:.2f}".format(np.max(HPC_w),np.min(HPC_w)))
    axre.set_title("max = {:.2f},min = {:.2f}".format(np.max(Re_w),np.min(Re_w)))                
    
    
    pca = PCA()
    dfs = np.array(HPCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
    ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    ax3d.plot(feature[20:21, 0], feature[20:21, 1], feature[20:21, 2],"o", alpha=1)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(HPCstate)[:,0,i])
#         plt.show()

#     moving_feature = np.concatenate(([moving_average(feature[:, 0])], [moving_average(feature[:, 1])], [moving_average(feature[:, 2])]),axis=0).T
#     print(moving_feature.shape)
#     plot_distance(feature,moving_feature)
#     plt.figure()
#     plt.plot(overlap_coefficient(moving_feature,feature))
#     plt.plot(moving_average(overlap_coefficient(moving_feature,feature)))
    
    pca = PCA()
    dfs = np.array(PFCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
    ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    ax3d.plot(feature[20:21, 0], feature[20:21, 1], feature[20:21, 2],"o", alpha=1)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(PFCstate)[:,0,i])
#         plt.show()

#     moving_feature = np.concatenate(([moving_average(feature[:, 0])], [moving_average(feature[:, 1])], [moving_average(feature[:, 2])]),axis=0).T
#     print(moving_feature.shape)
#     plot_distance(feature,moving_feature)
#     plt.figure()
#     plt.plot(overlap_coefficient(moving_feature,feature))
#     plt.plot(moving_average(overlap_coefficient(moving_feature,feature)))

    pca = PCA()
    dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
    ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    ax3d.plot(feature[20:21, 0], feature[20:21, 1], feature[20:21, 2],"o", alpha=1)
    ax3d.plot(moving_average(feature[:, 0]), moving_average(feature[:, 1]), moving_average(feature[:, 2]), alpha=0.8)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(Restate)[:,i])
#         plt.show()
    
#     moving_feature = np.concatenate(([moving_average(feature[:, 0])], [moving_average(feature[:, 1])], [moving_average(feature[:, 2])]),axis=0).T
#     print(moving_feature.shape)
#     plot_distance(feature,moving_feature)
    
#     plt.figure()
#     plt.plot(overlap_coefficient(moving_feature,feature))
#     plt.plot(moving_average(overlap_coefficient(moving_feature,feature)))
    
    
if __name__ == '__main__':
    for i in range(10):
        main(i+1)

In [None]:

def main(num):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 200
    inputsize = 2
    outputsize = 2
#     model_path = 'model/ReModel_mixRNN_long_s'+str(num)+'_200_1.pth'
#     model_path = 'model/ReModel_interRNN_Yl_s5_200_5_2000.pth'
    filename = "primal_Y_long"
    print(model_path)
    
    train_x,train_y = mkOwnDataSet(training_size,filename,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    data = mkOwnRandomBatch(train_x, batch_size)
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)

     
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden()
    data_limit = 200
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*1):
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))

    model_path = 'model/ReModel_interRNN_long_s5_200_1.pth'
    filename = "primal_long"
    print(model_path)
    
    train_x,train_y = mkOwnDataSet(training_size,filename,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    data = mkOwnRandomBatch(train_x, batch_size)
    hidden = rnn.initHidden()
            
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*1):
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))


    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
    
    pca = PCA()
    dfs = np.array(HPCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:200, 0], feature[:200, 1], feature[:200, 2], alpha=0.8)
    plt.show()
    ax3d.plot(feature[200:, 0], feature[200:, 1], feature[200:, 2], alpha=0.8)
    plt.show()
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(HPCstate)[:,0,i])
#         plt.show()
    
    pca = PCA()
    dfs = np.array(PFCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    print(feature[200:])
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:200, 0], feature[:200, 1], feature[:200, 2], alpha=0.8)
    plt.show()
    ax3d.plot(feature[200:, 0], feature[200:, 1], feature[200:, 2], alpha=0.8)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(PFCstate)[:,0,i])
#         plt.show()

    pca = PCA()
    dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:200, 0], feature[:200, 1], feature[:200, 2], alpha=0.8)
    plt.show()
    ax3d.plot(feature[200:, 0], feature[200:, 1], feature[200:, 2], alpha=0.8)
    plt.show()
#     ax3d.plot(moving_average(feature[:, 0]), moving_average(feature[:, 1]), moving_average(feature[:, 2]), alpha=0.8)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(Restate)[:,i])
#         plt.show()
    
#     moving_feature = np.concatenate(([moving_average(feature[:, 0])], [moving_average(feature[:, 1])], [moving_average(feature[:, 2])]),axis=0).T
#     print(moving_feature.shape)
#     plot_distance(feature,moving_feature)
    
#     plt.figure()
#     plt.plot(overlap_coefficient(moving_feature,feature))
#     plt.plot(moving_average(overlap_coefficient(moving_feature,feature)))
    
    
if __name__ == '__main__':
    for i in range(1):
        main(i+1)

In [None]:
def check_distance(list_a,list_b):
    result = []
    for point in list_b:
        dis_list = []
        for target in list_a:
            dis_list.append(np.linalg.norm(point[:3] - target[:3]))
        result.append(np.min(dis_list))
    return np.array(result)

def overlap_coefficient(list_a,list_b):
    threshold = 0.025
    #Aの各点からBにある任意の点の最小距離を計算
    distance = check_distance(list_a,list_b)
    
    #集合Aの要素数を取得
    num_listA = len(list_a)
    return distance

def plot_distance(data1,data2):
    r = data1
    l = data2
    
    overlap=overlap_coefficient(r,l)

    plt.figure()
#     plt.ylim(0,1)
    plt.plot(overlap)

In [None]:

def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    model_path = 'model/ReModel_interRNN_long_s9_200_1.pth'

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    for n, p in rnn.named_parameters():
        # if n == "PFC.weight_ih":
        #     p.data[hidden_size:,hidden_size:].sub_(p.data[hidden_size:,hidden_size:])
        if n == "HPC.weight_ih":
            p.data[:,inputsize:].sub_(p.data[:,inputsize:])
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
     
    xnum = 100
    ynum = 100
    initx = 0.2
    endx = 0.8
    widex = (endx-initx)/xnum
    
    inity = 0.0
    endy = 0.6
    widey = (endy-inity)/ynum
    
    place = np.zeros((hidden_size,xnum+1,ynum+1))
    grid = [[] for i in range(hidden_size)]
    data_limit = 20
    for y in range(ynum+1):
        for x in range(xnum+1):
            hidden = rnn.initHidden()
            point = torch.tensor([[initx+widex*x,widey*y] for i in range(batch_size)])
            for k in range(data_limit):
                    #print(data[k].shape)
                    output,hidden = rnn(point,hidden)
            for n in range(hidden_size):
                place[n][y][x] = hidden[1][0].tolist()[0][n]
                grid[n].append([initx+widex*x,widey*y,hidden[1][0].tolist()[0][n]])
                
    grid = np.array(grid)
    for i in range(hidden_size):
        fig = plt.figure()
        plt.imshow(place[i],cmap='jet',origin='lower',interpolation='bilinear', vmin=0)
        plt.show()
        
#         fig = plt.figure()
#         ax = fig.add_subplot(111, projection='3d')
#         ax.scatter(grid[i,:,0],grid[i,:,1],grid[i,:,2])
#         plt.show()
        
        fig = plt.figure()
        x = np.arange(initx, endx+widex, widex)
        y = np.arange(inity, endy+widey, widey)
        x, y = np.meshgrid(x, y)
        z = place[i]
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_surface(x, y, z, cmap='jet')
        plt.show()
    
    #MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
    
if __name__ == '__main__':
    main()

In [None]:
#init ari


def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t],cmap='jet',origin='lower',interpolation='bilinear')
        ims.append([img])
    ani = animation.ArtistAnimation(fig, ims, interval=50)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0


def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    model_path = 'model/ReModel_interRNN_long_s5_200_1.pth'

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
#     for n, p in rnn.named_parameters():
#         # if n == "PFC.weight_ih":
#         #     p.data[hidden_size:,hidden_size:].sub_(p.data[hidden_size:,hidden_size:])
#         if n == "HPC.weight_ih":
#             p.data[hidden_size:,inputsize:].sub_(p.data[hidden_size:,inputsize:])
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
     
    xnum = 20
    ynum = 20
    initx = 0
    endx = 1
    widex = (endx-initx)/xnum
    
    inity = 0.0
    endy = 0.6
    widey = (endy-inity)/ynum
    
    data_limit = 100
    place = np.zeros((data_limit,hidden_size,xnum+1,ynum+1))
    grid = [[] for i in range(hidden_size)]
    
    x_list = np.array(range(xnum+1))
    y_list = np.array(range(ynum+1))
    np.random.shuffle(y_list)
    for y in y_list:
        np.random.shuffle(x_list)
        for x in x_list:
            hidden = rnn.initHidden()
            point = torch.tensor([[initx+widex*x,widey*y] for i in range(batch_size)])
            for k in range(data_limit):
                    #print(data[k].shape)
                    output,hidden = rnn(point,hidden)
                    for n in range(hidden_size):
                        place[k][n][y][x] = hidden[1][0].tolist()[0][n]

    for i in range(hidden_size):
        fig = plt.figure()
        plt.imshow(place[0][i],cmap='jet',origin='lower',interpolation='bilinear')
        plt.show()
        
        MakeAnimation_img(place[:,i],"test!!!"+str(i))
        
    
if __name__ == '__main__':
    main()

In [None]:
#init nashi

def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    model_path = 'model/ReModel_interRNN_long_s5_200_1.pth'

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    for n, p in rnn.named_parameters():
        if n == "HPC.weight_ih":
            p.data[:,inputsize:].sub_(p.data[:,inputsize:])
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
     
    xnum = 50
    ynum = 50
    initx = 0.2
    endx = 0.8
    widex = (endx-initx)/xnum
    
    inity = 0.0
    endy = 0.6
    widey = (endy-inity)/ynum
    
    place = np.zeros((hidden_size,xnum+1,ynum+1))
    grid = [[] for i in range(hidden_size)]
    hidden = rnn.initHidden()
    
    y_list = np.array(range(ynum+1))
    np.random.shuffle(y_list)
    for y in y_list:
        x_list = np.array(range(xnum+1))
        np.random.shuffle(x_list)
        for x in x_list:
            #if y % 2 == 0: x = xnum - x
            point = torch.tensor([[initx+widex*x,widey*y] for i in range(batch_size)])
            output,hidden = rnn(point,hidden)
            for n in range(hidden_size):
                place[n][y][x] = hidden[1][0].tolist()[0][n]
                grid[n].append([initx+widex*x,widey*y,hidden[1][0].tolist()[0][n]])
                
    grid = np.array(grid)
    for i in range(hidden_size):
        fig = plt.figure()
        plt.imshow(place[i],cmap='jet',origin='lower',interpolation='bilinear')
        plt.show()
        
#         fig = plt.figure()
#         ax = fig.add_subplot(111, projection='3d')
#         ax.scatter(grid[i,:,0],grid[i,:,1],grid[i,:,2])
#         plt.show()
        
        fig = plt.figure()
        x = np.arange(initx, endx+widex, widex)
        y = np.arange(inity, endy+widey, widey)
        x, y = np.meshgrid(x, y)
        z = place[i]
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_surface(x, y, z, cmap='jet')
        plt.show()
    
    #MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
    
if __name__ == '__main__':
    main()

In [None]:
#v4 without Re
#v4_1 revised forward function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import networkx as nx
from networkx.algorithms import bipartite

%matplotlib notebook

def mkOwnDataSet(data_size, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt("primal_long_r.csv",delimiter=',')
    y = np.loadtxt("primal_long_l.csv",delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 200
    inputsize = 2
    outputsize = 2
    model_path = 'model/ReModel_interRNN_long131test_s1_100_3.pth'

    train_x,train_y = mkOwnDataSet(training_size,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    data = mkOwnRandomBatch(train_x, batch_size)
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
     

    target = [40,60]
    
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)[target[0]:target[1],0:20]
                PFC_w /= np.max(PFC_w)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)[target[0]:target[1],inputsize:]
                HPC_w /= np.max(HPC_w)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)[0:20,0:20]
                Re_w /= np.max(Re_w)
                
    fig = plt.figure()
    G = nx.from_numpy_matrix(np.matrix(PFC_w), create_using=nx.DiGraph)
    layout = nx.spring_layout(G)
    nx.draw(G, layout)
    plt.show()
    
    fig2 = plt.figure()
    ax1 = fig2.add_subplot(121)
    ax2 = fig2.add_subplot(122)
    
    ax1.imshow(PFC_w,cmap="coolwarm")
    ax2.imshow(HPC_w,cmap="coolwarm")
    ax1.set_title("max = {:.2f},min = {:.2f}".format(np.max(PFC_w),np.min(PFC_w)))
    ax2.set_title("max = {:.2f},min = {:.2f}".format(np.max(HPC_w),np.min(HPC_w)))
    
    plt.figure()
    A = nx.Graph()
    PFCnode = [("PFC"+str(i)) for i in range(20)]
    HPCnode = [("HPC"+str(i)) for i in range(20)]
    Renode = [("Re"+str(i)) for i in range(20)]
    A.add_nodes_from(PFCnode, bipartite=0)
    A.add_nodes_from(HPCnode, bipartite=1)
    A.add_nodes_from(Renode, bipartite=2)
    for i in range(20):
        for k in range(20):
            if PFC_w[i,k]>0:
                A.add_edge(PFCnode[i],HPCnode[k],weight=PFC_w[i,k])
            if Re_w[i,k]>0:
                A.add_edge(Renode[i],PFCnode[k],weight=Re_w[i,k])
    pos = {}
    pos.update((node, (2, index)) for index, node in enumerate(PFCnode))
    pos.update((node, (1, index)) for index, node in enumerate(HPCnode))
    pos.update((node, (3, index)) for index, node in enumerate(Renode))
    edges,weights = zip(*nx.get_edge_attributes(A,'weight').items())
    measures = nx.eigenvector_centrality(A,weight=True)
    nx.draw(A, pos=pos, nodelist=list(measures.keys()), node_color=list(measures.values()),  edgelist=edges, edge_color=weights, width=1.0, edge_cmap=plt.cm.Reds, with_labels=True)
    plt.show()
    
    print(np.count_nonzero(PFC_w>0.3, axis=0),np.count_nonzero(HPC_w>0.3, axis=0),np.count_nonzero(Re_w>0.3, axis=0))
    
    plt.figure()
    A = nx.Graph()
    PFCnode = [("PFC"+str(i)) for i in range(20)]
    HPCnode = [("HPC"+str(i)) for i in range(20)]
    PFC2node = [("Re"+str(i)) for i in range(20)]
    A.add_nodes_from(PFCnode, bipartite=0)
    A.add_nodes_from(HPCnode, bipartite=1)
    A.add_nodes_from(Renode, bipartite=2)
    for i in range(20):
        for k in range(20):
            if HPC_w[i,k]>0.5:
                A.add_edge(HPCnode[i],PFCnode[k],weight=HPC_w[i,k])
            if PFC_w[i,k]>0.5:
                A.add_edge(PFC2node[i],HPCnode[k],weight=PFC_w[i,k])

    pos = {}
    pos.update((node, (1, index)) for index, node in enumerate(PFCnode))
    pos.update((node, (2, index)) for index, node in enumerate(HPCnode))
    pos.update((node, (3, index)) for index, node in enumerate(PFC2node))
    edges,weights = zip(*nx.get_edge_attributes(A,'weight').items())
    measures = nx.eigenvector_centrality(A,weight=True)
    nx.draw(A, pos=pos, nodelist=list(measures.keys()), node_color=list(measures.values()),  edgelist=edges, edge_color=weights, width=1.0, edge_cmap=plt.cm.Reds, with_labels=True)
    plt.show()
    
    plt.figure()
    B = nx.Graph()
    PFCnode = [("PFC"+str(i)) for i in range(20)]
    HPCnode = [("HPC"+str(i)) for i in range(20)]
    B.add_nodes_from(PFCnode, bipartite=0)
    B.add_nodes_from(HPCnode, bipartite=1)
    for i in range(20):
        for k in range(20):
            if PFC_w[i,k]<0.8:
                B.add_edge(PFCnode[i],HPCnode[k],weight=0)
            else:
                B.add_edge(PFCnode[i],HPCnode[k],weight=PFC_w[i,k])
    bottom_nodes, top_nodes = bipartite.sets(B)
    pos = {}
    pos.update((node, (2, index)) for index, node in enumerate(bottom_nodes))
    pos.update((node, (1, index)) for index, node in enumerate(top_nodes))
    print(pos)
    edges,weights = zip(*nx.get_edge_attributes(B,'weight').items())
    nx.draw(B, pos=pos, node_color='b', edgelist=edges, edge_color=weights, width=1.0, edge_cmap=plt.cm.Reds)
    plt.show()
    
if __name__ == '__main__':
    main()

In [None]:
#v4 without Re
#v4_1 revised forward function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import networkx as nx
from networkx.algorithms import bipartite

%matplotlib notebook

def mkOwnDataSet(data_size, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt("primal0_r.csv",delimiter=',')
    y = np.loadtxt("primal0_l.csv",delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    model_path = 'model/v4_2Model_0_3.pth'

    train_x,train_y = mkOwnDataSet(training_size,data_length)

    rnn = MyLSTM(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    data = mkOwnRandomBatch(train_x, batch_size)
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
     
    traj = []
    PFCstate = []
    HPCstate = []
    Gate_states = []
    hidden = rnn.initHidden()
    data_limit = 100
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*3):
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            #print(output)
            Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    fig = plt.figure()
    print(pltdata.shape)
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:,0,0],traj[:,0,1])
    plt.show()
    
    
    print(np.array(PFCstate)[:,0].shape)
    #MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)[0:20,:]
                PFC_w /= np.max(PFC_w)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)[0:20,inputsize:]
                HPC_w /= np.max(HPC_w)
    
    plt.figure()
    A = nx.Graph()
    HPCnode = [("HPC"+str(i)) for i in range(20)]
    A.add_nodes_from(HPCnode, bipartite=0)
    pos = {}
    pos.update((node, (1, index)) for index, node in enumerate(HPCnode))
    for t in range(1,10):
        if t % 2 == 1:
            PFCnode = [("PFC"+str((t//2)*20+i)) for i in range(20)]
            A.add_nodes_from(PFCnode, bipartite=t+1)
            pos.update((node, (t+1, index)) for index, node in enumerate(PFCnode))
            PFC = True
            HPC = False
        else:
            HPCnode = [("HPC"+str((t//2)*20+i)) for i in range(20)]
            A.add_nodes_from(HPCnode, bipartite=t+1)
            pos.update((node, (t+1, index)) for index, node in enumerate(HPCnode))
            HPC = True
            PFC = False
        for i in range(20):
            for k in range(20):
                if PFC and PFC_w[i,k]>0.5:
                    A.add_edge(PFCnode[i],HPCnode[k],weight=PFC_w[i,k])
                if HPC and HPC_w[i,k]>0.5:
                    A.add_edge(HPCnode[i],PFCnode[k],weight=HPC_w[i,k])
    edges,weights = zip(*nx.get_edge_attributes(A,'weight').items())
    measures = nx.degree_centrality(A)
    nx.draw(A, pos=pos, nodelist=list(measures.keys()), node_color=list(measures.values()),  edgelist=edges, edge_color=weights, edge_cmap=plt.cm.Reds, with_labels=True)
    plt.show()
    
if __name__ == '__main__':
    main()

In [None]:
#v4 without Re
#v4_1 revised forward function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import networkx as nx
from networkx.algorithms import bipartite

%matplotlib notebook

def mkOwnDataSet(data_size, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt("primal0_r.csv",delimiter=',')
    y = np.loadtxt("primal0_l.csv",delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    model_path = 'model/ReModel_interRNN_long131test_s1_100_3.pth'

    train_x,train_y = mkOwnDataSet(training_size,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    data = mkOwnRandomBatch(train_x, batch_size)
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)[0:20,:]
                PFC_w /= np.max(PFC_w)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)[0:20,inputsize:]
                HPC_w /= np.max(HPC_w)
    
    plt.figure()
    A = nx.Graph()
    HPCnode = [("HPC"+str(i)) for i in range(20)]
    pos = {}
    pos.update((node, (1, index)) for index, node in enumerate(HPCnode))
    for t in range(1,10):
        if t % 2 == 1:
            PFCnode = [("PFC"+str((t//2)*20+i)) for i in range(20)]
            PFC = True
            HPC = False
        else:
            HPCnode = [("HPC"+str((t//2)*20+i)) for i in range(20)]
            HPC = True
            PFC = False
        for i in range(20):
            for k in range(20):
                if PFC and PFC_w[i,k]>0.5:
                    A.add_edge(PFCnode[i],HPCnode[k],weight=PFC_w[i,k])
                    pos.update({PFCnode[i]: (t+1, i)})
                    pos.update({HPCnode[k]: (t, k)})
                if HPC and HPC_w[i,k]>0.5:
                    A.add_edge(HPCnode[i],PFCnode[k],weight=HPC_w[i,k])
                    pos.update({HPCnode[i]: (t+1, i)})
                    pos.update({PFCnode[k]: (t, k)})
    
    edges,weights = zip(*nx.get_edge_attributes(A,'weight').items())
    measures = nx.degree_centrality(A)
    nx.draw(A, pos=pos, nodelist=list(measures.keys()), node_color=list(measures.values()),  edgelist=edges, edge_color=weights, edge_cmap=plt.cm.Reds, with_labels=True)
    plt.show()
    
    
    plt.figure()
    A = nx.Graph()
    HPCnode = [("HPC"+str(i)) for i in range(20)]
    pos = {}
    pos.update((node, (1, index)) for index, node in enumerate(HPCnode))
    initnode = 12
    nodelinks = [HPCnode[initnode]]
    for t in range(1,10):
        if t % 2 == 1:
            PFCnode = [("PFC"+str((t//2)*20+i)) for i in range(20)]
            PFC = True
            HPC = False
        else:
            HPCnode = [("HPC"+str((t//2)*20+i)) for i in range(20)]
            HPC = True
            PFC = False
        for i in range(20):
            for k in range(20):
                if PFC and PFC_w[i,k]>0.5 and HPCnode[k] in nodelinks:
                    A.add_edge(PFCnode[i],HPCnode[k],weight=PFC_w[i,k])
                    pos.update({PFCnode[i]: (t+1, i)})
                    pos.update({HPCnode[k]: (t, k)})
                    nodelinks.append(PFCnode[i])
                if HPC and HPC_w[i,k]>0.5 and PFCnode[k] in nodelinks:
                    A.add_edge(HPCnode[i],PFCnode[k],weight=HPC_w[i,k])
                    pos.update({HPCnode[i]: (t+1, i)})
                    pos.update({PFCnode[k]: (t, k)})
                    nodelinks.append(HPCnode[i])

    edges,weights = zip(*nx.get_edge_attributes(A,'weight').items())
    measures = nx.degree_centrality(A)
    nx.draw(A, pos=pos, nodelist=list(measures.keys()), node_color=list(measures.values()),  edgelist=edges, edge_color=weights, edge_cmap=plt.cm.Reds, with_labels=True)
    plt.show()
    
if __name__ == '__main__':
    main()

In [None]:
#v4 without Re
#v4_1 revised forward function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import networkx as nx
from networkx.algorithms import bipartite

%matplotlib notebook

def mkOwnDataSet(data_size, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt("primal0_r.csv",delimiter=',')
    y = np.loadtxt("primal0_l.csv",delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    model_path = 'model/v4_2Model_0_3.pth'

    train_x,train_y = mkOwnDataSet(training_size,data_length)

    rnn = MyLSTM(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    data = mkOwnRandomBatch(train_x, batch_size)
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
     
    traj = []
    PFCstate = []
    HPCstate = []
    Gate_states = []
    hidden = rnn.initHidden()
    data_limit = 100
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*3):
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            #print(output)
            Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    fig = plt.figure()
    print(pltdata.shape)
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:,0,0],traj[:,0,1])
    plt.show()
    
    
    print(np.array(PFCstate)[:,0].shape)
    #MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)[0:20,:]
                PFC_w /= np.max(PFC_w)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)[0:20,inputsize:]
                HPC_w /= np.max(HPC_w)
    
    plt.figure()
    A = nx.DiGraph()
    HPCnode = [("HPC"+str(i)) for i in range(20)]
    A.add_nodes_from(HPCnode, bipartite=0)
    pos = {}
    pos.update((node, (1, index)) for index, node in enumerate(HPCnode))
    for t in range(1,10):
        if t % 2 == 1:
            PFCnode = [("PFC"+str((t//2)*20+i)) for i in range(20)]
            A.add_nodes_from(PFCnode, bipartite=t+1)
            pos.update((node, (t+1, index)) for index, node in enumerate(PFCnode))
            PFC = True
            HPC = False
        else:
            HPCnode = [("HPC"+str((t//2)*20+i)) for i in range(20)]
            A.add_nodes_from(HPCnode, bipartite=t+1)
            pos.update((node, (t+1, index)) for index, node in enumerate(HPCnode))
            HPC = True
            PFC = False
        for i in range(20):
            for k in range(20):
                if PFC and PFC_w[i,k]>0.5:
                    A.add_edge(HPCnode[k],PFCnode[i],weight=PFC_w[i,k])
                if HPC and HPC_w[i,k]>0.5:
                    A.add_edge(PFCnode[k],HPCnode[i],weight=HPC_w[i,k])
    sourcenode = 'HPC12'
    targetnode = PFCnode[14]
#     links = list(nx.shortest_path(A,target=targetnode).keys())
    links = list(nx.shortest_path(A,source=sourcenode).keys())
    for node in list(A.nodes(data=False)):
        if node not in links:
            A.remove_node(node)
            pos.pop(node)
    
    
    edges,weights = zip(*nx.get_edge_attributes(A,'weight').items())
    measures = nx.degree_centrality(A)
    nx.draw(A, pos=pos, nodelist=list(measures.keys()), node_color=list(measures.values()),  edgelist=edges, edge_color=weights, edge_cmap=plt.cm.Reds, with_labels=True)
    plt.show()
    
if __name__ == '__main__':
    main()

In [None]:
import glob

# model_list = glob.glob('model/ReModel_interRNN_long131test_s5_100_4.pth')
model_list = glob.glob('model/R20_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s1_100_1_epoch150.pth')
print(model_list)

training_size = 100
test_size = 1000
epochs_num = 10
hidden_size = 20
batch_size = 10
data_length = 100
inputsize = 2
outputsize = 2

fig_weight = plt.figure(figsize=(10,5))
fig1 = fig_weight.add_subplot(131)
fig2 = fig_weight.add_subplot(132)
fig3 = fig_weight.add_subplot(133)

# target = [2,4]
# target_list = []

# for i in target:
#     target_list.append(model_list[i])

target_list = model_list

for model_path in target_list:
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
            if n == "PFC.bias_ih":
                PFC_b = np.array(p.data)
            if n == "PFC.bias_hh":
                PFC_inb = np.array(p.data)
            if n == "HPC.bias_ih":
                HPC_b = np.array(p.data)
            if n == "HPC.bias_hh":
                HPC_inb = np.array(p.data)
            if n == "Re.bias_ih":
                Re_b = np.array(p.data)
            if n == "Re.bias_hh":
                Re_inb = np.array(p.data)

    fig1.hist(PFC_w[PFC_w.nonzero()],bins=40,range=(-4,4),alpha=0.5,density=True)
    fig1.hist(PFC_inw[PFC_inw.nonzero()],bins=40,range=(-4,4),alpha=0.5,density=True)
    fig2.hist(HPC_w[HPC_w.nonzero()],bins=40,range=(-4,4),alpha=0.5,density=True)  
    fig2.hist(HPC_inw[HPC_inw.nonzero()],bins=40,range=(-4,4),alpha=0.5,density=True)  
    fig3.hist(Re_w[Re_w.nonzero()],bins=40,range=(-4,4),alpha=0.5,density=True) 
    fig3.hist(Re_inw[Re_inw.nonzero()],bins=40,range=(-4,4),alpha=0.5,density=True) 

    
#     fig1.hist(PFC_b[PFC_b.nonzero()],bins=20,alpha=0.5,density=False)
#     fig1.hist(PFC_inb[PFC_inb.nonzero()],bins=10,alpha=0.5,density=False)
#     fig2.hist(HPC_b[HPC_b.nonzero()],bins=20,alpha=0.5,density=False)  
#     fig2.hist(HPC_inb[HPC_inb.nonzero()],bins=10,alpha=0.5,density=False)  
#     fig3.hist(Re_b[Re_b.nonzero()],bins=20,alpha=0.5,density=False) 
#     fig3.hist(Re_inb[Re_inb.nonzero()],bins=10,alpha=0.5,density=False) 

In [None]:
import glob
from sklearn import linear_model

def make_bins(data):
    bins = np.array([])
    for i in range(len(data)-1):
        bins = np.append(bins,(data[i]+data[i+1])/2)
    return bins

def regression(datax,datay):
    clf = linear_model.LinearRegression()
    x = datax[datay!=-float("inf")].reshape(-1,1)
    y = datay[datay!=-float("inf")].reshape(-1,1)
    clf.fit(x,y)
    coef = clf.coef_
    score = clf.score(x,y)
    print(coef,score)
    return coef[0][0]

# model_list = glob.glob('model/ReModel_interRNN_long_s5_200_1.pth')
model_list = glob.glob('model/R20_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s1_100_1_epoch190.pth')

print(model_list)

training_size = 100
test_size = 1000
epochs_num = 10
hidden_size = 20
batch_size = 10
data_length = 100
inputsize = 2
outputsize = 2

fig_weight = plt.figure(figsize=(10,5))
fig1 = fig_weight.add_subplot(131)
fig2 = fig_weight.add_subplot(132)
fig3 = fig_weight.add_subplot(133)

# target = [2,4,7,8,9]
# target_list = []

# for i in target:
#     target_list.append(model_list[i])

target_list = model_list
PCA_list = np.array([[0,1,2,3,4,5]])

for model_path in target_list:
    print(model_path)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
            if n == "PFC.bias_ih":
                PFC_b = np.array(p.data)
            if n == "PFC.bias_hh":
                PFC_inb = np.array(p.data)
            if n == "HPC.bias_ih":
                HPC_b = np.array(p.data)
            if n == "HPC.bias_hh":
                HPC_inb = np.array(p.data)
            if n == "Re.bias_ih":
                Re_b = np.array(p.data)
            if n == "Re.bias_hh":
                Re_inb = np.array(p.data)
    
#     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w>0]),bins=80,density=True)
    PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw>0]),bins=40,density=True)
#     fig1.plot((make_bins(PFC_w_hist[1])),np.log(PFC_w_hist[0]))
    fig1.plot((make_bins(PFC_inw_hist[1])),np.log(PFC_inw_hist[0]))
#     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w<0]),bins=80,density=True)
    PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw<0]),bins=40,density=True)
#     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w.nonzero()]),bins=80,density=True)
#     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw.nonzero()]),bins=40,density=True)
#     PFC_b_hist = np.histogram(np.abs(PFC_b[PFC_b.nonzero()]),bins=80,density=True)
#     PFC_inb_hist = np.histogram(np.abs(PFC_inb[PFC_inb.nonzero()]),bins=40,density=True)
#     fig1.plot((make_bins(PFC_w_hist[1])),np.log(PFC_w_hist[0]))
    fig1.plot((make_bins(PFC_inw_hist[1])),np.log(PFC_inw_hist[0]))
    a1 = regression(make_bins(PFC_w_hist[1]),np.log(PFC_w_hist[0]))
    a2 = regression(make_bins(PFC_inw_hist[1]),np.log(PFC_inw_hist[0]))

#     HPC_w_hist = np.histogram(HPC_w[HPC_w>0],bins=80,density=True)
    HPC_inw_hist = np.histogram(HPC_inw[HPC_inw>0],bins=40,density=True)
#     fig2.plot((make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]))
    fig2.plot((make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]))
#     HPC_w_hist = np.histogram(np.abs(HPC_w[HPC_w<0]),bins=80,density=True)
    HPC_inw_hist = np.histogram(np.abs(HPC_inw[HPC_inw<0]),bins=40,density=True)
#     HPC_w_hist = np.histogram(np.abs(HPC_w[HPC_w.nonzero()]),bins=80,density=True)
#     HPC_inw_hist = np.histogram(np.abs(HPC_inw[HPC_inw.nonzero()]),bins=40,density=True)
    fig2.plot((make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]))
#     fig2.plot((make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]))
    b1 = regression(make_bins(HPC_w_hist[1]),np.log(HPC_w_hist[0]))
    b2 = regression(make_bins(HPC_inw_hist[1]),np.log(HPC_inw_hist[0]))


#     Re_w_hist = np.histogram(Re_w[Re_w>0],bins=100,density=True)
#     Re_inw_hist = np.histogram(Re_inw[Re_inw>0],bins=100,density=True)
#     Re_w_hist = np.histogram(np.abs(Re_w[Re_w<0]),bins=100,density=True)
#     Re_inw_hist = np.histogram(np.abs(Re_inw[Re_inw<0]),bins=100,density=True)
    Re_w_hist = np.histogram(np.abs(Re_w[Re_w.nonzero()]),bins=20,density=True)
    Re_inw_hist = np.histogram(np.abs(Re_inw[Re_inw.nonzero()]),bins=10,density=True)
    fig3.plot((make_bins(Re_w_hist[1])),np.log(Re_w_hist[0]))
#     fig3.plot((make_bins(Re_inw_hist[1])),np.log(Re_inw_hist[0]))
    c1 = regression(make_bins(Re_w_hist[1]),np.log(Re_w_hist[0]))
    c2 = regression(make_bins(Re_inw_hist[1]),np.log(Re_inw_hist[0]))
    
    PCA_list = np.concatenate((PCA_list,[[a1,a2,b1,b2,c1,c2]]),axis=0)
    
#     fig1.hist(PFC_w[PFC_w>0],bins=40,alpha=0.5,density=True)
#     fig1.hist(np.abs(PFC_w[PFC_w<0]),bins=40,alpha=0.5,density=True)
#     fig2.hist(HPC_w[HPC_w>0],bins=40,alpha=0.5,log=True,density=True)
#     fig2.hist(np.abs(HPC_w[HPC_w<0]),bins=40,alpha=0.5,log=True,density=True)
#     fig3.hist(Re_w[Re_w>0],bins=40,alpha=0.5,log=True,density=True)
#     fig3.hist(np.abs(Re_w[Re_w<0]),bins=40,alpha=0.5,log=True,density=True)

# # print(PCA_list[1:])
# pca = PCA()
# dfs = PCA_list[1:]
# pca.fit(dfs)
# feature = pca.transform(dfs)

# plt.figure(figsize=(6, 6))
# plt.scatter(feature[:, 0], feature[:, 1], alpha=0.8)
# #plt.scatter(feature[100:200, 0], feature[100:200, 1], alpha=0.8)
# plt.grid()
# plt.xlabel("PC1")
# plt.ylabel("PC2")
# for i,n in enumerate(target_list):
#     plt.annotate(n[-12:],(feature[i, 0], feature[i, 1]))
# plt.show()

# fig3d = plt.figure()
# ax3d = Axes3D(fig3d)
# ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], "o", alpha=0.8)
# plt.show()

In [None]:
import glob
import pandas as pd
from sklearn import linear_model

def make_bins(data):
    bins = np.array([])
    for i in range(len(data)-1):
        bins = np.append(bins,(data[i]+data[i+1])/2)
    return bins

def regression(datax,datay):
    clf = linear_model.LinearRegression()
    x = datax[datay!=-float("inf")].reshape(-1,1)
    y = datay[datay!=-float("inf")].reshape(-1,1)
    clf.fit(x,y)
    coef = clf.coef_
    score = clf.score(x,y)
#     print(coef,score)
    return coef[0][0]

model_list = glob.glob('model/ReModel_interRNN_long_s*_200_?.pth')
print(model_list)

training_size = 100
test_size = 1000
epochs_num = 10
hidden_size = 20
batch_size = 10
data_length = 100
inputsize = 2
outputsize = 2

# NGlist_name = np.array(["s3_100_4","s1_100_3","s9_100_2","s4_100_1"])
NGlist = np.array([])

for NG in NGlist_name:
    for i, n in enumerate(model_list):
        if n[-12:-4] == NG:
            NGlist = np.append(NGlist,int(i))
print(NGlist)
    
# target = [2,4,7,8,9]
# target_list = []

# for i in target:
#     target_list.append(model_list[i])

target_list = model_list
PCA_list = np.array([np.ones(4)])

for model_path in target_list:
    print(model_path)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
            if n == "PFC.bias_ih":
                PFC_b = np.array(p.data)
            if n == "PFC.bias_hh":
                PFC_inb = np.array(p.data)
            if n == "HPC.bias_ih":
                HPC_b = np.array(p.data)
            if n == "HPC.bias_hh":
                HPC_inb = np.array(p.data)
            if n == "Re.bias_ih":
                Re_b = np.array(p.data)
            if n == "Re.bias_hh":
                Re_inb = np.array(p.data)
    
    PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w>0]),bins=40,density=True)
    PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw>0]),bins=20,density=True)
    a1 = regression(make_bins(PFC_w_hist[1]),np.log(PFC_w_hist[0]))
    a2 = regression(make_bins(PFC_inw_hist[1]),np.log(PFC_inw_hist[0]))
    PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w<0]),bins=40,density=True)
    PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw<0]),bins=20,density=True)
    a3 = regression(make_bins(PFC_w_hist[1]),np.log(PFC_w_hist[0]))
    a4 = regression(make_bins(PFC_inw_hist[1]),np.log(PFC_inw_hist[0]))

    HPC_w_hist = np.histogram(HPC_w[HPC_w>0],bins=40,density=True)
    HPC_inw_hist = np.histogram(HPC_inw[HPC_inw>0],bins=20,density=True)
    b1 = regression(make_bins(HPC_w_hist[1]),np.log(HPC_w_hist[0]))
    b2 = regression(make_bins(HPC_inw_hist[1]),np.log(HPC_inw_hist[0]))
    HPC_w_hist = np.histogram(np.abs(HPC_w[HPC_w<0]),bins=40,density=True)
    HPC_inw_hist = np.histogram(np.abs(HPC_inw[HPC_inw<0]),bins=20,density=True)
    b3 = regression(make_bins(HPC_w_hist[1]),np.log(HPC_w_hist[0]))
    b4 = regression(make_bins(HPC_inw_hist[1]),np.log(HPC_inw_hist[0]))


    Re_w_hist = np.histogram(Re_w[Re_w>0],bins=40,density=True)
    Re_inw_hist = np.histogram(Re_inw[Re_inw>0],bins=20,density=True)
    c1 = regression(make_bins(Re_w_hist[1]),np.log(Re_w_hist[0]))
    c2 = regression(make_bins(Re_inw_hist[1]),np.log(Re_inw_hist[0]))
    Re_w_hist = np.histogram(np.abs(Re_w[Re_w<0]),bins=40,density=True)
    Re_inw_hist = np.histogram(np.abs(Re_inw[Re_inw<0]),bins=20,density=True)
    c3 = regression(make_bins(Re_w_hist[1]),np.log(Re_w_hist[0]))
    c4 = regression(make_bins(Re_inw_hist[1]),np.log(Re_inw_hist[0]))
    
#     PFC_b_hist = np.histogram(np.abs(PFC_b[PFC_b>0]),bins=20,density=True)
#     PFC_inb_hist = np.histogram(np.abs(PFC_inb[PFC_inb>0]),bins=10,density=True)
#     a1 = regression(make_bins(PFC_b_hist[1]),np.log(PFC_b_hist[0]))
#     a2 = regression(make_bins(PFC_inb_hist[1]),np.log(PFC_inb_hist[0]))
#     PFC_b_hist = np.histogram(np.abs(PFC_b[PFC_b<0]),bins=80,density=True)
#     PFC_inb_hist = np.histogram(np.abs(PFC_inb[PFC_inb<0]),bins=40,density=True)
#     a3 = regression(make_bins(PFC_b_hist[1]),np.log(PFC_b_hist[0]))
#     a4 = regression(make_bins(PFC_inb_hist[1]),np.log(PFC_inb_hist[0]))
    
    PCA_list = np.concatenate((PCA_list,[[a2,b2,a4,b4]]),axis=0)
    
#     fig1.hist(PFC_w[PFC_w>0],bins=40,alpha=0.5,density=True)
#     fig1.hist(np.abs(PFC_w[PFC_w<0]),bins=40,alpha=0.5,density=True)
#     fig2.hist(HPC_w[HPC_w>0],bins=40,alpha=0.5,log=True,density=True)
#     fig2.hist(np.abs(HPC_w[HPC_w<0]),bins=40,alpha=0.5,log=True,density=True)
#     fig3.hist(Re_w[Re_w>0],bins=40,alpha=0.5,log=True,density=True)
#     fig3.hist(np.abs(Re_w[Re_w<0]),bins=40,alpha=0.5,log=True,density=True)

# print(PCA_list[1:])
pca = PCA()
dfs = PCA_list[1:]
pca.fit(dfs)
feature = pca.transform(dfs)

plt.figure(figsize=(6, 6))
plt.scatter(feature[:, 0], feature[:, 1], alpha=0.8)
plt.grid()
plt.xlabel("PC1")
plt.ylabel("PC2")
for i,n in enumerate(target_list):
    plt.annotate(n[-12:],(feature[i, 0], feature[i, 1]))
for i in NGlist:
    i = int(i)
    plt.scatter(feature[i, 0], feature[i, 1], color="r", alpha=0.8)
plt.show()

plt.figure(figsize=(6, 6))
plt.scatter(dfs[:, 2], dfs[:, 3], alpha=0.8)
plt.grid()
plt.xlabel("a4")
plt.ylabel("b4")
for i,n in enumerate(target_list):
    plt.annotate(n[-12:],(dfs[i, 2], dfs[i, 3]))
for i in NGlist:
    i = int(i)
    plt.scatter(dfs[i, 2], dfs[i, 3], color="r", alpha=0.8)
plt.show()

fig3d = plt.figure()
ax3d = Axes3D(fig3d)
ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], "o", alpha=0.8)
for i in NGlist:
    i = int(i)
    ax3d.plot(feature[[i], 0], feature[[i], 1], feature[[i], 2], "o", color="r",alpha=0.8)
plt.show()

# print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
# print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
# print(pd.DataFrame(pca.components_, columns=["Hidden{}".format(x + 1) for x in range(dfs.shape[1])], index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
 

In [None]:
import glob
import pandas as pd
from sklearn import linear_model

model_list = glob.glob('model/ReModel_interRNN_long131test_s*_100_*.pth')
print(model_list)

training_size = 100
test_size = 1000
epochs_num = 10
hidden_size = 20
batch_size = 10
data_length = 100
inputsize = 2
outputsize = 2

NGlist_name = np.array(["s3_100_4","s1_100_3","s9_100_2","s4_100_1"])
# NGlist_name = np.array([])
NGlist = np.array([])

for NG in NGlist_name:
    for i, n in enumerate(model_list):
        if n[-12:-4] == NG:
            NGlist = np.append(NGlist,int(i))
print(NGlist)
    
# target = [2,4,7,8,9]
# target_list = []

# for i in target:
#     target_list.append(model_list[i])

target_list = model_list
PCA_list = np.array([np.ones(6)])

exp_plot = plt.figure()
exp_plot1 = exp_plot.add_subplot(111)

for model_path in target_list:
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
            if n == "PFC.bias_ih":
                PFC_b = np.array(p.data)
            if n == "PFC.bias_hh":
                PFC_inb = np.array(p.data)
            if n == "HPC.bias_ih":
                HPC_b = np.array(p.data)
            if n == "HPC.bias_hh":
                HPC_inb = np.array(p.data)
            if n == "Re.bias_ih":
                Re_b = np.array(p.data)
            if n == "Re.bias_hh":
                Re_inb = np.array(p.data)
    
    PFC_w_exp = scipy.stats.expon.fit(np.abs(PFC_w[PFC_w>0]))
    PFC_inw_exp = scipy.stats.expon.fit(np.abs(PFC_inw[PFC_inw>0]))
    a1 = PFC_w_exp[1]
    a2 = PFC_inw_exp[1]
    PFC_w_exp = scipy.stats.expon.fit(np.abs(PFC_w[PFC_w<0]))
    PFC_inw_exp = scipy.stats.expon.fit(np.abs(PFC_inw[PFC_inw<0]))
    a3 = PFC_w_exp[1]
    a4 = PFC_inw_exp[1]

    HPC_w_exp = scipy.stats.expon.fit(np.abs(HPC_w[HPC_w>0]))
    HPC_inw_exp = scipy.stats.expon.fit(np.abs(HPC_inw[HPC_inw>0]))
    b1 = HPC_w_exp[1]
    b2 = HPC_inw_exp[1]
    HPC_w_exp = scipy.stats.expon.fit(np.abs(HPC_w[HPC_w<0]))
    HPC_inw_exp = scipy.stats.expon.fit(np.abs(HPC_inw[HPC_inw<0]))
    b3 = HPC_w_exp[1]
    b4 = HPC_inw_exp[1]


    Re_w_exp = scipy.stats.expon.fit(np.abs(Re_w[Re_w>0]))
    Re_inw_exp = scipy.stats.expon.fit(np.abs(Re_inw[Re_inw>0]))
    c1 = Re_w_exp[1]
    c2 = Re_inw_exp[1]
    Re_w_exp_n = scipy.stats.expon.fit(np.abs(Re_w[Re_w<0]))
    Re_inw_exp_n = scipy.stats.expon.fit(np.abs(Re_inw[Re_inw<0]))
    c3 = Re_w_exp_n[1]
    c4 = Re_inw_exp_n[1]
    
#     PFC_b_hist = np.histogram(np.abs(PFC_b[PFC_b>0]),bins=20,density=True)
#     PFC_inb_hist = np.histogram(np.abs(PFC_inb[PFC_inb>0]),bins=10,density=True)
#     a1 = regression(make_bins(PFC_b_hist[1]),np.log(PFC_b_hist[0]))
#     a2 = regression(make_bins(PFC_inb_hist[1]),np.log(PFC_inb_hist[0]))
#     PFC_b_hist = np.histogram(np.abs(PFC_b[PFC_b<0]),bins=80,density=True)
#     PFC_inb_hist = np.histogram(np.abs(PFC_inb[PFC_inb<0]),bins=40,density=True)
#     a3 = regression(make_bins(PFC_b_hist[1]),np.log(PFC_b_hist[0]))
#     a4 = regression(make_bins(PFC_inb_hist[1]),np.log(PFC_inb_hist[0]))
    
    PCA_list = np.concatenate((PCA_list,[[b2,b4,a2,a4,c2,c4]]),axis=0)
    
#     fig1.hist(PFC_w[PFC_w>0],bins=40,alpha=0.5,density=True)
#     fig1.hist(np.abs(PFC_w[PFC_w<0]),bins=40,alpha=0.5,density=True)
#     fig2.hist(HPC_w[HPC_w>0],bins=40,alpha=0.5,log=True,density=True)
#     fig2.hist(np.abs(HPC_w[HPC_w<0]),bins=40,alpha=0.5,log=True,density=True)
#     fig3.hist(Re_w[Re_w>0],bins=40,alpha=0.5,log=True,density=True)
#     fig3.hist(np.abs(Re_w[Re_w<0]),bins=40,alpha=0.5,log=True,density=True)

    xvalues = np.linspace(0,2,100)
    pdf1 = scipy.stats.expon.pdf(xvalues,Re_inw_exp[0],Re_inw_exp[1])
    pdf2 = scipy.stats.expon.pdf(xvalues,Re_inw_exp_n[0],Re_inw_exp_n[1])
    if model_path[-12:-4] == "s9_100_2":
        exp_plot1.plot(pdf1,c="r",alpha=1)
        exp_plot1.plot(pdf2,c="r",alpha=1)
    elif model_path[-12:-4] == "s4_100_1":
        exp_plot1.plot(pdf1,c="g",alpha=1)
        exp_plot1.plot(pdf2,c="g",alpha=1)
    else:
        exp_plot1.plot(pdf1,c="b",alpha=0.01)

# print(PCA_list[1:])
pca = PCA()
dfs = PCA_list[1:]
pca.fit(dfs)
feature = pca.transform(dfs)

print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

print(pd.DataFrame(pca.components_, columns=["Hidden{}".format(x + 1) for x in range(dfs.shape[1])], index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
'+str(sparse)+'

plt.figure(figsize=(6, 6))
plt.scatter(feature[:, 0], feature[:, 1], alpha=0.8)
plt.grid()
plt.xlabel("PC1")
plt.ylabel("PC2")
for i,n in enumerate(target_list):
    plt.annotate(n[-12:],(feature[i, 0], feature[i, 1]))
for i in NGlist:
    i = int(i)
    plt.scatter(feature[i, 0], feature[i, 1], color="r", alpha=0.8)
plt.show()

plt.figure(figsize=(6, 6))
plt.scatter(dfs[:, 0], dfs[:, 1], alpha=0.8)
plt.grid()
plt.xlabel("a4")
plt.ylabel("b4")
for i,n in enumerate(target_list):
    plt.annotate(n[-12:],(dfs[i, 0], dfs[i, 1]))
for i in NGlist:
    i = int(i)
    plt.scatter(dfs[i, 0], dfs[i, 1], color="r", alpha=0.8)
plt.show()

fig3d = plt.figure()
ax3d = Axes3D(fig3d)
ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], "o", alpha=0.8)
for i in NGlist:
    i = int(i)
    ax3d.plot(feature[[i], 0], feature[[i], 1], feature[[i], 2], "o", color="r",alpha=0.8)
plt.show()

# print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
# print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
# print(pd.DataFrame(pca.components_, columns=["Hidden{}".format(x + 1) for x in range(dfs.shape[1])], index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
 

In [None]:
def calc_KL(data1,data2):
    x = np.linspace(-2,2,200)
    params1 = scipy.stats.norm.fit(data1)
    params2 = scipy.stats.norm.fit(data2)
    pdf1 = scipy.stats.norm.pdf(x,params1[0],params1[1])
    pdf2 = scipy.stats.norm.pdf(x,params2[0],params2[1])
    return scipy.stats.entropy(pdf1,pdf2)

def calc_KL_exp(data1,data2):
    x = np.linspace(0,2,200)
    params1 = scipy.stats.expon.fit(np.abs(data1[data1>=0]))
    params2 = scipy.stats.expon.fit(np.abs(data2[data2>=0]))
    pdf1 = scipy.stats.expon.pdf(x,params1[0],params1[1])
    pdf2 = scipy.stats.expon.pdf(x,params2[0],params2[1])
    return scipy.stats.entropy(pdf1,pdf2)

def calc_KL_exp2(data1):
    x = np.linspace(0,2,200)
    params1 = scipy.stats.expon.fit(np.abs(data1[data1>=0]))
    params2 = scipy.stats.expon.fit(np.abs(data1[data1<=0]))
    pdf1 = scipy.stats.expon.pdf(x,params1[0],params1[1])
    pdf2 = scipy.stats.expon.pdf(x,params2[0],params2[1])
    return scipy.stats.entropy(pdf1,pdf2)

In [None]:
import glob
import scipy

model_list = glob.glob('model/ReModel_interRNN_long131test_s*_100_2.pth')
print(model_list)

training_size = 100
test_size = 1000
epochs_num = 10
hidden_size = 20
batch_size = 10
data_length = 100
inputsize = 2
outputsize = 2

target_list = []

ANOVA_list = []

# target = [2,4]
# for i in target:
#     target_list.append(model_list[i])

target_list = model_list

fig_Re = plt.figure(figsize=(10,5))
for i,model_path in enumerate(target_list):
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
                weight_PFC_w = PFC_w[PFC_w.nonzero()]
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
                weight_PFC_inw = PFC_inw[PFC_inw.nonzero()]
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
                weight_HPC_w = HPC_w[HPC_w.nonzero()]
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
                weight_HPC_inw = HPC_inw[HPC_inw.nonzero()]
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
                weight_Re_w = Re_w[Re_w.nonzero()]
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
                weight_Re_inw = Re_inw[Re_inw.nonzero()]
    
#     plt.plot(np.ones(weight_Re_inw.shape[0])*i,weight_Re_inw,"o") 
    plt.plot(np.ones(weight_Re_w.shape[0])*i,weight_Re_w,"o") 
    ANOVA_list.append(weight_Re_w)
    
    print(model_path)
    
#     print(calc_KL(weight_HPC_inw,weight_HPC_w)+calc_KL(weight_PFC_inw,weight_PFC_w)+calc_KL(weight_Re_inw,weight_Re_w))
#     print(calc_KL(weight_HPC_w,weight_HPC_inw)+calc_KL(weight_PFC_w,weight_PFC_inw)+calc_KL(weight_Re_w,weight_Re_inw))

#     print(calc_KL_exp(weight_HPC_inw,weight_HPC_w))
#     print(calc_KL_exp(weight_HPC_w,weight_HPC_inw))
#     print(calc_KL_exp(weight_PFC_inw,weight_PFC_w))
#     print(calc_KL_exp(weight_PFC_w,weight_PFC_inw))
#     print(calc_KL_exp(weight_Re_inw,weight_Re_w))
#     print(calc_KL_exp(weight_Re_w,weight_Re_inw))
    
#     print(calc_KL_exp2(weight_HPC_inw))
#     print(calc_KL_exp2(weight_PFC_inw))
    print(calc_KL_exp2(weight_Re_inw))
#     print(calc_KL_exp2(weight_HPC_w))
#     print(calc_KL_exp2(weight_PFC_w))
#     print(calc_KL_exp2(weight_Re_w))

print(scipy.stats.f_oneway(ANOVA_list[3], ANOVA_list[5]))


In [None]:
import glob
import scipy

model_list = glob.glob('model/ReModel_interRNN_long131test_s1_100_3_plus_*.pth')
print(model_list)

training_size = 100
test_size = 1000
epochs_num = 10
hidden_size = 20
batch_size = 10
data_length = 100
inputsize = 2
outputsize = 2

target_list = []

ANOVA_list = []

target = [4,8]
for i in target:
    target_list.append(model_list[i])

# target_list = model_list

fig_Re = plt.figure(figsize=(10,5))
for i,model_path in enumerate(target_list):
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
                weight_PFC_w = PFC_w
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
                weight_PFC_inw = PFC_inw
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
                weight_HPC_w = HPC_w
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
                weight_HPC_inw = HPC_inw
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
                weight_Re_w = Re_w
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
                weight_Re_inw = Re_inw
    
#     plt.plot(np.ones(weight_Re_inw.shape[0])*i,weight_Re_inw,"o") 
    plt.plot(np.ones(weight_Re_w.shape[0])*i,weight_Re_w,"o") 
    ANOVA_list.append(weight_Re_inw)
    
    print(model_path)
    
#     print(calc_KL(weight_HPC_inw,weight_HPC_w)+calc_KL(weight_PFC_inw,weight_PFC_w)+calc_KL(weight_Re_inw,weight_Re_w))
#     print(calc_KL(weight_HPC_w,weight_HPC_inw)+calc_KL(weight_PFC_w,weight_PFC_inw)+calc_KL(weight_Re_w,weight_Re_inw))

#     print(calc_KL_exp(weight_HPC_inw,weight_HPC_w))
#     print(calc_KL_exp(weight_HPC_w,weight_HPC_inw))
#     print(calc_KL_exp(weight_PFC_inw,weight_PFC_w))
#     print(calc_KL_exp(weight_PFC_w,weight_PFC_inw))
#     print(calc_KL_exp(weight_Re_inw,weight_Re_w))
#     print(calc_KL_exp(weight_Re_w,weight_Re_inw))
    
#     print(calc_KL_exp2(weight_HPC_inw))
#     print(calc_KL_exp2(weight_PFC_inw))
    print(calc_KL_exp2(weight_Re_inw))
#     print(calc_KL_exp2(weight_HPC_w))
#     print(calc_KL_exp2(weight_PFC_w))
#     print(calc_KL_exp2(weight_Re_w))

print(scipy.stats.f_oneway(ANOVA_list[0], ANOVA_list[1]))
plt.figure()
plt.imshow(ANOVA_list[0]-ANOVA_list[1])


In [None]:
#v4 without Re
#v4_1 revised forward function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = hiddens[0][0]
    HPC_input = torch.cat([input,Re],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

def plot_distance_bet2traj(traj1,traj2):
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    return result


class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.RNNCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        sparse = 0.5
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]


def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    model_path = 'model/ReModel_interRNN_long131test_s6_100_2.pth'
    filename = "primal_long131test"
    
    train_x,train_y = mkOwnDataSet(training_size,filename,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    data = mkOwnRandomBatch(train_x, batch_size)
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
     
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 10
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*2):
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
            
    hidden = rnn.initHidden_rand()
    data = mkOwnRandomBatch(train_y, batch_size)
    init_point = init_point*0.9
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*2):
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()

    
    
    print(np.array(PFCstate)[:,0].shape)
    #MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
    
    pca = PCA()
    dfs = np.array(HPCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    dividenum = int(np.array(feature).shape[0]/2)
    
    fig = plt.figure()
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:dividenum,0,0],traj[:dividenum,0,1])
    plt.plot(traj[dividenum:,0,0],traj[dividenum:,0,1])
    plt.show()
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:dividenum, 0], feature[:dividenum, 1], feature[:dividenum, 2], alpha=0.8)
    ax3d.plot(feature[dividenum:, 0], feature[dividenum:, 1], feature[dividenum:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(HPCstate)[:,0,i])
#         plt.show()
    
    plot_distance(feature[:dividenum],feature[dividenum:])
    HPC_dis = overlap_coefficient(feature[:dividenum],feature[dividenum:])
    
    pca = PCA()
    dfs = np.array(PFCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:dividenum, 0], feature[:dividenum, 1], feature[:dividenum, 2], alpha=0.8)
    ax3d.plot(feature[dividenum:, 0], feature[dividenum:, 1], feature[dividenum:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(PFCstate)[:,0,i])
#         plt.show()

    plot_distance(feature[:dividenum],feature[dividenum:])
    PFC_dis = overlap_coefficient(feature[:dividenum],feature[dividenum:])

    pca = PCA()
    dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:dividenum, 0], feature[:dividenum, 1], feature[:dividenum, 2], alpha=0.8)
    ax3d.plot(feature[dividenum:, 0], feature[dividenum:, 1], feature[dividenum:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()

    plot_distance(feature[:dividenum],feature[dividenum:])
    Re_dis = overlap_coefficient(feature[:dividenum],feature[dividenum:])
    
#     MakeAnimation(feature[:dividenum, 0], feature[:dividenum, 1], feature[dividenum:, 0], feature[dividenum:, 1], data_limit)

    plot_distance_bet2traj(feature[:dividenum],feature[dividenum:])
    
    plt.figure()
    plt.plot(HPC_dis)
    plt.plot(PFC_dis)
    plt.plot(Re_dis)
    
if __name__ == '__main__':
    main()

In [None]:
#v4 without Re
#v4_1 revised forward function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from tslearn.clustering import TimeSeriesKMeans
import pandas as pd
import scipy

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = hiddens[0][0]
    HPC_input = torch.cat([input,Re],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#         else:
#             traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        else:
#             traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
            traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation2(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.01):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []
    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)


def plot_distance_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    plt.vlines(linelist,np.min(result),np.max(result))
    
    return result

def distance_bet2traj(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    return result

def distance_bet2traj_ave(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]-1):
        if i == 0:
            continue
        average = np.linalg.norm(traj1[i]-traj2[i-1]) + np.linalg.norm(traj1[i]-traj2[i]) + np.linalg.norm(traj1[i]-traj2[i+1])
        average = np.linalg.norm(traj1[i-1]-traj2[i]) + np.linalg.norm(traj1[i]-traj2[i]) + np.linalg.norm(traj1[i+1]-traj2[i])
        average = average/6
        result.append(average)
    return result

def plot_activity_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    threshold = 0.5
    for i in range(traj1.shape[0]):
        result.append(np.abs(traj1[i]-traj2[i]))
        
    plt.figure()
    for i,data in enumerate(np.array(result).T):
        if np.any(data[20:105]>threshold):
            print(i)
            plt.plot(data)
    
    return result

def search_delay(traj):
    linelist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if traj[i,1] > 0.2 and flag == False:
            linelist = np.append(linelist,i)
            flag = True
        if traj[i,1] < 0.2 and flag == True:
            linelist = np.append(linelist,i)
            flag = False
    return linelist

def pick_delay(traj,states):
    pointlist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < 0.1 and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > 0.1 and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            print(start,end)
        start = int(k)
    states_list.append(states[start:])
    return states_list

def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < 0.1 and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > 0.1 and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            print(start,end)
            start = int(k)
        
    states_list.append(states[start:])
    
    return states_list

def match_length(states_list):
    min_length = 1000
    result = []
    for state in states_list:
        min_length = np.min((min_length, len(state)))
    for i in range(len(states_list)):
        result.append(states_list[i][-min_length:])
        
    return result

def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq
    
    
    
class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0
        var = 0.01
        HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
#         PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens


def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 2
#     model_path = 'model/ReModel_interRNN_long131test_s7_100_1.pth'
#     model_path = 'model/ReModel_interRNNrand_AddRe_OUT5_161_s5_100_1.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_1212121_s6_100_1_2.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_H2v121_s9_100_1.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_Reinh_AddRe_OUT5_H121_s2_100_1_2.pth'
#     model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_H121_s3_100_2.pth'
#     model_path = "model/ReModel_L2_interRNNrand_Reinh_AddRe_OUT5_11311_s6_100_1.pth"
    model_path = 'model/R20_H/ReModel_L2_interRNNrand_OUT1_121H_s5_100_3_epoch90.pth'

    filename = "primal_long131test"
    
    train_x,train_y = mkOwnDataSet(training_size,filename,data_length)
    
    colors = ['C0','C1','C2','C3','C4','C9']


    
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size)
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size)
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size)                
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size)
    
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
     
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_test()
    data_limit = 20
    est_length = 5
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*est_length):
            output,hidden = rnn(output,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
            
    hidden = rnn.initHidden_test()
    pattern = 2
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size)
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size)
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size)                
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size)
    init_point = init_point*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*est_length):
#             hidden = rnn.noiseHidden_rand(hidden)
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    
    dividenum = int(np.array(PFCstate)[0:,0].shape[0]/2)
    traj_noise = []
    PFCstate_noise = []
    HPCstate_noise = []
    Restate_noise = []
    hidden = rnn.initHidden_test()
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size)
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size)
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size)                
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size)
    init_point = init_point*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*est_length):
#             hidden = rnn.noiseHidden_rand(hidden)
            output,hidden = rnn(output,hidden) 
#             hidden = rnn.noiseHidden_dis(hidden, np.array(Restate)[k+data_limit], np.array(Restate)[dividenum+k+data_limit])
#             hidden = rnn.noiseHidden_rand(hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    traj_noise = torch.tensor(traj_noise)
    traj_noise = torch.squeeze(traj_noise).numpy()

    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
    
    pca = PCA()
    dfs = np.array(HPCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    dividenum = int(np.array(feature).shape[0]/2)
    
#     MakeAnimation(traj[:dividenum,0,0],traj[:dividenum,0,1], traj[dividenum:,0,0], traj[dividenum:,0,1], data_limit)
#     MakeAnimation(traj[:dividenum,0,0],traj[:dividenum,0,1], traj_noise[:,0,0], traj_noise[:,0,1], data_limit)
#     MakeAnimation2(traj[:dividenum,0,0],traj[:dividenum,0,1], traj_noise[:,0,0], traj_noise[:,0,1], data_limit)
#     MakeAnimation_img(np.array(PFCstate)[:dividenum,0],"PFCr")
#     MakeAnimation_img(np.array(PFCstate)[dividenum:,0],"PFCl")
    
    fig = plt.figure()
#     plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:dividenum,0,0],traj[:dividenum,0,1])
#     plt.plot(traj[dividenum:,0,0],traj[dividenum:,0,1])
    plt.show()
    
    linelist = search_delay(traj[:dividenum,0])
    linelist2 = search_delay(traj[dividenum:,0])
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:dividenum, 0], feature[:dividenum, 1], feature[:dividenum, 2], alpha=0.8)
    ax3d.plot(feature[dividenum:, 0], feature[dividenum:, 1], feature[dividenum:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(HPCstate)[:,0,i])
#         plt.show()
    
    plot_distance(feature[:dividenum],feature[dividenum:])
    HPC_dis = overlap_coefficient(feature[:dividenum],feature[dividenum:])
    
    HPC_samedis = plot_distance_bet2traj(np.array(HPCstate)[:dividenum,0],np.array(HPCstate)[dividenum:,0],linelist)
    
    pca = PCA()
    dfs = np.array(PFCstate)[0:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:dividenum, 0], feature[:dividenum, 1], feature[:dividenum, 2], alpha=0.8)
    ax3d.plot(feature[dividenum:, 0], feature[dividenum:, 1], feature[dividenum:, 2], alpha=0.8)
#     ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
#     ax3d.plot(feature[20:21, 0], feature[20:21, 1], feature[20:21, 2],"o", alpha=1)
    plt.show()
#     for i in range(20):
#         fig_place = plt.figure()
#         place3d = Axes3D(fig_place)
#         place3d.plot(traj[:,0,0], traj[:,0,1], np.array(PFCstate)[:,0,i])
#         plt.show()

# #   ### for 3d divided plot ###
#     fig3d = plt.figure()
#     ax3d = Axes3D(fig3d)
#     end = 0
#     for i in range(5):
#         start = end
#         end = int(linelist[i*2+1])
#         ax3d.plot(feature[start:start+1, 0], feature[start:start+1, 1], feature[start:start+1, 2], "o", color=colors[i], alpha=0.8)
#         ax3d.plot(feature[start+20:start+21, 0], feature[start+20:start+21, 1], feature[start+20:start+21, 2], "o", color=colors[i], alpha=0.8)
#         ax3d.plot(feature[start+40:start+41, 0], feature[start+40:start+41, 1], feature[start+40:start+41, 2], "o", color=colors[i], alpha=0.8)
#         ax3d.plot(feature[start:end, 0], feature[start:end, 1], feature[start:end, 2], color=colors[i], alpha=0.8)
#     plt.show()
#     print(linelist)

#     plot_distance(feature[:dividenum],feature[dividenum:])
#     PFC_dis = overlap_coefficient(feature[:dividenum],feature[dividenum:])
    
#     PFC_samedis = plot_distance_bet2traj(np.array(PFCstate)[:dividenum,0],np.array(PFCstate)[dividenum:,0],linelist)

#     pca = PCA()
#     dfs = np.array(Restate)
#     pca.fit(dfs)
#     feature = pca.transform(dfs)
    
#     pred = KMeans(n_clusters=2).fit_predict(feature)
    
#     fig3d = plt.figure()
#     ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:dividenum, 0], feature[:dividenum, 1], feature[:dividenum, 2], alpha=0.8)
#     ax3d.plot(feature[dividenum:, 0], feature[dividenum:, 1], feature[dividenum:, 2], alpha=0.8)
#     ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
#     plt.show()

#     plot_distance(feature[:dividenum],feature[dividenum:])
#     Re_dis = overlap_coefficient(feature[:dividenum],feature[dividenum:])
    
# #     MakeAnimation(feature[:dividenum, 0], feature[:dividenum, 1], feature[dividenum:, 0], feature[dividenum:, 1], data_limit)

#     Re_samedis = plot_distance_bet2traj(Restate[:dividenum],Restate[dividenum:],linelist)
    
    
    
#     plt.figure()
#     plt.plot(HPC_dis)
#     plt.plot(PFC_dis)
#     plt.plot(Re_dis)
#     plt.vlines(linelist,0,1)
#     plt.show()
    
#     plt.figure()
#     plt.plot(HPC_samedis)
#     plt.plot(PFC_samedis)
#     plt.plot(Re_samedis)
#     plt.vlines(linelist,0,3)
#     plt.show()
    
#     Re_noisedisr = plot_distance_bet2traj(Restate[:dividenum],Restate_noise[:],linelist)
#     Re_noisedisl = plot_distance_bet2traj(Restate[dividenum:],Restate_noise[:],linelist)
#     PFC_noisedisr = plot_distance_bet2traj(np.array(PFCstate)[:dividenum,0],np.array(PFCstate_noise)[:,0],linelist)
#     PFC_noisedisl = plot_distance_bet2traj(np.array(PFCstate)[dividenum:,0],np.array(PFCstate_noise)[:,0],linelist)
#     plt.figure()
#     plt.plot(Re_noisedisr)
#     plt.plot(Re_noisedisl)
#     plt.vlines(linelist,0,3)
#     plt.show()
    
#     plt.figure()
#     plt.plot(PFC_noisedisr)
#     plt.plot(PFC_noisedisl)
#     plt.vlines(linelist,0,3)
#     plt.show()
    
#     plot_activity_bet2traj(Restate[dividenum:],Restate[:dividenum],linelist)
# #     plot_activity_bet2traj(np.array(PFCstate)[:dividenum,0],np.array(PFCstate)[dividenum:,0],linelist)

#     points = np.array(np.where(traj_noise[:,0,1]<0.1)[0]).astype("int64")
#     Replot = plt.figure()
#     axre2 = Replot.add_subplot(111)
#     axre2.imshow(np.corrcoef(np.array(np.array(Restate_noise)[points])))
#     axre2.set_title("Re")  
    
#     pca = PCA()
#     dfs = np.array(np.array(Restate_noise)[points])
# #     dfs = np.array(np.array(Restate_noise))
# #     dfs = np.array(HPCstate)[0:,0]
#     pca.fit(dfs)
#     feature = pca.transform(dfs)    
    
# #     delays = pick_delay(traj_noise[:,0], np.array(PFCstate_noise)[:,0])
# #     delays = pick_delay(traj_noise[:,0], np.array(Restate_noise))
# #     delays = pick_traj(traj_noise[:,0], np.array(PFCstate_noise)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(Restate_noise))
#     linelist_delay = [0]
#     for i in delays:
#         print(np.array(i).shape)
#         linelist_delay.append(linelist_delay[-1]+len(i))
    
#     fig = plt.figure()
#     plt.plot(feature[:dividenum, 0], alpha=0.8)
#     plt.vlines(linelist_delay,0,3)
#     plt.show()
    
#     fig = plt.figure()
#     plt.plot(feature[:dividenum, 1], alpha=0.8)
#     plt.show()
    
#     fig = plt.figure()
#     plt.plot(feature[:dividenum, 2], alpha=0.8)
#     plt.show()
    
#     fig = plt.figure()
#     plt.plot(feature[:dividenum, 0]+feature[:dividenum, 1]+feature[:dividenum, 2], alpha=0.8)
# #     plt.vlines(linelist,0,3)
#     plt.show()

#     print(np.array(delays[:-1]).shape)
#     tskm = TimeSeriesKMeans(n_clusters=3,metric='euclidean',max_iter=300)
    
#     data_kmeans = match_length(delays[:-1])
#     tskm_pred = tskm.fit_predict(np.array(data_kmeans))
#     print(tskm_pred)
#     tskm_feature = tskm.transform(np.array(data_kmeans))
#     print(tskm_feature)

#     data = match_length(delays[:-1])
#     for k in range(5):
#         fig = plt.figure()
#         target = k
#         for i in range(5):
#             plt.plot(distance_bet2traj(data[target][:],data[i][:]))
#         plt.show()
    
    
    pca = PCA()
    dfs = np.array(np.array(PFCstate)[:dividenum,0])
#     dfs = np.array(np.array(Restate)[:dividenum])
    pca.fit(dfs)
    feature = pca.transform(dfs)    
    print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    
    
#     delays = pick_delay(traj_noise[:,0], np.array(PFCstate_noise)[:,0])
#     delays = pick_delay(traj_noise[:,0], np.array(Restate_noise))
#     delays = pick_traj(traj[:,0], np.array(PFCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(Restate_noise))
    delays = pick_traj(traj[:dividenum,0], np.array(feature)[:])
    linelist_delay = [0]
    for i in delays:
        print(np.array(i).shape)
        linelist_delay.append(linelist_delay[-1]+len(i))

    print(np.array(delays[:-1]).shape)
    tskm = TimeSeriesKMeans(n_clusters=3,metric='euclidean',max_iter=300)
    
    data_kmeans = match_length(delays[:-1])
    tskm_pred = tskm.fit_predict(np.array(data_kmeans))
    print(tskm_pred)
    tskm_feature = tskm.transform(np.array(data_kmeans))
    print(tskm_feature)

#     data = match_length(delays[:-1])
#     for k in range(5):
#         fig = plt.figure()
#         target = k
#         for i in range(5):
#             plt.plot(distance_bet2traj(data[target][:],data[i][:]))
# #             plt.plot(distance_bet2traj_ave(data[target][:],data[i][:]))
#         plt.show()
    
    fig = plt.figure()
    for data in delays[1:-1]:
        plt.plot(np.array(data)[:,0],alpha=0.5)
        print(len(data))


#     pick_delay(traj_noise[:,0], Restate_noise)

    

#     for i in range(20):
#         plt.figure()
#         plt.plot(np.array(Restate)[:dividenum,i])
#         plt.title("neuron#"+str(i+1))

    plt.figure()
    plt.plot(traj[:dividenum,0,0])
    plt.plot(traj[:dividenum,0,1])
    
    
if __name__ == '__main__':
    main()

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 19 10:54:37 2021

@author: munenori
"""

import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import scipy

%matplotlib notebook


def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniPFC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation2(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_testdata(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_test_1.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter_gate_PFC.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_3D(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0


def mkOwnDataSet(data_size, data_length=100, freq=60., noise=0.01):
   
    x = np.loadtxt("primal_long131test_r.csv",delimiter=',')
    y = np.loadtxt("primal_long131test_l.csv",delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []
    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
    # ###  ver2  ###
    # orders = []
    # order  = np.array([1])
    # order  = np.append(order,np.zeros(delay_length))
    # order = np.append(order, [2])
    # orders.append(order) 
    # order  = np.array([2])
    # order  = np.append(order,np.zeros(delay_length))
    # order = np.append(order, [3])
    # orders.append(order) 
    # order  = np.array([3])
    # order  = np.append(order,np.zeros(delay_length))
    # order = np.append(order, [4])
    # orders.append(order) 
    # order  = np.array([4])
    # order  = np.append(order,np.zeros(delay_length))
    # order = np.append(order, [1])
    # orders.append(order) 
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    
    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    
    
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
#                 target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
                target = np.array([[0.5,0] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.0001):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.0
        v = 0.01
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]



class MyLSTM_RNN_noise(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse):
        super(MyLSTM_RNN_noise, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+10
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        c = 0.02
        v = 0.02
        HPC_noised = hiddens[1][0] + torch.randn(self.batch_size, self.hidden_size_HPC)*c
        PFC_noised = hiddens[0][0] + torch.randn(self.batch_size, self.hidden_size_PFC)*v
        Re_noised = hiddens[2] + torch.randn(self.batch_size, self.hidden_size_Re)*c
        Re_input = torch.cat([PFC_noised,HPC_noised],dim=1)
        Re_hidden = self.Re(Re_input, Re_noised)
        HPC_input = torch.cat([input,Re_noised],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([HPC_noised,Re_noised],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]


class MyLSTM_vHPC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_vHPC, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.vHPC = nn.LSTMCell(hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        sparse = 0.1
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.vHPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.vHPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[3][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        vHPC_input = hiddens[1][0]
        vHPC_hidden = self.vHPC(vHPC_input,hiddens[2])
        PFC_input = torch.cat([hiddens[2][0],hiddens[3][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, vHPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        vHPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,vHPC_hidden,Re_hidden]

class MyLSTM_3lay(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_3lay, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        sparse = 0.1
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input,hiddens[2])
        PFC_input = hiddens[2][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        vHPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
class MyLSTM_feedforward(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse):
        super(MyLSTM_feedforward, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.Linear(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input)
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]
    

class MyLSTM_feedforward_Thalamus2(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse):
        super(MyLSTM_feedforward_Thalamus2, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.hidden_size_THinh = hidden_size
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.Linear(self.hidden_size_PFC+self.hidden_size_HPC+self.hidden_size_THinh, self.hidden_size_Re)
        self.THinh = nn.Linear(self.hidden_size_Re, self.hidden_size_THinh)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        self.ReLU = nn.ReLU()
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0],hiddens[3]],dim=1)
        Re_hidden = self.Re(Re_input)
        THinh_input = hiddens[2]
        THinh_hidden = self.ReLU(self.THinh(THinh_input))
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden,THinh_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        THinh_hidden = self.ReLU(torch.rand(self.batch_size, self.hidden_size_THinh))
        return [PFC_hidden,HPC_hidden,Re_hidden,THinh_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        THinh_hidden = self.ReLU(torch.rand(self.batch_size, self.hidden_size_THinh)*const)
        return [PFC_hidden,HPC_hidden,Re_hidden,THinh_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re,self.hidden_size_THinh]



    
    
class MyLSTM_RNN_uniPFC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=1):
        super(MyLSTM_RNN_uniPFC, self).__init__()

        self.hidden_size_PFC = hidden_size+0
        self.hidden_size_HPC = hidden_size+0
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)
        # nn.init.normal_(self.PFC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.Re.weight_ih.data,0,0.1/10)
        # nn.init.normal_(self.PFC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.Re.weight_hh.data,0,1/10)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = hiddens[1][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.0
        v = 0.01
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*v
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]

    

def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq

def main(num):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    # model_path = 'model/ReModel_L2_interRNNrand_Reinh_AddRe_OUT5_H121_s'+str(num)+'_100_2_2.pth'
    # model_path = 'model/ReModel_L2_interRNNrand_Reinh_AddRe_OUT5_H121_s2_100_1_2.pth'
    model_path = 'model/R20_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s2_100_3_epoch180.pth'
#     model_path = 'model/R20_H_stopinit_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s2_100_3_epoch10.pth'
#     model_path = 'model/R20_H_uniPFC_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s3_100_1_epoch145.pth'
#     model_path = 'model/R20_H_uniHPC_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s1_100_3_epoch190.pth'
#     model_path = 'model/R20_uniPFC_H/ReModel_L2_interRNNrand_OUT1_uniPFC_121H_s5_100_2_epoch85.pth'
#     model_path = 'model/R20_uniHPC_H/ReModel_L2_interRNNrand_OUT1_121H_s2_100_1_epoch100.pth'
#     model_path = 'model/R20_H_transfer2B/ReModel_L2_interRNNrand_OUT1_transfers14_121_s2_100_2_epoch60.pth'
#     model_path = 'model/R20FF_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s4_100_3_epoch190.pth'
#     model_path = 'model/R20FF_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s4_100_1_epoch100.pth'
#     model_path = 'model_test/R20_H_retrainv2_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s9_100_2_epoch215.pth'
#     model_path = 'model/R20_feedReinhReLU_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s8_100_1_epoch150.pth'


    sparse = 1
    delay_length = 2
    
    if os.path.exists(model_path):
        print(model_path)
    else:
        print("Not exist")
        return

    # train_x,train_y = mkOwnDataSet(training_size,data_length)
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)
#     test_x = mkOwnDataSet(test_size,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size, sparse)
#     rnn = MyLSTM_RNN_uniPFC(inputsize, hidden_size, outputsize, batch_size, sparse)
#     rnn = MyLSTM_feedforward(inputsize, hidden_size, outputsize, batch_size, sparse)
#     rnn = MyLSTM_feedforward_Thalamus2(inputsize, hidden_size, outputsize, batch_size, sparse)
    
    rnn.load_state_dict(torch.load(model_path))
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()                
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 120
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
    for k in range(data.shape[0]*2+10):
#             if output.tolist()[0][1]<0.05:
#                 hidden = rnn.noiseHidden_rand(hidden)
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             print(output)
            Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    fig = plt.figure()
    print(pltdata.shape)
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:,0,0],traj[:,0,1])
    plt.show()
    
    # MakeAnimation2(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    # MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    # MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")
    # MakeAnimation_img(np.array(Restate),"Re")
    #MakeAnimation_testdata(pltdata[:,0,0],pltdata[:,0,1])
    # MakeAnimation_data(pltdata[:,0,0],pltdata[:,0,1],data_limit)

    # for n, p in rnn.named_parameters():
    #         if n == "PFC.weight_ih":
    #             PFC_w = np.array(p.data)
    #         if n == "HPC.weight_ih":
    #             HPC_w = np.array(p.data)
    #         if n == "Re.weight_ih":
    #             Re_w = np.array(p.data)
                
    # # for n, p in rnn.named_parameters():
    # #         if n == "PFC.weight_hh":
    # #             PFC_w = np.array(p.data)
    # #         if n == "HPC.weight_hh":
    # #             HPC_w = np.array(p.data)
    # #         if n == "Re.weight_hh":
    # #             Re_w = np.array(p.data)
                
    # fig2 = plt.figure()
    # ax1 = fig2.add_subplot(131)
    # ax2 = fig2.add_subplot(132)
    # axre = fig2.add_subplot(133)
    
    # ax1.imshow(PFC_w,cmap="coolwarm")
    # ax2.imshow(HPC_w,cmap="coolwarm")
    # axre.imshow(Re_w,cmap="coolwarm")
    # ax1.set_title("max = {:.2f},min = {:.2f}".format(np.max(PFC_w),np.min(PFC_w)))
    # ax2.set_title("max = {:.2f},min = {:.2f}".format(np.max(HPC_w),np.min(HPC_w)))
    # axre.set_title("max = {:.2f},min = {:.2f}".format(np.max(Re_w),np.min(Re_w)))

    
    # fig3 = plt.figure()
    # ax3 = fig3.add_subplot(131)
    # ax4 = fig3.add_subplot(132)
    # axre2 = fig3.add_subplot(133)
    # ax3.imshow(np.corrcoef(np.array(PFCstate)[:,0]))
    # ax4.imshow(np.corrcoef(np.array(HPCstate)[:,0]))
    # axre2.imshow(np.corrcoef(np.array(Restate)))
    # ax3.set_title("PFC")   
    # ax4.set_title("HPC")  
    # axre2.set_title("Re")  
    
    # Replot = plt.figure()
    # axre2 = Replot.add_subplot(111)
    # axre2.imshow(np.corrcoef(np.array(Restate)))
    # axre2.set_title("Re")  
    
#     pca = PCA()
#     # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0],np.array(Restate)),axis=1)
#     # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0]),axis=1)
#     dfs = np.array(HPCstate)[:,0]
# #     dfs = np.array(Restate)
#     pca.fit(dfs)
#     feature = pca.transform(dfs)
    # print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
    # print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    # print(pd.DataFrame(pca.components_, columns=["Hidden{}".format(x + 1) for x in range(dfs.shape[1])], index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    # pred = KMeans(n_clusters=2).fit_predict(feature)
    
    # plt.figure(figsize=(6, 6))
    # plt.scatter(feature[:, 0], feature[:, 1], alpha=0.8)
    # #plt.scatter(feature[100:200, 0], feature[100:200, 1], alpha=0.8)
    # plt.scatter(feature[0, 0], feature[0, 1], c="r", alpha=0.8)
    # plt.grid()
    # plt.xlabel("PC1")
    # plt.ylabel("PC2")
    # for i in range(np.min([200,data_limit+k])):
    #     plt.annotate(i,(feature[i, 0], feature[i, 1]))
    # plt.show()
    
    # # MakeAnimation_attracter(feature[:, 0], feature[:, 1])
    
#     fig3d = plt.figure()
#     ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.3)
    
    # fig_place = plt.figure()
    # place3d = Axes3D(fig_place)
    # place3d.plot(traj[:,0,0], traj[:,0,1], np.array(HPCstate)[:,0,0])
    
    # fig4 = plt.figure(figsize=(10,5))
    # ax5 = fig4.add_subplot(311)
    # ax6 = fig4.add_subplot(312)
    # axRe3 = fig4.add_subplot(313)
    # #ax5.plot(np.array(PFCstate)[:,0,pred==1])
    # ax5.imshow(np.array(PFCstate)[:200,0,:].T)
    # # ax5.plot(np.average(np.array(PFCstate)[:,0,pred==0],axis=1))
    # # ax5.plot(np.average(np.array(PFCstate)[:,0,pred==1],axis=1))
    # # ax5.plot(np.average(np.array(HPCstate)[:,0],axis=1))
    # ax6.imshow(np.array(HPCstate)[:200,0,:].T)
    # axRe3.imshow(np.array(Restate)[:200,:].T)
    # ax5.set_title("PFC")   
    # ax6.set_title("HPC")  
    # axRe3.set_title("Re") 
    # print(pred==0)
    
    # # fig5 = plt.figure(figsize=(5,5))
    # # plt.plot(np.average(np.array(PFCstate)[:,0,pred==0],axis=1),np.average(np.array(PFCstate)[:,0,pred==1],axis=1))
    # # plt.show()
    # #MakeAnimation_attracter(np.average(np.array(PFCstate)[:,0,pred==0],axis=1),np.average(np.array(PFCstate)[:,0,pred==1],axis=1))
    
    
    # fig6 = plt.figure(figsize=(10,20))
    # plt.imshow(np.array(Gate_states)[:200,0,0:].T)
    # print(np.max(np.array(Gate_states)[:200,0,0:].T),np.min(np.array(Gate_states)[:,0,0:].T),np.average(np.array(Gate_states)[:,0,0:].T))
    # plt.show()
    # fig7 = plt.figure(figsize=(10,20))
    # plt.imshow(np.array(Gate_states)[:200,1,0:].T)
    # print(np.max(np.array(Gate_states)[:200,1,0:].T),np.min(np.array(Gate_states)[:,1,0:].T),np.average(np.array(Gate_states)[:,1,0:].T))
    # plt.show()
    
    # plt.figure()
    # plt.plot(np.array(Gate_states)[:,1,14])
    # plt.plot(np.array(Gate_states)[:,1,34])
    # plt.plot(np.array(Gate_states)[:,1,54])
    # plt.plot(np.array(Gate_states)[:,1,74])
    # plt.ylim(0,1)
    # plt.show()
    
    # plt.figure()
    # plt.plot(np.array(Gate_states)[:,1,13])
    # plt.plot(np.array(Gate_states)[:,1,33])
    # plt.plot(np.array(Gate_states)[:,1,53])
    # plt.plot(np.array(Gate_states)[:,1,73])
    # plt.ylim(0,1)
    # plt.show()
    
    
    # plt.figure(figsize=(20,20))
    # plt.plot(np.array(HPCstate)[:200,0,:10])
    # plt.show()
    
    #np.save("right_traj.npy",dfs)
    #np.save("left_traj.npy",dfs)
    
    # fig5 = plt.figure()
    # plt.hist(PFC_w[PFC_w.nonzero()],bins=400,range=(-2,2))
    # fig6 = plt.figure()
    # plt.hist(HPC_w[HPC_w.nonzero()],bins=400,range=(-2,2))    
    # fig7 = plt.figure()
    # plt.hist(Re_w[Re_w.nonzero()],bins=400,range=(-2,2))   
    
    
    # pca = PCA()
    # dfs = Re_w
    # pca.fit(dfs)
    # feature = pca.transform(dfs)
    # pred = KMeans(n_clusters=2).fit_predict(feature)
    
    # plt.figure(figsize=(6, 6))
    # plt.scatter(feature[:, 0], feature[:, 1], alpha=0.8)
    # plt.grid()
    # plt.xlabel("PC1")
    # plt.ylabel("PC2")

    # plt.show()
    
    plt.figure()
    plt.plot(traj[:,0,0])
    plt.plot(traj[:,0,1])
    
    delays = pick_traj(traj[:,0], np.array(traj)[:])[3:-1]
    bifur = np.array([])
    for data in delays:
        check_bifur = np.argmax(np.abs(np.array(data)[:,0,0] - 0.5))
        bifur = np.append(bifur,data[check_bifur,:,0])
    if bifur.shape[0] == 0:
        pass
    plt.figure()
    plt.plot(traj[:,0,0])
    plt.plot(traj[:,0,1])
    print(np.var(bifur),np.abs(np.median(bifur)-np.mean(bifur)))
    print(bifur)
    
    pca = PCA()
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0],np.array(Restate)),axis=1)
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0]),axis=1)
    dfs = np.array(HPCstate)[:,0]
#     dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    datalen = 120
    print(np.corrcoef(np.abs(traj[:datalen*1,0,1]-traj[datalen*2:datalen*3,0,1]),np.abs(feature[:datalen*1,0]-feature[datalen*2:datalen*3,0]))[0,1])
    
#     delays = pick_traj(traj_noise[:,0], np.array(PFCstate_noise)[:,0])
#     delays = pick_traj(traj[:,0], np.array(PFCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(HPCstate_noise)[:,0])
#     delays = pick_traj(traj[:,0], np.array(HPCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(Restate_noise)[:])
    delays = pick_traj(traj[:,0], np.array(Restate)[:])
#     delays = pick_delay(traj[:,0], np.array(Restate)[:])
#     delays = pick_traj(traj_noise[:,0], np.array(traj_noise)[:])[3:-1]
#     delays = pick_traj(traj[:,0], np.array(Gate_states)[:,0,0:20]-np.array(Gate_states)[:,0,20:40])
#     delays = pick_traj(traj[:,0], np.array(Gate_states)[:,0,0:20])
#     delays = pick_delay(traj[:,0], np.array(Gate_states)[:,1])
#     delays = pick_traj(traj_noise[:,0], np.array(Gate_states_noise)[:,1])
#     delays = pick_delay(traj_noise[:,0], np.array(Gate_states_noise)[:,0])
#     delays = pick_traj(traj[:,0], feature[:])
#     delays = pick_delay(traj[:,0], feature[:])

    bifur = np.array([])
    for i in range(20):
        plt.figure()
        for data in delays[:-1]:
#             avedata = moving_average(data[:,0+i])
#             plt.plot(avedata[2:-2],alpha=0.5)
#             plt.plot(data[:,0+i]-avedata[2:-2],alpha=0.5)
            plt.plot(np.array(data)[:,0+i],alpha=0.5)
#             plt.plot(np.array(data)[:,0+i]-np.array(delays[1:-1][0])[:,0+i],alpha=0.5)
#             plt.plot(np.log(data)[:,0+i],alpha=0.5)
#             print(len(data),len(avedata))
        plt.title("neuron#"+str(i+1))
#     print(np.var(bifur),np.median(bifur)-np.mean(bifur))

#     for i in range(20):
#         plt.figure()
#         delays = delays[1:-1]
#         plt.plot(np.array(delays)[,:,0+i],alpha=0.5)
#         plt.title("neuron#"+str(i+1))
# #     print(np.var(bifur),np.median(bifur)-np.mean(bifur))

    for i in range(3):
        data_len = 1000
        delays_samelen = []
        result = 0
        for data in delays[1:-1]:
            data_len = np.min([data_len,len(data)])
        for data in delays[1:-1]:
            delays_samelen.append(data[:data_len])
        delays_samelen = np.array(delays_samelen)
        for k in range(data_len):
            result += np.var(delays_samelen[:,k,i])
        result /= data_len
        print(result)
#     print(np.var(bifur),np.median(bifur)-np.mean(bifur))


    pca = PCA()
    dfs = np.array(PFCstate)[0:,0]
#     dfs = np.array(Restate)[0:]
    pca.fit(dfs)
    PFCfeature = pca.transform(dfs)
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    pca = PCA()
#     dfs = np.array(PFCstate)[0:,0]
    dfs = np.array(Restate)[0:]
    pca.fit(dfs)
    Refeature = pca.transform(dfs)
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
#     plt.figure()
    for i in range(20):
        plt.figure()
#         plt.ylim(-1,1)
        plt.plot(traj[:,0,0],alpha=0.5)
#         plt.plot(traj_noise[:,0,1],alpha=0.5)
#         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
        for k in range(1):
#             plt.plot(np.array(HPCstate)[:,k,i],alpha=0.5)
#             plt.plot(np.array(HPCstate_noise)[:,k,i],alpha=0.5)
#             plt.plot(np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:],alpha=0.3)
#             plt.plot(np.array(PFCstate)[:,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate_noise)[:,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[:120,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[:120,k,i]-np.array(PFCstate)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:],alpha=0.3)
#             plt.plot(np.array(PFCstate_noise)[:120,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate_noise)[:120,k,i]-np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i],alpha=0.5)
#             plt.plot(np.array(Restate_noise)[:,i],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-np.array(Restate_noise)[:,i],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-moving_average(np.array(Restate)[2:-2,i])[:],alpha=0.3)
#             plt.plot(moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
#             plt.plot(moving_average(np.array(Restate)[:,i])-moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-moving_average(np.array(Restate)[:,i])[4:-4],alpha=0.5)
#             plt.plot(np.array(Gate_states)[:,0,0+i],"o",alpha=0.5)
#             plt.plot(np.array(Gate_states)[:,1,0+i],"o",alpha=0.5)
#             plt.plot(np.array(Gate_states_noise)[:,1,0+i],alpha=0.5)
#             plt.plot(np.abs(np.array(Gate_states)[:,1,0+i]-np.array(Gate_states_noise)[:,1,0+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states)[:,1,0+i]/np.array(Gate_states)[:,1,20+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states_noise)[:,1,0+i]/np.array(Gate_states_noise)[:,1,20+i]),alpha=0.5)
            plt.plot(np.log(np.array(Gate_states)[:,0,0+i]/np.array(Gate_states)[:,0,20+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states_noise)[:,0,0+i]/np.array(Gate_states_noise)[:,0,20+i]),alpha=0.5)
#             plt.plot(np.array(PFCfeature)[:,0],alpha=0.5)
#             plt.plot(np.array(PFCfeature)[:,0]-moving_average(np.array(PFCfeature)[2:-2,0]),alpha=0.5)
#             plt.plot(np.array(Refeature)[:,1],alpha=0.5)
#             plt.plot(np.array(Refeature)[:,1]-moving_average(np.array(Refeature)[:,1])[2:-2],alpha=0.5)

        plt.title("neuron#"+str(i+1))
    
#     fig = plt.figure()
#     fig1 = fig.add_subplot(111)
#     fig1.hist(np.array(Gate_states)[:,1,0:20].ravel(),bins=30,range=(-0.1,1.1),alpha=0.5,density=True)
#     plt.show()
    
    fig = plt.figure()
    fig1 = fig.add_subplot(111)
    hist = np.histogram(np.array(Gate_states)[:,1,0:20].ravel(),bins=30,range=(-0.1,1.1),density=True)
    print(hist)
    plt.plot(hist[1][3:-2],hist[0][2:-2])
    plt.show()
    
    beta = scipy.stats.beta.fit(hist[0][2:-2])
    print(beta)
    
    
    pca = PCA()
    dfs = np.array(HPCstate)[0:,0]
#     dfs = np.array(Restate)[0:]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
#     pred = KMeans(n_clusters=2).fit_predict(feature)

    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
#     ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
    ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()
    
    plt.figure()
    plt.plot(feature[:240, 0], feature[:240, 1], alpha=0.8)
    plt.show()
    
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    
    
if __name__ == '__main__':
    features = []
    for i in range(1):
        features.append(main(i+1))

In [None]:
    
class MyLSTM_RNN_uniPFC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=1):
        super(MyLSTM_RNN_uniPFC, self).__init__()

        self.hidden_size_PFC = hidden_size+0
        self.hidden_size_HPC = hidden_size+0
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)
        # nn.init.normal_(self.PFC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.Re.weight_ih.data,0,0.1/10)
        # nn.init.normal_(self.PFC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.Re.weight_hh.data,0,1/10)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = hiddens[1][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]
    
    
class MyLSTM_RNN_uniHPC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=1):
        super(MyLSTM_RNN_uniHPC, self).__init__()

        self.hidden_size_PFC = hidden_size+0
        self.hidden_size_HPC = hidden_size+0
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)
        # nn.init.normal_(self.PFC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.Re.weight_ih.data,0,0.1/10)
        # nn.init.normal_(self.PFC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.Re.weight_hh.data,0,1/10)

    def forward(self, input, hiddens):
        Re_input = hiddens[0][0]
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]



class MyLSTM_RNN_uniPFCHPC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=1):
        super(MyLSTM_RNN_uniPFCHPC, self).__init__()

        self.hidden_size_PFC = hidden_size+0
        self.hidden_size_HPC = hidden_size+0
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)
        # nn.init.normal_(self.PFC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.Re.weight_ih.data,0,0.1/10)
        # nn.init.normal_(self.PFC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.Re.weight_hh.data,0,1/10)

    def forward(self, input, hiddens):
        Re_input = hiddens[0][0]
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = hiddens[1][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]


In [None]:
##########    Test for mixture weight!!!!!!!!!!!!!   #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import glob
import scipy

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniPFC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniHPC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#         else:
#             traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#             traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
#                 target = np.array([[0.5,0] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def plot_distance_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    plt.vlines(linelist,np.min(result),np.max(result))
    
    return result

def distance_bet2traj(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    return result

def plot_activity_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    threshold = 0.5
    for i in range(traj1.shape[0]):
        result.append(np.abs(traj1[i]-traj2[i]))
        
    plt.figure()
    for i,data in enumerate(np.array(result).T):
        if np.any(data[20:105]>threshold):
            print(i)
            plt.plot(data)
    
    return result

def search_delay(traj):
    linelist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if traj[i,1] > 0.45 and flag == False:
            linelist = np.append(linelist,i)
            flag = True
        if traj[i,1] < 0.45 and flag == True:
            linelist = np.append(linelist,i)
            flag = False
    return linelist

def pick_delay(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
        start = int(k)
    states_list.append(traj[start:])
    return states_list

def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(traj[start:])
    return states_list

def match_length(states_list):
    min_length = 1000
    result = []
    for state in states_list:
        min_length = np.min((min_length, len(state)))
    for state in states_list:
        result.append(state[-min_length:])
        
    return result

def vec_var(datas):
    datas = np.array(datas)
    average = np.average(datas,axis=0)
    result = 0
    for data in datas:
        result += np.linalg.norm(data-average)
    result /= datas.shape[0]
    return result
    
def lyapunov_exp(data):
    result = np.mean(np.log(np.abs(np.diff(data))))
    return result

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    
    
class MyLSTM_feedforward(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_feedforward, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.Linear(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input)
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]
    

class MyLSTM_feedforward_Thalamus2(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_feedforward_Thalamus2, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.hidden_size_THinh = hidden_size
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.Linear(self.hidden_size_PFC+self.hidden_size_HPC+self.hidden_size_THinh, self.hidden_size_Re)
        self.THinh = nn.Linear(self.hidden_size_Re, self.hidden_size_THinh)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        self.ReLU = nn.ReLU()
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0],hiddens[3]],dim=1)
        Re_hidden = self.Re(Re_input)
        THinh_input = hiddens[2]
        THinh_hidden = self.ReLU(self.THinh(THinh_input))
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden,THinh_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        THinh_hidden = self.ReLU(torch.rand(self.batch_size, self.hidden_size_THinh))
        return [PFC_hidden,HPC_hidden,Re_hidden,THinh_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        THinh_hidden = self.ReLU(torch.rand(self.batch_size, self.hidden_size_THinh)*const)
        return [PFC_hidden,HPC_hidden,Re_hidden,THinh_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re,self.hidden_size_THinh]

    
    

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0.2
        var = 0.2
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
#         HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
#         PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.1
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens


def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq    


def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 2
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_feedforward(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_feedforward_Thalamus2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 120
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
#             output,hidden = rnn(output,hidden)
#             traj.append(output.tolist())
#             PFCstate.append(hidden[0][0].tolist())
#             HPCstate.append(hidden[1][0].tolist())
#             Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    pattern = 3
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    dividenum = int(np.array(PFCstate)[0:,0].shape[0]/2)
    traj_noise = []
    PFCstate_noise = []
    HPCstate_noise = []
    Restate_noise = []
    Gate_states_noise = []
    hidden = rnn.initHidden_rand()
#     data = mkOwnRandomBatch(train_y, batch_size)
#     init_point = torch.rand(10,2)*1
#     data_limit = 125
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
# #             hidden = rnn.noiseHidden_rand(hidden)
#             output,hidden = rnn(output,hidden) 
# #             hidden = rnn.noiseHidden_dis(hidden, np.array(Restate)[k+data_limit], np.array(Restate)[dividenum+k+data_limit])
# #             hidden = rnn.noiseHidden_rand(hidden)
#             traj_noise.append(output.tolist())
#             PFCstate_noise.append(hidden[0][0].tolist())
#             HPCstate_noise.append(hidden[1][0].tolist())
#             Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
    traj_noise = torch.tensor(traj_noise)
    traj_noise = torch.squeeze(traj_noise).numpy()

    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")


#     plt.figure()
#     plt.plot(traj[:,0,0],color=colors[-1])
#     plt.plot(traj_noise[:,0,0],color=colors[-2])


#     PFC_corrlist = np.array([])
#     HPC_corrlist = np.array([])
#     cross_corrlist = np.array([])
#     for i in range(20):
#         plt.figure()
# #         plt.ylim(-1,1)
#         plt.plot(traj[:,0,0])
#         plt.plot(traj_noise[:,0,0])
# #         plt.plot(np.abs(traj[:,0,0]-traj_noise[:,0,0]))
#     #         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
#         for k in range(1):
# #                 plt.plot(np.array(PFCstate)[:,k,i],alpha=0.5)
# #                 plt.plot(np.array(PFCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),alpha=0.5)
# #                 plt.plot(np.array(HPCstate)[:,k,i],alpha=0.5)
# #                 plt.plot(np.array(HPCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]),alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[:120,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[:120,k,i]-np.array(PFCstate)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[:120,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[:120,k,i]-np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
# #             plt.plot(np.array(Restate)[:,i],alpha=0.5)
# #             plt.plot(np.array(Restate_noise)[:,i],alpha=0.5)
# #             plt.plot(np.array(Restate)[:,i]-np.array(Restate_noise)[:,i],alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate)[:,i]),alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate)[:,i])-moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
# #                 plt.plot(np.array(Gate_states)[:,0,0+i],alpha=0.5)
# #                 plt.plot(np.array(Gate_states_noise)[:,1,0+i],alpha=0.5)
# #                 plt.plot(np.array(Gate_states_noise)[:,1,20+i],alpha=0.5)
# #                 plt.plot(np.log(np.array(Gate_states)[:,1,0+i]/np.array(Gate_states)[:,1,20+i]),alpha=0.5)
# #                 plt.plot(np.log(np.array(Gate_states_noise)[:,1,0+i]/np.array(Gate_states_noise)[:,1,20+i]),alpha=0.5)
# #                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
# #                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
# #                 print(np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 cross_corrlist = np.append(cross_corrlist,np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 PFC_corrlist = np.append(PFC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
#                 HPC_corrlist = np.append(HPC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
# #         plt.title("neuron#"+str(i+1))
#     print(np.mean(PFC_corrlist),np.mean(HPC_corrlist),np.mean(cross_corrlist))
    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
    dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
#     dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("PFC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
    PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
    PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
    PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]))[0,1]
    PFC_diff = np.sqrt(np.power(feature[:data_limit,0]-feature[data_limit:,0],2) + np.power(feature[:data_limit,1]-feature[data_limit:,1],2) + np.power(feature[:data_limit,2]-feature[data_limit:,2],2))
    PFC_d = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),PFC_diff)[0,1]
    PFC_diff2 = np.sum(np.abs(np.array(PFCstate)[:,0]-np.array(PFCstate_noise)[:,0]),axis=1)
    PFC_e = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),PFC_diff2)[0,1]
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
#     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
#     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
#     plt.show()

#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()

#     print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("HPC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
    HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
    HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
    HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     HPC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
    HPC_diff = np.sqrt(np.power(feature[:data_limit,0]-feature[data_limit:,0],2) + np.power(feature[:data_limit,1]-feature[data_limit:,1],2) + np.power(feature[:data_limit,2]-feature[data_limit:,2],2))
    HPC_d = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),HPC_diff)[0,1]
    HPC_diff2 = np.sum(np.abs(np.array(HPCstate)[:,0]-np.array(HPCstate_noise)[:,0]),axis=1)
    HPC_e = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),HPC_diff2)[0,1]
    print(HPC_e)
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,3]-feature[100:,3]))
#     plt.plot(np.abs(feature[:100,4]-feature[100:,4]))
#     plt.plot(np.abs(feature[:100,5]-feature[100:,5]))
#     plt.show()
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()

#     print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))


    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.concatenate((np.array(Restate)[:],np.array(Restate_noise)[:]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("Re correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
#     Re_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]
#     Re_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]
#     Re_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]
    
#     plt.figure()
# #     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
# #     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
# #     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
# #     plt.plot(np.abs(moving_average(feature[:100,0])-moving_average(feature[100:,0])))
# #     plt.plot(np.abs(moving_average(feature[:100,1])-moving_average(feature[100:,1])))
# #     plt.plot(np.abs(moving_average(feature[:100,2])-moving_average(feature[100:,2])))
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()


####################### Fluc(frac) and coherence part #############################


    pca = PCA()
#     dfs = np.array(HPCstate)[40:,0]
    dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    delaysB = pick_delay(traj[:,0], feature[:])

    PCAnum = 0
    data = feature[:,PCAnum]
    movingA = moving_average(data)
    fracA = data - movingA[2:-2]
    print(math.dist(data,movingA[2:-2]))
    frac_amp = math.dist(data,movingA[2:-2])
    
#     frac_amp = 0
#     for i in range(20):
#         data = np.array(HPCstate)[:,0,i]
#         movingA = moving_average(data)
#         fracA = data - movingA[2:-2]
#         frac_amp += math.dist(data,movingA[2:-2])
#     frac_amp = frac_amp/20
    
#     fig = plt.figure()
#     fig1 = fig.add_subplot(111)
#     fig1.hist(np.array(Gate_states)[:,0,0:20].ravel(),bins=30,range=(-0.1,1.1),alpha=0.5,density=True)
#     fig1.hist(np.array(Gate_states)[:,1,0:20].ravel(),bins=30,range=(-0.1,1.1),alpha=0.5,density=True)
#     plt.show()

#     hist = np.histogram(np.array(Gate_states)[:,0,0:20].ravel(),bins=30,range=(-0.1,1.1),density=True)
#     data = np.array([np.clip(hist[1][3:-2],0,1),hist[0][2:-2]]).T
    data = np.array(Gate_states)[:,0,0:20].ravel()
    try:
        beta_PFC = scipy.stats.beta.fit(data, floc=0)
        print(beta_PFC)
        beta_param = np.average([beta_PFC[0],beta_PFC[1]])
    except Exception:
        print("Error: maybe takes negative a or b")
        beta_param = 0.5
        
#     data = np.array(Gate_states)[:,1,0:20].ravel()
#     try:
#         beta_HPC = scipy.stats.beta.fit(data)
#         print(beta_HPC)
#     except Exception:
#         print("Error: maybe takes negative a or b")
        

#     pca = PCA()
# #     dfs = np.array(HPCstate)[:,0]
#     dfs = np.array(Restate)
#     pca.fit(dfs)
#     Refeature = pca.transform(dfs)

#     shift = 2
#     seglen = 60
        
#     coherence_diff_list = []
#     for i in range(20):
#         for k in range(1):
#             data = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
#             freqs,times,sx1 = signal.stft(data,fs=1,window="boxcar",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
#             data = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
#             freqs,times,sx2 = signal.stft(data*1,fs=1,window="boxcar",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
#             data = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
#             freqs,times,sx3 = signal.stft(data,fs=1,window="boxcar",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            
#             xsp = sx1*np.conjugate(sx2)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx2))
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("Re-PFC neuron#"+str(i+1))
# #             print(np.max(coherence),np.min(coherence))
            
# #             plt.figure()
# #             plt.plot(times, np.sum(coherence.T,axis=1))
            
# #             xsp = sx1*np.conjugate(sx3)
# #             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx3))            
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("Re-HPC neuron#"+str(i+1))
# #             print(np.max(coherence),np.min(coherence))

            
# #             xsp = sx2*np.conjugate(sx3)
# #             coherence = (np.abs(xsp)**2)/(np.abs(sx2)*np.abs(sx3))
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("PFC-HPC neuron#"+str(i+1))
            
# #             degree = np.degrees(np.angle(xsp))
# #             print(degree.shape)
# #             plt.figure()
# #             plt.plot(degree)

# #             plt.figure()
# #             plt.plot(times, np.sum(coherence.T,axis=1))
#             coherence_diff_list.append(np.max(np.sum(coherence.T,axis=1))-np.min(np.sum(coherence.T,axis=1)))
    
    
    return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),frac_amp
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]),HPC_d
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, frac_amp
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, np.sum(coherence_diff_list)
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])
#     return 0,0,0,0,beta_param


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1

    allratio_list = []
    good_points = [[],[]]
    bad_points = [[],[]]
    for num in range(3):
        for i in range(10):
#             path = 'model/R20_H_bigbatch/'
            path = 'model/R20_H_uniHPC_bigbatch/'
#             path = 'model/R20_H_stopinit_bigbatch/'
#             path = 'model/R20FF_H_bigbatch/'
#             path = 'model/R20_feedReinhReLU_H_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+1)+'_100_'+str(num+1)+'_*epoch*.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            ratio_list = []
            ratio_list_max = []
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()
#             good_list = []
            if i+1 == 4 and num+1 == 3:
                continue
            if i+1 == 5 and num+1 == 2:
                continue
            
            first_goodmodel = [0,0]
            good_flag = False
            k=0
            for model in model_list:
                print(model)
                if int(model.split("epoch")[-1].split(".")[0])>194:
                    continue
#                 if int(model.split("epoch")[-1].split(".")[0])<19:
#                     continue
    #             PFC,HPC = main(model)
#                 PFC,HPC,PFC_max,HPC_max = main(model)
                PFC,HPC,PFC_max,HPC_max,frac = main(model)
    #             ratio_list.append(PFC/HPC)
                ratio_list.append(np.abs(PFC-HPC))
    #             ratio_list_max.append(PFC_max/HPC_max)
#                 ratio_list_max.append(np.abs(PFC_max-HPC_max))
#                 ratio_list_max.append(HPC_max)
#                 ratio_list_max.append(frac)
                ratio_list_max.append(PFC_max)
#                 correlation_fig.plot(PFC_max,HPC_max,"o")
#                 if model in good_list:
#                     correlation_fig.plot(PFC_max,HPC_max,"o",color="b")
#                 else:
#                     correlation_fig.plot(PFC_max,HPC_max,"o",color="r")
    #             correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
    #             correlation_fig.plot(k*5,PFC/HPC,"o")
        #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
                print(PFC_max,HPC_max)
                print("dist:"+str(frac))
#                 if good_flag != True and model in good_list:
#                     good_flag = True
#                     first_goodmodel[0] = int(model.split("epoch")[-1].split(".")[0])
#     #                 first_goodmodel[1] = ratio_list[-1]
#                     first_goodmodel[1] = ratio_list_max[-1]
                if model in good_list:
                    good_points[0].append(ratio_list_max[-1]) 
                    good_points[1].append(int(model.split("epoch")[-1].split(".")[0]))
                else:
                    bad_points[0].append(ratio_list_max[-1]) 
                    bad_points[1].append(int(model.split("epoch")[-1].split(".")[0]))
                k+=1
    #         correlation_fig.plot(np.arange(0,200,5),np.array(ratio_list)-np.mean(ratio_list),"o")
    #         ratio_list = np.array(ratio_list).clip(-2,2)
    #         ratio_list = moving_average(ratio_list)[2:-2]
    #         correlation_fig.plot(np.arange(0,len(ratio_list)*5,5),np.array(ratio_list))

#             ratio_list_max = np.array(ratio_list_max).clip(-2,2)
#             ratio_list_max = moving_average(ratio_list_max)[4:-2]
#             correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
#             if good_flag == True:
#                 correlation_fig.plot(first_goodmodel[0],first_goodmodel[1],"o",color="C{}".format(i))
            allratio_list.append(np.array(ratio_list_max))
    correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.average(np.array(allratio_list), axis=0),color="b")
    correlation_fig.errorbar(np.arange(0,len(ratio_list_max)*5,5),np.average(np.array(allratio_list), axis=0),yerr=np.sqrt(np.var(np.array(allratio_list), axis=0)), color="b", alpha=0.3)
#     correlation_fig.set_ylim(0,1)
#     np.save("ReFFInhReLU_HPCave.npy",np.mean(np.array(allratio_list),axis=0))
#     np.save("ReFFInhReLU_HPCvar.npy",np.var(np.array(allratio_list),axis=0))
    np.save("uniHPC_PFCmax_good.npy",np.array(good_points))
    np.save("uniHPC_PFCmax_bad.npy",np.array(bad_points))

        

# Compare!!!!!!!!!!!!!!!!

In [None]:
##################################  Compare!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!   #############################################

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import glob

%matplotlib notebook


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

class MyLSTM_comp(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_comp, self).__init__()

        self.hidden_size = hidden_size
        self.LSTM1 = nn.LSTMCell(input_size, hidden_size)
        self.LSTM2 = nn.LSTMCell(hidden_size, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size

    def forward(self, input, hiddens):
        input = input.float()
        hidden1 = self.LSTM1(input, hiddens[0])
        hidden2 = self.LSTM2(hidden1[0], hiddens[1])
        output = self.linear(hidden2[0])
        return output, [hidden1,hidden2]

    def initHidden(self):
        hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [hidden,hidden]
    
    def initHidden_rand(self):
        hidden = [torch.rand(self.batch_size, self.hidden_size)*0.01, torch.rand(self.batch_size, self.hidden_size)*0.01]
        return [hidden,hidden]  

class MyLSTM_PFCHPC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=1):
        super(MyLSTM_PFCHPC, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re = hiddens[0][0]
        HPC_input = torch.cat([input,Re],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = hiddens[1][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size)*const, torch.rand(self.batch_size, self.hidden_size)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size)*var, torch.rand(self.batch_size, self.hidden_size)*var]
        return [PFC_hidden,HPC_hidden]
    
    
def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(states[start:])
    return states_list


def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq

    

def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 30
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    # model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_cue18_s5_100_1_2.pth'
    # model_path = 'model/ReModel_L2_interRNNrand_AddRe_OUT5_cue13_s9_100_1.pth'
    # model_path = 'model/R20_cue_131/ReModel_L2_interRNNrand_OUT1_cue7_131_s9_100_2_epoch150.pth'
#     model_path = 'model/R20_cue_131/ReModel_L2_interRNNrand_OUT1_cue7_131_s9_100_3_epoch195.pth'
#     model_path = 'model/R20_cue_131/ReModel_L2_interRNNrand_OUT1_cue7_131_s7_100_1_epoch195.pth'
#     model_path = 'model/R20_cue_131_1to3/ReModel_L2_interRNNrand_OUT1_cue7_131_1to3_s4_100_1_epoch175.pth'
    model_path = model
    # model_path = 'model/v4_2Model_MTRNN2_cue_9.pth'
    filename = "primal_long131test"
    sparse = 1
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_PFCHPC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    # hidden = rnn.initHidden()
    data_limit = 120
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
#     for k in range(data.shape[0]*8):
#             output,hidden = rnn(output,hidden) 
#             traj.append(output.tolist())
#             PFCstate.append(hidden[0][0].tolist())
#             HPCstate.append(hidden[1][0].tolist())
#             Restate.append(hidden[2][0].tolist())
#             if np.any(cue_point==k+data_limit):
#                 output = torch.cat([output,torch.ones(10,1)],axis=1)
#             else:
#                 output = torch.cat([output,torch.zeros(10,1)],axis=1)
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
#     fig = plt.figure()
#     print(pltdata.shape)
#     plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
#     plt.plot(traj[:,0,0],traj[:,0,1])
#     plt.show()
#     MakeAnimation2(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    # MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    # MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")
    # MakeAnimation_img(np.array(Restate),"Re")
    #MakeAnimation_testdata(pltdata[:,0,0],pltdata[:,0,1])
    
    
    traj_noise = []
    PFCstate_noise = []
    HPCstate_noise = []
    Restate_noise = []
    Gate_states_noise = []
    hidden = rnn.initHidden_rand()
    
    pattern = 3
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
#     data_limit = 125
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
#     for k in range(data.shape[0]*8):
# #             hidden = rnn.noiseHidden_rand(hidden)
#             output,hidden = rnn(output,hidden) 
# #             hidden = rnn.noiseHidden_dis(hidden, np.array(Restate)[k+data_limit], np.array(Restate)[dividenum+k+data_limit])
# #             hidden = rnn.noiseHidden_rand(hidden)
#             traj_noise.append(output.tolist())
#             PFCstate_noise.append(hidden[0][0].tolist())
#             HPCstate_noise.append(hidden[1][0].tolist())
#             Restate_noise.append(hidden[2][0].tolist())
#             if np.any(cue_point_noise==k+data_limit):
#                 output = torch.cat([output,torch.ones(10,1)],axis=1)
#             else:
#                 output = torch.cat([output,torch.zeros(10,1)],axis=1)
    traj_noise = torch.tensor(traj_noise)
    traj_noise = torch.squeeze(traj_noise).numpy()
    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
#     MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_b = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw_b = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w_b = np.array(p.data)
            if n == "Re.weight_hh":
                Re_inw_b = np.array(p.data)
    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
#     MakeAnimation(traj[:,0,0],traj[:,0,1], traj_noise[:,0,0], traj_noise[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")

#     fig = plt.figure()
#     plt.plot(traj[:,0,0],traj[:,0,1])
# #     plt.plot(traj_noise[:,0,0],traj_noise[:,0,1])
#     plt.show()

#     plt.figure()
#     plt.plot(traj[:,0,0])
#     plt.plot(traj_noise[:,0,0])


#     PFC_corrlist = np.array([])
#     HPC_corrlist = np.array([])
#     cross_corrlist = np.array([])
#     for i in range(20):
#         plt.figure()
#         plt.ylim(-1,1)
#         plt.plot(traj[:,0,0])
#         plt.plot(traj_noise[:,0,0])
# #         plt.plot(np.abs(traj[:,0,0]-traj_noise[:,0,0]))
#     #         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
#         for k in range(1):
# #                 plt.plot(np.array(PFCstate)[:,k,i],alpha=0.5)
# #                 plt.plot(np.array(PFCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),alpha=0.5)
# #                 plt.plot(np.array(HPCstate)[:,k,i],alpha=0.5)
# #                 plt.plot(np.array(HPCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]),alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[:120,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[:120,k,i]-np.array(PFCstate)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[:120,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[:120,k,i]-np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
# #             plt.plot(np.array(Restate)[:,i],alpha=0.5)
# #             plt.plot(np.array(Restate_noise)[:,i],alpha=0.5)
# #             plt.plot(np.array(Restate)[:,i]-np.array(Restate_noise)[:,i],alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate)[:,i]),alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate)[:,i])-moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
# #                 plt.plot(np.array(Gate_states)[:,0,0+i],alpha=0.5)
# #                 plt.plot(np.array(Gate_states_noise)[:,1,0+i],alpha=0.5)
# #                 plt.plot(np.array(Gate_states_noise)[:,1,20+i],alpha=0.5)
# #                 plt.plot(np.log(np.array(Gate_states)[:,1,0+i]/np.array(Gate_states)[:,1,20+i]),alpha=0.5)
# #                 plt.plot(np.log(np.array(Gate_states_noise)[:,1,0+i]/np.array(Gate_states_noise)[:,1,20+i]),alpha=0.5)
# #                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
# #                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
# #                 print(np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 cross_corrlist = np.append(cross_corrlist,np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 PFC_corrlist = np.append(PFC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
#                 HPC_corrlist = np.append(HPC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
# #         plt.title("neuron#"+str(i+1))
#     print(np.mean(PFC_corrlist),np.mean(HPC_corrlist),np.mean(cross_corrlist))
    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
    dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
#     dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    ratios = pca.explained_variance_ratio_
    print("PFC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
    PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
    PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
    PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]*ratios[0]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]*ratios[1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]*ratios[2]

#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
#     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
#     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
#     plt.show()
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()

    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    ratios = pca.explained_variance_ratio_
    print("HPC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
    HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
    HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
    HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     HPC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]*ratios[0]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]*ratios[1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]*ratios[2]
    HPC_diff = np.sqrt(np.power(feature[:data_limit,0]-feature[data_limit:,0],2) + np.power(feature[:data_limit,1]-feature[data_limit:,1],2) + np.power(feature[:data_limit,2]-feature[data_limit:,2],2))
    HPC_d = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),HPC_diff)[0,1]
    HPC_diff2 = np.sum(np.abs(np.array(HPCstate)[:,0]-np.array(HPCstate_noise)[:,0]),axis=1)
    HPC_e = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),HPC_diff2)[0,1]
    print(HPC_e)
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
#     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
#     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
#     plt.show()

#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_a,PFC_b,PFC_c]),HPC_e
#     print(traj[:,0,1],traj_noise[:,0,1],np.abs(traj[:,0,1]-traj_noise[:,0,1]))
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])

if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/compare30_131/*s10_100_1_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1
    
#     for i in range(5):
# #         path = 'model/compare30_cue_131/'
#         path = 'model/PFCHPC30_H/'
#         model_list = glob.glob(path+'*s'+str(i+6)+'_100_2_*epoch*.pth')
#         model_list = sorted(model_list)
#         model_list = sorted(model_list,key=len,reverse=False)
#         ratio_list = []
#         ratio_list_max = []
#         with open(path+"good_list.txt", mode="r") as f:
#             good_list = f.read().splitlines()
#         first_goodmodel = [0,0]
#         good_flag = False
#         k=0
#         for model in model_list:
#             print(model)
# #             PFC,HPC = main(model)
#             PFC,HPC,PFC_max,HPC_max = main(model)
# #             ratio_list.append(PFC/HPC)
#             ratio_list.append(np.abs(PFC-HPC))
# #             ratio_list_max.append(PFC_max/HPC_max)
#             ratio_list_max.append(np.abs(PFC_max-HPC_max))
# #             correlation_fig.plot(PFC,HPC,"o")
# #             correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
# #             correlation_fig.plot(k*5,PFC/HPC,"o")
#     #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#             print(PFC_max,HPC_max)
#             if good_flag != True and model in good_list:
#                 good_flag = True
#                 first_goodmodel[0] = int(model.split("epoch")[-1].split(".")[0])
# #                 first_goodmodel[1] = ratio_list[-1]
#                 first_goodmodel[1] = ratio_list_max[-1]
#             k+=1
# #         correlation_fig.plot(np.arange(0,200,5),np.array(ratio_list)-np.mean(ratio_list),"o")
# #         ratio_list = np.array(ratio_list).clip(-2,2)
# #         ratio_list = moving_average(ratio_list)[2:-2]
# #         correlation_fig.plot(np.arange(0,len(ratio_list)*5,5),np.array(ratio_list))
        
#         ratio_list_max = np.array(ratio_list_max).clip(-2,2)
# #         ratio_list_max = moving_average(ratio_list_max)[4:-2]
#         correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
#         if good_flag == True:
#             correlation_fig.plot(first_goodmodel[0],first_goodmodel[1],"o",color="C{}".format(i))
            
    allratio_list = []
    for num in range(3):
        for i in range(10):
#             path = 'model/compare30_H_bigbatch/'
            path = 'model/PFCHPC30_H_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+1)+'_100_'+str(num+1)+'_*epoch*.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            ratio_list = []
            ratio_list_max = []
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()
    #         good_list = []
            
            first_goodmodel = [0,0]
            good_flag = False
            k=0
            for model in model_list:
                print(model)
                if int(model.split("epoch")[-1].split(".")[0])>199:
                    continue
    #             PFC,HPC = main(model)
                PFC,HPC,PFC_max,HPC_max = main(model)
    #             ratio_list.append(PFC/HPC)
                ratio_list.append(np.abs(PFC-HPC))
    #             ratio_list_max.append(PFC_max/HPC_max)
#                 ratio_list_max.append(np.abs(PFC_max-HPC_max))
                ratio_list_max.append(HPC_max)
    #             correlation_fig.plot(PFC,HPC,"o")
#                 if model in good_list:
#                     correlation_fig.plot(PFC_max,HPC_max,"o",color="b")
#                 else:
#                     correlation_fig.plot(PFC_max,HPC_max,"o",color="r")
    #             correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
    #             correlation_fig.plot(k*5,PFC/HPC,"o")
        #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
                print(PFC_max,HPC_max)
                if good_flag != True and model in good_list:
                    good_flag = True
                    first_goodmodel[0] = int(model.split("epoch")[-1].split(".")[0])
    #                 first_goodmodel[1] = ratio_list[-1]
                    first_goodmodel[1] = ratio_list_max[-1]
                k+=1
    #         correlation_fig.plot(np.arange(0,200,5),np.array(ratio_list)-np.mean(ratio_list),"o")
    #         ratio_list = np.array(ratio_list).clip(-2,2)
    #         ratio_list = moving_average(ratio_list)[2:-2]
    #         correlation_fig.plot(np.arange(0,len(ratio_list)*5,5),np.array(ratio_list))

#             ratio_list_max = np.array(ratio_list_max).clip(-2,2)
#             ratio_list_max = moving_average(ratio_list_max)[4:-2]
#             correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
#             if good_flag == True:
#                 correlation_fig.plot(first_goodmodel[0],first_goodmodel[1],"o",color="C{}".format(i))
            allratio_list.append(np.array(ratio_list_max))
    correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.average(np.array(allratio_list), axis=0),color="b")
    np.save("PFCHPC_HPCmaxave.npy",np.mean(np.array(allratio_list),axis=0))
    np.save("PFCHPC_HPCmaxvar.npy",np.var(np.array(allratio_list),axis=0))
        
        

In [None]:
### あとち　PFCHPCのPFCのcriterion

In [None]:
##########    Test for mixture weight!!!!!!!!!!!!!   #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#         else:
#             traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#             traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def plot_distance_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    plt.vlines(linelist,np.min(result),np.max(result))
    
    return result

def distance_bet2traj(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    return result

def plot_activity_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    threshold = 0.5
    for i in range(traj1.shape[0]):
        result.append(np.abs(traj1[i]-traj2[i]))
        
    plt.figure()
    for i,data in enumerate(np.array(result).T):
        if np.any(data[20:105]>threshold):
            print(i)
            plt.plot(data)
    
    return result

def search_delay(traj):
    linelist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if traj[i,1] > 0.45 and flag == False:
            linelist = np.append(linelist,i)
            flag = True
        if traj[i,1] < 0.45 and flag == True:
            linelist = np.append(linelist,i)
            flag = False
    return linelist

def pick_delay(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
        start = int(k)
    states_list.append(traj[start:])
    return states_list

def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(traj[start:])
    return states_list

def match_length(states_list):
    min_length = 1000
    result = []
    for state in states_list:
        min_length = np.min((min_length, len(state)))
    for state in states_list:
        result.append(state[-min_length:])
        
    return result

def vec_var(datas):
    datas = np.array(datas)
    average = np.average(datas,axis=0)
    result = 0
    for data in datas:
        result += np.linalg.norm(data-average)
    result /= datas.shape[0]
    return result
    
def lyapunov_exp(data):
    result = np.mean(np.log(np.abs(np.diff(data))))
    return result

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0.2
        var = 0.2
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
#         HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
#         PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.1
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens


def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq    


def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 3
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniPFC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 600
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*est_length+10):
            output,hidden = rnn(output,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    pca = PCA()
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0],np.array(Restate)),axis=1)
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0]),axis=1)
    dfs = np.array(PFCstate)[:,0]
#     dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
#     delays = pick_traj(traj_noise[:,0], np.array(PFCstate_noise)[:,0])
#     delays = pick_traj(traj[:,0], np.array(PFCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(HPCstate_noise)[:,0])
#     delays = pick_traj(traj[:,0], np.array(HPCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(Restate_noise)[:])
#     delays = pick_traj(traj[:,0], np.array(Restate)[:])
#     delays = pick_traj(traj_noise[:,0], np.array(traj_noise)[:])[3:-1]
#     delays = pick_traj(traj[:,0], np.array(Gate_states)[:,1])
#     delays = pick_delay(traj[:,0], np.array(Gate_states)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(Gate_states_noise)[:,1])
#     delays = pick_delay(traj_noise[:,0], np.array(Gate_states_noise)[:,0])
#     delays = pick_delay(traj[:,0], feature[:])
    delays = pick_traj(traj[:,0], feature[:])
#     bifur = np.array([])
#     for i in range(3):
#         plt.figure()
#         for data in delays[1:-1]:
#             plt.plot(np.array(data)[:,0+i],alpha=0.5)
#             print(len(data))
#         plt.title("neuron#"+str(i+1))
# #     print(np.var(bifur),np.median(bifur)-np.mean(bifur))

    totalresult = 0
    for i in range(3):
        data_len = 1000
        delays_samelen = []
        result = 0
        for data in delays[1:-1]:
            data_len = np.min([data_len,len(data)])
        for data in delays[1:-1]:
            delays_samelen.append(data[:data_len])
        delays_samelen = np.array(delays_samelen)
        for k in range(data_len):
            result += np.var(delays_samelen[:,k,i])
        result /= data_len
#         print(result)
        totalresult += result
#     print(np.var(bifur),np.median(bifur)-np.mean(bifur))
    return totalresult


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1

    for i in range(10):
        path = 'model/R20_H_bigbatch/'
#         path = 'model/R20_H_uniPFC_bigbatch/'
        model_list = glob.glob(path+'*s'+str(i+1)+'_100_3_*epoch*.pth')
        model_list = sorted(model_list)
        model_list = sorted(model_list,key=len,reverse=False)
        var_list = []
#         with open(path+"good_list.txt", mode="r") as f:
#             good_list = f.read().splitlines()
#         good_list = []
        first_goodmodel = [0,0]
        good_flag = False
        k=0
        for model in model_list:
            print(model)
#             PFC,HPC = main(model)
            var_list.append(main(model))
        correlation_fig.plot(var_list)
        

In [None]:
#v4 without Re
#v4_1 revised forward function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd


def Culc_gate(input, params, hiddens):
    Re = hiddens[0][0]
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,Re],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_v4_2_NN2r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_testdata(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_test_1.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter_v420_PFC.gif', writer='pillow')
    plt.show()
    
    return 0



def MakeAnimation_3D(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)


class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size

    def forward(self, input, hiddens):
        input = input.float()
        Re = hiddens[0][0]
        HPC_input = torch.cat([input,Re],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = hiddens[1][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden]



def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 30
    batch_size = 10
    data_length = 200
    inputsize = 2
    outputsize = 2
    # model_path = 'model/v4_2Model_inter_long_s10.pth'
    model_path = 'model/PFCHPC30_H/v4_3_121_s3_100_1_epoch170.pth'
    # model_path = 'model/PFCHPC_131/v4_3_131_s8_100_1_epoch150.pth'
    # model_path = 'model/PFCHPC_30_131/v4_3_N30_131_s3_100_2_epoch85.pth'
    filename = "primal_long"
    
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)
    
    pattern = 2
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()                
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    

    rnn = MyLSTM(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
    traj = []
    PFCstate = []
    HPCstate = []
    Gate_states = []
    hidden = rnn.initHidden()
    data_limit = 2
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*1):
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
#             print(output)
            Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    fig = plt.figure()
    print(pltdata.shape)
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:,0,0],traj[:,0,1])
    plt.show()
    
    # MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")
    #MakeAnimation_testdata(pltdata[:,0,0],pltdata[:,0,1])

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
                

    fig2 = plt.figure()
    ax1 = fig2.add_subplot(121)
    ax2 = fig2.add_subplot(122)
    
#     ax1.imshow(PFC_inw,cmap="coolwarm",vmin=-3,vmax=3)
#     ax2.imshow(HPC_inw,cmap="coolwarm",vmin=-3,vmax=3)
#     ax1.imshow(PFC_w,cmap="coolwarm")
#     ax2.imshow(HPC_w,cmap="coolwarm")
    ax1.imshow(PFC_inw,cmap="coolwarm")
    ax2.imshow(HPC_inw,cmap="coolwarm")
    ax1.set_title("max = {:.2f},min = {:.2f}".format(np.max(PFC_w),np.min(PFC_w)))
    ax2.set_title("max = {:.2f},min = {:.2f}".format(np.max(HPC_w),np.min(HPC_w)))
    ax1.set_title("max = {:.2f},min = {:.2f}".format(np.max(PFC_inw),np.min(PFC_inw)))
    ax2.set_title("max = {:.2f},min = {:.2f}".format(np.max(HPC_inw),np.min(HPC_inw)))
    
    
    fig3 = plt.figure()
    ax3 = fig3.add_subplot(121)
    ax4 = fig3.add_subplot(122)
    ax3.imshow(np.corrcoef(np.array(PFCstate)[:,0]))
    ax4.imshow(np.corrcoef(np.array(HPCstate)[:,0]))
    ax3.set_title("PFC")   
    ax4.set_title("HPC")  
    
    pca = PCA()
    dfs = np.array(PFCstate)[:,0]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    print(pd.DataFrame(pca.components_, columns=["Hidden{}".format(x + 1) for x in range(dfs.shape[1])], index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    pred = KMeans(n_clusters=2).fit_predict(feature)
    
    plt.figure(figsize=(6, 6))
    plt.scatter(feature[:, 0], feature[:, 1], alpha=0.8)
    #plt.scatter(feature[100:200, 0], feature[100:200, 1], alpha=0.8)
    plt.grid()
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    for i in range(200):
        plt.annotate(i,(feature[i, 0], feature[i, 1]))
    plt.show()
    
    #MakeAnimation_attracter(feature[:, 0], feature[:, 1])
    
    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
    ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.3)
    
#     fig_place = plt.figure()
#     place3d = Axes3D(fig_place)
#     place3d.plot(traj[:,0,0], traj[:,0,1], np.array(HPCstate)[:,0,0])
    
    fig4 = plt.figure(figsize=(10,5))
    ax5 = fig4.add_subplot(211)
    ax6 = fig4.add_subplot(212)
    #ax5.plot(np.array(PFCstate)[:,0,pred==1])
    ax5.imshow(np.array(PFCstate)[:,0,:].T)
    # ax5.plot(np.average(np.array(PFCstate)[:,0,pred==0],axis=1))
    # ax5.plot(np.average(np.array(PFCstate)[:,0,pred==1],axis=1))
    # ax5.plot(np.average(np.array(HPCstate)[:,0],axis=1))
    ax6.imshow(np.array(HPCstate)[:,0,:].T)
    ax5.set_title("PFC")   
    ax6.set_title("HPC")  
    print(pred==0)
    
    # fig5 = plt.figure(figsize=(5,5))
    # plt.plot(np.average(np.array(PFCstate)[:,0,pred==0],axis=1),np.average(np.array(PFCstate)[:,0,pred==1],axis=1))
    # plt.show()
    #MakeAnimation_attracter(np.average(np.array(PFCstate)[:,0,pred==0],axis=1),np.average(np.array(PFCstate)[:,0,pred==1],axis=1))
    
    
    # fig6 = plt.figure(figsize=(10,20))
    # plt.imshow(np.array(Gate_states)[:,0,0:].T)
    # print(np.max(np.array(Gate_states)[:,0,0:].T),np.min(np.array(Gate_states)[:,0,0:].T),np.average(np.array(Gate_states)[:,0,0:].T))
    # plt.show()
    # fig7 = plt.figure(figsize=(10,20))
    # plt.imshow(np.array(Gate_states)[:,1,0:].T)
    # print(np.max(np.array(Gate_states)[:,1,0:].T),np.min(np.array(Gate_states)[:,1,0:].T),np.average(np.array(Gate_states)[:,1,0:].T))
    # plt.show()
    
    # plt.figure()
    # plt.plot(np.array(Gate_states)[:,1,14])
    # plt.plot(np.array(Gate_states)[:,1,34])
    # plt.plot(np.array(Gate_states)[:,1,54])
    # plt.plot(np.array(Gate_states)[:,1,74])
    # plt.ylim(0,1)
    # plt.show()
    
    # plt.figure()
    # plt.plot(np.array(Gate_states)[:,1,13])
    # plt.plot(np.array(Gate_states)[:,1,33])
    # plt.plot(np.array(Gate_states)[:,1,53])
    # plt.plot(np.array(Gate_states)[:,1,73])
    # plt.ylim(0,1)
    # plt.show()
    
    for i in range(20):
        plt.figure()
        plt.plot(traj[:,0,0])
        plt.plot(np.array(HPCstate)[:,0,i])
        plt.show()
    
    
    # plt.figure(figsize=(20,20))
    # plt.plot(np.array(HPCstate)[:200,0,:10])
    # plt.show()
    
    #np.save("right_traj.npy",dfs)
    #np.save("left_traj.npy",dfs)
    
    plt.figure()
    plt.plot(traj[:,0,0])
    plt.plot(traj[:,0,1])
    
    
if __name__ == '__main__':
    main()

In [None]:
##########    Test for mixture weight!!!!!!!!!!!!!   #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#         else:
#             traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#             traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def plot_distance_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    plt.vlines(linelist,np.min(result),np.max(result))
    
    return result

def distance_bet2traj(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    return result

def plot_activity_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    threshold = 0.5
    for i in range(traj1.shape[0]):
        result.append(np.abs(traj1[i]-traj2[i]))
        
    plt.figure()
    for i,data in enumerate(np.array(result).T):
        if np.any(data[20:105]>threshold):
            print(i)
            plt.plot(data)
    
    return result

def search_delay(traj):
    linelist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if traj[i,1] > 0.45 and flag == False:
            linelist = np.append(linelist,i)
            flag = True
        if traj[i,1] < 0.45 and flag == True:
            linelist = np.append(linelist,i)
            flag = False
    return linelist

def pick_delay(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
        start = int(k)
    states_list.append(traj[start:])
    return states_list

def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(traj[start:])
    return states_list

def match_length(states_list):
    min_length = 1000
    result = []
    for state in states_list:
        min_length = np.min((min_length, len(state)))
    for state in states_list:
        result.append(state[-min_length:])
        
    return result

def vec_var(datas):
    datas = np.array(datas)
    average = np.average(datas,axis=0)
    result = 0
    for data in datas:
        result += np.linalg.norm(data-average)
    result /= datas.shape[0]
    return result
    
def lyapunov_exp(data):
    result = np.mean(np.log(np.abs(np.diff(data))))
    return result

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0.2
        var = 0.2
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
#         HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
#         PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.1
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens


def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq    


def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 3
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniPFC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 600
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*est_length+10):
            output,hidden = rnn(output,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    pca = PCA()
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0],np.array(Restate)),axis=1)
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0]),axis=1)
    dfs = np.array(PFCstate)[:,0]
#     dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
#     delays = pick_traj(traj_noise[:,0], np.array(PFCstate_noise)[:,0])
#     delays = pick_traj(traj[:,0], np.array(PFCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(HPCstate_noise)[:,0])
#     delays = pick_traj(traj[:,0], np.array(HPCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(Restate_noise)[:])
#     delays = pick_traj(traj[:,0], np.array(Restate)[:])
#     delays = pick_traj(traj_noise[:,0], np.array(traj_noise)[:])[3:-1]
#     delays = pick_traj(traj[:,0], np.array(Gate_states)[:,1])
#     delays = pick_delay(traj[:,0], np.array(Gate_states)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(Gate_states_noise)[:,1])
#     delays = pick_delay(traj_noise[:,0], np.array(Gate_states_noise)[:,0])
#     delays = pick_delay(traj[:,0], feature[:])
    delays = pick_traj(traj[:,0], feature[:])
#     bifur = np.array([])
#     for i in range(3):
#         plt.figure()
#         for data in delays[1:-1]:
#             plt.plot(np.array(data)[:,0+i],alpha=0.5)
#             print(len(data))
#         plt.title("neuron#"+str(i+1))
# #     print(np.var(bifur),np.median(bifur)-np.mean(bifur))

    totalresult = 0
    for i in range(3):
        data_len = 1000
        delays_samelen = []
        result = 0
        for data in delays[1:-1]:
            data_len = np.min([data_len,len(data)])
        for data in delays[1:-1]:
            delays_samelen.append(data[:data_len])
        delays_samelen = np.array(delays_samelen)
        for k in range(data_len):
            result += np.var(delays_samelen[:,k,i])
        result /= data_len
#         print(result)
        totalresult += result
#     print(np.var(bifur),np.median(bifur)-np.mean(bifur))
    return totalresult


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1

    foraverage = []
    for k in range(1):
        for i in range(10):
            path = 'model/R20_H_bigbatch/'
#             path = 'model/R20_H_uniPFC_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+1)+'_100_'+str(k+3)+'_*epoch1??.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            
#             if i == 3 and k == 2:
#                 continue
#             if i == 4 and k == 1:
#                 continue
            
            var_list = []
    #         with open(path+"good_list.txt", mode="r") as f:
    #             good_list = f.read().splitlines()
    #         good_list = []
            first_goodmodel = [0,0]
            good_flag = False
            for model in model_list:
                print(model)
    #             PFC,HPC = main(model)
                var_list.append(main(model))
            correlation_fig.plot(var_list)
            foraverage.append(var_list)
            print(foraverage)
    np.array(foraverage)
    plt.figure()
    plt.plot(np.mean(foraverage,axis=0))
#     np.save("uniPFCave.npy",np.mean(foraverage,axis=0))
#     np.save("uniPFCvar.npy",np.var(foraverage,axis=0))
        

In [None]:
from scipy.stats import linregress

a = np.load("Re_PFCmaxave.npy")[2:-1]
a_err = np.sqrt(np.load("Re_PFCmaxvar.npy"))[2:-1]
b = np.load("uniPFC_PFCmaxave.npy")[2:-1]
b_err = np.sqrt(np.load("uniPFC_PFCmaxvar.npy"))[2:-1]
# a = np.load("RePFCave.npy")
# a_err = np.load("RePFCvar.npy")
# b = np.load("uniPFCave.npy")
# b_err = np.load("uniPFCvar.npy")
x = np.arange(0,a.size*5,5)

model_LR_a = linregress(x,a)
model_LR_b = linregress(x,b)
# model_LR_a = linregress(x,1-a)
# model_LR_b = linregress(x,1-b)


plt.figure()
plt.plot(x,a,"o")
plt.errorbar(x,a,yerr=a_err, fmt="o", color="b", alpha=0.3)
plt.plot(x,b,"o",color="green")
plt.errorbar(x,b,yerr=b_err, fmt="o", color="green", alpha=0.3)
plt.plot(x,model_LR_a.intercept + model_LR_a.slope*x,color="C0")
plt.plot(x,model_LR_b.intercept + model_LR_b.slope*x,color="C2")
plt.show()
print(np.corrcoef(x,a),np.corrcoef(x,b))
print(st.f_oneway(np.array(a),np.array(b)))

a = np.load("Re_HPCmaxave.npy")[2:-1]
a_err = np.sqrt(np.load("Re_HPCmaxvar.npy"))[2:-1]
b = np.load("UniHPC_HPCmaxave.npy")[2:-1]
b_err = np.sqrt(np.load("UniHPC_HPCmaxvar.npy"))[2:-1]
# a = np.load("RePFCave.npy")
# a_err = np.load("RePFCvar.npy")
# b = np.load("uniPFCave.npy")
# b_err = np.load("uniPFCvar.npy")
x = np.arange(0,a.size*5,5)

model_LR_a = linregress(x,a)
model_LR_b = linregress(x,b)

plt.figure()
plt.plot(x,a,"o")
plt.errorbar(x,a,yerr=a_err, fmt="o", color="b", alpha=0.3)
plt.plot(x,b,"o")
plt.errorbar(x,b,yerr=b_err, fmt="o", color="orange", alpha=0.3)
plt.plot(x,model_LR_a.intercept + model_LR_a.slope*x,color="C0")
plt.plot(x,model_LR_b.intercept + model_LR_b.slope*x,color="C1")
plt.show()
print(np.corrcoef(x,a),np.corrcoef(x,b))
print(st.f_oneway(np.array(a),np.array(b)))

In [None]:
##########    Eigenvalue and PFC/HPCcorrelation check   #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import numpy.linalg as LA

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#         else:
#             traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#             traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

#     ###  ver1's test  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def plot_distance_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    plt.vlines(linelist,np.min(result),np.max(result))
    
    return result

def distance_bet2traj(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    return result

def plot_activity_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    threshold = 0.5
    for i in range(traj1.shape[0]):
        result.append(np.abs(traj1[i]-traj2[i]))
        
    plt.figure()
    for i,data in enumerate(np.array(result).T):
        if np.any(data[20:105]>threshold):
            print(i)
            plt.plot(data)
    
    return result

def search_delay(traj):
    linelist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if traj[i,1] > 0.45 and flag == False:
            linelist = np.append(linelist,i)
            flag = True
        if traj[i,1] < 0.45 and flag == True:
            linelist = np.append(linelist,i)
            flag = False
    return linelist

def pick_delay(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
        start = int(k)
    states_list.append(traj[start:])
    return states_list

def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(traj[start:])
    return states_list

def match_length(states_list):
    min_length = 1000
    result = []
    for state in states_list:
        min_length = np.min((min_length, len(state)))
    for state in states_list:
        result.append(state[-min_length:])
        
    return result

def vec_var(datas):
    datas = np.array(datas)
    average = np.average(datas,axis=0)
    result = 0
    for data in datas:
        result += np.linalg.norm(data-average)
    result /= datas.shape[0]
    return result
    
def lyapunov_exp(data):
    result = np.mean(np.log(np.abs(np.diff(data))))
    return result

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0.2
        var = 0.2
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
#         HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
#         PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.1
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens


def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq    


def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 3
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniPFC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 120
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
#             output,hidden = rnn(output,hidden)
#             traj.append(output.tolist())
#             PFCstate.append(hidden[0][0].tolist())
#             HPCstate.append(hidden[1][0].tolist())
#             Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    pattern = 3
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    dividenum = int(np.array(PFCstate)[0:,0].shape[0]/2)
    traj_noise = []
    PFCstate_noise = []
    HPCstate_noise = []
    Restate_noise = []
    Gate_states_noise = []
    hidden = rnn.initHidden_rand()
#     data = mkOwnRandomBatch(train_y, batch_size)
#     init_point = torch.rand(10,2)*1
#     data_limit = 125
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
# #             hidden = rnn.noiseHidden_rand(hidden)
#             output,hidden = rnn(output,hidden) 
# #             hidden = rnn.noiseHidden_dis(hidden, np.array(Restate)[k+data_limit], np.array(Restate)[dividenum+k+data_limit])
# #             hidden = rnn.noiseHidden_rand(hidden)
#             traj_noise.append(output.tolist())
#             PFCstate_noise.append(hidden[0][0].tolist())
#             HPCstate_noise.append(hidden[1][0].tolist())
#             Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
    traj_noise = torch.tensor(traj_noise)
    traj_noise = torch.squeeze(traj_noise).numpy()

    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")


#     plt.figure()
#     plt.plot(traj[:,0,1],color=colors[-1])
#     plt.plot(traj_noise[:,0,1],color=colors[-2])


#     PFC_corrlist = np.array([])
#     HPC_corrlist = np.array([])
#     cross_corrlist = np.array([])
#     for i in range(20):
#         plt.figure()
# #         plt.ylim(-1,1)
#         plt.plot(traj[:,0,0])
#         plt.plot(traj_noise[:,0,0])
# #         plt.plot(np.abs(traj[:,0,0]-traj_noise[:,0,0]))
#     #         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
#         for k in range(1):
# #                 plt.plot(np.array(PFCstate)[:,k,i],alpha=0.5)
# #                 plt.plot(np.array(PFCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),alpha=0.5)
# #                 plt.plot(np.array(HPCstate)[:,k,i],alpha=0.5)
# #                 plt.plot(np.array(HPCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]),alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[:120,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[:120,k,i]-np.array(PFCstate)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[:120,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[:120,k,i]-np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
# #             plt.plot(np.array(Restate)[:,i],alpha=0.5)
# #             plt.plot(np.array(Restate_noise)[:,i],alpha=0.5)
# #             plt.plot(np.array(Restate)[:,i]-np.array(Restate_noise)[:,i],alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate)[:,i]),alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate)[:,i])-moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
# #                 plt.plot(np.array(Gate_states)[:,0,0+i],alpha=0.5)
# #                 plt.plot(np.array(Gate_states_noise)[:,1,0+i],alpha=0.5)
# #                 plt.plot(np.array(Gate_states_noise)[:,1,20+i],alpha=0.5)
# #                 plt.plot(np.log(np.array(Gate_states)[:,1,0+i]/np.array(Gate_states)[:,1,20+i]),alpha=0.5)
# #                 plt.plot(np.log(np.array(Gate_states_noise)[:,1,0+i]/np.array(Gate_states_noise)[:,1,20+i]),alpha=0.5)
# #                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
# #                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
# #                 print(np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 cross_corrlist = np.append(cross_corrlist,np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 PFC_corrlist = np.append(PFC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
#                 HPC_corrlist = np.append(HPC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
# #         plt.title("neuron#"+str(i+1))
#     print(np.mean(PFC_corrlist),np.mean(HPC_corrlist),np.mean(cross_corrlist))
    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
    dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
#     dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("PFC correlation")
    print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0])))
    print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1])))
    print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2])))
    PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
    PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
    PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]))[0,1]
    
#     plt.figure()
#     plt.plot(np.abs(traj[:,0,0]-traj_noise[:,0,0]))
#     plt.plot(np.abs(feature[:data_limit,0]-feature[data_limit:,0]))
#     plt.plot(np.abs(feature[:data_limit,1]-feature[data_limit:,1]))
#     plt.plot(np.abs(feature[:data_limit,2]-feature[data_limit:,2]))
#     plt.show()

#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()

#     print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("HPC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
    HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
    HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
    HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     HPC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
    
#     plt.figure()
#     plt.plot(np.abs(traj[:,0,0]-traj_noise[:,0,0]))
#     plt.plot(np.abs(feature[:data_limit,0]-feature[data_limit:,0]))
#     plt.plot(np.abs(feature[:data_limit,1]-feature[data_limit:,1]))
#     plt.plot(np.abs(feature[:data_limit,2]-feature[data_limit:,2]))
#     plt.show()
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()

#     print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))


    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.concatenate((np.array(Restate)[:],np.array(Restate_noise)[:]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("Re correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
#     Re_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]
#     Re_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]
#     Re_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]
    
#     plt.figure()
# #     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
# #     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
# #     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
# #     plt.plot(np.abs(moving_average(feature[:100,0])-moving_average(feature[100:,0])))
# #     plt.plot(np.abs(moving_average(feature[:100,1])-moving_average(feature[100:,1])))
# #     plt.plot(np.abs(moving_average(feature[:100,2])-moving_average(feature[100:,2])))
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()


    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
            if n == "linear.weight":
                Out_w = np.array(p.data)
    
# #     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w>0]),bins=80,density=True)
# #     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw>0]),bins=40,density=True)
# #     fig1.plot((make_bins(PFC_w_hist[1])),np.log(PFC_w_hist[0]+0.01),color=color)
# #     fig1.plot((make_bins(PFC_inw_hist[1])),np.log(PFC_inw_hist[0]+0.01),color=color)
# #     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w<0]),bins=80,density=True)
# #     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw<0]),bins=40,density=True)
#     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w.nonzero()]),bins=80,density=True)
#     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw.nonzero()]),bins=40,density=True)
# #     PFC_hist = np.histogram(np.abs(np.append(PFC_w[PFC_w.nonzero()],PFC_inw[PFC_inw.nonzero()])),bins=40,density=True)
# #     PFC_b_hist = np.histogram(np.abs(PFC_b[PFC_b.nonzero()]),bins=80,density=True)
# #     PFC_inb_hist = np.histogram(np.abs(PFC_inb[PFC_inb.nonzero()]),bins=40,density=True)
# #     fig1.plot(np.log(make_bins(PFC_w_hist[1])),np.log(PFC_w_hist[0]+0.01),color=color)
#     fig1.plot(np.log(make_bins(PFC_inw_hist[1])),np.log(PFC_inw_hist[0]+0.01),color=color)
# #     fig1.plot((make_bins(PFC_hist[1])),np.log(PFC_hist[0]))
#     a1 = regression(make_bins(PFC_w_hist[1]),np.log(PFC_w_hist[0]+0.01))
#     a2 = regression(make_bins(PFC_inw_hist[1]),np.log(PFC_inw_hist[0]+0.01))


# #     HPC_w_hist = np.histogram(HPC_w[HPC_w>0],bins=80,density=True)
# #     HPC_inw_hist = np.histogram(HPC_inw[HPC_inw>0],bins=40,density=True)
# #     fig2.plot((make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]))
# #     fig2.plot((make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]))
# #     HPC_w_hist = np.histogram(np.abs(HPC_w[HPC_w<0]),bins=80,density=True)
# #     HPC_inw_hist = np.histogram(np.abs(HPC_inw[HPC_inw<0]),bins=40,density=True)
#     HPC_w_hist = np.histogram(np.abs(HPC_w[HPC_w.nonzero()]),bins=80,density=True)
#     HPC_inw_hist = np.histogram(np.abs(HPC_inw[HPC_inw.nonzero()]),bins=40,density=True)
# #     HPC_w_hist = np.histogram(np.abs(HPC_w),bins=80,density=True)
# #     HPC_inw_hist = np.histogram(np.abs(HPC_inw),bins=40,density=True)
# #     fig2.plot((make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]))
# #     fig2.plot((make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]))
# #     fig2.plot(np.log(make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]+0.01),color=color)
#     fig2.plot(np.log(make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]+0.01),color=color)
#     b1 = regression(make_bins(HPC_w_hist[1]),np.log(HPC_w_hist[0]+0.01))
#     b2 = regression(make_bins(HPC_inw_hist[1]),np.log(HPC_inw_hist[0]+0.01))


# #     Re_w_hist = np.histogram(Re_w[Re_w>0],bins=100,density=True)
# #     Re_inw_hist = np.histogram(Re_inw[Re_inw>0],bins=100,density=True)
# #     Re_w_hist = np.histogram(np.abs(Re_w[Re_w<0]),bins=100,density=True)
# #     Re_inw_hist = np.histogram(np.abs(Re_inw[Re_inw<0]),bins=20,density=True)
#     Re_w_hist = np.histogram(np.abs(Re_w[Re_w.nonzero()]),bins=40,density=True)
#     Re_inw_hist = np.histogram(np.abs(Re_inw[Re_inw.nonzero()]),bins=40,density=True)
# #     Re_inw_hist = np.histogram(np.abs(Re_inw),bins=40,density=True)
#     Re_initw_hist = np.histogram(np.abs(Re_initw),bins=40,density=True)
#     fig3.plot(np.log(make_bins(Re_w_hist[1])),np.log(Re_w_hist[0]+0.01),color=color)
# #     fig3.plot(np.log(make_bins(Re_inw_hist[1])),np.log(Re_inw_hist[0]+0.01),color=color)
# #     fig3.plot(np.log(make_bins(Re_initw_hist[1])),np.log(Re_initw_hist[0]+0.01))
#     c1 = regression(make_bins(Re_w_hist[1]),np.log(Re_w_hist[0]+0.01))
#     c2 = regression(make_bins(Re_inw_hist[1]),np.log(Re_inw_hist[0]+0.01))
    
#     Re_uw,Re_uv = LA.eig(Re_w[:,20:])
#     fig_eigenval.plot(Re_inuw.real,Re_inuw.imag,"o",color=color)
#     print(Re_inuw)
#     print(np.count_nonzero(Re_uw.real>0))
#     print(np.count_nonzero(Re_uw.imag==0))
#     print(np.count_nonzero((Re_uw.imag==0)*(Re_uw.real>0)))


    i = 0
    PFC_inuw,PFC_inuv = LA.eig(PFC_inw[i:i+20])
    
    i = 0
    HPC_inuw,HPC_inuv = LA.eig(HPC_inw[i:i+20])
    
    HPC_uw,HPC_uv = LA.eig(np.pad(HPC_w[i:i+20,2:],[(0,0),(0,0)]))

#     HPC_allw = HPC_w[0:20,2:]*HPC_w[40:60,2:]+2*HPC_inw[0:20]*HPC_w[40:60,2:]+HPC_w[0:20,2:]*HPC_inw[40:60]
    HPC_allw = HPC_w[0:20,2:]*HPC_w[40:60,2:]+HPC_inw[0:20]*HPC_w[40:60,2:]+HPC_w[0:20,2:]*HPC_inw[40:60]+HPC_inw[0:20]*HPC_inw[40:60]
#     HPC_allw = HPC_w[0:20,2:]*HPC_w[40:60,2:]
#     HPC_allw = HPC_inw[0:20]
#     PFC_allw = PFC_w[0:20,:20]*PFC_w[40:60,:20]+2*PFC_inw[0:20]*PFC_w[40:60,:20]+PFC_w[0:20,:20]*PFC_inw[40:60]
    PFC_allw = PFC_w[0:20,:20]*PFC_w[40:60,:20]+PFC_inw[0:20]*PFC_w[40:60,:20]+PFC_w[0:20,:20]*PFC_inw[40:60]+PFC_inw[0:20]*PFC_inw[40:60]
#     PFC_allw2 = PFC_w[0:20,20:]*PFC_w[40:60,20:]+PFC_inw[0:20]*PFC_w[40:60,20:]+PFC_w[0:20,20:]*PFC_inw[40:60]+PFC_inw[0:20]*PFC_inw[40:60]
#     PFC_allw = PFC_w[0:20,:20]*PFC_w[40:60,:20]
#     PFC_allw2 = PFC_w[0:20,20:]*PFC_w[40:60,20:]
#     Re_allw  = Re_w[:,0:20] + Re_w[:,20:] + Re_inw
    Re_allw  = Re_w[:,0:20] + Re_inw
#     Re_allw2  = Re_w[:,20:] + Re_inw
#     Re_allw  = Re_w[:,0:20]
#     Re_allw2  = Re_w[:,20:]
#     all_uw,all_uv = LA.eig(HPC_allw)
#     all_uw,all_uv = LA.eig(HPC_inw[0:20]*HPC_inw[40:60])
    all_uw,all_uv = LA.eig(PFC_inw[0:20]*PFC_inw[40:60])
#     all_uw,all_uv = LA.eig(HPC_allw)
        
#     if model_path in test_list_B:
# #         fig_eigenval.plot(np.average(abs(all_uw)),np.max(abs(all_uw)),"o",color="b",alpha=0.3,zorder=2)
#         PCA_list_B = np.concatenate((PCA_list_B,[[np.count_nonzero((all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real<0)), np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0)),np.count_nonzero((all_uw.imag!=0)*(all_uw.real<0))]]),axis=0)
#     else:
#         PCA_list_A = np.concatenate((PCA_list_A,[[np.count_nonzero((all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real<0)), np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0)),np.count_nonzero((all_uw.imag!=0)*(all_uw.real<0))]]),axis=0)
       
    
    return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.min([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),np.mean(abs(all_uw))


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1


    total_ratio_list = np.array([])
    total_ratio_list_max = np.array([])
    total_eigen_list = np.array([])
    allratio_list = []
    good_points = []
    failed_points = []
    for num in range(3):
        for i in range(10):
            path = 'model/R20_H_bigbatch/'
#             path = 'model/R20_H_uniPFC_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+1)+'_100_'+str(num+1)+'_*epoch*.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            ratio_list = []
            ratio_list_max = []
            eigen_list = []
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()
    #         good_list = []
            first_goodmodel = [0,0]
            good_flag = False
            k=0
            for model in model_list:
                print(model)
    #             PFC,HPC = main(model)
                PFC,HPC,PFC_max,HPC_max,eigen = main(model)
    #             ratio_list.append(PFC/HPC)
    #             ratio_list.append(np.abs(PFC-HPC))
    #             ratio_list.append(np.abs(PFC))
                ratio_list.append(np.abs(HPC_max))
    #             ratio_list_max.append(PFC_max/HPC_max)
    #             ratio_list_max.append(np.abs(PFC_max-HPC_max))
                ratio_list_max.append(np.abs(PFC_max))
                eigen_list.append(eigen)
    #             correlation_fig.plot(PFC,HPC,"o")
    #             correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
    #             correlation_fig.plot(k*5,PFC/HPC,"o")
        #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
                print(PFC_max,HPC_max,eigen)
                if good_flag != True and model in good_list:
                    good_flag = True
                    first_goodmodel[0] = int(model.split("epoch")[-1].split(".")[0])
    #                 first_goodmodel[1] = ratio_list[-1]
                    first_goodmodel[1] = ratio_list_max[-1]

                if model in good_list:
    #                 correlation_fig.plot(np.array(eigen_list[-1]),np.array(ratio_list_max[-1]),"o",color="b")
#                     correlation_fig.plot(np.array(ratio_list[-1]),np.array(ratio_list_max[-1]),"o",color="b")
    #                 correlation_fig.text(np.array(eigen_list[-1]),np.array(ratio_list_max[-1]),model.split("epoch")[-1].split(".")[0])
#                     correlation_fig.plot(int(model.split("epoch")[-1].split(".")[0]),np.array(ratio_list_max[-1]),"o",color="b",alpha=0.3)
                    correlation_fig.plot(0,np.array(ratio_list_max[-1]),"o",color="b",alpha=0.3)
                    good_points.append(np.array(ratio_list_max[-1]))


                else:
    #                 correlation_fig.plot(np.array(eigen_list[-1]),np.array(ratio_list_max[-1]),"o",color="r")
#                     correlation_fig.plot(np.array(ratio_list[-1]),np.array(ratio_list_max[-1]),"o",color="r")
    #                 correlation_fig.text(np.array(eigen_list[-1]),np.array(ratio_list_max[-1]),model.split("epoch")[-1].split(".")[0])
#                     correlation_fig.plot(int(model.split("epoch")[-1].split(".")[0]),np.array(ratio_list_max[-1]),"o",color="r",alpha=0.3)
                    correlation_fig.plot(1,np.array(ratio_list_max[-1]),"o",color="r",alpha=0.3)
                    failed_points.append(np.array(ratio_list_max[-1]))


                k+=1
    #         correlation_fig.plot(np.arange(0,200,5),np.array(ratio_list)-np.mean(ratio_list),"o")
    #         ratio_list = np.array(ratio_list).clip(-2,2)
    #         ratio_list = moving_average(ratio_list)[2:-2]
    #         correlation_fig.plot(np.arange(0,len(ratio_list)*5,5),np.array(ratio_list))

    #         ratio_list_max = np.array(ratio_list_max).clip(-2,2)
    #         ratio_list_max = moving_average(ratio_list_max)[4:-2]
    #         correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
    #         correlation_fig.plot(np.array(eigen_list),np.array(ratio_list_max),"o",color="C{}".format(i))
    #         if good_flag == True:
    #             correlation_fig.plot(first_goodmodel[0],first_goodmodel[1],"o",color="C{}".format(i))
#             print(np.corrcoef(np.array(ratio_list),np.array(ratio_list_max)))
            total_ratio_list = np.append(total_ratio_list,np.array(ratio_list))
            total_ratio_list_max = np.append(total_ratio_list_max,np.array(ratio_list_max))
            total_eigen_list = np.append(total_eigen_list,np.array(eigen_list))
            allratio_list.append(ratio_list_max)
#     np.save("UniPFC_PFCminave.npy",np.mean(np.array(allratio_list),axis=0))
#     np.save("UniPFC_PFCminvar.npy",np.var(np.array(allratio_list),axis=0))
#     print(np.corrcoef(np.array(total_eigen_list),np.array(total_ratio_list_max)))
    print(scipy.stats.f_oneway(good_points, failed_points))
    print(st.ttest_ind(good_points, failed_points,equal_var=False))

        
        

In [None]:
##########    Eigenvalue vs PFCcorrelation. interesting activity  #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import numpy.linalg as LA

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#         else:
#             traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#             traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

#     ###  ver1's test  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def plot_distance_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    plt.vlines(linelist,np.min(result),np.max(result))
    
    return result

def distance_bet2traj(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    return result

def plot_activity_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    threshold = 0.5
    for i in range(traj1.shape[0]):
        result.append(np.abs(traj1[i]-traj2[i]))
        
    plt.figure()
    for i,data in enumerate(np.array(result).T):
        if np.any(data[20:105]>threshold):
            print(i)
            plt.plot(data)
    
    return result

def search_delay(traj):
    linelist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if traj[i,1] > 0.45 and flag == False:
            linelist = np.append(linelist,i)
            flag = True
        if traj[i,1] < 0.45 and flag == True:
            linelist = np.append(linelist,i)
            flag = False
    return linelist

def pick_delay(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
        start = int(k)
    states_list.append(traj[start:])
    return states_list

def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(traj[start:])
    return states_list

def match_length(states_list):
    min_length = 1000
    result = []
    for state in states_list:
        min_length = np.min((min_length, len(state)))
    for state in states_list:
        result.append(state[-min_length:])
        
    return result

def vec_var(datas):
    datas = np.array(datas)
    average = np.average(datas,axis=0)
    result = 0
    for data in datas:
        result += np.linalg.norm(data-average)
    result /= datas.shape[0]
    return result
    
def lyapunov_exp(data):
    result = np.mean(np.log(np.abs(np.diff(data))))
    return result

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0.2
        var = 0.2
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
#         HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
#         PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.1
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens


def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq    


def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 3
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniPFC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 120
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
#             output,hidden = rnn(output,hidden)
#             traj.append(output.tolist())
#             PFCstate.append(hidden[0][0].tolist())
#             HPCstate.append(hidden[1][0].tolist())
#             Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    pattern = 3
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    dividenum = int(np.array(PFCstate)[0:,0].shape[0]/2)
    traj_noise = []
    PFCstate_noise = []
    HPCstate_noise = []
    Restate_noise = []
    Gate_states_noise = []
    hidden = rnn.initHidden_rand()
#     data = mkOwnRandomBatch(train_y, batch_size)
#     init_point = torch.rand(10,2)*1
#     data_limit = 125
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
# #             hidden = rnn.noiseHidden_rand(hidden)
#             output,hidden = rnn(output,hidden) 
# #             hidden = rnn.noiseHidden_dis(hidden, np.array(Restate)[k+data_limit], np.array(Restate)[dividenum+k+data_limit])
# #             hidden = rnn.noiseHidden_rand(hidden)
#             traj_noise.append(output.tolist())
#             PFCstate_noise.append(hidden[0][0].tolist())
#             HPCstate_noise.append(hidden[1][0].tolist())
#             Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
    traj_noise = torch.tensor(traj_noise)
    traj_noise = torch.squeeze(traj_noise).numpy()

    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")


#     plt.figure()
#     plt.plot(traj[:,0,1],color=colors[-1])
#     plt.plot(traj_noise[:,0,1],color=colors[-2])


#     PFC_corrlist = np.array([])
#     HPC_corrlist = np.array([])
#     cross_corrlist = np.array([])
#     for i in range(20):
#         plt.figure()
# #         plt.ylim(-1,1)
#         plt.plot(traj[:,0,0])
#         plt.plot(traj_noise[:,0,0])
# #         plt.plot(np.abs(traj[:,0,0]-traj_noise[:,0,0]))
#     #         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
#         for k in range(1):
# #                 plt.plot(np.array(PFCstate)[:,k,i],alpha=0.5)
# #                 plt.plot(np.array(PFCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),alpha=0.5)
# #                 plt.plot(np.array(HPCstate)[:,k,i],alpha=0.5)
# #                 plt.plot(np.array(HPCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]),alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[:120,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate)[:120,k,i]-np.array(PFCstate)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[:120,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#     #             plt.plot(np.array(PFCstate_noise)[:120,k,i]-np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
# #             plt.plot(np.array(Restate)[:,i],alpha=0.5)
# #             plt.plot(np.array(Restate_noise)[:,i],alpha=0.5)
# #             plt.plot(np.array(Restate)[:,i]-np.array(Restate_noise)[:,i],alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate)[:,i]),alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
# #                 plt.plot(moving_average(np.array(Restate)[:,i])-moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
# #                 plt.plot(np.array(Gate_states)[:,0,0+i],alpha=0.5)
# #                 plt.plot(np.array(Gate_states_noise)[:,1,0+i],alpha=0.5)
# #                 plt.plot(np.array(Gate_states_noise)[:,1,20+i],alpha=0.5)
# #                 plt.plot(np.log(np.array(Gate_states)[:,1,0+i]/np.array(Gate_states)[:,1,20+i]),alpha=0.5)
# #                 plt.plot(np.log(np.array(Gate_states_noise)[:,1,0+i]/np.array(Gate_states_noise)[:,1,20+i]),alpha=0.5)
# #                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
# #                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
# #                 print(np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 cross_corrlist = np.append(cross_corrlist,np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 PFC_corrlist = np.append(PFC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
#                 HPC_corrlist = np.append(HPC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
# #         plt.title("neuron#"+str(i+1))
#     print(np.mean(PFC_corrlist),np.mean(HPC_corrlist),np.mean(cross_corrlist))
    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
    dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
#     dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("PFC correlation")
    print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0])))
    print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1])))
    print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2])))
    PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
    PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
    PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]))[0,1]
    
#     plt.figure()
#     plt.plot(np.abs(traj[:,0,0]-traj_noise[:,0,0]))
#     plt.plot(np.abs(feature[:data_limit,0]-feature[data_limit:,0]))
#     plt.plot(np.abs(feature[:data_limit,1]-feature[data_limit:,1]))
#     plt.plot(np.abs(feature[:data_limit,2]-feature[data_limit:,2]))
#     plt.show()

#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()

#     print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("HPC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
    HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
    HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
    HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     HPC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
    
    plt.figure()
    plt.plot(np.abs(traj[:,0,0]-traj_noise[:,0,0]))
    plt.plot(np.abs(feature[:data_limit,0]-feature[data_limit:,0]))
    plt.plot(np.abs(feature[:data_limit,1]-feature[data_limit:,1]))
    plt.plot(np.abs(feature[:data_limit,2]-feature[data_limit:,2]))
    plt.show()
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()

#     print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))


    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.concatenate((np.array(Restate)[:],np.array(Restate_noise)[:]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("Re correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
#     Re_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]
#     Re_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]
#     Re_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]
    
#     plt.figure()
# #     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
# #     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
# #     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
# #     plt.plot(np.abs(moving_average(feature[:100,0])-moving_average(feature[100:,0])))
# #     plt.plot(np.abs(moving_average(feature[:100,1])-moving_average(feature[100:,1])))
# #     plt.plot(np.abs(moving_average(feature[:100,2])-moving_average(feature[100:,2])))
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()


    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
            if n == "linear.weight":
                Out_w = np.array(p.data)
    
# #     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w>0]),bins=80,density=True)
# #     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw>0]),bins=40,density=True)
# #     fig1.plot((make_bins(PFC_w_hist[1])),np.log(PFC_w_hist[0]+0.01),color=color)
# #     fig1.plot((make_bins(PFC_inw_hist[1])),np.log(PFC_inw_hist[0]+0.01),color=color)
# #     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w<0]),bins=80,density=True)
# #     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw<0]),bins=40,density=True)
#     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w.nonzero()]),bins=80,density=True)
#     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw.nonzero()]),bins=40,density=True)
# #     PFC_hist = np.histogram(np.abs(np.append(PFC_w[PFC_w.nonzero()],PFC_inw[PFC_inw.nonzero()])),bins=40,density=True)
# #     PFC_b_hist = np.histogram(np.abs(PFC_b[PFC_b.nonzero()]),bins=80,density=True)
# #     PFC_inb_hist = np.histogram(np.abs(PFC_inb[PFC_inb.nonzero()]),bins=40,density=True)
# #     fig1.plot(np.log(make_bins(PFC_w_hist[1])),np.log(PFC_w_hist[0]+0.01),color=color)
#     fig1.plot(np.log(make_bins(PFC_inw_hist[1])),np.log(PFC_inw_hist[0]+0.01),color=color)
# #     fig1.plot((make_bins(PFC_hist[1])),np.log(PFC_hist[0]))
#     a1 = regression(make_bins(PFC_w_hist[1]),np.log(PFC_w_hist[0]+0.01))
#     a2 = regression(make_bins(PFC_inw_hist[1]),np.log(PFC_inw_hist[0]+0.01))


# #     HPC_w_hist = np.histogram(HPC_w[HPC_w>0],bins=80,density=True)
# #     HPC_inw_hist = np.histogram(HPC_inw[HPC_inw>0],bins=40,density=True)
# #     fig2.plot((make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]))
# #     fig2.plot((make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]))
# #     HPC_w_hist = np.histogram(np.abs(HPC_w[HPC_w<0]),bins=80,density=True)
# #     HPC_inw_hist = np.histogram(np.abs(HPC_inw[HPC_inw<0]),bins=40,density=True)
#     HPC_w_hist = np.histogram(np.abs(HPC_w[HPC_w.nonzero()]),bins=80,density=True)
#     HPC_inw_hist = np.histogram(np.abs(HPC_inw[HPC_inw.nonzero()]),bins=40,density=True)
# #     HPC_w_hist = np.histogram(np.abs(HPC_w),bins=80,density=True)
# #     HPC_inw_hist = np.histogram(np.abs(HPC_inw),bins=40,density=True)
# #     fig2.plot((make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]))
# #     fig2.plot((make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]))
# #     fig2.plot(np.log(make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]+0.01),color=color)
#     fig2.plot(np.log(make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]+0.01),color=color)
#     b1 = regression(make_bins(HPC_w_hist[1]),np.log(HPC_w_hist[0]+0.01))
#     b2 = regression(make_bins(HPC_inw_hist[1]),np.log(HPC_inw_hist[0]+0.01))


# #     Re_w_hist = np.histogram(Re_w[Re_w>0],bins=100,density=True)
# #     Re_inw_hist = np.histogram(Re_inw[Re_inw>0],bins=100,density=True)
# #     Re_w_hist = np.histogram(np.abs(Re_w[Re_w<0]),bins=100,density=True)
# #     Re_inw_hist = np.histogram(np.abs(Re_inw[Re_inw<0]),bins=20,density=True)
#     Re_w_hist = np.histogram(np.abs(Re_w[Re_w.nonzero()]),bins=40,density=True)
#     Re_inw_hist = np.histogram(np.abs(Re_inw[Re_inw.nonzero()]),bins=40,density=True)
# #     Re_inw_hist = np.histogram(np.abs(Re_inw),bins=40,density=True)
#     Re_initw_hist = np.histogram(np.abs(Re_initw),bins=40,density=True)
#     fig3.plot(np.log(make_bins(Re_w_hist[1])),np.log(Re_w_hist[0]+0.01),color=color)
# #     fig3.plot(np.log(make_bins(Re_inw_hist[1])),np.log(Re_inw_hist[0]+0.01),color=color)
# #     fig3.plot(np.log(make_bins(Re_initw_hist[1])),np.log(Re_initw_hist[0]+0.01))
#     c1 = regression(make_bins(Re_w_hist[1]),np.log(Re_w_hist[0]+0.01))
#     c2 = regression(make_bins(Re_inw_hist[1]),np.log(Re_inw_hist[0]+0.01))
    
#     Re_uw,Re_uv = LA.eig(Re_w[:,20:])
#     fig_eigenval.plot(Re_inuw.real,Re_inuw.imag,"o",color=color)
#     print(Re_inuw)
#     print(np.count_nonzero(Re_uw.real>0))
#     print(np.count_nonzero(Re_uw.imag==0))
#     print(np.count_nonzero((Re_uw.imag==0)*(Re_uw.real>0)))


    i = 0
    PFC_inuw,PFC_inuv = LA.eig(PFC_inw[i:i+20])
    
    i = 0
    HPC_inuw,HPC_inuv = LA.eig(HPC_inw[i:i+20])
    
    HPC_uw,HPC_uv = LA.eig(np.pad(HPC_w[i:i+20,2:],[(0,0),(0,0)]))

#     HPC_allw = HPC_w[0:20,2:]*HPC_w[40:60,2:]+2*HPC_inw[0:20]*HPC_w[40:60,2:]+HPC_w[0:20,2:]*HPC_inw[40:60]
    HPC_allw = HPC_w[0:20,2:]*HPC_w[40:60,2:]+HPC_inw[0:20]*HPC_w[40:60,2:]+HPC_w[0:20,2:]*HPC_inw[40:60]+HPC_inw[0:20]*HPC_inw[40:60]
#     HPC_allw = HPC_w[0:20,2:]*HPC_w[40:60,2:]
#     HPC_allw = HPC_inw[0:20]
#     PFC_allw = PFC_w[0:20,:20]*PFC_w[40:60,:20]+2*PFC_inw[0:20]*PFC_w[40:60,:20]+PFC_w[0:20,:20]*PFC_inw[40:60]
    PFC_allw = PFC_w[0:20,:20]*PFC_w[40:60,:20]+PFC_inw[0:20]*PFC_w[40:60,:20]+PFC_w[0:20,:20]*PFC_inw[40:60]+PFC_inw[0:20]*PFC_inw[40:60]
#     PFC_allw2 = PFC_w[0:20,20:]*PFC_w[40:60,20:]+PFC_inw[0:20]*PFC_w[40:60,20:]+PFC_w[0:20,20:]*PFC_inw[40:60]+PFC_inw[0:20]*PFC_inw[40:60]
#     PFC_allw = PFC_w[0:20,:20]*PFC_w[40:60,:20]
#     PFC_allw2 = PFC_w[0:20,20:]*PFC_w[40:60,20:]
#     Re_allw  = Re_w[:,0:20] + Re_w[:,20:] + Re_inw
    Re_allw  = Re_w[:,0:20] + Re_inw
#     Re_allw2  = Re_w[:,20:] + Re_inw
#     Re_allw  = Re_w[:,0:20]
#     Re_allw2  = Re_w[:,20:]
#     all_uw,all_uv = LA.eig(HPC_allw)
#     all_uw,all_uv = LA.eig(HPC_inw[0:20]*HPC_inw[40:60])
    all_uw,all_uv = LA.eig(PFC_inw[0:20]*PFC_inw[40:60])
#     all_uw,all_uv = LA.eig(HPC_allw)
        
#     if model_path in test_list_B:
# #         fig_eigenval.plot(np.average(abs(all_uw)),np.max(abs(all_uw)),"o",color="b",alpha=0.3,zorder=2)
#         PCA_list_B = np.concatenate((PCA_list_B,[[np.count_nonzero((all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real<0)), np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0)),np.count_nonzero((all_uw.imag!=0)*(all_uw.real<0))]]),axis=0)
#     else:
#         PCA_list_A = np.concatenate((PCA_list_A,[[np.count_nonzero((all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real<0)), np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0)),np.count_nonzero((all_uw.imag!=0)*(all_uw.real<0))]]),axis=0)
       
    
    return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),np.mean(abs(all_uw))


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1


    total_ratio_list = np.array([])
    total_ratio_list_max = np.array([])
    total_eigen_list = np.array([])
    for i in range(1):
        path = 'model/R20_H_bigbatch/'
#         path = 'model/R20_H_uniPFC_bigbatch/'
        model_list = glob.glob(path+'*s'+str(i+3)+'_100_3_*epoch*.pth')
        model_list = sorted(model_list)
        model_list = sorted(model_list,key=len,reverse=False)
        ratio_list = []
        ratio_list_max = []
        eigen_list = []
        with open(path+"good_list.txt", mode="r") as f:
            good_list = f.read().splitlines()
#         good_list = []
        first_goodmodel = [0,0]
        good_flag = False
        k=0
        for model in model_list:
            print(model)
#             PFC,HPC = main(model)
            PFC,HPC,PFC_max,HPC_max,eigen = main(model)
#             ratio_list.append(PFC/HPC)
#             ratio_list.append(np.abs(PFC-HPC))
#             ratio_list.append(np.abs(PFC))
            ratio_list.append(np.abs(HPC_max))
#             ratio_list_max.append(PFC_max/HPC_max)
#             ratio_list_max.append(np.abs(PFC_max-HPC_max))
            ratio_list_max.append(np.abs(PFC_max))
            eigen_list.append(eigen)
#             correlation_fig.plot(PFC,HPC,"o")
#             correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#             correlation_fig.plot(k*5,PFC/HPC,"o")
    #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
            print(PFC_max,HPC_max,eigen)
            if good_flag != True and model in good_list:
                good_flag = True
                first_goodmodel[0] = int(model.split("epoch")[-1].split(".")[0])
#                 first_goodmodel[1] = ratio_list[-1]
                first_goodmodel[1] = ratio_list_max[-1]
    
            if model in good_list:
                correlation_fig.plot(np.array(eigen_list[-1]),np.array(ratio_list_max[-1]),"o",color="r")
#                 correlation_fig.plot(np.array(ratio_list[-1]),np.array(ratio_list_max[-1]),"o",color="r")
                correlation_fig.text(np.array(eigen_list[-1]),np.array(ratio_list_max[-1]),model.split("epoch")[-1].split(".")[0])
            else:
                correlation_fig.plot(np.array(eigen_list[-1]),np.array(ratio_list_max[-1]),"o",color="b")
#                 correlation_fig.plot(np.array(ratio_list[-1]),np.array(ratio_list_max[-1]),"o",color="b")
                correlation_fig.text(np.array(eigen_list[-1]),np.array(ratio_list_max[-1]),model.split("epoch")[-1].split(".")[0])

                                     
            k+=1
#         correlation_fig.plot(np.arange(0,200,5),np.array(ratio_list)-np.mean(ratio_list),"o")
#         ratio_list = np.array(ratio_list).clip(-2,2)
#         ratio_list = moving_average(ratio_list)[2:-2]
#         correlation_fig.plot(np.arange(0,len(ratio_list)*5,5),np.array(ratio_list))
        
#         ratio_list_max = np.array(ratio_list_max).clip(-2,2)
#         ratio_list_max = moving_average(ratio_list_max)[4:-2]
#         correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
#         correlation_fig.plot(np.array(eigen_list),np.array(ratio_list_max),"o",color="C{}".format(i))
#         if good_flag == True:
#             correlation_fig.plot(first_goodmodel[0],first_goodmodel[1],"o",color="C{}".format(i))
        print(np.corrcoef(np.array(ratio_list),np.array(ratio_list_max)))
        total_ratio_list = np.append(total_ratio_list,np.array(ratio_list))
        total_ratio_list_max = np.append(total_ratio_list_max,np.array(ratio_list_max))
        total_eigen_list = np.append(total_eigen_list,np.array(eigen_list))
    print(np.corrcoef(np.array(total_eigen_list),np.array(total_ratio_list_max)))
        
        

In [None]:
import glob
from sklearn import linear_model
import matplotlib.cm as cm
import numpy.linalg as LA
import scipy.stats as st
import seaborn as sns

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size+0
        self.hidden_size_HPC = hidden_size+0
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.normal_(self.PFC.weight_ih.data,0,1)
        nn.init.normal_(self.HPC.weight_ih.data,0,1)
        nn.init.normal_(self.Re.weight_ih.data,0,0.1)
        nn.init.normal_(self.PFC.weight_hh.data,0,1)
        nn.init.normal_(self.HPC.weight_hh.data,0,1)
        nn.init.normal_(self.Re.weight_hh.data,0,0.1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

def make_bins(data):
    bins = np.array([])
    for i in range(len(data)-1):
        bins = np.append(bins,(data[i]+data[i+1])/2)
    return bins

def regression(datax,datay):
    clf = linear_model.LinearRegression()
    x = datax[datay!=-float("inf")].reshape(-1,1)
    y = datay[datay!=-float("inf")].reshape(-1,1)
    clf.fit(x,y)
    coef = clf.coef_
    score = clf.score(x,y)
    print(coef,score)
    return coef[0][0]


# model_list = glob.glob('model/ReModel_interRNN_long_s5_200_1.pth')
# model_list = glob.glob('model/ReModel_L2_interRNNrand_AddRe_OUT5_1212121_s*_100_1_2_?.pth')
# model_list = glob.glob('model/ReModel_L2_interRNNrand_AddRe_PnoiseRe_OUT5_161_s*_100_4.pth')
# model_list = glob.glob('model/R+2/*s2*_100_1*.pth')
# model_list = glob.glob('model/P+/*s3_100_1*add.pth')
# model_list = glob.glob('model/R20/*s*_100_*.pth')
model_list = glob.glob('model/R20_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s*_100_*_epoch*.pth')
# model_list = glob.glob('model/R20_H_stopinit_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s*_100_*_epoch1??.pth')
# model_list = glob.glob('model/R20_H_uniPFC_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s*_100_*_epoch1??.pth')
# model_list = glob.glob('model_test/4_2/*.pth')

model_list = sorted(model_list)
model_list = sorted(model_list,key=len,reverse=False)

# print(model_list)
with open("model/R20_H_bigbatch/good_list.txt") as f:
    good_list = f.read().splitlines()
    
# with open("attract_contfail_test_list2.txt") as f:
with open("attract_contfail_list_add_test.txt") as f:
    test_list_A = f.read().splitlines()
# with open("attract_contgood_test_list.txt") as f:
with open("attract_contgood_list_add_test.txt") as f:
    test_list_B = f.read().splitlines()
    
test_list_A_selected = []
test_list_B_selected = []
test_list_A_limited = []
test_list_B_limited = []

for model_path in model_list:
    if model_path in test_list_A:
        test_list_A_selected.append(model_path)
    if model_path in test_list_B:
        test_list_B_selected.append(model_path)            

# if len(test_list_A_selected) > len(test_list_B_selected):
#     random_num = np.random.randint(0,len(test_list_A_selected),len(test_list_B_selected))
#     for i in random_num:
#         test_list_A_limited.append(test_list_A_selected[i])
#     test_list_B_limited = test_list_B_selected
# else:
#     random_num = np.random.randint(0,len(test_list_B_selected),len(test_list_A_selected))
#     for i in random_num:
#         test_list_B_limited.append(test_list_B_selected[i])
#     test_list_A_limited = test_list_A_selected

limited_number = 0
if limited_number > 0:
    random_num = np.random.randint(0,len(test_list_A_selected),limited_number)
    for i in random_num:
        test_list_A_limited.append(test_list_A_selected[i])
    random_num = np.random.randint(0,len(test_list_B_selected),limited_number)
    for i in random_num:
        test_list_B_limited.append(test_list_B_selected[i])
else:
    test_list_A_limited = test_list_A_selected
    test_list_B_limited = test_list_B_selected

    
print(len(test_list_A_limited),len(test_list_B_limited))
    
# test_list = glob.glob('model/R20/*s1*_100_2*.pth')
# test_list = glob.glob('model/R20_cont/*OUT1*s*_100_*.pth')
# test_list = glob.glob('model_test/4_2/*.pth')
test_list = model_list
# test_list = lines
test_list = sorted(test_list)
test_list = sorted(test_list,key=len,reverse=False)[:]

# target = [1,3,4,11,12,13,16,-2,-1]
# target = [0,2,5,6,7,8,9,10,14,15,17]
# target = [11,12,-2]
# target = [21,22,23]
# target = [-7, -2]
# test_target = []

# for i in target:
#     test_target.append(test_list[i])
# test_list = test_target

# print(test_list)

training_size = 100
test_size = 1000
epochs_num = 10
hidden_size = 20
batch_size = 10
data_length = 100
inputsize = 2
outputsize = 2

fig_weight = plt.figure(figsize=(10,5))
fig1 = fig_weight.add_subplot(131)
fig2 = fig_weight.add_subplot(132)
fig3 = fig_weight.add_subplot(133)

# target = [22,23,-2,-1]
# target_list = []

# for i in target:
#     target_list.append(model_list[i])


target_list = model_list[:]

fig_eigen = plt.figure()
fig_eigenval = fig_eigen.add_subplot(111)
test_colors = []

fig3d = plt.figure()
ax3d = Axes3D(fig3d)

PCA_list_A = np.array([[0,1,2,3,4]])
PCA_list_B = np.array([[0,1,2,3,4]])

stat_list_A = []
stat_list_B = []

stat_list_xy_A = [np.array([0]),np.array([0])]
stat_list_xy_B = [np.array([0]),np.array([0])]

testnum = 0
num=0
for model_path in target_list:
    num+=1
    color=cm.jet(num/len(target_list))
#     print(model_path)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniPFC(inputsize, hidden_size, outputsize, batch_size)
    
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_initw = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_initw = np.array(p.data)
            if n == "Re.weight_ih" or n == "Re.weight":
                Re_initw = np.array(p.data)
            if n == "linear.weight":
                Output_initw = np.array(p.data)

    rnn.load_state_dict(torch.load(model_path))

    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
            if n == "PFC.bias_ih":
                PFC_b = np.array(p.data)
            if n == "PFC.bias_hh":
                PFC_inb = np.array(p.data)
            if n == "HPC.bias_ih":
                HPC_b = np.array(p.data)
            if n == "HPC.bias_hh":
                HPC_inb = np.array(p.data)
            if n == "Re.bias_ih":
                Re_b = np.array(p.data)
            if n == "Re.bias_hh":
                Re_inb = np.array(p.data)
            if n == "linear.weight":
                Out_w = np.array(p.data)
    
# #     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w>0]),bins=80,density=True)
# #     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw>0]),bins=40,density=True)
# #     fig1.plot((make_bins(PFC_w_hist[1])),np.log(PFC_w_hist[0]+0.01),color=color)
# #     fig1.plot((make_bins(PFC_inw_hist[1])),np.log(PFC_inw_hist[0]+0.01),color=color)
# #     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w<0]),bins=80,density=True)
# #     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw<0]),bins=40,density=True)
#     PFC_w_hist = np.histogram(np.abs(PFC_w[PFC_w.nonzero()]),bins=80,density=True)
#     PFC_inw_hist = np.histogram(np.abs(PFC_inw[PFC_inw.nonzero()]),bins=40,density=True)
# #     PFC_hist = np.histogram(np.abs(np.append(PFC_w[PFC_w.nonzero()],PFC_inw[PFC_inw.nonzero()])),bins=40,density=True)
# #     PFC_b_hist = np.histogram(np.abs(PFC_b[PFC_b.nonzero()]),bins=80,density=True)
# #     PFC_inb_hist = np.histogram(np.abs(PFC_inb[PFC_inb.nonzero()]),bins=40,density=True)
# #     fig1.plot(np.log(make_bins(PFC_w_hist[1])),np.log(PFC_w_hist[0]+0.01),color=color)
#     fig1.plot(np.log(make_bins(PFC_inw_hist[1])),np.log(PFC_inw_hist[0]+0.01),color=color)
# #     fig1.plot((make_bins(PFC_hist[1])),np.log(PFC_hist[0]))
#     a1 = regression(make_bins(PFC_w_hist[1]),np.log(PFC_w_hist[0]+0.01))
#     a2 = regression(make_bins(PFC_inw_hist[1]),np.log(PFC_inw_hist[0]+0.01))


# #     HPC_w_hist = np.histogram(HPC_w[HPC_w>0],bins=80,density=True)
# #     HPC_inw_hist = np.histogram(HPC_inw[HPC_inw>0],bins=40,density=True)
# #     fig2.plot((make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]))
# #     fig2.plot((make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]))
# #     HPC_w_hist = np.histogram(np.abs(HPC_w[HPC_w<0]),bins=80,density=True)
# #     HPC_inw_hist = np.histogram(np.abs(HPC_inw[HPC_inw<0]),bins=40,density=True)
#     HPC_w_hist = np.histogram(np.abs(HPC_w[HPC_w.nonzero()]),bins=80,density=True)
#     HPC_inw_hist = np.histogram(np.abs(HPC_inw[HPC_inw.nonzero()]),bins=40,density=True)
# #     HPC_w_hist = np.histogram(np.abs(HPC_w),bins=80,density=True)
# #     HPC_inw_hist = np.histogram(np.abs(HPC_inw),bins=40,density=True)
# #     fig2.plot((make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]))
# #     fig2.plot((make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]))
# #     fig2.plot(np.log(make_bins(HPC_w_hist[1])),np.log(HPC_w_hist[0]+0.01),color=color)
#     fig2.plot(np.log(make_bins(HPC_inw_hist[1])),np.log(HPC_inw_hist[0]+0.01),color=color)
#     b1 = regression(make_bins(HPC_w_hist[1]),np.log(HPC_w_hist[0]+0.01))
#     b2 = regression(make_bins(HPC_inw_hist[1]),np.log(HPC_inw_hist[0]+0.01))


# #     Re_w_hist = np.histogram(Re_w[Re_w>0],bins=100,density=True)
# #     Re_inw_hist = np.histogram(Re_inw[Re_inw>0],bins=100,density=True)
# #     Re_w_hist = np.histogram(np.abs(Re_w[Re_w<0]),bins=100,density=True)
# #     Re_inw_hist = np.histogram(np.abs(Re_inw[Re_inw<0]),bins=20,density=True)
#     Re_w_hist = np.histogram(np.abs(Re_w[Re_w.nonzero()]),bins=40,density=True)
#     Re_inw_hist = np.histogram(np.abs(Re_inw[Re_inw.nonzero()]),bins=40,density=True)
# #     Re_inw_hist = np.histogram(np.abs(Re_inw),bins=40,density=True)
#     Re_initw_hist = np.histogram(np.abs(Re_initw),bins=40,density=True)
#     fig3.plot(np.log(make_bins(Re_w_hist[1])),np.log(Re_w_hist[0]+0.01),color=color)
# #     fig3.plot(np.log(make_bins(Re_inw_hist[1])),np.log(Re_inw_hist[0]+0.01),color=color)
# #     fig3.plot(np.log(make_bins(Re_initw_hist[1])),np.log(Re_initw_hist[0]+0.01))
#     c1 = regression(make_bins(Re_w_hist[1]),np.log(Re_w_hist[0]+0.01))
#     c2 = regression(make_bins(Re_inw_hist[1]),np.log(Re_inw_hist[0]+0.01))
    
#     Re_uw,Re_uv = LA.eig(Re_w[:,20:])
#     fig_eigenval.plot(Re_inuw.real,Re_inuw.imag,"o",color=color)
#     print(Re_inuw)
#     print(np.count_nonzero(Re_uw.real>0))
#     print(np.count_nonzero(Re_uw.imag==0))
#     print(np.count_nonzero((Re_uw.imag==0)*(Re_uw.real>0)))


    i = 0
    PFC_inuw,PFC_inuv = LA.eig(PFC_inw[i:i+20])
#     print(PFC_inuw,PFC_inuv)
#     if model_path in test_list:
#         testnum+=1
#         testcolor=cm.jet(testnum/len(test_list))
#         fig_eigenval.plot(PFC_inuw.real,PFC_inuw.imag,"o",color=testcolor,zorder=2)
#     else :
#         fig_eigenval.plot(PFC_inuw.real,PFC_inuw.imag,"o",color=color,alpha=0.01,zorder=1)
#     for k in range(20):
#         fig_eigenval.text(PFC_inuw.real[k],PFC_inuw.imag[k],k)
#     print(LA.matrix_rank(PFC_inw[i:i+20]))
#     print(np.count_nonzero(PFC_inuw.real>0))
#     print(np.count_nonzero(PFC_inuw.imag==0))
#     print(np.count_nonzero((PFC_inuw.imag==0)*(HPC_inuw.real>0)))
    
    i = 0
    HPC_inuw,HPC_inuv = LA.eig(HPC_inw[i:i+20])
#     print(HPC_inuw,HPC_inuv)
#     if model_path in test_list:
#         testnum+=1
#         testcolor=cm.jet(testnum/len(test_list))
#         fig_eigenval.plot(HPC_inuw.real,HPC_inuw.imag,"o",color=testcolor,alpha=1,zorder=2)
# #     elif num < 40:
# #         fig_eigenval.plot(HPC_inuw.real,HPC_inuw.imag,"o",color="b",alpha=1)
#     else:
#         fig_eigenval.plot(HPC_inuw.real,HPC_inuw.imag,"o",color=color,alpha=0.1,zorder=1)
#     for k in range(20):
#         fig_eigenval.text(HPC_inuw.real[k],HPC_inuw.imag[k],k)
#     print(LA.matrix_rank(HPC_inw[i:i+20]))
#     print(np.count_nonzero(HPC_inuw.real>0))
#     print(np.count_nonzero(HPC_inuw.imag==0))
#     print(np.count_nonzero((HPC_inuw.imag==0)*(HPC_inuw.real>0)))
    
#     HPC_uw,HPC_uv = LA.eig(np.pad(HPC_w[i:i+20,2:],[(0,0),(0,0)]))
#     HPC_uw,HPC_uv = LA.eig(np.pad(HPC_w[i:i+20,:2],[(0,0),(0,18)]))
#     print(HPC_uw,HPC_uv)
#     if model_path in test_list:
#         testnum+=1
#         testcolor=cm.jet(testnum/len(test_list))
#         test_colors.append(testnum/len(test_list))
#         fig_eigenval.plot(HPC_uw.real,HPC_uw.imag,"o",color=testcolor,alpha=1,zorder=2)
#     else:
#         fig_eigenval.plot(HPC_uw.real,HPC_uw.imag,"o",color=color,alpha=0.1,zorder=1)
#     print(np.count_nonzero(HPC_uw.real>0))
#     print(np.count_nonzero(HPC_uw.imag==0))
#     print(np.count_nonzero((HPC_uw.imag==0)*(HPC_uw.real>0)))

# #     HPC_allw = HPC_w[0:20,2:]*HPC_w[40:60,2:]+2*HPC_inw[0:20]*HPC_w[40:60,2:]+HPC_w[0:20,2:]*HPC_inw[40:60]
    HPC_allw = HPC_w[0:20,2:]*HPC_w[40:60,2:]+HPC_inw[0:20]*HPC_w[40:60,2:]+HPC_w[0:20,2:]*HPC_inw[40:60]+HPC_inw[0:20]*HPC_inw[40:60]
# #     HPC_allw = HPC_w[0:20,2:]*HPC_w[40:60,2:]
# #     HPC_allw = HPC_inw[0:20]
# #     PFC_allw = PFC_w[0:20,:20]*PFC_w[40:60,:20]+2*PFC_inw[0:20]*PFC_w[40:60,:20]+PFC_w[0:20,:20]*PFC_inw[40:60]
# #     PFC_allw = PFC_w[0:20,:20]*PFC_w[40:60,:20]+PFC_inw[0:20]*PFC_w[40:60,:20]+PFC_w[0:20,:20]*PFC_inw[40:60]+PFC_inw[0:20]*PFC_inw[40:60]
#     PFC_allw = (PFC_w[0:20,:20]+PFC_inw[0:20])*(PFC_w[40:60,:20]+PFC_inw[40:60])
#     PFC_allw2 = PFC_w[0:20,20:]*PFC_w[40:60,20:]+PFC_inw[0:20]*PFC_w[40:60,20:]+PFC_w[0:20,20:]*PFC_inw[40:60]+PFC_inw[0:20]*PFC_inw[40:60]
#     PFC_allw2 = (PFC_w[0:20,20:]+PFC_inw[0:20])*(PFC_w[40:60,20:]+PFC_inw[40:60])
#     PFC_allw = PFC_w[0:20,:20]*PFC_w[40:60,:20]
#     PFC_allw2 = PFC_w[0:20,20:]*PFC_w[40:60,20:]
#     Re_allw  = Re_w[:,0:20] + Re_w[:,20:] + Re_inw
#     Re_allw  = Re_w[:,0:20] + Re_inw
#     Re_allw2  = Re_w[:,20:] + Re_inw
    Re_allw  = Re_w[:,0:20]
#     Re_allw2  = Re_w[:,20:]
#     all_uw,all_uv = LA.eig(HPC_allw)
#     all_uw,all_uv = LA.eig(HPC_inw[0:20]*HPC_inw[40:60])
#     all_uw,all_uv = LA.eig(PFC_allw)
#     all_uw,all_uv = LA.eig(PFC_allw2)
    all_uw,all_uv = LA.eig(PFC_inw[0:20]*PFC_inw[40:60])
#     all_uw,all_uv = LA.eig(Re_allw)
#     all_uw,all_uv = LA.eig(Re_allw2)
#     all_uw,all_uv = LA.eig(Re_inw)
#     all_uw,all_uv = LA.eig(HPC_allw @ Re_allw2)
#     all_uw,all_uv = LA.eig(PFC_allw2 * Re_allw)
#     all_uw,all_uv = LA.eig(HPC_allw * Re_allw2 + PFC_allw2 * Re_allw)
#     all_uw,all_uv = LA.eig(HPC_allw @ Re_allw2 + PFC_allw2 @ Re_allw)
#     all_uw,all_uv = LA.eig(HPC_allw * PFC_allw * Re_allw)
#     all_uw,all_uv = LA.eig(HPC_allw @ PFC_allw @ Re_allw)
#     all_uw,all_uv = LA.eig(HPC_allw @ Re_allw2 + HPC_allw @ PFC_allw @ Re_allw)
#     all_uw,all_uv = LA.eig(PFC_allw2 @ Re_allw)
#     all_uw,all_uv = LA.eig(HPC_inw[0:20]*HPC_inw[40:60] + HPC_allw * Re_allw2 + HPC_allw * PFC_allw * Re_allw)
#     all_uw,all_uv = LA.eig(HPC_inw[0:20]*HPC_inw[40:60] + HPC_allw @ Re_allw2 + HPC_allw @ PFC_allw @ Re_allw)
#     all_uw,all_uv = LA.eig(PFC_allw * Re_allw * HPC_allw + PFC_allw2 * Re_allw + PFC_inw[0:20]*PFC_inw[40:60])
#     all_uw,all_uv = LA.eig(PFC_allw @ Re_allw @ HPC_allw + PFC_allw2 @ Re_allw + PFC_inw[0:20]*PFC_inw[40:60])
#     all_uw,all_uv = LA.eig(np.pad(Out_w[:,:],[(0,18),(0,0)]))
#     all_uw,all_uv = LA.eig(Out_w[:,:]@(HPC_w[0:20,:2]*HPC_w[40:60,:2]))
#     all_uw,all_uv = LA.eig(np.pad(HPC_w[20:40,:2],[(0,0),(0,18)]))
#     all_uw,all_uv = LA.eig(np.pad(HPC_w[60:,:2],[(0,0),(0,18)]))
#     all_uw,all_uv = LA.eig(HPC_inw[20:40])
# #     print(HPC_inuw,HPC_inuv)

    if int(model_path.split("epoch")[-1].split(".")[0]) < 70:
        continue

    if model_path in good_list:
        testnum+=1
#         testcolor=cm.jet(testnum/len(test_list))
        testcolor = "b"
        test_colors.append(testnum/len(test_list))
#         fig_eigenval.plot(all_uw.real,all_uw.imag,"o",color=testcolor,alpha=0.3,zorder=1)
#         fig_eigenval.plot(0,np.average(abs(all_uw)),"o",color=testcolor,alpha=0.3,zorder=1)
        fig_eigenval.plot(int(model_path.split("epoch")[-1].split(".")[0]),np.average(abs(all_uw)),"o",color="b",alpha=0.3,zorder=2)
#         print(model_path,testcolor,np.average(abs(all_uw)))
        stat_list_A.append(np.average((abs(all_uw))))
#         print(all_uw)
#         print(all_uv.T)
#         print(np.count_nonzero(all_uw.real>0))
#         print(np.count_nonzero(all_uw.imag==0))
#         print(np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)))
#         print(np.average(abs(all_uw)))
#         for k in range(20):
#             fig_eigenval.text(all_uw.real[k],all_uw.imag[k],str(model_path.split("epoch")[-1].split(".")[0]))
    else:
#         fig_eigenval.plot(all_uw.real,all_uw.imag,"o",color=color,alpha=0.1,zorder=1)
#         fig_eigenval.plot(all_uw.real,all_uw.imag,"o",color="r",alpha=0.3,zorder=1)
#         fig_eigenval.plot(1,np.average(abs(all_uw)),"o",color="r",alpha=0.3,zorder=1)
        fig_eigenval.plot(int(model_path.split("epoch")[-1].split(".")[0]),np.average(abs(all_uw)),"o",color="b",alpha=0.3,zorder=1)

        stat_list_B.append(np.average((abs(all_uw))))
        print(model_path,np.average(abs(all_uw)))
        
        
# ## Eigenvalue first100 and last100 ##

#     if int(model_path.split("epoch")[-1].split(".")[0])<100:
#         testnum+=1
# #         testcolor=cm.jet(testnum/len(test_list))
#         testcolor = "k"
#         test_colors.append(testnum/len(test_list))
#         fig_eigenval.plot(all_uw.real,all_uw.imag,"o",color=testcolor,alpha=0.1,zorder=2)
# #         fig_eigenval.plot(0,np.average(abs(all_uw)),"o",color=testcolor,alpha=0.3,zorder=1)
# #         fig_eigenval.plot(int(model_path.split("epoch")[-1].split(".")[0]),np.average(abs(all_uw)),"o",color="b",alpha=0.3,zorder=1)
#         print(model_path,testcolor,np.average(abs(all_uw)))
#         stat_list_A.append(np.average((abs(all_uw))))
# #         print(all_uw)
# #         print(all_uv.T)
# #         print(np.count_nonzero(all_uw.real>0))
# #         print(np.count_nonzero(all_uw.imag==0))
# #         print(np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)))
# #         print(np.average(abs(all_uw)))
# #         for k in range(20):
# #             fig_eigenval.text(all_uw.real[k],all_uw.imag[k],str(model_path.split("epoch")[-1].split(".")[0]))
#     elif int(model_path.split("epoch")[-1].split(".")[0])>100:
# #         fig_eigenval.plot(all_uw.real,all_uw.imag,"o",color=color,alpha=0.1,zorder=1)
#         fig_eigenval.plot(all_uw.real,all_uw.imag,"o",color="c",alpha=0.1,zorder=1)
# #         fig_eigenval.plot(1,np.average(abs(all_uw)),"o",color="r",alpha=0.3,zorder=1)
# #         fig_eigenval.plot(int(model_path.split("epoch")[-1].split(".")[0]),np.average(abs(all_uw)),"o",color="b",alpha=0.3,zorder=1)

#         stat_list_B.append(np.average((abs(all_uw))))
#         print(model_path,np.average(abs(all_uw)))

        
#     if model_path in test_list_A_limited:
# #         ax3d.plot([np.count_nonzero(all_uw.real>0)], [np.count_nonzero(all_uw.imag==0)], [np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0))],"o", color = "r", alpha=0.1)
# #         ax3d.plot([np.count_nonzero((all_uw.imag==0)*(all_uw.real>0))], [np.count_nonzero((all_uw.imag==0)*(all_uw.real<0))], [0],"o", color = "r", alpha=0.1)
# #         ax3d.plot([np.count_nonzero((all_uw.imag==0)*(all_uw.real>0))], [np.count_nonzero((all_uw.imag==0)*(all_uw.real<0))], [np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0))],"o", color = "r", alpha=0.1)
# #         fig_eigenval.plot(all_uw.real,all_uw.imag,"o",color="r",alpha=0.3,zorder=2)
# #         fig_eigenval.plot(all_uw.real,all_uw.imag,"o",alpha=0.3,zorder=2)
# #         fig_eigenval.plot(np.zeros(20),all_uw.imag,"o",alpha=0.3,zorder=2)
#         fig_eigenval.plot(0,np.average(abs(all_uw)),"o",color="r",alpha=0.3,zorder=1)
# #         fig_eigenval.plot(all_uw.real[np.argmax(abs(all_uw))],all_uw.imag[np.argmax(abs(all_uw))],"o",alpha=0.3,zorder=2)
# #         fig_eigenval.plot(np.max(abs(all_uw)),"o",color="r",alpha=0.3,zorder=2)
# #         fig_eigenval.plot(np.average(abs(all_uw)),np.max(abs(all_uw)),"o",color="r",alpha=0.3,zorder=2)
# #         fig_eigenval.plot(np.max(np.abs(all_uw.real)),np.average(abs(all_uw)),"o",color="r",alpha=0.3,zorder=2)
# #         print(model_path,"bad")
# #         print([np.average(abs(all_uw)),np.max(all_uw.real),np.min(all_uw.real)])
# #         print([np.count_nonzero((all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real<0)), np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0)),np.count_nonzero((all_uw.imag!=0)*(all_uw.real<0))])
#         PCA_list_A = np.concatenate((PCA_list_A,[[np.count_nonzero((all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real<0)), np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0)),np.count_nonzero((all_uw.imag!=0)*(all_uw.real<0))]]),axis=0)
#         stat_list_A.append(np.sort((abs(all_uw))))
# #         stat_list_xy_A[0] = np.append(stat_list_xy_A[0],np.array(all_uw.real))
# #         stat_list_xy_A[0] = np.append(stat_list_xy_A[0],np.abs(all_uw.imag))
# #         stat_list_xy_A[1] = np.append(stat_list_xy_A[1],abs(all_uw))
# #         stat_list_xy_A[0] = np.append(stat_list_xy_A[0],all_uw.imag[np.where(all_uw.imag>0)])
# #         stat_list_xy_A[1] = np.append(stat_list_xy_A[1],abs(all_uw)[np.where(all_uw.imag>0)])
#         stat_list_xy_A[0] = np.append(stat_list_xy_A[0],np.average(abs(all_uw)))
#         stat_list_xy_A[1] = np.append(stat_list_xy_A[1],np.max(abs(all_uw)))
# #         stat_list_xy_A[1] = np.append(stat_list_xy_A[1],np.average(abs(all_uw))/np.max(abs(all_uw)))
# #         fig_eigenval.text(all_uw.real[np.argmax(abs(all_uw))],all_uw.imag[np.argmax(abs(all_uw))],str(model_path.split("s")[-1]))
# #         fig_eigenval.text(np.average(abs(all_uw)),np.max(abs(all_uw)),str(model_path.split("s")[-1]))
#         fig_eigenval.text(0,np.average(abs(all_uw)),str(model_path.split("s")[-1]))
#     if model_path in test_list_B_limited:
# #         ax3d.plot([np.count_nonzero(all_uw.real>0)], [np.count_nonzero(all_uw.imag==0)], [np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0))],"o", color = "b", alpha=0.1)
# #         ax3d.plot([np.count_nonzero((all_uw.imag==0)*(all_uw.real>0))], [np.count_nonzero((all_uw.imag==0)*(all_uw.real<0))], [1],"o", color = "b", alpha=0.1)
# #         ax3d.plot([np.count_nonzero((all_uw.imag==0)*(all_uw.real>0))],[np.count_nonzero((all_uw.imag==0)*(all_uw.real>0))], [np.count_nonzero((all_uw.imag==0)*(all_uw.real<0))], [np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0))],"o", color = "b", alpha=0.1)
# #         fig_eigenval.plot(all_uw.real,all_uw.imag,"o",color="b",alpha=0.3,zorder=1)
# #         fig_eigenval.plot(np.ones(20),all_uw.imag,"o",color="b",alpha=0.3,zorder=1)
#         fig_eigenval.plot(1,np.average(abs(all_uw)),"o",color="b",alpha=0.3,zorder=1)
# #         fig_eigenval.plot(all_uw.real[np.argmax(abs(all_uw))],all_uw.imag[np.argmax(abs(all_uw))],"o",color="b",alpha=0.3,zorder=2)
# #         fig_eigenval.plot(np.max(abs(all_uw)),"o",color="b",alpha=0.3,zorder=2)
# #         fig_eigenval.plot(np.average(abs(all_uw)),np.max(abs(all_uw)),"o",color="b",alpha=0.3,zorder=2)
# #         fig_eigenval.plot(np.max(np.abs(all_uw.real)),np.average(abs(all_uw)),"o",color="b",alpha=0.3,zorder=2)
# #         print(model_path,"good")
# #         print([np.average(abs(all_uw)),np.max(all_uw.real),np.min(all_uw.real)])
# #         print([np.count_nonzero((all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real<0)), np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0)),np.count_nonzero((all_uw.imag!=0)*(all_uw.real<0))])
#         PCA_list_B = np.concatenate((PCA_list_B,[[np.count_nonzero((all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)), np.count_nonzero((all_uw.imag==0)*(all_uw.real<0)), np.count_nonzero((all_uw.imag!=0)*(all_uw.real>0)),np.count_nonzero((all_uw.imag!=0)*(all_uw.real<0))]]),axis=0)
#         stat_list_B.append(np.sort((abs(all_uw))))
# #         stat_list_xy_B[0] = np.append(stat_list_xy_B[0],all_uw.real)
# #         stat_list_xy_B[0] = np.append(stat_list_xy_B[0],np.abs(all_uw.imag))
# #         stat_list_xy_B[1] = np.append(stat_list_xy_B[1],abs(all_uw))
# #         stat_list_xy_B[0] = np.append(stat_list_xy_B[0],abs(all_uw.imag[np.where(all_uw.imag>0)]))
# #         stat_list_xy_B[1] = np.append(stat_list_xy_B[1],abs(all_uw)[np.where(all_uw.imag>0)])
#         stat_list_xy_B[0] = np.append(stat_list_xy_B[0],np.average(abs(all_uw)))
#         stat_list_xy_B[1] = np.append(stat_list_xy_B[1],np.max(abs(all_uw)))
# #         stat_list_xy_B[1] = np.append(stat_list_xy_B[1],np.average(abs(all_uw))/np.max(abs(all_uw)))
# #         fig_eigenval.text(np.average(abs(all_uw)),np.max(abs(all_uw)),str(model_path.split("s")[-1]))
#         fig_eigenval.text(1,np.average(abs(all_uw)),str(model_path.split("s")[-1]))
        
        
# #     print(LA.matrix_rank(HPC_inw[i:i+20]))
#     print(np.count_nonzero(all_uw.real>0))
#     print(np.count_nonzero(all_uw.imag==0))
#     print(np.count_nonzero((all_uw.imag==0)*(all_uw.real>0)))
    
    
#     PCA_list = np.concatenate((PCA_list,[[a1,a2,c2]]),axis=0)
#     print(a1-a2)
    
#     fig1.hist(PFC_w[PFC_w>0],bins=40,alpha=0.5,density=True)
#     fig1.hist(np.abs(PFC_w[PFC_w<0]),bins=40,alpha=0.5,density=True)
#     fig2.hist(HPC_w[HPC_w>0],bins=40,alpha=0.5,log=True,density=True)
#     fig2.hist(np.abs(HPC_w[HPC_w<0]),bins=40,alpha=0.5,log=True,density=True)
#     fig3.hist(Re_w[Re_w>0],bins=40,alpha=0.5,log=True,density=True)
#     fig3.hist(np.abs(Re_w[Re_w<0]),bins=40,alpha=0.5,log=True,density=True)

# # # print(PCA_list[1:])
# PCA_list = np.concatenate((PCA_list_A[1:],PCA_list_B[1:]),axis=0)
# length = PCA_list_A[1:].shape[0]
# pca = PCA()
# dfs = PCA_list
# pca.fit(dfs)
# feature = pca.transform(dfs)
# # print(length,feature)
# plt.figure(figsize=(6, 6))
# plt.scatter(feature[:length, 0], feature[:length, 1], color="r",alpha=0.2)
# plt.scatter(feature[length:, 0], feature[length:, 1], color="b",alpha=0.2)
# #plt.scatter(feature[100:200, 0], feature[100:200, 1], alpha=0.8)
# plt.grid()
# plt.xlabel("PC1")
# plt.ylabel("PC2")
# # for i,n in enumerate(target_list):
# #     plt.annotate(n[-12:],(feature[i, 0], feature[i, 1]))
# plt.show()

# print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
# print(pd.DataFrame(pca.components_, columns=["Hidden{}".format(x + 1) for x in range(dfs.shape[1])], index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    

# fig3d = plt.figure()
# ax3d = Axes3D(fig3d)
# ax3d.plot(feature[:length, 0], feature[:length, 1], feature[:length, 2], "o", alpha=0.1)
# ax3d.plot(feature[length:, 0], feature[length:, 1], feature[length:, 2], "o", alpha=0.1)
# plt.show()

# gradient = np.array(test_colors)
# gradient = np.vstack((gradient, gradient))

# fig, ax = plt.subplots()
# ax.imshow(gradient, cmap=plt.cm.jet)
# ax.set_axis_off()

fig_eigenval.tick_params(axis="both",labelsize=28)

# print(np.array(stat_list_xy_A).T,np.array(stat_list_xy_B).T)
# st.f_oneway(np.array(stat_list_xy_A).T,np.array(stat_list_xy_B).T)
print(st.f_oneway(np.array(stat_list_A).reshape(-1,1),np.array(stat_list_B).reshape(-1,1)))
# st.ttest_ind(np.array(stat_list_xy_A).T,np.array(stat_list_xy_B).T,equal_var=False)
plt.figure()
#     plt.boxplot([good_points, bad_points],labels=["Good","Bad"])
#     violinplot = plt.violinplot([good_points, bad_points],showmedians=False,showextrema=False)

good_points = stat_list_A
bad_points = stat_list_B 

quartile1_good, medians_good, quartile3_good = np.percentile(good_points, [25, 50, 75])
#     lower_good,upper_good = adjacent_values(good_points, quartile1_good, quartile3_good)
quartile1_bad, medians_bad, quartile3_bad = np.percentile(bad_points, [25, 50, 75])

#     plt.vlines(1, quartile1_good, quartile3_good, color='k', linestyle='-', lw=5)
#     plt.vlines(1, upper_good,lower_good, color='k', linestyle='-', lw=1)
#     plt.scatter([1], [medians_good], marker='o', color='white', s=30, zorder=3)
#     plt.vlines(2, quartile1_bad, quartile3_bad, color='k', linestyle='-', lw=5)
#     violinplot["bodies"][1].set_facecolor("red")
palette = sns.color_palette("Pastel1")
my_pal = {0:palette[1],1:palette[0]}
my_pal_strip = {0:"b",1:"r"}
sns.stripplot(data=[np.random.choice(good_points,200),np.random.choice(bad_points, 200)],palette=my_pal_strip,zorder=2)
sns.violinplot(data=[good_points,bad_points], palette=my_pal,linewidth=2)
plt.scatter([0,1], [medians_good,medians_bad], marker='o', color='white', s=50, zorder=3)
plt.xticks([0,1],labels=["Good","Failed"])
plt.yticks(fontsize=28)
plt.xticks(fontsize=28)
plt.plot()
sns.set_palette("tab10")


In [None]:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 19 10:54:37 2021

@author: munenori
"""



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def correct_traj(traj):
    Rightflag = False
    Leftflag = False
    Pointflag = False
    correct_number = 0
    for i in range(traj.shape[0]):
        if traj[i,0] < 0.3 and Pointflag == False:
            Leftflag = True
            Pointflag = True
            if Rightflag == True:
                correct_number += 1
                Rightflag = False
        if traj[i,0] > 0.3 and traj[i,0] < 0.5 and Pointflag == True:
            Pointflag = False
            
        if traj[i,0] > 0.7 and Pointflag == False:
            Rightflag = True
            Pointflag = True
            if Leftflag == True:
                correct_number += 1
                Leftflag = False
        if traj[i,0] < 0.7 and traj[i,0] > 0.5 and Pointflag == True:
            Pointflag = False

    return correct_number
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation2(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_testdata(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_test_1.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_attracter_gate_PFC.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_3D(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.01):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

#     ###  ver1  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)


class MyLSTM_comp(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_comp, self).__init__()

        self.hidden_size = hidden_size
        self.LSTM1 = nn.LSTMCell(input_size, hidden_size)
        self.LSTM2 = nn.LSTMCell(hidden_size, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size

    def forward(self, input, hiddens):
        input = input.float()
        hidden1 = self.LSTM1(input, hiddens[0])
        hidden2 = self.LSTM2(hidden1[0], hiddens[1])
        output = self.linear(hidden2[0])
        return output, [hidden1,hidden2]

    def initHidden(self):
        hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [hidden,hidden]
    
    def initHidden_rand(self):
        hidden = [torch.rand(self.batch_size, self.hidden_size)*0.2, torch.rand(self.batch_size, self.hidden_size)*0.2]
        return [hidden,hidden]  
    
class MyLSTM_comp_cue2(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_comp_cue2, self).__init__()

        self.hidden_size = hidden_size
        self.LSTM1 = nn.LSTMCell(input_size-1, hidden_size)
        self.LSTM2 = nn.LSTMCell(hidden_size+1, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size

    def forward(self, input, hiddens):
        input = input.float()
        hidden1 = self.LSTM1(input[:,:2], hiddens[0])
        # print(hidden1[0].shape,torch.reshape(input[:,2],(10,1)))
        hidden2_input = torch.cat((hidden1[0],torch.reshape(input[:,2],(10,1))),dim=1)
        hidden2 = self.LSTM2(hidden2_input, hiddens[1])
        output = self.linear(hidden2[0])
        return output, [hidden1,hidden2]

    def initHidden(self):
        hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [hidden,hidden]
    
    def initHidden_rand(self):
        hidden = [torch.rand(self.batch_size, self.hidden_size)*0.2, torch.rand(self.batch_size, self.hidden_size)*0.2]
        # hidden = [torch.tensor(torch.rand(self.batch_size, self.hidden_size)*0.2,retain_graph=True), torch.tensor(torch.rand(self.batch_size, self.hidden_size)*0.2,retain_graph=True)]
        return [hidden,hidden]    


def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(states[start:])
    return states_list



def main():
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 30
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
#     model_path = 'model/compare30_H_bigbatch/Compare30_121H_s8_100_1_epoch120.pth'
    model_path = 'model/PFCHPC30_H_bigbatch/v4_3_121_s2_100_2_epoch185.pth'

    sparse = 1
    
    # delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_PFCHPC(inputsize, hidden_size, outputsize, batch_size)
    # rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))

    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
        
    
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    # hidden = rnn.initHidden()
    data_limit = 400
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
    for k in range(data.shape[0]*0+10):
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    fig = plt.figure()
    print(pltdata.shape)
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:,0,0],traj[:,0,1])
    plt.show()
    # MakeAnimation2(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    # MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    # MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")
    # MakeAnimation_img(np.array(Restate),"Re")
    #MakeAnimation_testdata(pltdata[:,0,0],pltdata[:,0,1])

    for n, p in rnn.named_parameters():
        if n == "LSTM1.weight_ih":
            PFC_w = np.array(p.data)
        if n == "LSTM2.weight_ih":
            HPC_w = np.array(p.data)
        if n == "LSTM1.weight_hh":
            PFC_inw = np.array(p.data)
        if n == "LSTM2.weight_hh":
            HPC_inw = np.array(p.data)
                
                

    fig = plt.figure()
    plt.plot(traj[:,0,0],traj[:,0,1])
#     plt.plot(traj_noise[:,0,0],traj_noise[:,0,1])
    plt.show()
    
    pca = PCA()
    dfs = np.array(PFCstate)[0:,0]
#     dfs = np.array(Restate)[0:]
    pca.fit(dfs)
    feature = pca.transform(dfs)

    for i in range(30):
        plt.figure()
#         plt.ylim(-1,1)
        plt.plot(traj[:,0,0],alpha=0.5)
        plt.plot(traj[:,0,1],alpha=0.5)
#         plt.plot(traj_noise[:,0,0],alpha=0.5)
#         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
        for k in range(1):
            plt.plot(np.array(HPCstate)[:,k,i],alpha=0.5)
#             plt.plot(np.array(HPCstate_noise)[:,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[:,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate_noise)[:,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[:120,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[:120,k,i]-np.array(PFCstate)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate_noise)[:120,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate_noise)[:120,k,i]-np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i],alpha=0.5)
#             plt.plot(np.array(Restate_noise)[:,i],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-np.array(Restate_noise)[:,i],alpha=0.5)
#             plt.plot(moving_average(np.array(Restate)[:,i])[2:],alpha=1)
#             plt.plot(moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
#             plt.plot(moving_average(np.array(Restate)[:,i])-moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-moving_average(np.array(Restate)[:,i])[4:-4],alpha=0.5)
#             plt.plot(np.array(Gate_states)[:,0,0+i],alpha=0.5)
#             plt.plot(np.array(Gate_states)[:,1,0+i],alpha=0.5)
#             plt.plot(np.array(Gate_states_noise)[:,1,0+i],alpha=0.5)
#             plt.plot(np.abs(np.array(Gate_states)[:,1,0+i]-np.array(Gate_states_noise)[:,1,0+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states)[:,1,0+i]/np.array(Gate_states)[:,1,20+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states_noise)[:,1,0+i]/np.array(Gate_states_noise)[:,1,20+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states)[:,0,0+i]/np.array(Gate_states)[:,0,20+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states_noise)[:,0,0+i]/np.array(Gate_states_noise)[:,0,20+i]),alpha=0.5)
            plt.plot(np.array(feature)[:,0],alpha=0.5)
        plt.title("neuron#"+str(i+1))
    
#     for i in range(20):
#         plt.figure()
#         plt.ylim(-1,1)
# #         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
#         plt.plot(traj[:,0,0])
#         plt.plot(traj_noise[:,0,0])
#         for k in range(1):
#             plt.plot(np.array(HPCstate)[:,k,i],alpha=0.5)
#             plt.plot(np.array(HPCstate_noise)[:,k,i],alpha=0.5)
# #             plt.plot(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i],alpha=0.5)
# #             plt.plot(np.array(PFCstate)[:,k,i],alpha=0.5)
# #             plt.plot(np.array(PFCstate_noise)[:,k,i],alpha=0.5)
# #             plt.plot(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i],alpha=0.5)
# #             plt.plot(np.array(PFCstate)[:120,k,i]-np.array(PFCstate)[120:240,k,i],alpha=0.5)
# #             plt.plot(np.array(PFCstate_noise)[:120,k,i],alpha=0.5)
# #             plt.plot(np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
# #             plt.plot(np.array(PFCstate_noise)[:120,k,i]-np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
# #             plt.plot(np.array(Restate)[:300,i],alpha=0.5)
# #             plt.plot(np.array(Restate_noise)[:300,i],alpha=0.5)
# #             plt.plot(moving_average(np.array(Restate)[:300,i]),alpha=0.5)
# #             plt.plot(moving_average(np.array(Restate_noise)[:300,i]),alpha=0.5)
#         plt.title("HPCneuron#"+str(i+1))
    
#     plt.figure()
#     a = np.abs(np.corrcoef(np.array(PFCstate)[:,0].T) - np.corrcoef(np.array(PFCstate_noise)[:,0].T))
# #     a = np.abs(np.corrcoef(np.array(Restate)[:].T) - np.corrcoef(np.array(Restate_noise)[:].T))
#     plt.imshow(a,cmap=plt.get_cmap("Reds"))
    
    pca = PCA()
    dfs = np.array(PFCstate)[0:,0]
#     dfs = np.array(Restate)[0:]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
#     delays = pick_traj(traj_noise[:,0], np.array(PFCstate_noise)[:,0])
#     delays = pick_traj(traj[:,0], np.array(PFCstate)[:,0])
#     delays = pick_traj(traj[:,0], np.array(HPCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(HPCstate_noise)[:,0])
#     delays = pick_traj(traj[:,0], np.array(HPCstate)[:,0])
#     delays = pick_traj(traj[:,0], np.array(Restate)[:])
#     delays = pick_traj(traj_noise[:,0], np.array(Restate_noise)[:])
#     delays = pick_traj(traj_noise[:,0], np.array(traj_noise)[:])[3:-1]
#     delays = pick_traj(traj[:,0], np.array(Gate_states)[:,1])
#     delays = pick_traj(traj_noise[:,0], np.array(Gate_states_noise)[:,1])
    delays = pick_traj(traj[:,0], np.array(feature)[:])
    bifur = np.array([])
    for i in range(1):
        plt.figure()
#         plt.ylim(-1,1)
        plt.plot(traj[:120,0,0],alpha=0.5)
#         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
        for data in delays[:]:
#             plt.plot(np.array(data)[:,0+i,0],alpha=0.5)
            plt.plot(np.array(data)[:,0+i],alpha=0.5)
#             print(len(data))
#             check_bifur = np.argmax(np.abs(np.array(data)[:,0+i,0] - 0.5))
#             bifur = np.append(bifur,data[check_bifur,i,0])
#             print(check_bifur,data[check_bifur,i,0])
        plt.title("neuron#"+str(i+1))
#     print(np.var(bifur),np.median(bifur)-np.mean(bifur))

#     for i in range(20):
#         plt.figure()
#         for k in range(1):
#             k+=0
#             data = PFC_w_a[i]
#             plt.plot(data)
#             data = PFC_w_b[i]
#             plt.plot(data)

#     length = 140
# #     plt.figure()
# #     plt.plot(np.array(Restate_noise)[:length,0])
# #     plt.show()
#     neuron_list = []
#     neuron_list_Re = []
#     for i in range(20):
#         plt.figure()
#         neurons_a = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0]),axis=1)
#         neurons_b = np.concatenate((np.array(PFCstate_noise)[:,0],np.array(HPCstate_noise)[:,0]),axis=1)
#         plt.plot(np.array(Restate)[:length,i],alpha=0.3)
#         plt.plot(np.array(Restate_noise)[:length,i],alpha=0.3)
#         plt.ylim(-1,1)
#         for k in range(40):
#             k += 0
#             data_a = neurons_a[:length,k]*Re_w_a[i,k]
#             data_b = neurons_b[:length,k]*Re_w_b[i,k]
#             diff = data_a - data_b
#             if np.sum(np.abs(data_a)) > 5 and np.sum(np.abs(data_a)) > 5:
#                 print("a",k,Re_w_a[i,k],np.sum(np.abs(data_a)))
#                 plt.plot(data_a)
#             if np.sum(np.abs(data_b)) > 5 and np.sum(np.abs(data_b)) > 5:
#                 print("b",k,Re_w_b[i,k],np.sum(np.abs(data_b)))
#                 plt.plot(data_b)
#             if np.sum(np.abs(diff)) > 5 and np.sum(np.abs(diff)) > 5:
#                 print("diff",k,Re_w_a[i,k],Re_w_b[i,k],np.sum(np.abs(diff)))
#                 neuron_list.append(k+1)
# #                 plt.plot(diff)
            
#     for i in range(20):
#         plt.figure()
#         i+=0
#         neurons_a = np.array(Restate)[:]
#         neurons_b = np.array(Restate_noise)[:]
#         plt.plot(np.array(Restate)[:length,i],alpha=0.3)
#         plt.plot(np.array(Restate_noise)[:length,i],alpha=0.3)
#         plt.ylim(-1,1)
#         for k in range(20):
#             k += 0
#             data_a = neurons_a[:length,k]*Re_inw_a[i,k]
#             data_b = neurons_b[:length,k]*Re_inw_b[i,k]
#             diff = data_a - data_b
#             if np.sum(np.abs(data_a)) > 5 and np.sum(np.abs(data_a)) > 5:
#                 print("a",k,Re_inw_a[i,k],np.sum(np.abs(data_a)))
#                 plt.plot(data_a)
#             if np.sum(np.abs(data_b)) > 5 and np.sum(np.abs(data_b)) > 5:
#                 print("b",k,Re_inw_b[i,k],np.sum(np.abs(data_b)))
#                 plt.plot(data_b)
#             if np.sum(np.abs(diff)) > 5 and np.sum(np.abs(diff)) > 5:
#                 print("diff",k,Re_inw_a[i,k],Re_inw_b[i,k],np.sum(np.abs(diff)))
#                 neuron_list_Re.append(k+1)
# #                 plt.plot(diff)

    pca = PCA()
    dfs = np.array(HPCstate)[0:,0]
#     dfs = np.array(Restate)[0:]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    
#     pred = KMeans(n_clusters=2).fit_predict(feature)

    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
#     ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
    ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()
    
    plt.figure()
    plt.plot(feature[:240, 0], feature[:240, 1], alpha=0.8)
    plt.show()
    
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    
#     delays = pick_traj(traj[:,0], feature[:])
#     for i in range(20):
#         plt.figure()
# #         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
#         for data in delays[1:-1]:
# #             plt.plot(np.array(data)[:,0+i,0],alpha=0.5)
# #             plt.plot(np.array(data)[:,0+i],alpha=0.5)
#             plt.plot(np.array(data)[:,0+i]-moving_average(np.array(data)[:,0+i])[2:-2],alpha=0.5)
# #             print(len(data))
# #             check_bifur = np.argmax(np.abs(np.array(data)[:,0+i,0] - 0.5))
# #             bifur = np.append(bifur,data[check_bifur,i,0])
# #             print(check_bifur,data[check_bifur,i,0])
#         plt.title("neuron#"+str(i+1))

#     print(collections.Counter(neuron_list))
#     print(collections.Counter(neuron_list_Re))

    data_limit = 120
    pca = PCA()
    dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
#     dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    ratios = pca.explained_variance_ratio_
    print("PFC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
#     PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
    PFC_a = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:data_limit*2,0]))[0,1]
    PFC_b = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:data_limit*2,1]))[0,1]
    PFC_c = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:data_limit*2,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:data_limit*2,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:data_limit*2,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:data_limit*2,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]*ratios[0]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]*ratios[1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]*ratios[2]

#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
#     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
#     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
#     plt.show()
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()

#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.array(HPCstate)[:,0]
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    ratios = pca.explained_variance_ratio_
    print("HPC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
#     HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
    HPC_a = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:data_limit*2,0]))[0,1]
    HPC_b = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:data_limit*2,1]))[0,1]
    HPC_c = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:data_limit*2,2]))[0,1]
#     HPC_a = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:data_limit*2,0]))[0,1]
#     HPC_b = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:data_limit*2,1]))[0,1]
#     HPC_c = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:data_limit*2,2]))[0,1]
#     HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]*ratios[0]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]*ratios[1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]*ratios[2]
    HPC_diff = np.sqrt(np.power(feature[:data_limit,0]-feature[data_limit:data_limit*2,0],2) + np.power(feature[:data_limit,1]-feature[data_limit:data_limit*2,1],2) + np.power(feature[:data_limit,2]-feature[data_limit:data_limit*2,2],2))
    HPC_d = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),HPC_diff)[0,1]
#     HPC_d = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),HPC_diff)[0,1]
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
#     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
#     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
#     plt.show()

#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    print(np.mean([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),HPC_d)
#     print(traj[:,0,1],traj_noise[:,0,1],np.abs(traj[:,0,1]-traj_noise[:,0,1]))
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])
    
if __name__ == '__main__':
    main()

In [None]:
get_ipython().magic('matplotlib notebook')

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 19 10:54:37 2021

@author: munenori
"""

plt.rcParams["figure.subplot.left"] = 0.15

#######   for phase?? like?? check   ##########

import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd

%matplotlib notebook


def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniPFC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation2(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_testdata(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_test_1.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter_gate_PFC.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_3D(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims, interval=100)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0


def mkOwnDataSet(data_size, data_length=100, freq=60., noise=0.01):
   
    x = np.loadtxt("primal_long131test_r.csv",delimiter=',')
    y = np.loadtxt("primal_long131test_l.csv",delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []
    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
    # ###  ver2  ###
    # orders = []
    # order  = np.array([1])
    # order  = np.append(order,np.zeros(delay_length))
    # order = np.append(order, [2])
    # orders.append(order) 
    # order  = np.array([2])
    # order  = np.append(order,np.zeros(delay_length))
    # order = np.append(order, [3])
    # orders.append(order) 
    # order  = np.array([3])
    # order  = np.append(order,np.zeros(delay_length))
    # order = np.append(order, [4])
    # orders.append(order) 
    # order  = np.array([4])
    # order  = np.append(order,np.zeros(delay_length))
    # order = np.append(order, [1])
    # orders.append(order) 
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    
    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    
    
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
#                 target = np.array([[0.5,0] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.0
        v = 0.01
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]



class MyLSTM_RNN_noise(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse):
        super(MyLSTM_RNN_noise, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+10
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        c = 0.02
        v = 0.02
        HPC_noised = hiddens[1][0] + torch.randn(self.batch_size, self.hidden_size_HPC)*c
        PFC_noised = hiddens[0][0] + torch.randn(self.batch_size, self.hidden_size_PFC)*v
        Re_noised = hiddens[2] + torch.randn(self.batch_size, self.hidden_size_Re)*c
        Re_input = torch.cat([PFC_noised,HPC_noised],dim=1)
        Re_hidden = self.Re(Re_input, Re_noised)
        HPC_input = torch.cat([input,Re_noised],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([HPC_noised,Re_noised],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]


class MyLSTM_vHPC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_vHPC, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.vHPC = nn.LSTMCell(hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        sparse = 0.1
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.vHPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.vHPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[3][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        vHPC_input = hiddens[1][0]
        vHPC_hidden = self.vHPC(vHPC_input,hiddens[2])
        PFC_input = torch.cat([hiddens[2][0],hiddens[3][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, vHPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        vHPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,vHPC_hidden,Re_hidden]

class MyLSTM_3lay(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM_3lay, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        sparse = 0.1
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input,hiddens[2])
        PFC_input = hiddens[2][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        vHPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
class MyLSTM_feedforward(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse):
        super(MyLSTM_feedforward, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.Linear(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input)
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]
    
    
class MyLSTM_RNN_uniPFC(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=1):
        super(MyLSTM_RNN_uniPFC, self).__init__()

        self.hidden_size_PFC = hidden_size+0
        self.hidden_size_HPC = hidden_size+0
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.Re.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)
        nn.init.sparse_(self.Re.weight_hh.data,sparse)
        # nn.init.normal_(self.PFC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_ih.data,0,1/10)
        # nn.init.normal_(self.Re.weight_ih.data,0,0.1/10)
        # nn.init.normal_(self.PFC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.HPC.weight_hh.data,0,1/10)
        # nn.init.normal_(self.Re.weight_hh.data,0,1/10)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = hiddens[1][0]
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.0
        v = 0.01
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*v
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]

    

def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq

def main(num):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    # model_path = 'model/ReModel_L2_interRNNrand_Reinh_AddRe_OUT5_H121_s'+str(num)+'_100_2_2.pth'
    # model_path = 'model/ReModel_L2_interRNNrand_Reinh_AddRe_OUT5_H121_s2_100_1_2.pth'
#     model_path = 'model/R20_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s6_100_3_epoch130.pth'
    model_path = 'model/R20_H_stopinit_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s7_100_1_epoch195.pth'
#     model_path = 'model/R20_H_uniPFC_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s10_100_2_epoch170.pth'
#     model_path = 'model/R20_H_uniHPC_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s1_100_2_epoch160.pth'
#     model_path = 'model/R20_uniPFC_H/ReModel_L2_interRNNrand_OUT1_uniPFC_121H_s5_100_2_epoch85.pth'
#     model_path = 'model/R20_uniHPC_H/ReModel_L2_interRNNrand_OUT1_121H_s2_100_1_epoch100.pth'
#     model_path = 'model/R20_H_transfer2B/ReModel_L2_interRNNrand_OUT1_transfers14_121_s2_100_2_epoch60.pth'
#     model_path = 'model/R20FF_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s4_100_3_epoch190.pth'
#     model_path = 'model/R20FF_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s8_100_1_epoch180.pth'


    sparse = 1
    delay_length = 2
    
    if os.path.exists(model_path):
        print(model_path)
    else:
        print("Not exist")
        return

    # train_x,train_y = mkOwnDataSet(training_size,data_length)
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)
    test_x = mkOwnDataSet(test_size,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size, sparse)
#     rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size, sparse)
#     rnn = MyLSTM_feedforward(inputsize, hidden_size, outputsize, batch_size, sparse)
    
    rnn.load_state_dict(torch.load(model_path))
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()                
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 600
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
    for k in range(data.shape[0]*0+10):
#             if output.tolist()[0][1]<0.05:
#                 hidden = rnn.noiseHidden_rand(hidden)
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             print(output)
#             Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    fig = plt.figure()
    print(pltdata.shape)
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:,0,0],traj[:,0,1])
    plt.show()
    
    pca = PCA()
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0],np.array(Restate)),axis=1)
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0]),axis=1)
    dfs = np.array(HPCstate)[:,0]
#     dfs = np.array(Restate)
    pca.fit(dfs)
    HPCfeature = pca.transform(dfs)
    
    pca = PCA()
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0],np.array(Restate)),axis=1)
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0]),axis=1)
#     dfs = np.array(HPCstate)[:,0]
    dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    datalen = 120
    print(np.corrcoef(np.abs(traj[:datalen*1,0,1]-traj[datalen*2:datalen*3,0,1]),np.abs(feature[:datalen*1,0]-feature[datalen*2:datalen*3,0]))[0,1])
    
    delaysA = pick_delay(traj[:,0], np.array(PFCstate)[:,0])
#     delaysA = pick_delay(traj[:,0], np.array(HPCfeature)[:])
    delaysB = pick_delay(traj[:,0], feature[:])

    PCAnum = 0
    for i in range(20):
        plt.figure()
        for k,data in enumerate(delaysA[1:-1]):
#             plt.plot(data[:,i])
#             plt.plot(np.array(delaysB[1+k])[:,0])
            movingA = moving_average(data[:,i])
            movingB = moving_average(np.array(delaysB[1+k])[:,PCAnum])
            fracA = data[:,i] - movingA[2:-2]
            fracB = np.array(delaysB[1+k])[:,PCAnum]- movingB[2:-2]
#             print(np.corrcoef(fracA,fracB)[0,1])
            fracA[np.abs(fracA)<0.001] = 0
            print("No"+str(k))
            for t in range(8):
#                 print(np.corrcoef(np.sign(fracA)[:-3],np.sign(fracB)[k:k-3])[0,1])
                print(math.dist(np.sign(fracA)[4:-4],np.sign(fracB)[t:t-8]))
#             print(data.shape)
        plt.plot((fracA))
        plt.plot((fracB))
        plt.title("neuron#"+str(i+1))
    
    data_limit = 120
    pca = PCA()
    dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
#     dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    ratios = pca.explained_variance_ratio_
    print("PFC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
#     PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:data_limit*2,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:data_limit*2,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:data_limit*2,2]))[0,1]
    PFC_a = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:data_limit*2,0]))[0,1]
    PFC_b = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:data_limit*2,1]))[0,1]
    PFC_c = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:data_limit*2,2]))[0,1]
#     PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]*ratios[0]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]*ratios[1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]*ratios[2]

#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
#     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
#     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
#     plt.show()
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
#     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
#     plt.show()

#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.array(HPCstate)[:,0]
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    ratios = pca.explained_variance_ratio_
    print("HPC correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
#     HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     HPC_a = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:data_limit*2,0]))[0,1]
#     HPC_b = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:data_limit*2,1]))[0,1]
#     HPC_c = np.corrcoef(np.abs(traj[:data_limit,0,0]-traj[data_limit:data_limit*2,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:data_limit*2,2]))[0,1]
    HPC_a = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:data_limit*2,0]))[0,1]
    HPC_b = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:data_limit*2,1]))[0,1]
    HPC_c = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:data_limit*2,2]))[0,1]
#     HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]*ratios[0]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]*ratios[1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]*ratios[2]
    HPC_diff = np.sqrt(np.power(feature[:data_limit,0]-feature[data_limit:data_limit*2,0],2) + np.power(feature[:data_limit,1]-feature[data_limit:data_limit*2,1],2) + np.power(feature[:data_limit,2]-feature[data_limit:data_limit*2,2],2))
    HPC_d = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),HPC_diff)[0,1]
#     HPC_d = np.corrcoef(np.abs(traj[:data_limit,0,1]-traj[data_limit:data_limit*2,0,1]),HPC_diff)[0,1]
    
#     plt.figure()
#     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
#     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
#     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
#     plt.show()

#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    print(np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),HPC_d)
        
                              
if __name__ == '__main__':
    features = []
    for i in range(1):
        features.append(main(i+1))

In [None]:
########## stft include ######################


#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 19 10:54:37 2021

@author: munenori
"""

plt.rcParams["figure.subplot.left"] = 0.15

import numpy as np
# import os
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
# import time
# import math
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation
# import sklearn 
# from sklearn.decomposition import PCA
# from sklearn.cluster import KMeans
# from mpl_toolkits.mplot3d import Axes3D
# import pandas as pd
from scipy.fft import fft 
from scipy import signal


%matplotlib notebook


    

def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq

def main(num):
    training_size = 100
    test_size = 1
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    # model_path = 'model/ReModel_L2_interRNNrand_Reinh_AddRe_OUT5_H121_s'+str(num)+'_100_2_2.pth'
    # model_path = 'model/ReModel_L2_interRNNrand_Reinh_AddRe_OUT5_H121_s2_100_1_2.pth'
    model_path = 'model/R20_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s2_100_3_epoch150.pth'
#     model_path = 'model/R20_H_stopinit_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s3_100_1_epoch295.pth'
#     model_path = 'model/R20_H_uniPFC_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s10_100_2_epoch170.pth'
#     model_path = 'model/R20_H_uniHPC_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s1_100_2_epoch160.pth'
#     model_path = 'model/R20_uniPFC_H/ReModel_L2_interRNNrand_OUT1_uniPFC_121H_s5_100_2_epoch85.pth'
#     model_path = 'model/R20_uniHPC_H/ReModel_L2_interRNNrand_OUT1_121H_s2_100_1_epoch100.pth'
#     model_path = 'model/R20_H_transfer2B/ReModel_L2_interRNNrand_OUT1_transfers14_121_s2_100_2_epoch60.pth'
#     model_path = 'model/R20FF_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s4_100_3_epoch190.pth'
#     model_path = 'model/R20FF_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s8_100_1_epoch180.pth'

    model_path_noise = 'model/R20_H_bigbatch/ReModel_L2_interRNNrand_OUT1_121_s2_100_3_epoch180.pth'


    sparse = 1
    delay_length = 2
    
    if os.path.exists(model_path):
        print(model_path)
    else:
        print("Not exist")
        return

    # train_x,train_y = mkOwnDataSet(training_size,data_length)
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)
#     test_x = mkOwnDataSet(test_size,data_length)

    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size, sparse)
#     rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size, sparse)
#     rnn = MyLSTM_feedforward(inputsize, hidden_size, outputsize, batch_size, sparse)
    
    rnn.load_state_dict(torch.load(model_path))
    
    pattern = 4
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()                
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
        

    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 40
    est_length = 2
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
    for k in range(data.shape[0]*est_length+10):
#             if output.tolist()[0][1]<0.05:
#                 hidden = rnn.noiseHidden_rand(hidden)
            output,hidden = rnn(output,hidden) 
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
#             print(output)
#             Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
    traj = torch.tensor(traj)
    pltdata = torch.squeeze(data).numpy()
    traj = torch.squeeze(traj).numpy()
    fig = plt.figure()
    print(pltdata.shape)
    plt.plot(pltdata[:,0,0],pltdata[:,0,1],"--")
    plt.plot(traj[:,0,0],traj[:,0,1])
    plt.show()
    
    
    rnn.load_state_dict(torch.load(model_path_noise))
    dividenum = int(np.array(PFCstate)[0:,0].shape[0]/2)
    traj_noise = []
    PFCstate_noise = []
    HPCstate_noise = []
    Restate_noise = []
    hidden = rnn.initHidden_rand()
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*est_length+10):
#             hidden = rnn.noiseHidden_rand(hidden)
            output,hidden = rnn(output,hidden) 
#             hidden = rnn.noiseHidden_dis(hidden, np.array(Restate)[k+data_limit], np.array(Restate)[dividenum+k+data_limit])
#             hidden = rnn.noiseHidden_rand(hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
            #Gate_states.append(Culc_gate(output,params,hidden))
    traj_noise = torch.tensor(traj_noise)
    traj_noise = torch.squeeze(traj_noise).numpy()
    
    
    
    # MakeAnimation2(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    # MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    # MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")
    # MakeAnimation_img(np.array(Restate),"Re")
    #MakeAnimation_testdata(pltdata[:,0,0],pltdata[:,0,1])
    # MakeAnimation_data(pltdata[:,0,0],pltdata[:,0,1],data_limit)

    # for n, p in rnn.named_parameters():
    #         if n == "PFC.weight_ih":
    #             PFC_w = np.array(p.data)
    #         if n == "HPC.weight_ih":
    #             HPC_w = np.array(p.data)
    #         if n == "Re.weight_ih":
    #             Re_w = np.array(p.data)
                
    # # for n, p in rnn.named_parameters():
    # #         if n == "PFC.weight_hh":
    # #             PFC_w = np.array(p.data)
    # #         if n == "HPC.weight_hh":
    # #             HPC_w = np.array(p.data)
    # #         if n == "Re.weight_hh":
    # #             Re_w = np.array(p.data)
                
    # fig2 = plt.figure()
    # ax1 = fig2.add_subplot(131)
    # ax2 = fig2.add_subplot(132)
    # axre = fig2.add_subplot(133)
    
    # ax1.imshow(PFC_w,cmap="coolwarm")
    # ax2.imshow(HPC_w,cmap="coolwarm")
    # axre.imshow(Re_w,cmap="coolwarm")
    # ax1.set_title("max = {:.2f},min = {:.2f}".format(np.max(PFC_w),np.min(PFC_w)))
    # ax2.set_title("max = {:.2f},min = {:.2f}".format(np.max(HPC_w),np.min(HPC_w)))
    # axre.set_title("max = {:.2f},min = {:.2f}".format(np.max(Re_w),np.min(Re_w)))

    
    # fig3 = plt.figure()
    # ax3 = fig3.add_subplot(131)
    # ax4 = fig3.add_subplot(132)
    # axre2 = fig3.add_subplot(133)
    # ax3.imshow(np.corrcoef(np.array(PFCstate)[:,0]))
    # ax4.imshow(np.corrcoef(np.array(HPCstate)[:,0]))
    # axre2.imshow(np.corrcoef(np.array(Restate)))
    # ax3.set_title("PFC")   
    # ax4.set_title("HPC")  
    # axre2.set_title("Re")  
    
    # Replot = plt.figure()
    # axre2 = Replot.add_subplot(111)
    # axre2.imshow(np.corrcoef(np.array(Restate)))
    # axre2.set_title("Re")  
    
#     pca = PCA()
#     # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0],np.array(Restate)),axis=1)
#     # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0]),axis=1)
#     dfs = np.array(HPCstate)[:,0]
# #     dfs = np.array(Restate)
#     pca.fit(dfs)
#     feature = pca.transform(dfs)
    # print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
    # print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    # print(pd.DataFrame(pca.components_, columns=["Hidden{}".format(x + 1) for x in range(dfs.shape[1])], index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    # pred = KMeans(n_clusters=2).fit_predict(feature)
    
    # plt.figure(figsize=(6, 6))
    # plt.scatter(feature[:, 0], feature[:, 1], alpha=0.8)
    # #plt.scatter(feature[100:200, 0], feature[100:200, 1], alpha=0.8)
    # plt.scatter(feature[0, 0], feature[0, 1], c="r", alpha=0.8)
    # plt.grid()
    # plt.xlabel("PC1")
    # plt.ylabel("PC2")
    # for i in range(np.min([200,data_limit+k])):
    #     plt.annotate(i,(feature[i, 0], feature[i, 1]))
    # plt.show()
    
    # # MakeAnimation_attracter(feature[:, 0], feature[:, 1])
    
#     fig3d = plt.figure()
#     ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.3)
    
    # fig_place = plt.figure()
    # place3d = Axes3D(fig_place)
    # place3d.plot(traj[:,0,0], traj[:,0,1], np.array(HPCstate)[:,0,0])
    
    # fig4 = plt.figure(figsize=(10,5))
    # ax5 = fig4.add_subplot(311)
    # ax6 = fig4.add_subplot(312)
    # axRe3 = fig4.add_subplot(313)
    # #ax5.plot(np.array(PFCstate)[:,0,pred==1])
    # ax5.imshow(np.array(PFCstate)[:200,0,:].T)
    # # ax5.plot(np.average(np.array(PFCstate)[:,0,pred==0],axis=1))
    # # ax5.plot(np.average(np.array(PFCstate)[:,0,pred==1],axis=1))
    # # ax5.plot(np.average(np.array(HPCstate)[:,0],axis=1))
    # ax6.imshow(np.array(HPCstate)[:200,0,:].T)
    # axRe3.imshow(np.array(Restate)[:200,:].T)
    # ax5.set_title("PFC")   
    # ax6.set_title("HPC")  
    # axRe3.set_title("Re") 
    # print(pred==0)
    
    # # fig5 = plt.figure(figsize=(5,5))
    # # plt.plot(np.average(np.array(PFCstate)[:,0,pred==0],axis=1),np.average(np.array(PFCstate)[:,0,pred==1],axis=1))
    # # plt.show()
    # #MakeAnimation_attracter(np.average(np.array(PFCstate)[:,0,pred==0],axis=1),np.average(np.array(PFCstate)[:,0,pred==1],axis=1))
    
    
    # fig6 = plt.figure(figsize=(10,20))
    # plt.imshow(np.array(Gate_states)[:200,0,0:].T)
    # print(np.max(np.array(Gate_states)[:200,0,0:].T),np.min(np.array(Gate_states)[:,0,0:].T),np.average(np.array(Gate_states)[:,0,0:].T))
    # plt.show()
    # fig7 = plt.figure(figsize=(10,20))
    # plt.imshow(np.array(Gate_states)[:200,1,0:].T)
    # print(np.max(np.array(Gate_states)[:200,1,0:].T),np.min(np.array(Gate_states)[:,1,0:].T),np.average(np.array(Gate_states)[:,1,0:].T))
    # plt.show()
    
    # plt.figure()
    # plt.plot(np.array(Gate_states)[:,1,14])
    # plt.plot(np.array(Gate_states)[:,1,34])
    # plt.plot(np.array(Gate_states)[:,1,54])
    # plt.plot(np.array(Gate_states)[:,1,74])
    # plt.ylim(0,1)
    # plt.show()
    
    # plt.figure()
    # plt.plot(np.array(Gate_states)[:,1,13])
    # plt.plot(np.array(Gate_states)[:,1,33])
    # plt.plot(np.array(Gate_states)[:,1,53])
    # plt.plot(np.array(Gate_states)[:,1,73])
    # plt.ylim(0,1)
    # plt.show()
    
    
    # plt.figure(figsize=(20,20))
    # plt.plot(np.array(HPCstate)[:200,0,:10])
    # plt.show()
    
    #np.save("right_traj.npy",dfs)
    #np.save("left_traj.npy",dfs)
    
    # fig5 = plt.figure()
    # plt.hist(PFC_w[PFC_w.nonzero()],bins=400,range=(-2,2))
    # fig6 = plt.figure()
    # plt.hist(HPC_w[HPC_w.nonzero()],bins=400,range=(-2,2))    
    # fig7 = plt.figure()
    # plt.hist(Re_w[Re_w.nonzero()],bins=400,range=(-2,2))   
    
    
    # pca = PCA()
    # dfs = Re_w
    # pca.fit(dfs)
    # feature = pca.transform(dfs)
    # pred = KMeans(n_clusters=2).fit_predict(feature)
    
    # plt.figure(figsize=(6, 6))
    # plt.scatter(feature[:, 0], feature[:, 1], alpha=0.8)
    # plt.grid()
    # plt.xlabel("PC1")
    # plt.ylabel("PC2")

    # plt.show()
    
#     plt.figure()
#     plt.plot(traj[:,0,0])
#     plt.plot(traj[:,0,1])
    
    delays = pick_traj(traj[:,0], np.array(traj)[:])[3:-1]
    bifur = np.array([])
    for data in delays:
        check_bifur = np.argmax(np.abs(np.array(data)[:,0,0] - 0.5))
        bifur = np.append(bifur,data[check_bifur,:,0])
    if bifur.shape[0] == 0:
        pass
    plt.figure()
    plt.plot(traj[:,0,0])
    plt.plot(traj[:,0,1])
    print(np.var(bifur),np.abs(np.median(bifur)-np.mean(bifur)))
    print(bifur)
    
    pca = PCA()
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0],np.array(Restate)),axis=1)
    # dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(HPCstate)[:,0]),axis=1)
#     dfs = np.array(HPCstate)[:,0]
    dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    datalen = 120
#     print(np.corrcoef(np.abs(traj[:datalen*1,0,1]-traj[datalen*2:datalen*3,0,1]),np.abs(feature[:datalen*1,0]-feature[datalen*2:datalen*3,0]))[0,1])
    
#     delays = pick_traj(traj_noise[:,0], np.array(PFCstate_noise)[:,0])
    delays = pick_traj(traj[:,0], np.array(PFCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(HPCstate_noise)[:,0])
#     delays = pick_traj(traj[:,0], np.array(HPCstate)[:,0])
#     delays = pick_traj(traj_noise[:,0], np.array(Restate_noise)[:])
#     delays = pick_traj(traj[:,0], np.array(Restate)[:])
#     delays = pick_delay(traj[:,0], np.array(Restate)[:])
#     delays = pick_traj(traj_noise[:,0], np.array(traj_noise)[:])[3:-1]
#     delays = pick_traj(traj[:,0], np.array(Gate_states)[:,0,0:20]-np.array(Gate_states)[:,0,20:40])
#     delays = pick_traj(traj[:,0], np.array(Gate_states)[:,0,0:20])
#     delays = pick_delay(traj[:,0], np.array(Gate_states)[:,1])
#     delays = pick_traj(traj_noise[:,0], np.array(Gate_states_noise)[:,1])
#     delays = pick_delay(traj_noise[:,0], np.array(Gate_states_noise)[:,0])
#     delays = pick_traj(traj[:,0], feature[:])
#     delays = pick_delay(traj[:,0], feature[:])

    bifur = np.array([])
    for i in range(20):
        plt.figure()
        for data in delays[1:-1]:
            avedata = moving_average(data[:,0+i])
#             plt.plot(avedata[2:-2],alpha=0.5)
            plt.plot(data[:,0+i]-avedata[2:-2],alpha=0.5)
#             plt.plot(np.array(data)[:,0+i],alpha=0.5)
#             plt.plot(np.array(data)[:,0+i]-np.array(delays[1:-1][0])[:,0+i],alpha=0.5)
#             plt.plot(np.log(data)[:,0+i],alpha=0.5)
#             print(len(data),len(avedata))
        plt.title("neuron#"+str(i+1))
#     print(np.var(bifur),np.median(bifur)-np.mean(bifur))

#     for i in range(20):
#         plt.figure()
#         delays = delays[1:-1]
#         plt.plot(np.array(delays)[,:,0+i],alpha=0.5)
#         plt.title("neuron#"+str(i+1))
# #     print(np.var(bifur),np.median(bifur)-np.mean(bifur))

    for i in range(3):
        data_len = 1000
        delays_samelen = []
        result = 0
        for data in delays[1:-1]:
            data_len = np.min([data_len,len(data)])
        for data in delays[1:-1]:
            delays_samelen.append(data[:data_len])
        delays_samelen = np.array(delays_samelen)
        for k in range(data_len):
            result += np.var(delays_samelen[:,k,i])
        result /= data_len
        print(result)
#     print(np.var(bifur),np.median(bifur)-np.mean(bifur))


    pca = PCA()
    dfs = np.array(PFCstate)[0:,0]
#     dfs = np.array(Restate)[0:]
    pca.fit(dfs)
    PFCfeature = pca.transform(dfs)
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    pca = PCA()
#     dfs = np.array(PFCstate)[0:,0]
    dfs = np.array(Restate)[0:]
    pca.fit(dfs)
    Refeature = pca.transform(dfs)
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
#     plt.figure()
    for i in range(20):
        plt.figure()
#         plt.ylim(-1,1)
#         plt.plot(traj[:,0,0],alpha=0.5)
#         plt.plot(traj_noise[:,0,1],alpha=0.5)
#         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
        for k in range(1):
#             plt.plot(np.array(HPCstate)[:data_len,k,i],alpha=0.8)
#             plt.plot(np.array(HPCstate_noise)[:,k,i],alpha=0.5)
#             plt.plot(np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:],alpha=0.3)
            plt.plot(np.array(HPCstate)[:data_len,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:data_len],alpha=0.8)
#             plt.plot(np.array(PFCstate)[:data_len,k,i],alpha=0.8)
#             plt.plot(np.array(PFCstate_noise)[:,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[:120,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[:120,k,i]-np.array(PFCstate)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:],alpha=0.3)
#             plt.plot(np.array(PFCstate)[:data_len,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:data_len],alpha=0.8)
#             plt.plot(np.array(PFCstate_noise)[:120,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(PFCstate_noise)[:120,k,i]-np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i],alpha=0.5)
#             plt.plot(np.array(Restate_noise)[:,i],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-np.array(Restate_noise)[:,i],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-moving_average(np.array(Restate)[2:-2,i])[:],alpha=0.3)
#             plt.plot(moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
#             plt.plot(moving_average(np.array(Restate)[:,i])-moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-moving_average(np.array(Restate)[:,i])[4:-4],alpha=0.5)
#             plt.plot(np.array(Gate_states)[:,0,0+i],"o",alpha=0.5)
#             plt.plot(np.array(Gate_states)[:,1,0+i],"o",alpha=0.5)
#             plt.plot(np.array(Gate_states_noise)[:,1,0+i],alpha=0.5)
#             plt.plot(np.abs(np.array(Gate_states)[:,1,0+i]-np.array(Gate_states_noise)[:,1,0+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states)[:,1,0+i]/np.array(Gate_states)[:,1,20+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states_noise)[:,1,0+i]/np.array(Gate_states_noise)[:,1,20+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states)[:,0,0+i]/np.array(Gate_states)[:,0,20+i]),alpha=0.5)
#             plt.plot(np.log(np.array(Gate_states_noise)[:,0,0+i]/np.array(Gate_states_noise)[:,0,20+i]),alpha=0.5)
#             plt.plot(np.array(PFCfeature)[:data_len,1],alpha=0.8)
#             plt.plot(np.array(PFCfeature)[:,0]-moving_average(np.array(PFCfeature)[2:-2,0]),alpha=0.5)
#             plt.plot(np.array(Refeature)[:data_len,0],alpha=0.8)
#             plt.plot(np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2],alpha=0.5)
            plt.plot(np.array(Refeature)[:data_len,0]-moving_average(np.array(Refeature)[:data_len,0])[2:-2],alpha=0.8)


        plt.title("neuron#"+str(i+1))
        plt.yticks(fontsize=26)
        plt.xticks(fontsize=26)    
        
#     fig = plt.figure()
#     fig1 = fig.add_subplot(111)
#     fig1.hist(np.array(Gate_states)[:,1,0:20].ravel(),bins=30,range=(-0.1,1.1),alpha=0.5,density=True)
#     plt.show()
    
    pca = PCA()
    dfs = np.array(HPCstate)[0:,0]
#     dfs = np.array(Restate)[0:]
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
#     pred = KMeans(n_clusters=2).fit_predict(feature)

    fig3d = plt.figure()
    ax3d = Axes3D(fig3d)
#     ax3d.plot(feature[:100, 0], feature[:100, 1], feature[:100, 2], alpha=0.8)
#     ax3d.plot(feature[100:, 0], feature[100:, 1], feature[100:, 2], alpha=0.8)
    ax3d.plot(feature[:, 0], feature[:, 1], feature[:, 2], alpha=0.8)
    ax3d.plot(feature[0:1, 0], feature[0:1, 1], feature[0:1, 2],"o", alpha=1)
    plt.show()
    
    plt.figure()
    plt.plot(feature[:240, 0], feature[:240, 1], alpha=0.8)
    plt.show()
    
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    pca = PCA()
    dfs = np.array(Restate_noise)[0:]
    pca.fit(dfs)
    Refeature_noise = pca.transform(dfs)
#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))
    
    
########### FFT Zone ######################    
    
#     for i in range(20):
#         plt.figure()
#         for k in range(1):
#             data = np.array(Restate)[:,i]
# #             data_filt = data
# #             data_filt = highpass(data, samplerate, fp, fs, gpass, gstop)
# #             data_filt = np.array(Restate)[:,i]-moving_average(np.array(Restate)[2:-2,i])[:]
#             data_filt = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
#             fourier = fft(data_filt)
#             freq = np.fft.fftfreq(len(fourier))*1
#             plt.plot(freq,np.abs(fourier))
# #             angle = np.angle(fourier)
# #             angle[np.abs(fourier)<0.01] = 0
# #             plt.plot(np.degrees(angle),"o")
            
# #             data = np.array(Restate_noise)[:,i]
# # #             data_filt = highpass(data, samplerate, fp, fs, gpass, gstop)
# #             data_filt = data
# #             fourier = fft(data_filt)
# #             freq = np.fft.fftfreq(len(fourier))*15
# #             plt.plot(np.log(np.abs(freq)),np.log(np.abs(fourier) ** 2),"o")
# # #             plt.plot(np.degrees(np.angle(fourier)))

#             data = np.array(PFCstate)[:,k,i]
#             data_filt = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
#             fourier = fft(data_filt)
#             freq = np.fft.fftfreq(len(fourier))*1
#             plt.plot(freq,np.abs(fourier))
# #             angle = np.angle(fourier)
# #             angle[np.abs(fourier)<1] = 0
# #             plt.plot(np.degrees(angle),"o")

#         plt.title("neuron#"+str(i+1))
    
#     for i in range(20):
#         plt.figure()
#         for k in range(1):
#             data = np.array(PFCstate)[:,k,i]
#             data_filt = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
#             fourier = fft(data_filt)
#             plt.plot(np.degrees(np.angle(fourier)))
#         plt.title("neuron#"+str(i+1))

    shift = 2
    seglen = 40
        

########### scipy.spectrogram ver ######################
#     for i in range(20):
#         plt.figure()
#         for k in range(1):
# #             data = np.array(PFCstate)[:,k,i]
#             data = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
# #             fourier = np.fft.fft(data)
# #             freq = np.fft.fftfreq(len(fourier))*15
# #             plt.plot(freq,fourier)
#             freqs,times,sx1 = signal.spectrogram(data,fs=1,nperseg=seglen,noverlap=seglen-shift,detrend=False,scaling='spectrum',mode="psd")
#             plt.pcolormesh(times,freqs, (sx1))
#         plt.title("neuron#"+str(i+1))
#         plt.figure()
#         for k in range(1):
# #             data = np.array(Restate)[:,i]
#             data = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
# #             data = np.array(HPCstate_noise)[:,k,i]
# #             fourier = np.fft.fft(data)
# #             freq = np.fft.fftfreq(len(fourier))*15
# #             plt.plot(freq,fourier)
#             freqs,times,sx2 = signal.spectrogram(data,fs=1,nperseg=seglen,noverlap=seglen-shift,detrend=False,scaling='spectrum',mode="psd")
#             plt.pcolormesh(times,freqs,(sx2))
# #             plt.pcolormesh(times,freqs, np.abs(np.log10(sx1)-np.log10(sx2)))
#         plt.title("neuron#"+str(i+1))
#         plt.figure()
#         for k in range(1):
# #             data = np.array(Restate)[:,i]
#             data = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
# #             data = np.array(HPCstate_noise)[:,k,i]
# #             fourier = np.fft.fft(data)
# #             freq = np.fft.fftfreq(len(fourier))*15
# #             plt.plot(freq,fourier)
#             freqs,times,sx3 = signal.spectrogram(data,fs=1,nperseg=seglen,noverlap=seglen-shift,detrend=False,scaling='spectrum',mode="psd")
#             plt.pcolormesh(times,freqs,(sx3))
# #             plt.pcolormesh(times,freqs, np.abs(np.log10(sx1)-np.log10(sx2)))
#         plt.title("neuron#"+str(i+1))

#         plt.figure()
#         for k in range(1):
# #             data = np.array(Restate)[:,i]
#             data1 = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
#             data2 = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
#             freqs,sx3 = signal.coherence(data1,data2,fs=1,nperseg=seglen,noverlap=seglen-shift,detrend=False)
#             plt.plot(freqs,(sx3))
# #             plt.pcolormesh(times,freqs, np.abs(np.log10(sx1)-np.log10(sx2)))
#         plt.title("neuron#"+str(i+1))
    
#         plt.figure()
#         for k in range(1):
#             xsp = sx1*np.conjugate(sx2)
#             plt.pcolormesh(times,freqs,(xsp))
# #             plt.pcolormesh(times,freqs, np.abs(np.log10(sx1)-np.log10(sx2)))
#         plt.title("neuron#"+str(i+1))

# # ########### scipy.stft ver ######################
#     for i in range(20):
# #         plt.figure()
# #         for k in range(1):
# # #             data = np.array(PFCstate)[:,k,i]
# #             data = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
# # #             fourier = np.fft.fft(data)
# # #             freq = np.fft.fftfreq(len(fourier))*15
# # #             plt.plot(freq,fourier)
# #             freqs,times,sx1 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
# #             plt.pcolormesh(times,freqs, np.array(np.abs(sx1)**2))
# # #             print(np.max((np.abs(sx1)**2)))
# #         plt.title("neuron#"+str(i+1))
# #         plt.figure()
# #         for k in range(1):
# # #             data = np.array(Restate)[:,i]
# #             data = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
# # #             data = np.array(HPCstate_noise)[:,k,i]
# # #             fourier = np.fft.fft(data)
# # #             freq = np.fft.fftfreq(len(fourier))*15
# # #             plt.plot(freq,fourier)
# #             freqs,times,sx2 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
# #             plt.pcolormesh(times,freqs,np.array(np.abs(sx2)**2))
# # #             print(np.max((np.abs(sx2)**2)))
# # #             plt.pcolormesh(times,freqs, np.abs(np.log10(sx1)-np.log10(sx2)))
# #         plt.title("neuron#"+str(i+1))
# # #         plt.figure()
# # #         for k in range(1):
# # # #             data = np.array(Restate)[:,i]
# # #             data = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
# # # #             data = np.array(HPCstate_noise)[:,k,i]
# # # #             fourier = np.fft.fft(data)
# # # #             freq = np.fft.fftfreq(len(fourier))*15
# # # #             plt.plot(freq,fourier)
# # #             freqs,times,sx3 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
# # #             plt.pcolormesh(times,freqs,(sx3))
# # # #             plt.pcolormesh(times,freqs, np.abs(np.log10(sx1)-np.log10(sx2)))
# # #         plt.title("neuron#"+str(i+1))

# # #         plt.figure()
# # #         for k in range(1):
# # # #             data = np.array(Restate)[:,i]
# # #             data1 = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
# # #             data2 = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
# # #             freqs,sx3 = signal.coherence(data1,data2,fs=1,nperseg=seglen,noverlap=seglen-shift,detrend=False)
# # #             plt.plot(freqs,(sx3))
# # # #             plt.pcolormesh(times,freqs, np.abs(np.log10(sx1)-np.log10(sx2)))
# # #         plt.title("neuron#"+str(i+1))
    
# #         plt.figure()
# #         for k in range(1):
# #             xsp = sx1*np.conjugate(sx2)
# # #             plt.pcolormesh(times,freqs,np.abs(xsp))
# # #             plt.pcolormesh(times,freqs, np.abs(np.log10(sx1)-np.log10(sx2)))
# #             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx2))
# #             plt.pcolormesh(times,freqs,np.log(np.abs(xsp)**2))
# #             print(np.max(np.log10(np.abs(xsp))))
# # #             plt.pcolormesh(times,freqs,np.array(coherence))
# #         plt.title("neuron#"+str(i+1))
        
        
#         for k in range(1):
#             data = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
#             freqs,times,sx1 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
#             data = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
#             freqs,times,sx2 = signal.stft(data*1,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
#             data = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
#             freqs,times,sx3 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            
#             xsp = sx1*np.conjugate(sx2)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx2))
#             plt.figure()
#             plt.pcolormesh(times,freqs,coherence)
#             plt.title("Re-PFC neuron#"+str(i+1))
# #             print(np.max(coherence),np.min(coherence))
            
# #             plt.figure()
# #             plt.plot(times, np.sum(coherence.T,axis=1))
            
# #             xsp = sx1*np.conjugate(sx3)
# #             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx3))            
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("Re-HPC neuron#"+str(i+1))
# #             print(np.max(coherence),np.min(coherence))

            
# #             xsp = sx2*np.conjugate(sx3)
# #             coherence = (np.abs(xsp)**2)/(np.abs(sx2)*np.abs(sx3))
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("PFC-HPC neuron#"+str(i+1))
            
# #             degree = np.degrees(np.angle(xsp))
# #             print(degree.shape)
# #             plt.figure()
# #             plt.plot(degree)

#             plt.figure()
#             plt.plot(times, np.sum(coherence.T,axis=1))
#             print(np.max(np.sum(coherence.T,axis=1))-np.min(np.sum(coherence.T,axis=1)))
            


#         for k in range(1):
#             data = np.array(Refeature_noise)[:,0]-moving_average(np.array(Refeature_noise)[:,0])[2:-2]
#             freqs,times,sx1 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
#             data = np.array(PFCstate_noise)[:,k,i]-moving_average(np.array(PFCstate_noise)[2:-2,k,i])[:]
#             freqs,times,sx2 = signal.stft(data*1,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
#             data = np.array(HPCstate_noise)[:,k,i]-moving_average(np.array(HPCstate_noise)[2:-2,k,i])[:]
#             freqs,times,sx3 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            
#             xsp = sx1*np.conjugate(sx2)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx2))
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("Re-PFC neuron#"+str(i+1))
# #             print(np.max(coherence),np.min(coherence))

# #             xsp = sx1*np.conjugate(sx3)
# #             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx3))            
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("Re-HPC neuron#"+str(i+1))
# #             print(np.max(coherence),np.min(coherence)) 
            
# #             plt.figure()
#             plt.plot(times, np.sum(coherence.T,axis=1))
#             print(data.shape,freqs,times)
#             print(np.max(np.sum(coherence.T,axis=1))-np.min(np.sum(coherence.T,axis=1)))
                

    
    
if __name__ == '__main__':
    features = []
    for i in range(1):
        features.append(main(i+1))

In [None]:
##########    Test for mixture weight!!!!!!!!!!!!!   #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import glob
import scipy

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniPFC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniHPC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#         else:
#             traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#             traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
#                 target = np.array([[0.5,0] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def plot_distance_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    plt.vlines(linelist,np.min(result),np.max(result))
    
    return result

def distance_bet2traj(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    return result

def plot_activity_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    threshold = 0.5
    for i in range(traj1.shape[0]):
        result.append(np.abs(traj1[i]-traj2[i]))
        
    plt.figure()
    for i,data in enumerate(np.array(result).T):
        if np.any(data[20:105]>threshold):
            print(i)
            plt.plot(data)
    
    return result

def search_delay(traj):
    linelist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if traj[i,1] > 0.45 and flag == False:
            linelist = np.append(linelist,i)
            flag = True
        if traj[i,1] < 0.45 and flag == True:
            linelist = np.append(linelist,i)
            flag = False
    return linelist

def pick_delay(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
        start = int(k)
    states_list.append(traj[start:])
    return states_list

def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(traj[start:])
    return states_list

def match_length(states_list):
    min_length = 1000
    result = []
    for state in states_list:
        min_length = np.min((min_length, len(state)))
    for state in states_list:
        result.append(state[-min_length:])
        
    return result

def vec_var(datas):
    datas = np.array(datas)
    average = np.average(datas,axis=0)
    result = 0
    for data in datas:
        result += np.linalg.norm(data-average)
    result /= datas.shape[0]
    return result
    
def lyapunov_exp(data):
    result = np.mean(np.log(np.abs(np.diff(data))))
    return result

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0.2
        var = 0.2
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
#         HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
#         PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.1
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens


def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq    


def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 480
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
#             output,hidden = rnn(output,hidden)
#             traj.append(output.tolist())
#             PFCstate.append(hidden[0][0].tolist())
#             HPCstate.append(hidden[1][0].tolist())
#             Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    pca = PCA()
#     dfs = np.array(HPCstate)[:,0]
    dfs = np.array(Restate)
    pca.fit(dfs)
    Refeature = pca.transform(dfs)

    shift = 2
    seglen = 60
    
    plt.figure()
    plt.plot(np.array(PFCstate)[:,0,2]-moving_average(np.array(PFCstate)[2:-2,0,2])[:])
    plt.plot([40,120,160,240,280,360,400,480],np.zeros(8),"o",color="k",alpha=0.3)
    
    plt.figure()
    plt.plot(np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2])
    plt.plot([40,120,160,240,280,360,400,480],np.zeros(8),"o",color="k",alpha=0.3)
        
    coherence_diff_list = []
    for i in range(20):
        for k in range(1):
            data = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
            freqs,times,sx1 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            data = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
            freqs,times,sx2 = signal.stft(data*1,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            data = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
            freqs,times,sx3 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            
            xsp = sx1*np.conjugate(sx2)
            coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx2))
            plt.figure()
            plt.pcolormesh(times,freqs,coherence)
            plt.title("Re-PFC neuron#"+str(i+1))
            print(np.max(coherence),np.min(coherence))
            
#             plt.figure()
#             plt.plot(times, np.sum(coherence.T,axis=1))
            
#             xsp = sx1*np.conjugate(sx3)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx3))            
#             plt.figure()
#             plt.pcolormesh(times,freqs,coherence)
#             plt.title("Re-HPC neuron#"+str(i+1))
#             print(np.max(coherence),np.min(coherence))

            
#             xsp = sx2*np.conjugate(sx3)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx2)*np.abs(sx3))
#             plt.figure()
#             plt.pcolormesh(times,freqs,coherence)
#             plt.title("PFC-HPC neuron#"+str(i+1))
            
#             degree = np.degrees(np.angle(xsp))
#             print(degree.shape)
#             plt.figure()
#             plt.plot(degree)

            plt.figure()
            plt.plot(times, np.sum(coherence.T,axis=1))
            coherence_diff_list.append(np.max(np.sum(coherence.T,axis=1))-np.min(np.sum(coherence.T,axis=1)))
            plt.plot([40,120,160,240,280,360,400,480],np.zeros(8),"o",color="k",alpha=0.3)
    
    
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),frac_amp
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]),HPC_d
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, frac_amp
    return 0,0,0,0, np.sum(coherence_diff_list)
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1

    allratio_list = []
    for num in range(1):
        for i in range(1):
#             path = 'model/R20_H_bigbatch/'
            path = 'model/R20_H_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+3)+'_100_'+str(num+1)+'_*epoch150.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            ratio_list = []
            ratio_list_max = []
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()
#             good_list = []
#             if i+1 == 4 and num+1 == 3:
#                 continue
#             if i+1 == 5 and num+1 == 2:
#                 continue
            
            first_goodmodel = [0,0]
            good_flag = False
            k=0
            for model in model_list:
                print(model)
                if int(model.split("epoch")[-1].split(".")[0])>299:
                    continue
    #             PFC,HPC = main(model)
#                 PFC,HPC,PFC_max,HPC_max = main(model)
                PFC,HPC,PFC_max,HPC_max,frac = main(model)
    #             ratio_list.append(PFC/HPC)
                ratio_list.append(np.abs(PFC-HPC))
    #             ratio_list_max.append(PFC_max/HPC_max)
#                 ratio_list_max.append(np.abs(PFC_max-HPC_max))
#                 ratio_list_max.append(HPC_max)
                ratio_list_max.append(frac)
#                 ratio_list_max.append(PFC_max)
#                 correlation_fig.plot(PFC_max,HPC_max,"o")
#                 if model in good_list:
#                     correlation_fig.plot(PFC_max,HPC_max,"o",color="b")
#                 else:
#                     correlation_fig.plot(PFC_max,HPC_max,"o",color="r")
    #             correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
    #             correlation_fig.plot(k*5,PFC/HPC,"o")
        #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
                print(PFC_max,HPC_max)
                print("dist:"+str(frac))
                if good_flag != True and model in good_list:
                    good_flag = True
                    first_goodmodel[0] = int(model.split("epoch")[-1].split(".")[0])
    #                 first_goodmodel[1] = ratio_list[-1]
                    first_goodmodel[1] = ratio_list_max[-1]
                k+=1
    #         correlation_fig.plot(np.arange(0,200,5),np.array(ratio_list)-np.mean(ratio_list),"o")
    #         ratio_list = np.array(ratio_list).clip(-2,2)
    #         ratio_list = moving_average(ratio_list)[2:-2]
    #         correlation_fig.plot(np.arange(0,len(ratio_list)*5,5),np.array(ratio_list))

#             ratio_list_max = np.array(ratio_list_max).clip(-2,2)
#             ratio_list_max = moving_average(ratio_list_max)[4:-2]
#             correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
#             if good_flag == True:
#                 correlation_fig.plot(first_goodmodel[0],first_goodmodel[1],"o",color="C{}".format(i))
            allratio_list.append(np.array(ratio_list_max))
    correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.average(np.array(allratio_list), axis=0),color="b")
#     np.save("Re_frac.npy",np.mean(np.array(allratio_list),axis=0))
#     np.save("UniHPC_HPCmaxvar.npy",np.var(np.array(allratio_list),axis=0))

        

In [None]:
##########    Checker of Activity difference and grad   #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import glob
import scipy

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniPFC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniHPC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#         else:
#             traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#             traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
#                 target = np.array([[0.5,0] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.0001):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def plot_distance_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    plt.vlines(linelist,np.min(result),np.max(result))
    
    return result

def distance_bet2traj(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    return result

def plot_activity_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    threshold = 0.5
    for i in range(traj1.shape[0]):
        result.append(np.abs(traj1[i]-traj2[i]))
        
    plt.figure()
    for i,data in enumerate(np.array(result).T):
        if np.any(data[20:105]>threshold):
            print(i)
            plt.plot(data)
    
    return result

def search_delay(traj):
    linelist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if traj[i,1] > 0.45 and flag == False:
            linelist = np.append(linelist,i)
            flag = True
        if traj[i,1] < 0.45 and flag == True:
            linelist = np.append(linelist,i)
            flag = False
    return linelist

def pick_delay(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
        start = int(k)
    states_list.append(traj[start:])
    return states_list

def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(traj[start:])
    return states_list

def match_length(states_list):
    min_length = 1000
    result = []
    for state in states_list:
        min_length = np.min((min_length, len(state)))
    for state in states_list:
        result.append(state[-min_length:])
        
    return result

def vec_var(datas):
    datas = np.array(datas)
    average = np.average(datas,axis=0)
    result = 0
    for data in datas:
        result += np.linalg.norm(data-average)
    result /= datas.shape[0]
    return result
    
def lyapunov_exp(data):
    result = np.mean(np.log(np.abs(np.diff(data))))
    return result

def culc_grad(data):
    datalen = len(data)
    grad_list = np.array([])
    pre_value = data[0]
    for i in range(datalen):
        grad = np.abs(data[i] - pre_value)
        grad_list = np.append(grad_list,grad)
        pre_value = data[i]
        
    return grad_list

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0.2
        var = 0.2
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
#         HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
#         PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.1
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens


def moving_average(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y)
    return mean_seq    


def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 3
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 120
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
#             output,hidden = rnn(output,hidden)
#             traj.append(output.tolist())
#             PFCstate.append(hidden[0][0].tolist())
#             HPCstate.append(hidden[1][0].tolist())
#             Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    pattern = 3
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    dividenum = int(np.array(PFCstate)[0:,0].shape[0]/2)
    traj_noise = []
    PFCstate_noise = []
    HPCstate_noise = []
    Restate_noise = []
    Gate_states_noise = []
    hidden = rnn.initHidden_rand()
#     data = mkOwnRandomBatch(train_y, batch_size)
#     init_point = torch.rand(10,2)*1
#     data_limit = 125
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
# #             hidden = rnn.noiseHidden_rand(hidden)
#             output,hidden = rnn(output,hidden) 
# #             hidden = rnn.noiseHidden_dis(hidden, np.array(Restate)[k+data_limit], np.array(Restate)[dividenum+k+data_limit])
# #             hidden = rnn.noiseHidden_rand(hidden)
#             traj_noise.append(output.tolist())
#             PFCstate_noise.append(hidden[0][0].tolist())
#             HPCstate_noise.append(hidden[1][0].tolist())
#             Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
    traj_noise = torch.tensor(traj_noise)
    traj_noise = torch.squeeze(traj_noise).numpy()

    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")
    
    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
    dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
#     dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("PFC correlation")
    PFC_activity = np.abs(feature[:data_limit,1]-feature[data_limit:,1])
#     PFC_activity = np.abs(moving_average(feature[:data_limit,0])-moving_average(feature[data_limit:,0]))
#     PFC_activity = np.abs((feature[:data_limit,0]-moving_average(feature[:data_limit,0])[2:-2])-(feature[data_limit:,0]-moving_average(feature[data_limit:,0])[2:-2]))
#     PFC_activity = culc_grad(moving_average(feature[:data_limit,0]))
    
#     plt.figure()
#     plt.plot(feature[:data_limit,0],color="b")
#     plt.plot(feature[data_limit:,0],color="r")
#     plt.plot(moving_average(feature[:data_limit,0]),":",color="b")
#     plt.plot(moving_average(feature[data_limit:,0]),":",color="r")
#     plt.show()
    
    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("HPC correlation")
#     HPC_activity = np.abs(feature[:data_limit,0]-feature[data_limit:,0])
    HPC_activity = culc_grad(moving_average(feature[:data_limit,0]))


    pca = PCA()
#     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
    dfs = np.concatenate((np.array(Restate)[:],np.array(Restate_noise)[:]),axis=0)
    print(dfs.shape)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    print("Re correlation")
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
#     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
    Re_activity = np.abs(feature[:data_limit,0]-feature[data_limit:,0])
#     Re_activity = np.abs(moving_average(feature[:data_limit,0])-moving_average(feature[data_limit:,0]))
#     Re_activity = np.abs((feature[:data_limit,0]-moving_average(feature[:data_limit,0])[2:-2])-(feature[data_limit:,0]-moving_average(feature[data_limit:,0])[2:-2]))
#     Re_activity = culc_grad(moving_average(feature[:data_limit,0]))
    
#     plt.figure()
#     plt.plot(feature[:data_limit,0],color="b")
#     plt.plot(feature[data_limit:,0],color="r")
#     plt.plot(moving_average(feature[:data_limit,0]),":",color="b")
#     plt.plot(moving_average(feature[data_limit:,0]),":",color="r")
#     plt.show()


    
    return PFC_activity, HPC_activity, Re_activity
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]),HPC_d
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, frac_amp
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, np.sum(coherence_diff_list)
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1

    allRe_list = []
    allHPC_list = []
    allPFC_list = []
    for num in range(3):
        for i in range(1):
            path = 'model/R20_H_bigbatch/'
#             path = 'model/R20_H_uniHPC_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+1)+'_100_'+str(num+1)+'_*epoch*.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            Re_list = []
            ratio_list_max = []
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()
#             good_list = []
#             if i+1 == 4 and num+1 == 3:
#                 continue
#             if i+1 == 5 and num+1 == 2:
#                 continue
            
            first_goodmodel = [0,0]
            good_flag = False
            k=0
            for model in model_list:
#                 if model in good_list:
#                     pass
#                 else:
#                     continue
                print(model)
                if int(model.split("epoch")[-1].split(".")[0])>299:
                    continue
                PFC_activity, HPC_activity, Re_activity = main(model)
                allRe_list.append(Re_activity)
                allHPC_list.append(HPC_activity)
                allPFC_list.append(PFC_activity)

                k+=1
#             correlation_fig.plot(Re_activity,color="C{}".format(i))
    allRe_list = np.array(allRe_list)
    allHPC_list = np.array(allHPC_list)
    allPFC_list = np.array(allPFC_list)
    print(allRe_list.shape)
    correlation_fig.plot(np.mean(allRe_list,axis=0).T,color="C0")
    correlation_fig.errorbar(np.arange(allRe_list.shape[1]),np.mean(allRe_list,axis=0).T,yerr=np.sqrt(np.var(allRe_list,axis=0).T), color="C0", alpha=0.3)
#     correlation_fig.plot(np.mean(allHPC_list,axis=0).T,color="C1")
#     correlation_fig.errorbar(np.arange(allHPC_list.shape[1]),np.mean(allHPC_list,axis=0).T,yerr=np.sqrt(np.var(allHPC_list,axis=0).T), color="C1", alpha=0.3)
#     correlation_fig.plot(np.mean(allPFC_list,axis=0).T,color="C2")
#     correlation_fig.errorbar(np.arange(allPFC_list.shape[1]),np.mean(allPFC_list,axis=0).T,yerr=np.sqrt(np.var(allPFC_list,axis=0).T), color="C2", alpha=0.3)

#             allratio_list.append(np.array(ratio_list_max))

        

In [None]:
##########  Coherence and Beta Anova  #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import glob
import scipy

%matplotlib notebook

def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 480
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
#             output,hidden = rnn(output,hidden)
#             traj.append(output.tolist())
#             PFCstate.append(hidden[0][0].tolist())
#             HPCstate.append(hidden[1][0].tolist())
#             Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    pca = PCA()
#     dfs = np.array(HPCstate)[:,0]
    dfs = np.array(Restate)
    pca.fit(dfs)
    Refeature = pca.transform(dfs)

    shift = 2
    seglen = 60
        
    coherence_diff_list = []
    for i in range(20):
        for k in range(1):
            data = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
            freqs,times,sx1 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            data = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
            freqs,times,sx2 = signal.stft(data*1,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            data = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
            freqs,times,sx3 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            
#             xsp = sx1*np.conjugate(sx2)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx2))
#             plt.figure()
#             plt.pcolormesh(times,freqs,coherence)
#             plt.title("Re-PFC neuron#"+str(i+1))
#             print(np.max(coherence),np.min(coherence))
            
#             plt.figure()
#             plt.plot(times, np.sum(coherence.T,axis=1))
            
            xsp = sx1*np.conjugate(sx3)
            coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx3))            
#             plt.figure()
#             plt.pcolormesh(times,freqs,coherence)
#             plt.title("Re-HPC neuron#"+str(i+1))
#             print(np.max(coherence),np.min(coherence))

            
#             xsp = sx2*np.conjugate(sx3)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx2)*np.abs(sx3))
#             plt.figure()
#             plt.pcolormesh(times,freqs,coherence)
#             plt.title("PFC-HPC neuron#"+str(i+1))
            
#             degree = np.degrees(np.angle(xsp))
#             print(degree.shape)
#             plt.figure()
#             plt.plot(degree)

#             plt.figure()
#             plt.plot(times, np.sum(coherence.T,axis=1))
            coherence_diff_list.append(np.max(np.sum(coherence.T,axis=1))-np.min(np.sum(coherence.T,axis=1)))
    
    
    ######### beta params #############
    data = np.array(Gate_states)[:,0,0:20].ravel()
    try:
        beta_PFC = scipy.stats.beta.fit(data, floc=0)
        print(beta_PFC)
        beta_param = np.mean([beta_PFC[0],beta_PFC[1]])
    except Exception:
        print("Error: maybe takes negative a or b")
        beta_param = 0.5
    
    
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),frac_amp
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]),HPC_d
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, frac_amp
#     return 0,0,0,0, np.sum(coherence_diff_list)
    return 0,0,0,0, beta_param
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1

    allratio_list = []
    good_points = []
    bad_points = []
    for num in range(3):
        for i in range(10):
            path = 'model/R20_H_stopinit_bigbatch/'
#             path = 'model/R20_H_uniHPC_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+1)+'_100_'+str(num+1)+'_*epoch???.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            ratio_list = []
            ratio_list_max = []
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()
#             good_list = []
#             if i+1 == 4 and num+1 == 3:
#                 continue
#             if i+1 == 5 and num+1 == 2:
#                 continue
            
            first_goodmodel = [0,0]
            good_flag = False
            goodmodel_list = []
            k=0
            for model in model_list:
                print(model)
#                 if int(model.split("epoch")[-1].split(".")[0])<199:
#                     continue
    #             PFC,HPC = main(model)
#                 PFC,HPC,PFC_max,HPC_max = main(model)
                PFC,HPC,PFC_max,HPC_max,frac = main(model)
    #             ratio_list.append(PFC/HPC)
                ratio_list.append(np.abs(PFC-HPC))
    #             ratio_list_max.append(PFC_max/HPC_max)
#                 ratio_list_max.append(np.abs(PFC_max-HPC_max))
#                 ratio_list_max.append(HPC_max)
                ratio_list_max.append(frac)
#                 ratio_list_max.append(PFC_max)
#                 correlation_fig.plot(PFC_max,HPC_max,"o")
#                 if model in good_list:
#                     correlation_fig.plot(int(model.split("epoch")[-1].split(".")[0]),frac,"o",color="b")
#                 else:
#                     correlation_fig.plot(int(model.split("epoch")[-1].split(".")[0]),frac,"o",color="r")
                if model in good_list:
                    goodmodel_list.append(k)
#                     good_points.append(frac)
                else:
#                     bad_points.append(frac)
                    pass
                k+=1
    #         correlation_fig.plot(np.arange(0,200,5),np.array(ratio_list)-np.mean(ratio_list),"o")
    #         ratio_list = np.array(ratio_list).clip(-2,2)
    #         ratio_list = moving_average(ratio_list)[2:-2]
    #         correlation_fig.plot(np.arange(0,len(ratio_list)*5,5),np.array(ratio_list))

#             ratio_list_max = np.array(ratio_list_max).clip(-2,2)
#             ratio_list_max = moving_average(ratio_list_max)[4:-2]
#             correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
            
            length = len(ratio_list_max)
            print(goodmodel_list)
            ratio_list_max = np.array(ratio_list_max)
#             ratio_list_max = ratio_list_max/np.median(ratio_list_max)
#             correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
            for l in range(length):
                if np.any(np.array(goodmodel_list)==l):
                    correlation_fig.plot(0,ratio_list_max[l],"o",color="b", alpha=0.3)
                    good_points.append(ratio_list_max[l])
                else:
                    correlation_fig.plot(1,ratio_list_max[l],"o",color="r", alpha=0.3)
                    bad_points.append(ratio_list_max[l])
                
#             if good_flag == True:
#                 correlation_fig.plot(first_goodmodel[0],first_goodmodel[1],"o",color="C{}".format(i))
            allratio_list.append(np.array(ratio_list_max))
#     correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.average(np.array(allratio_list), axis=0),color="b")
#     np.save("Re_frac.npy",np.mean(np.array(allratio_list),axis=0))
#     np.save("UniHPC_HPCmaxvar.npy",np.var(np.array(allratio_list),axis=0))
    print(scipy.stats.f_oneway(good_points, bad_points))
    
    print(st.ttest_ind(good_points, bad_points,equal_var=False))
    plt.figure()
    plt.boxplot([good_points, bad_points],labels=["Good","Bad"])

        

In [None]:
##########    Test for mixture weight!!!!!!!!!!!!!   #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import glob
import scipy

%matplotlib notebook


class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0.2
        var = 0.2
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
#         HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
#         PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.1
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens

def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 3
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w = torch.clone(p.data)
                   
    return np.array(HPC_w)
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])


if __name__ == '__main__':
    fig = plt.figure()
    allratio_list = []
    w_list = []
    w_input = []
    w_cell = []
    for num in range(3):
        for i in range(10):
#             path = 'model/R20_H_bigbatch/'
            path = 'model/R20_H_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+1)+'_100_'+str(num+1)+'_*epoch1??.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()

            for model in model_list:
#                 if model in good_list:
#                     print(model)
#                 else:
#                     continue
                weight = main(model)
                w_list.append(np.array(weight))
                w_input.append(np.mean(np.abs(weight)[0:20,2:]))
                w_cell.append(np.mean(np.abs(weight)[40:60,2:]))
#                 w_input.append(np.mean(np.abs(weight)[0:20,20:]))
#                 w_cell.append(np.mean(np.abs(weight)[40:60,20:]))
    plt.imshow(np.mean(w_list,axis=0),cmap="coolwarm")
#     print(np.mean(np.mean(w_list,axis=0)[0:20,2:]),np.mean(np.mean(w_list,axis=0)[40:60,2:]))
    print(np.mean(w_input),np.mean(w_cell))
    print(scipy.stats.f_oneway(w_input, w_cell))
    print(st.ttest_ind(w_input, w_cell,equal_var=False))
    plt.figure()
    plt.plot(np.zeros(len(w_input)),w_input,"o")
    plt.plot(np.ones(len(w_cell)),w_cell,"o")


In [None]:
##########  Coherence and Beta Anova  #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import glob
import scipy
from scipy.cluster import hierarchy as hier

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]


def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w = np.array(p.data)
            if n == "PFC.weight_hh":
                PFC_inw = np.array(p.data)
            if n == "HPC.weight_ih":
                HPC_w = np.array(p.data)
            if n == "HPC.weight_hh":
                HPC_inw = np.array(p.data)
            if n == "Re.weight_ih":
                Re_w = np.array(p.data)
            if n == "Re.weight_hh":
                Re_inw = np.array(p.data)
            if n == "PFC.bias_ih":
                PFC_b = np.array(p.data)
            if n == "PFC.bias_hh":
                PFC_inb = np.array(p.data)
            if n == "HPC.bias_ih":
                HPC_b = np.array(p.data)
            if n == "HPC.bias_hh":
                HPC_inb = np.array(p.data)
            if n == "Re.bias_ih":
                Re_b = np.array(p.data)
            if n == "Re.bias_hh":
                Re_inb = np.array(p.data)
            if n == "linear.weight":
                Out_w = np.array(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 120
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
#             output,hidden = rnn(output,hidden)
#             traj.append(output.tolist())
#             PFCstate.append(hidden[0][0].tolist())
#             HPCstate.append(hidden[1][0].tolist())
#             Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    

#     correlation_matrix_HPC = np.corrcoef(np.array(HPCstate)[:,0])
    correlation_matrix_PFC = np.corrcoef(np.array(PFCstate)[:,0])
    
    dend = hier.linkage(np.array(PFCstate)[:,0])
    clusters = hier.fcluster(dend,0.6,criterion="distance")
    print(np.max(clusters))
    
    
    
#     plt.figure()
#     hier.dendrogram(dend)
#     plt.imshow(correlation_matrix_PFC,vmax=1,vmin=-1)
    
#     all_uw,all_uv = LA.eig(HPC_inw[0:20]*HPC_inw[40:60])
    all_uw,all_uv = LA.eig(PFC_inw[0:20]*PFC_inw[40:60])


#     pca = PCA()
#     dfs = np.array(HPCstate)[0:,0]
# #     dfs = np.array(Restate)
#     pca.fit(dfs)
#     feature = pca.transform(dfs)
    
#     delaysB = pick_delay(traj[:,0], feature[:])

#     PCAnum = 0
#     data = feature[:,PCAnum]
#     movingA = moving_average(data)
#     fracA = data - movingA[2:-2]
#     print(math.dist(data,movingA[2:-2]))
#     frac_amp = math.dist(data,movingA[2:-2])
    

    
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),frac_amp
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]),HPC_d
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, frac_amp
#     return 0,0,0,0, np.sum(coherence_diff_list)
#     return np.mean(abs(all_uw)), np.var(correlation_matrix_HPC)
#     return np.mean(abs(all_uw)), np.var(correlation_matrix_PFC)
    return np.mean(abs(all_uw)), np.max(clusters)
#     return np.mean(abs(all_uw)), frac_amp
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1

    allratio_list = []
    good_points = []
    bad_points = []
    for num in range(3):
        for i in range(1):
            path = 'model/R20_H_bigbatch/'
#             path = 'model/R20_H_uniHPC_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+9)+'_100_'+str(num+1)+'_*epoch*.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            ratio_list = []
            ratio_list_max = []
            eigen_list = []
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()
#             good_list = []
#             if i+1 == 4 and num+1 == 3:
#                 continue
#             if i+1 == 5 and num+1 == 2:
#                 continue
            
            first_goodmodel = [0,0]
            good_flag = False
            goodmodel_list = []
            k=0
            for model in model_list:
                print(model)
#                 if int(model.split("epoch")[-1].split(".")[0])<199:
#                     continue
    #             PFC,HPC = main(model)
#                 PFC,HPC,PFC_max,HPC_max = main(model)
                eigen_mean,corr_var = main(model)

                ratio_list_max.append(corr_var)
                eigen_list.append(eigen_mean)
#                 correlation_fig.plot(eigen_mean,corr_var,"o")

                k+=1

            ratio_list_max = np.array(ratio_list_max)
            allratio_list.append(np.array(ratio_list_max))
            correlation_fig.plot(eigen_list,ratio_list_max,"o")
            print(scipy.stats.pearsonr(eigen_list,ratio_list_max))
#     correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.average(np.array(allratio_list), axis=0),color="b")
#     np.save("Re_frac.npy",np.mean(np.array(allratio_list),axis=0))
#     np.save("UniHPC_HPCmaxvar.npy",np.var(np.array(allratio_list),axis=0))
        

In [None]:
##########  Coherence and Beta Anova  #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import glob
import scipy

%matplotlib notebook


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
#                 target = np.array([[0.5,0] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.0001):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 480
    est_length = 0
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
#     for k in range(data.shape[0]*est_length):
#             output,hidden = rnn(output,hidden)
#             traj.append(output.tolist())
#             PFCstate.append(hidden[0][0].tolist())
#             HPCstate.append(hidden[1][0].tolist())
#             Restate.append(hidden[2][0].tolist())
#             Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    pca = PCA()
#     dfs = np.array(HPCstate)[:,0]
    dfs = np.array(Restate)
    pca.fit(dfs)
    Refeature = pca.transform(dfs)

    shift = 2
    seglen = 60
        
    coherence_diff_list = []
    for i in range(20):
        for k in range(1):
            data = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
            freqs,times,sx1 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            data = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
            freqs,times,sx2 = signal.stft(data*1,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            data = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
            freqs,times,sx3 = signal.stft(data,fs=1,window="tukey",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            
#             xsp = sx1*np.conjugate(sx2)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx2))
#             plt.figure()
#             plt.pcolormesh(times,freqs,coherence)
#             plt.title("Re-PFC neuron#"+str(i+1))
#             print(np.max(coherence),np.min(coherence))
            
#             plt.figure()
#             plt.plot(times, np.sum(coherence.T,axis=1))
            
            xsp = sx1*np.conjugate(sx3)
            coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx3))            
#             plt.figure()
#             plt.pcolormesh(times,freqs,coherence)
#             plt.title("Re-HPC neuron#"+str(i+1))
#             print(np.max(coherence),np.min(coherence))

            
#             xsp = sx2*np.conjugate(sx3)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx2)*np.abs(sx3))
#             plt.figure()
#             plt.pcolormesh(times,freqs,coherence)
#             plt.title("PFC-HPC neuron#"+str(i+1))
            
#             degree = np.degrees(np.angle(xsp))
#             print(degree.shape)
#             plt.figure()
#             plt.plot(degree)

#             plt.figure()
#             plt.plot(times, np.sum(coherence.T,axis=1))
            coherence_diff_list.append(np.max(np.sum(coherence.T,axis=1))-np.min(np.sum(coherence.T,axis=1)))
    
    
#     ######### beta params #############
#     data = np.array(Gate_states)[:,0,0:20].ravel()
#     try:
#         beta_PFC = scipy.stats.beta.fit(data, floc=0)
#         print(beta_PFC)
#         beta_param = np.mean([beta_PFC[0],beta_PFC[1]])
#     except Exception:
#         print("Error: maybe takes negative a or b")
#         beta_param = 0.5
    
    all_uw,all_uv = LA.eig(Re_inw)
#     all_uw,all_uv = LA.eig(PFC_inw[0:20]*PFC_inw[40:60])
    
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),frac_amp
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]),HPC_d
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, frac_amp
    return 0,0,0,0, np.sum(coherence_diff_list), np.mean(abs(all_uw))
#     return 0,0,0,0, beta_param
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])


if __name__ == '__main__':
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1

    allratio_list = []
    good_points = []
    bad_points = []
    for num in range(3):
        for i in range(10):
            path = 'model/R20_H_bigbatch/'
#             path = 'model/R20_H_uniHPC_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+1)+'_100_'+str(num+1)+'_*epoch*.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            ratio_list = []
            eigen_list = []
            ratio_list_max = []
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()
#             good_list = []
#             if i+1 == 4 and num+1 == 3:
#                 continue
#             if i+1 == 5 and num+1 == 2:
#                 continue
            
            first_goodmodel = [0,0]
            good_flag = False
            goodmodel_list = []
            k=0
            for model in model_list:
                print(model)
#                 if int(model.split("epoch")[-1].split(".")[0])<199:
#                     continue
    #             PFC,HPC = main(model)
#                 PFC,HPC,PFC_max,HPC_max = main(model)
                PFC,HPC,PFC_max,HPC_max,frac,eigen = main(model)
    #             ratio_list.append(PFC/HPC)
                ratio_list.append(np.abs(PFC-HPC))
    #             ratio_list_max.append(PFC_max/HPC_max)
#                 ratio_list_max.append(np.abs(PFC_max-HPC_max))
#                 ratio_list_max.append(HPC_max)
                ratio_list_max.append(frac)
#                 ratio_list_max.append(PFC_max)
                eigen_list.append(eigen)
#                 correlation_fig.plot(PFC_max,HPC_max,"o")
#                 if model in good_list:
#                     correlation_fig.plot(int(model.split("epoch")[-1].split(".")[0]),frac,"o",color="b")
#                 else:
#                     correlation_fig.plot(int(model.split("epoch")[-1].split(".")[0]),frac,"o",color="r")
                if model in good_list:
                    goodmodel_list.append(k)
#                     good_points.append(frac)
                else:
#                     bad_points.append(frac)
                    pass
                k+=1
    #         correlation_fig.plot(np.arange(0,200,5),np.array(ratio_list)-np.mean(ratio_list),"o")
    #         ratio_list = np.array(ratio_list).clip(-2,2)
    #         ratio_list = moving_average(ratio_list)[2:-2]
    #         correlation_fig.plot(np.arange(0,len(ratio_list)*5,5),np.array(ratio_list))

#             ratio_list_max = np.array(ratio_list_max).clip(-2,2)
#             ratio_list_max = moving_average(ratio_list_max)[4:-2]
#             correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
            
            length = len(ratio_list_max)
            print(goodmodel_list)
            ratio_list_max = np.array(ratio_list_max)
#             ratio_list_max = ratio_list_max/np.median(ratio_list_max)
#             correlation_fig.plot(eigen_list,np.array(ratio_list_max),"o",color="C{}".format(i))
            for l in range(length):
                if np.any(np.array(goodmodel_list)==l):
                    correlation_fig.plot(eigen_list[l],ratio_list_max[l],"o",color="b", alpha=0.3)
                    good_points.append(ratio_list_max[l])
                else:
                    correlation_fig.plot(eigen_list[l],ratio_list_max[l],"o",color="r", alpha=0.3)
                    bad_points.append(ratio_list_max[l])
                
#             if good_flag == True:
#                 correlation_fig.plot(first_goodmodel[0],first_goodmodel[1],"o",color="C{}".format(i))
            allratio_list.append(np.array(ratio_list_max))
#     correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.average(np.array(allratio_list), axis=0),color="b")
#     np.save("Re_frac.npy",np.mean(np.array(allratio_list),axis=0))
#     np.save("UniHPC_HPCmaxvar.npy",np.var(np.array(allratio_list),axis=0))
    print(scipy.stats.f_oneway(good_points, bad_points))
    
    print(st.ttest_ind(good_points, bad_points,equal_var=False))
    plt.figure()
    plt.boxplot([good_points, bad_points],labels=["Good","Bad"])

        

# Insert PCA or low-pass data

In [None]:
##########   Insert PCA or low-pass data  #########################
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import sklearn 
import nolds
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from mpl_toolkits.mplot3d import Axes3D
from scipy.special import expit
import pandas as pd
import glob
import scipy

%matplotlib notebook

def Culc_gate(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniPFC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = hiddens[1][0][0]
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]

def Culc_gate_uniHPC(input, params, hiddens):
    Re = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
    #Re = torch.zeros(10, 20)
    HPC_input = torch.cat([input,hiddens[2]],dim=1)
    PFC_input = torch.cat([hiddens[1][0][0],hiddens[2][0]])
    PFC_gates = params[0]@PFC_input+params[1]@hiddens[0][1][0]+params[2]+params[3]
    HPC_gates = params[4]@HPC_input[0]+params[5]@hiddens[1][1][0]+params[6]+params[7]
    PFC_gates = torch.sigmoid(PFC_gates)
    HPC_gates = torch.sigmoid(HPC_gates)
    
    return [PFC_gates.tolist(), HPC_gates.tolist()]
    
    

def MakeAnimation(data_x,data_y,traj_x,traj_y, data_limit):
    fig, ax = plt.subplots()
    ims = []
    for t in range(traj_x.shape[0]):
        if t < data_limit:
            traj = ax.plot(data_x[:t], data_y[:t], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#         else:
#             traj = ax.plot(data_x[:data_limit], data_y[:data_limit], "b", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or", data_x[data_limit-1:t], data_y[data_limit-1:t], "--b")
        else:
            traj = ax.plot(traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
#             traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob", traj_x[:t], traj_y[:t], "r", traj_x[t-1], traj_y[t-1], "or")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_v4_2_1_r.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_data(data_x,data_y, data_limit):
    fig, ax = plt.subplots()
    ax.set_xlim(0.2,0.8)
    ax.set_ylim(-0.05,0.8)
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b")
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=30)    
    ani.save('anim_Re_0.gif', writer='pillow')
    plt.show()
    
    return 0

def MakeAnimation_attracter(data_x,data_y):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data_x.shape[0]):
        traj = ax.plot(data_x[:t], data_y[:t], "b", data_x[t-1], data_y[t-1], "ob",)
        title = ax.text(0.5, 1.01, 'time={}'.format(t),
                     ha='center', va='bottom',
                     transform=ax.transAxes, fontsize='large')
        ims.append(traj+[title])
    ani = animation.ArtistAnimation(fig, ims, interval=10)    
    ani.save('anim_attracter2.gif', writer='pillow')
    plt.show()
    
    return 0


    
def MakeAnimation_img(data,module):
    fig, ax = plt.subplots()
    ims = []
    for t in range(data.shape[0]):
        img = ax.imshow(data[t].reshape(1,data.shape[1]))
        title = ax.text(0.5, 1.01, module+' time={}'.format(t),
             ha='center', va='bottom',
             transform=ax.transAxes, fontsize='large')
        ims.append([img,title])
    ani = animation.ArtistAnimation(fig, ims)    
    ani.save('anim_fire_'+module+'_.gif', writer='pillow')
    plt.show()
    
    return 0

def mkOwnDataSet(data_size, filename, data_length=100, freq=60., noise=0.005):
    
    x = np.loadtxt(str(filename+"_r.csv"),delimiter=',')
    y = np.loadtxt(str(filename+"_l.csv"),delimiter=',')
    train_x = []
    train_y = []
    
    for offset in range(data_size):
        train_x.append([[x[i][0] + np.random.normal(loc=0.0, scale=noise),x[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,x.shape[0],round(x.shape[0]/data_length))])
        train_y.append([[y[i][0] + np.random.normal(loc=0.0, scale=noise),y[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(0,y.shape[0],round(y.shape[0]/data_length))])
    return train_x, train_y


def mkOwnRandomBatch(train_x, train_t, batch_size=10):
    """
    train_x, train_tを受け取ってbatch_x, batch_tを返す
    """
    batch_x = []
    for _ in range(batch_size):
        idx = np.random.randint(0, len(train_x) - batch_size)
        batch_x.append(train_x[idx])
    return torch.tensor(batch_x).transpose(0,1)

def make_Ttraj(direction):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(5)*0.5
    if direction == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.5, 5)
    elif direction == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.5, 5)        
    backstraight_x = np.ones(5)*0.5

    straight_y = np.linspace(0, 0.5, 5)
    branch1_y = np.linspace(0.5, 0.5, 5)
    branch2_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 5)
    
    traj_x = np.append(traj_x, [straight_x,branch1_x,branch2_x,backstraight_x])
    traj_y = np.append(traj_y, [straight_y,branch1_y,branch2_y,backstraight_y])
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj

def make_Htraj(direction1, direction2):
    traj_x = np.array([])
    traj_y = np.array([])
    straight_x = np.ones(10)*0.5
    if direction1 == 1:
        branch1_x = np.linspace(0.5, 0.75, 5)
        branch2_x = np.linspace(0.75, 0.75, 5)
        branch3_x = np.linspace(0.75, 0.75, 5)
        branch4_x = np.linspace(0.75, 0.5, 5)
    elif direction1 == 2:
        branch1_x = np.linspace(0.5, 0.25, 5)
        branch2_x = np.linspace(0.25, 0.25, 5)   
        branch3_x = np.linspace(0.25, 0.25, 5)  
        branch4_x = np.linspace(0.25, 0.5, 5)       
    backstraight_x = np.ones(10)*0.5

    straight_y = np.linspace(0, 0.5, 10)
    branch1_y = np.linspace(0.5, 0.5, 5)
    if direction2 == 1:
        branch2_y = np.linspace(0.5, 0.25, 5)
        branch3_y = np.linspace(0.25, 0.5, 5)
    elif direction2 == 2:
        branch2_y = np.linspace(0.5, 0.75, 5)
        branch3_y = np.linspace(0.75, 0.5, 5)
    branch4_y = np.linspace(0.5, 0.5, 5)
    backstraight_y = np.linspace(0.5 , 0, 10)
    
    traj_x = np.append(traj_x, straight_x)
    traj_x = np.append(traj_x, [branch1_x,branch2_x,branch3_x,branch4_x])
    traj_x = np.append(traj_x, backstraight_x)
    traj_y = np.append(traj_y, straight_y)
    traj_y = np.append(traj_y, [branch1_y,branch2_y,branch3_y,branch4_y])
    traj_y = np.append(traj_y, backstraight_y)
    traj = np.concatenate([[traj_x],[traj_y]],axis=0).T
    return traj


def make_traj(delay_length):
    freq=60
    noise=0.005
    data_length = 40
    
    # x = np.loadtxt("right1.csv",delimiter=',')
    x_1 = make_Htraj(1,1)
    x_2 = make_Htraj(1,2)
    # y = np.loadtxt("left1.csv",delimiter=',')
    y_1 = make_Htraj(2,1)
    y_2 = make_Htraj(2,2)
    z = np.loadtxt("stay1.csv",delimiter=',')
    data_list = []

    ###  ver1  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    
#     ###  ver2  ###
#     orders = []
#     order  = np.array([1])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [2])
#     orders.append(order) 
#     order  = np.array([2])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [3])
#     orders.append(order) 
#     order  = np.array([3])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [4])
#     orders.append(order) 
#     order  = np.array([4])
#     order  = np.append(order,np.zeros(delay_length))
#     order = np.append(order, [1])
#     orders.append(order) 

    ###  ver1's test  ###
    orders = []
    order  = np.array([1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    orders.append(order) 
    order  = np.array([2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    orders.append(order) 
    order  = np.array([3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    orders.append(order) 
    order  = np.array([4])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [2])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [3])
    order  = np.append(order,np.zeros(delay_length))
    order = np.append(order, [1])
    orders.append(order) 
    
    # order = [1,0,0,0,2]
    # order = [2,0,0,1,0,0,0,2,0,0,1]
    for order in orders:
        data = np.array([[0,0]])
        for idx in order:
            if idx == 1:
                target = x_1
            elif idx == 2:
                target = x_2
            elif idx == 3:
                target = y_1
            elif idx == 4:
                target = y_2
            elif idx == 0:
                target = np.array([[z[i][0],z[i][1]] for i in np.random.choice(z.shape[0], data_length)])
#                 target = np.array([[0.5,0] for i in np.random.choice(z.shape[0], data_length)])
            if idx != 0:
                target += np.random.normal(loc=0.0, scale=noise, size=target.shape)
    
            data = np.append(data,target,axis=0)
        data_list.append(data[1:])
    return data_list[0],data_list[1],data_list[2],data_list[3]

def mkOwnDataSet_auto(data_size, delay_length, freq=60., noise=0.01):
   
    x1,x2,y1,y2 = make_traj(delay_length)
    train_x1 = []
    train_x2 = []
    train_y1 = []
    train_y2 = []
    
    for offset in range(data_size):
        train_x1.append([[x1[i][0] + np.random.normal(loc=0.0, scale=noise),x1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x1.shape[0])])
        train_x2.append([[x2[i][0] + np.random.normal(loc=0.0, scale=noise),x2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(x2.shape[0])])

        train_y1.append([[y1[i][0] + np.random.normal(loc=0.0, scale=noise),y1[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y1.shape[0])])
        train_y2.append([[y2[i][0] + np.random.normal(loc=0.0, scale=noise),y2[i][1]+ np.random.normal(loc=0.0, scale=noise)] for i in np.arange(y2.shape[0])])

    return train_x1, train_x2, train_y1, train_y2

def plot_distance_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    plt.figure()
    plt.plot(result)
    
    plt.vlines(linelist,np.min(result),np.max(result))
    
    return result

def distance_bet2traj(traj1,traj2):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    for i in range(traj1.shape[0]):
        result.append(np.linalg.norm(traj1[i]-traj2[i]))
    return result

def plot_activity_bet2traj(traj1,traj2,linelist):
    traj1 = np.array(traj1)
    traj2 = np.array(traj2)
    result = []
    threshold = 0.5
    for i in range(traj1.shape[0]):
        result.append(np.abs(traj1[i]-traj2[i]))
        
    plt.figure()
    for i,data in enumerate(np.array(result).T):
        if np.any(data[20:105]>threshold):
            print(i)
            plt.plot(data)
    
    return result

def search_delay(traj):
    linelist = np.array([])
    flag = False
    for i in range(traj.shape[0]):
        if traj[i,1] > 0.45 and flag == False:
            linelist = np.append(linelist,i)
            flag = True
        if traj[i,1] < 0.45 and flag == True:
            linelist = np.append(linelist,i)
            flag = False
    return linelist

def pick_delay(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
        start = int(k)
    states_list.append(traj[start:])
    return states_list

def pick_traj(traj,states):
    pointlist = np.array([])
    flag = False
    threshold = 0.2
    for i in range(traj.shape[0]):
        if i < 5:
            continue
        if traj[i,1] < threshold and flag == False:
            pointlist = np.append(pointlist,i)
            flag = True
        if traj[i,1] > threshold and flag == True:
            pointlist = np.append(pointlist,i)
            flag = False
    states_list = []
    start = 0
    for i,k in enumerate(pointlist):
        end = int(k)
        if i % 2 != 0:
            states_list.append(states[start:end])
            start = int(k)
    states_list.append(traj[start:])
    return states_list

def match_length(states_list):
    min_length = 1000
    result = []
    for state in states_list:
        min_length = np.min((min_length, len(state)))
    for state in states_list:
        result.append(state[-min_length:])
        
    return result

def vec_var(datas):
    datas = np.array(datas)
    average = np.average(datas,axis=0)
    result = 0
    for data in datas:
        result += np.linalg.norm(data-average)
    result /= datas.shape[0]
    return result
    
def lyapunov_exp(data):
    result = np.mean(np.log(np.abs(np.diff(data))))
    return result

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size):
        super(MyLSTM, self).__init__()

        self.hidden_size = hidden_size
        self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+hidden_size, hidden_size)
        self.Re = nn.LSTMCell(hidden_size*2, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)
        self.batch_size = batch_size
        nn.init.sparse_(self.PFC.weight_ih.data,1)
        nn.init.sparse_(self.HPC.weight_ih.data,1)
        nn.init.sparse_(self.PFC.weight_hh.data,1)
        nn.init.sparse_(self.HPC.weight_hh.data,1)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2][0]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2][0]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    
    
class MyLSTM_feedforward(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_feedforward, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.Linear(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input)
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re]
    

class MyLSTM_feedforward_Thalamus2(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_feedforward_Thalamus2, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.hidden_size_THinh = hidden_size
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.Linear(self.hidden_size_PFC+self.hidden_size_HPC+self.hidden_size_THinh, self.hidden_size_Re)
        self.THinh = nn.Linear(self.hidden_size_Re, self.hidden_size_THinh)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        self.ReLU = nn.ReLU()
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0],hiddens[3]],dim=1)
        Re_hidden = self.Re(Re_input)
        THinh_input = hiddens[2]
        THinh_hidden = self.ReLU(self.THinh(THinh_input))
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden,THinh_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        THinh_hidden = self.ReLU(torch.rand(self.batch_size, self.hidden_size_THinh))
        return [PFC_hidden,HPC_hidden,Re_hidden,THinh_hidden]
    
    def initHidden_rand(self):
        const = 0.01
        var = 0.01
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        THinh_hidden = self.ReLU(torch.rand(self.batch_size, self.hidden_size_THinh)*const)
        return [PFC_hidden,HPC_hidden,Re_hidden,THinh_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0.02
        v = 0.02
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*v
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*c
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def hidden_size_list(self):
        return [self.hidden_size_PFC,self.hidden_size_HPC,self.hidden_size_Re,self.hidden_size_THinh]

    
    

class MyLSTM_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, batch_size, sparse=5):
        super(MyLSTM_RNN, self).__init__()

        self.hidden_size_PFC = hidden_size
        self.hidden_size_HPC = hidden_size
        self.hidden_size_Re = hidden_size+0
        self.PFC = nn.LSTMCell(self.hidden_size_HPC+self.hidden_size_Re, self.hidden_size_PFC)
        # self.PFC = nn.LSTMCell(hidden_size*2, hidden_size)
        self.HPC = nn.LSTMCell(input_size+self.hidden_size_Re, self.hidden_size_HPC)
        self.Re = nn.RNNCell(self.hidden_size_HPC+self.hidden_size_PFC, self.hidden_size_Re)
        self.linear = nn.Linear(self.hidden_size_HPC, output_size)
        self.batch_size = batch_size
        sparse = 0.1*int(sparse)
        nn.init.sparse_(self.PFC.weight_ih.data,sparse)
        nn.init.sparse_(self.HPC.weight_ih.data,sparse)
        nn.init.sparse_(self.PFC.weight_hh.data,sparse)
        nn.init.sparse_(self.HPC.weight_hh.data,sparse)

    def forward(self, input, hiddens):
        input = input.float()
        Re_input = torch.cat([hiddens[0][0],hiddens[1][0]],dim=1)
        Re_hidden = self.Re(Re_input, hiddens[2])
        HPC_input = torch.cat([input,hiddens[2]],dim=1)
        HPC_hidden = self.HPC(HPC_input, hiddens[1])
        PFC_input = torch.cat([hiddens[1][0],hiddens[2]],dim=1)
        PFC_hidden = self.PFC(PFC_input, hiddens[0])
        output = self.linear(HPC_hidden[0])
        return output, [PFC_hidden, HPC_hidden, Re_hidden]

    def initHidden(self):
        HPC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        PFC_hidden = [torch.zeros(self.batch_size, self.hidden_size), torch.zeros(self.batch_size, self.hidden_size)]
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size)
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_rand(self):
        const = 0.1
        var = 0.1
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*const
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def initHidden_test(self):
        const = 0.2
        var = 0.2
        HPC_hidden = [torch.rand(self.batch_size, self.hidden_size_HPC)*const, torch.rand(self.batch_size, self.hidden_size_HPC)*const]
#         HPC_hidden = [torch.ones(self.batch_size, self.hidden_size_HPC)*const, torch.ones(self.batch_size, self.hidden_size_HPC)*const]
        PFC_hidden = [torch.rand(self.batch_size, self.hidden_size_PFC)*var, torch.rand(self.batch_size, self.hidden_size_PFC)*var]
#         PFC_hidden = [torch.ones(self.batch_size, self.hidden_size_PFC)*const, torch.ones(self.batch_size, self.hidden_size_PFC)*const]
#         Re_hidden = torch.ones(self.batch_size, self.hidden_size_Re)*const
        Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*var
        return [PFC_hidden,HPC_hidden,Re_hidden]
    
    def noiseHidden_rand(self, hiddens):
        c = 0
        v = 0.1
        HPC_hidden = torch.randn(self.batch_size, self.hidden_size_HPC)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.randn(self.batch_size, self.hidden_size_PFC)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_select(self, hiddens):
        c = 0
        v = 0.01
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
        Re_hidden[:,index] += torch.randn(self.batch_size, index.size)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    def noiseHidden_dis(self, hiddens, statr, statl):
        c = 0
        v =-0.1
        index = np.array([1,2,8,9,17])
        HPC_hidden = torch.zeros(self.batch_size, self.hidden_size_HPC)
        HPC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[1][0].add_(HPC_hidden)
        PFC_hidden = torch.zeros(self.batch_size, self.hidden_size_PFC)
        PFC_hidden[:,index] += torch.randn(self.batch_size, index.size)*c
        hiddens[0][0].add_(PFC_hidden)
        Re_hidden = torch.randn(self.batch_size, self.hidden_size_Re)*0.0
        Re_hidden += (-statr+statl)*v
#         Re_hidden = torch.rand(self.batch_size, self.hidden_size_Re)*v
#         Re_hidden[:,index] = torch.zeros(self.batch_size, index.size)*c
#         Re_hidden = torch.zeros(self.batch_size, self.hidden_size_Re)*c
#         Re_hidden[:,index] += hiddens[2][:,index]*v
        hiddens[2].add_(Re_hidden)
        return hiddens
    
    
    
def correct_traj(traj):
    Rightflag = False
    Leftflag = False
    Pointflag = False
    extreme = 0.5
    extreme_point = 0
    point_list = []
    for i in range(traj.shape[0]):
        if np.abs(traj[i,0]-0.5) > 0.1:
            if traj[i,1] < 0.4:
                Pointflag = True
                if extreme > traj[i,1]:
                    extreme = traj[i,1]
                    point = i
            if traj[i,1] > 0.4 and traj[i,1] < 0.5 and Pointflag == True:
                Pointflag = False
                extreme = 0.5
                point_list.append(point)

            if traj[i,1] > 0.6:
                Pointflag = True
                if extreme < traj[i,1]:
                    extreme = traj[i,1]
                    point = i
            if traj[i,1] < 0.6 and traj[i,1] > 0.5 and Pointflag == True:
                Pointflag = False
                extreme = 0.5
                point_list.append(point)

    return point_list


def moving_average(data):
    y = np.ones(9)/9
    mean_seq = np.convolve(data, y)
    return mean_seq    

def moving_average_test(data):
    y = np.ones(9)/9
    mean_seq = np.convolve(data, y, mode="same")
    return mean_seq   

def moving_average_grad(data):
    y = np.ones(5)/5
    mean_seq = np.convolve(data, y, mode="same")
    result = (np.clip(mean_seq[-3] + (mean_seq[-3]-mean_seq[-4]) + (mean_seq[-2]-mean_seq[-3]*5/4),-1,1) + mean_seq[-1]*5/3)/2
#     result = (np.clip(mean_seq[-3] + (mean_seq[-3]-mean_seq[-4])*2,-1,1) + mean_seq[-1]*5/3)/2
#     print(np.clip(mean_seq[-3] + (mean_seq[-3]-mean_seq[-4]) + (mean_seq[-2]-mean_seq[-3]*5/4),-1,1), mean_seq[-1]*5/3)
    return np.append(mean_seq,result)    


def main(model):
    training_size = 100
    test_size = 1000
    epochs_num = 10
    hidden_size = 20
    batch_size = 10
    data_length = 100
    inputsize = 2
    outputsize = 2
    delay_length = 2
    
    model_path = model
    
    colors = ['C0','C1','C2','C3','C4','C5','C6','C7','C8','C9']
    
#     delay_length = np.random.randint(1, 3)
    delay_length = 2
    train_x1,train_x2,train_y1,train_y2 = mkOwnDataSet_auto(training_size,delay_length)

#     rnn = MyLSTM_comp(inputsize, hidden_size, outputsize, batch_size)
    rnn = MyLSTM_RNN(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_RNN_uniHPC(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyMTRNN2(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_feedforward(inputsize, hidden_size, outputsize, batch_size)
#     rnn = MyLSTM_feedforward_Thalamus2(inputsize, hidden_size, outputsize, batch_size)

    rnn.load_state_dict(torch.load(model_path))
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    params = []
    for p in rnn.parameters():
        params.append(p.data)
    
        
#     model_path = 'model/R20_131/ReModel_L2_interRNNrand_OUT1_131_s6_100_1_epoch125.pth'
    model_path = model
    rnn.load_state_dict(torch.load(model_path))
        
    for n, p in rnn.named_parameters():
            if n == "PFC.weight_ih":
                PFC_w_pre = torch.clone(p.data)
            if n == "HPC.weight_ih":
                HPC_w_pre = torch.clone(p.data)
            if n == "Re.weight_ih":
                Re_w_pre = torch.clone(p.data)
            if n == "Re.weight_hh":
                Re_inw = torch.clone(p.data)
            if n == "linear.weight":
                OUT_w_pre = torch.clone(p.data)
                   
    traj = []
    PFCstate = []
    HPCstate = []
    Restate = []
    Gate_states = []
    hidden = rnn.initHidden_rand()
    data_limit = 160
    est_length = 1
    init_point = torch.rand(10,2)*1
    for k in range(data_limit):
            #print(data[k].shape)
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
#             Gate_states.append(Culc_gate_uniPFC(output,params,hidden))
    for k in range(data.shape[0]*est_length):
            output,hidden = rnn(output,hidden)
            traj.append(output.tolist())
            PFCstate.append(hidden[0][0].tolist())
            HPCstate.append(hidden[1][0].tolist())
            Restate.append(hidden[2][0].tolist())
            Gate_states.append(Culc_gate(output,params,hidden))
    traj = torch.tensor(traj)
    traj = torch.squeeze(traj).numpy()
    
    
    
########### For PCA or high-pass insert!!! #################

    PCA_size = 4
    pca = PCA(n_components=PCA_size)
    dfs = np.array(PFCstate)[:,0]
#     dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)

#     print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(PCA_size)]).head())
    print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(PCA_size)]))
    
    recovered = pca.inverse_transform(feature[:])
    
    
    pattern = 1
    if  pattern == 1:
        data = mkOwnRandomBatch(train_x1, train_x1, batch_size).float()
    elif pattern == 2:
        data = mkOwnRandomBatch(train_x2, train_x2, batch_size).float()
    elif pattern == 3:
        data = mkOwnRandomBatch(train_y1, train_y1, batch_size).float()
    else:
        data = mkOwnRandomBatch(train_y2, train_y2, batch_size).float()
    
    dividenum = int(np.array(PFCstate)[0:,0].shape[0]/2)
    traj_noise = []
    PFCstate_noise = []
    HPCstate_noise = []
    Restate_noise = []
    Restate_noise2 = []
    Gate_states_noise = []
    hidden = rnn.initHidden_rand()
#     data = mkOwnRandomBatch(train_y, batch_size)
#     init_point = torch.rand(10,2)*1
    data_limit = 50
    rest = 42
    for k in range(data_limit):
#             for cell in range(20):
#                     print(hidden[2][0][cell+0],np.array(Restate)[int(delays[Delaynum])+Delaytime,cell+0]-moving_average(np.array(Restate)[:,cell+0])[4+int(delays[Delaynum])+Delaytime])
#                     hidden[2][0][cell+8] = np.average(np.array(Restate_noise)[-6:,cell+8])
#                     hidden[2][0][cell+8] = np.average(np.array(Restate)[k-5+data_limit:k+5+data_limit,cell+8])
#                     hidden[2][0][cell+0] = moving_average(np.array(Restate)[:,cell+0])[4+k]
#                     hidden[2][0][cell+0] = moving_average(np.array(Restate)[:,cell+0])[4+int(delays[Delaynum])+Delaytime]
#                     hidden[2][0][cell+0] = np.array(Restate)[startpoint+Delaytime,cell+0]-moving_average(np.array(Restate)[:,cell+0])[4+startpoint+Delaytime]
#                 hidden[2][0][cell+0] = np.array(Restate)[k,cell+0] - moving_average(np.array(Restate)[:,cell+0])[4+k]
            if k > 60:
                for cell in range(20):
#                     hidden[2][0][cell+0] = moving_average(np.array(Restate_noise)[:,cell+0])[-3]*5/3
#                     hidden[2][0][cell+0] = moving_average(np.array(Restate_noise)[:,cell+0])[-5]
#                     hidden[2][0][cell+0] = moving_average(np.array(Restate)[:,cell+0])[k+4]
#                     hidden[2][0][cell+0] = moving_average_test(np.array(Restate_noise)[:,cell+0])[-1]*9/5
#                     hidden[2][0][cell+0] = moving_average_grad(np.array(Restate_noise)[:,cell+0])[-1]*1
#                     hidden[2][0][cell+0] = np.array(Restate_noise)[-1,cell+0] - 0.3*moving_average(np.array(Restate_noise)[:,cell+0])[-5]
                    hidden[2][0][cell+0] = np.array(Restate_noise)[-1,cell+0] - 0.25*moving_average_grad(np.array(Restate_noise)[:,cell+0])[-1]
#                     hidden[2][0][cell+0] = np.array(Restate)[k,cell+0] - 1.0*moving_average(np.array(Restate)[:,cell+0])[k+4]*1
#                     hidden[2][0][cell+0] = np.array(recovered)[k,cell+0]
#                     hidden[0][0][0][cell+0] = np.array(recovered)[k,cell+0]
#                     pass
            #print(data[k].shape)
            Restate_noise2.append(hidden[2][0].tolist())
            output,hidden = rnn(data[k],hidden)
#             output,hidden = rnn(init_point,hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
    for k in range(data.shape[0]*est_length+rest):
            for cell in range(20):
#                 hidden[2][0][cell+0] = moving_average(np.array(Restate_noise)[:,cell+0])[-3]*5/3
#                 hidden[2][0][cell+0] = moving_average(np.array(Restate_noise)[:,cell+0])[-5]
#                 hidden[2][0][cell+0] = moving_average_test(np.array(Restate_noise)[:,cell+0])[-1]*9/5
#                 hidden[2][0][cell+0] = moving_average_grad(np.array(Restate_noise)[:,cell+0])[-1]*1
#                 hidden[2][0][cell+0] = np.array(Restate_noise)[-1,cell+0] - 0.3*moving_average(np.array(Restate_noise)[:,cell+0])[-5]*1
                hidden[2][0][cell+0] = np.array(Restate_noise)[-1,cell+0] - 0.25*moving_average_grad(np.array(Restate_noise)[:,cell+0])[-1]
#                 hidden[0][0][0][cell+0] = np.array(recovered)[k+data_limit,cell+0]
#                 pass
                
#             hidden = rnn.noiseHidden_rand(hidden)
            Restate_noise2.append(hidden[2][0].tolist())
            output,hidden = rnn(output,hidden) 
#             hidden = rnn.noiseHidden_dis(hidden, np.array(Restate)[k+data_limit], np.array(Restate)[dividenum+k+data_limit])
#             hidden = rnn.noiseHidden_rand(hidden)
            traj_noise.append(output.tolist())
            PFCstate_noise.append(hidden[0][0].tolist())
            HPCstate_noise.append(hidden[1][0].tolist())
            Restate_noise.append(hidden[2][0].tolist())
#             Gate_states_noise.append(Culc_gate(output,params,hidden))
    traj_noise = torch.tensor(traj_noise)
    traj_noise = torch.squeeze(traj_noise).numpy()
    
#     traj_dif_b = distance_bet2traj(traj[:,0],traj_noise[:,0])
# #     print(traj_dif_b)
#     traj_dif = np.sum(traj_dif_b)
    
    data = traj
    ex_points = correct_traj(data[:,0])
#     print(ex_points)

    data = traj_noise
    ex_points_noise = correct_traj(data[:,0])
#     print(ex_points)
    
#     print(traj[ex_points[-1]]-traj_noise[ex_points_noise[-1]])
    var_ex = np.linalg.norm(traj[ex_points[-1]][0]-traj_noise[ex_points_noise[-1]][0])

    
    print(np.array(PFCstate)[:,0].shape)
#     MakeAnimation(pltdata[:,0,0],pltdata[:,0,1], traj[:,0,0], traj[:,0,1], data_limit)
    #MakeAnimation_img(np.array(PFCstate)[:,0],"PFC")
    #MakeAnimation_img(np.array(HPCstate)[:,0],"HPC")


    plt.figure()
    plt.plot(traj[:,0,0],traj[:,0,1])
    plt.plot(traj_noise[:,0,0],traj_noise[:,0,1])
    plt.yticks(fontsize=28)
    plt.xticks(fontsize=28)

    plt.figure()
    plt.plot(traj[:,0,0],"--",color="C2",alpha=0.6)
    plt.plot(traj[:,0,1],"--",color="C3",alpha=0.6)
    plt.plot(traj_noise[:,0,0],color="C2",alpha=0.9)
    plt.plot(traj_noise[:,0,1],color="C3",alpha=0.9)
    plt.yticks(fontsize=28)
    plt.xticks(fontsize=28)

    PFC_corrlist = np.array([])
    HPC_corrlist = np.array([])
    cross_corrlist = np.array([])
    for i in range(20):
        plt.figure()
#         plt.ylim(-1,1)
        plt.plot(traj[:,0,0])
        plt.plot(traj_noise[:,0,0])
#         plt.plot(np.abs(traj[:,0,0]-traj_noise[:,0,0]))
    #         plt.plot(np.array(PFCstate_noise)[:400,0,17],alpha=0.5)
        for k in range(1):
#                 plt.plot(np.array(PFCstate)[:,k,i],alpha=0.5)
#                 plt.plot(np.array(PFCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),alpha=0.5)
#                 plt.plot(np.array(HPCstate)[:,k,i],alpha=0.5)
#                 plt.plot(np.array(HPCstate_noise)[:,k,i],alpha=0.5)
#                 plt.plot(np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]),alpha=0.5)
    #             plt.plot(np.array(PFCstate)[:120,k,i],alpha=0.5)
    #             plt.plot(np.array(PFCstate)[120:240,k,i],alpha=0.5)
    #             plt.plot(np.array(PFCstate)[:120,k,i]-np.array(PFCstate)[120:240,k,i],alpha=0.5)
    #             plt.plot(np.array(PFCstate_noise)[:120,k,i],alpha=0.5)
    #             plt.plot(np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
    #             plt.plot(np.array(PFCstate_noise)[:120,k,i]-np.array(PFCstate_noise)[120:240,k,i],alpha=0.5)
            plt.plot(np.array(Restate)[:,i],alpha=0.5)
            plt.plot(np.array(Restate_noise2)[:,i],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-moving_average(np.array(Restate)[:,i])[2:-2],alpha=0.5)
#             plt.plot(np.array(Restate)[:,i]-np.array(Restate_noise)[:,i],alpha=0.5)
#                 plt.plot(moving_average(np.array(Restate)[:,i]),alpha=0.5)
#                 plt.plot(moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
#                 plt.plot(moving_average(np.array(Restate)[:,i])-moving_average(np.array(Restate_noise)[:,i]),alpha=0.5)
#                 plt.plot(np.array(Gate_states)[:,0,0+i],alpha=0.5)
#                 plt.plot(np.array(Gate_states_noise)[:,1,0+i],alpha=0.5)
#                 plt.plot(np.array(Gate_states_noise)[:,1,20+i],alpha=0.5)
#                 plt.plot(np.log(np.array(Gate_states)[:,1,0+i]/np.array(Gate_states)[:,1,20+i]),alpha=0.5)
#                 plt.plot(np.log(np.array(Gate_states_noise)[:,1,0+i]/np.array(Gate_states_noise)[:,1,20+i]),alpha=0.5)
#                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
#                 print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
#                 print(np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 cross_corrlist = np.append(cross_corrlist,np.corrcoef(np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i])))
#                 PFC_corrlist = np.append(PFC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(PFCstate)[:,k,i]-np.array(PFCstate_noise)[:,k,i]))[0,1])
#                 HPC_corrlist = np.append(HPC_corrlist,np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(np.array(HPCstate)[:,k,i]-np.array(HPCstate_noise)[:,k,i]))[0,1])
#         plt.title("neuron#"+str(i+1))
    print(np.mean(PFC_corrlist),np.mean(HPC_corrlist),np.mean(cross_corrlist))
    
#     pca = PCA()
# #     dfs = np.array(PFCstate)[:,0]
#     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
# #     dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
#     print(dfs.shape)
#     pca.fit(dfs)
#     feature = pca.transform(dfs)
#     print("PFC correlation")
# #     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
# #     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
# #     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
#     PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
# #     PFC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
# #     PFC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
# #     PFC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
# #     PFC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]))[0,1]
# #     PFC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]))[0,1]
# #     PFC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]))[0,1]
#     PFC_diff = np.sqrt(np.power(feature[:data_limit,0]-feature[data_limit:,0],2) + np.power(feature[:data_limit,1]-feature[data_limit:,1],2) + np.power(feature[:data_limit,2]-feature[data_limit:,2],2))
#     PFC_d = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),PFC_diff)[0,1]
#     PFC_diff2 = np.sum(np.abs(np.array(PFCstate)[:,0]-np.array(PFCstate_noise)[:,0]),axis=1)
#     PFC_e = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),PFC_diff2)[0,1]
    
# #     plt.figure()
# #     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
# #     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
# #     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
# #     plt.show()

# #     plt.figure()
# #     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
# #     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
# #     plt.show()

# #     print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
#     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))

    
#     pca = PCA()
# #     dfs = np.array(PFCstate)[:,0]
# #     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
#     dfs = np.concatenate((np.array(HPCstate)[:,0],np.array(HPCstate_noise)[:,0]),axis=0)
#     print(dfs.shape)
#     pca.fit(dfs)
#     feature = pca.transform(dfs)
#     print("HPC correlation")
# #     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
# #     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
# #     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
#     HPC_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
#     HPC_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
#     HPC_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
# #     HPC_a = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,0]-feature[data_limit:,0]))[0,1]
# #     HPC_b = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,1]-feature[data_limit:,1]))[0,1]
# #     HPC_c = np.corrcoef(np.abs(traj[:,0,1]-traj_noise[:,0,1]),np.abs(feature[:data_limit,2]-feature[data_limit:,2]))[0,1]
#     HPC_diff = np.sqrt(np.power(feature[:data_limit,0]-feature[data_limit:,0],2) + np.power(feature[:data_limit,1]-feature[data_limit:,1],2) + np.power(feature[:data_limit,2]-feature[data_limit:,2],2))
#     HPC_d = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),HPC_diff)[0,1]
#     HPC_diff2 = np.sum(np.abs(np.array(HPCstate)[:,0]-np.array(HPCstate_noise)[:,0]),axis=1)
#     HPC_e = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),HPC_diff2)[0,1]
#     print(HPC_e)
    
# #     plt.figure()
# #     plt.plot(np.abs(feature[:100,3]-feature[100:,3]))
# #     plt.plot(np.abs(feature[:100,4]-feature[100:,4]))
# #     plt.plot(np.abs(feature[:100,5]-feature[100:,5]))
# #     plt.show()
    
# #     plt.figure()
# #     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
# #     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
# #     plt.show()

# #     print(pd.DataFrame(feature, columns=["PC{}".format(x + 1) for x in range(dfs.shape[1])]).head())
# #     print(pd.DataFrame(pca.explained_variance_ratio_, index=["PC{}".format(x + 1) for x in range(dfs.shape[1])]))


#     pca = PCA()
# #     dfs = np.array(PFCstate)[:,0]
# #     dfs = np.concatenate((np.array(PFCstate)[:,0],np.array(PFCstate_noise)[:,0]),axis=0)
#     dfs = np.concatenate((np.array(Restate)[:],np.array(Restate_noise)[:]),axis=0)
#     print(dfs.shape)
#     pca.fit(dfs)
#     feature = pca.transform(dfs)
#     print("Re correlation")
# #     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0])))
# #     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1])))
# #     print(np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2])))
# #     Re_a = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,0]-feature[100:,0]))[0,1]
# #     Re_b = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,1]-feature[100:,1]))[0,1]
# #     Re_c = np.corrcoef(np.abs(traj[:,0,0]-traj_noise[:,0,0]),np.abs(feature[:100,2]-feature[100:,2]))[0,1]
    
# #     plt.figure()
# # #     plt.plot(np.abs(feature[:100,0]-feature[100:,0]))
# # #     plt.plot(np.abs(feature[:100,1]-feature[100:,1]))
# # #     plt.plot(np.abs(feature[:100,2]-feature[100:,2]))
# # #     plt.plot(np.abs(moving_average(feature[:100,0])-moving_average(feature[100:,0])))
# # #     plt.plot(np.abs(moving_average(feature[:100,1])-moving_average(feature[100:,1])))
# # #     plt.plot(np.abs(moving_average(feature[:100,2])-moving_average(feature[100:,2])))
# #     plt.plot(np.abs(feature[:100,0]),color=colors[-3])
# #     plt.plot(np.abs(feature[100:,0]),color=colors[-4])
# #     plt.show()


####################### Fluc(frac) and coherence part #############################


    pca = PCA()
#     dfs = np.array(HPCstate)[40:,0]
    dfs = np.array(Restate)
    pca.fit(dfs)
    feature = pca.transform(dfs)
    
    delaysB = pick_delay(traj[:,0], feature[:])

#     PCAnum = 0
#     data = feature[:,PCAnum]
#     movingA = moving_average(data)
#     fracA = data - movingA[2:-2]
#     print(math.dist(data,movingA[2:-2]))
#     frac_amp = math.dist(data,movingA[2:-2])
    
#     frac_amp = 0
#     for i in range(20):
#         data = np.array(HPCstate)[:,0,i]
#         movingA = moving_average(data)
#         fracA = data - movingA[2:-2]
#         frac_amp += math.dist(data,movingA[2:-2])
#     frac_amp = frac_amp/20
    
#     fig = plt.figure()
#     fig1 = fig.add_subplot(111)
#     fig1.hist(np.array(Gate_states)[:,0,0:20].ravel(),bins=30,range=(-0.1,1.1),alpha=0.5,density=True)
#     fig1.hist(np.array(Gate_states)[:,1,0:20].ravel(),bins=30,range=(-0.1,1.1),alpha=0.5,density=True)
#     plt.show()

#     hist = np.histogram(np.array(Gate_states)[:,0,0:20].ravel(),bins=30,range=(-0.1,1.1),density=True)
#     data = np.array([np.clip(hist[1][3:-2],0,1),hist[0][2:-2]]).T
    data = np.array(Gate_states)[:,0,0:20].ravel()
    try:
        beta_PFC = scipy.stats.beta.fit(data, floc=0)
        print(beta_PFC)
        beta_param = np.average([beta_PFC[0],beta_PFC[1]])
    except Exception:
        print("Error: maybe takes negative a or b")
        beta_param = 0.5
        
#     data = np.array(Gate_states)[:,1,0:20].ravel()
#     try:
#         beta_HPC = scipy.stats.beta.fit(data)
#         print(beta_HPC)
#     except Exception:
#         print("Error: maybe takes negative a or b")
        

#     pca = PCA()
# #     dfs = np.array(HPCstate)[:,0]
#     dfs = np.array(Restate)
#     pca.fit(dfs)
#     Refeature = pca.transform(dfs)

#     shift = 2
#     seglen = 60
        
#     coherence_diff_list = []
#     for i in range(20):
#         for k in range(1):
#             data = np.array(Refeature)[:,0]-moving_average(np.array(Refeature)[:,0])[2:-2]
#             freqs,times,sx1 = signal.stft(data,fs=1,window="boxcar",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
#             data = np.array(PFCstate)[:,k,i]-moving_average(np.array(PFCstate)[2:-2,k,i])[:]
#             freqs,times,sx2 = signal.stft(data*1,fs=1,window="boxcar",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
#             data = np.array(HPCstate)[:,k,i]-moving_average(np.array(HPCstate)[2:-2,k,i])[:]
#             freqs,times,sx3 = signal.stft(data,fs=1,window="boxcar",nperseg=seglen,noverlap=seglen-shift,detrend=False,boundary=None)
            
#             xsp = sx1*np.conjugate(sx2)
#             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx2))
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("Re-PFC neuron#"+str(i+1))
# #             print(np.max(coherence),np.min(coherence))
            
# #             plt.figure()
# #             plt.plot(times, np.sum(coherence.T,axis=1))
            
# #             xsp = sx1*np.conjugate(sx3)
# #             coherence = (np.abs(xsp)**2)/(np.abs(sx1)*np.abs(sx3))            
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("Re-HPC neuron#"+str(i+1))
# #             print(np.max(coherence),np.min(coherence))

            
# #             xsp = sx2*np.conjugate(sx3)
# #             coherence = (np.abs(xsp)**2)/(np.abs(sx2)*np.abs(sx3))
# #             plt.figure()
# #             plt.pcolormesh(times,freqs,coherence)
# #             plt.title("PFC-HPC neuron#"+str(i+1))
            
# #             degree = np.degrees(np.angle(xsp))
# #             print(degree.shape)
# #             plt.figure()
# #             plt.plot(degree)

# #             plt.figure()
# #             plt.plot(times, np.sum(coherence.T,axis=1))
#             coherence_diff_list.append(np.max(np.sum(coherence.T,axis=1))-np.min(np.sum(coherence.T,axis=1)))
    
    
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c]),frac_amp
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]),HPC_d
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, frac_amp
#     return np.mean([PFC_a,PFC_b,PFC_c]),np.mean([HPC_a,HPC_b,HPC_c]),np.max([PFC_d]), HPC_e, np.sum(coherence_diff_list)
#     return np.max([PFC_a,PFC_b,PFC_c]),np.max([HPC_a,HPC_b,HPC_c])
    return 0,0,0,0,traj_dif


if __name__ == '__main__':
    plt.rcParams["font.size"] = 16
    fig = plt.figure()
    correlation_fig = fig.subplots()
#     correlation_fig.set_xlim(-1,1)
#     correlation_fig.set_ylim(0,1)

#     model_list = glob.glob('model/R20_131/*OUT1**s8_100_2_*epoch*.pth')
#     model_list = sorted(model_list)
#     model_list = sorted(model_list,key=len,reverse=False)
#     k=0
#     for model in model_list:
#         print(model,PFC,HPC)
#         PFC,HPC = main(model)
# #         correlation_fig.plot(PFC,HPC,"o")
# #         correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
#         correlation_fig.plot(k*5,PFC/HPC,"o")
# #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
#         k+=1

    allratio_list = []
    good_points = [[],[]]
    bad_points = [[],[]]
    for num in range(1):
        for i in range(1):
            path = 'model/R20_H_bigbatch/'
#             path = 'model/R20_H_uniHPC_bigbatch/'
#             path = 'model/R20_H_stopinit_bigbatch/'
#             path = 'model/R20FF_H_bigbatch/'
#             path = 'model/R20_feedReinhReLU_H_bigbatch/'
            model_list = glob.glob(path+'*s'+str(i+6)+'_100_'+str(num+4)+'_*epoch200.pth')
            model_list = sorted(model_list)
            model_list = sorted(model_list,key=len,reverse=False)
            ratio_list = []
            ratio_list_max = []
            with open(path+"good_list.txt", mode="r") as f:
                good_list = f.read().splitlines()
#             good_list = []
#             if i+1 == 4 and num+1 == 3:
#                 continue
#             if i+1 == 5 and num+1 == 2:
#                 continue
            
            first_goodmodel = [0,0]
            good_flag = False
            k=0
            for model in model_list:
                print(model)
#                 if int(model.split("epoch")[-1].split(".")[0])>194:
#                     continue
#                 if int(model.split("epoch")[-1].split(".")[0])<19:
#                     continue
    #             PFC,HPC = main(model)
#                 PFC,HPC,PFC_max,HPC_max = main(model)
                PFC,HPC,PFC_max,HPC_max,frac = main(model)
    #             ratio_list.append(PFC/HPC)
                ratio_list.append(np.abs(PFC-HPC))
    #             ratio_list_max.append(PFC_max/HPC_max)
#                 ratio_list_max.append(np.abs(PFC_max-HPC_max))
#                 ratio_list_max.append(HPC_max)
                ratio_list_max.append(frac)
#                 ratio_list_max.append(PFC_max)
#                 correlation_fig.plot(PFC_max,HPC_max,"o")
#                 if model in good_list:
#                     correlation_fig.plot(PFC_max,HPC_max,"o",color="b")
#                 else:
#                     correlation_fig.plot(PFC_max,HPC_max,"o",color="r")
    #             correlation_fig.text(PFC,HPC,model.split("epoch")[-1].split(".")[0])
    #             correlation_fig.plot(k*5,PFC/HPC,"o")
        #         correlation_fig.text(k*5,PFC/HPC,model.split("epoch")[-1].split(".")[0])
                print(PFC_max,HPC_max)
                print("dist:"+str(frac))
#                 if good_flag != True and model in good_list:
#                     good_flag = True
#                     first_goodmodel[0] = int(model.split("epoch")[-1].split(".")[0])
#     #                 first_goodmodel[1] = ratio_list[-1]
#                     first_goodmodel[1] = ratio_list_max[-1]
                if model in good_list:
                    good_points[0].append(ratio_list_max[-1]) 
                    good_points[1].append(int(model.split("epoch")[-1].split(".")[0]))
                else:
                    bad_points[0].append(ratio_list_max[-1]) 
                    bad_points[1].append(int(model.split("epoch")[-1].split(".")[0]))
                k+=1
    #         correlation_fig.plot(np.arange(0,200,5),np.array(ratio_list)-np.mean(ratio_list),"o")
    #         ratio_list = np.array(ratio_list).clip(-2,2)
    #         ratio_list = moving_average(ratio_list)[2:-2]
    #         correlation_fig.plot(np.arange(0,len(ratio_list)*5,5),np.array(ratio_list))

#             ratio_list_max = np.array(ratio_list_max).clip(-2,2)
#             ratio_list_max = moving_average(ratio_list_max)[4:-2]
#             correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.array(ratio_list_max),color="C{}".format(i))
#             if good_flag == True:
#                 correlation_fig.plot(first_goodmodel[0],first_goodmodel[1],"o",color="C{}".format(i))
            allratio_list.append(np.array(ratio_list_max))
    correlation_fig.plot(np.arange(0,len(ratio_list_max)*5,5),np.average(np.array(allratio_list), axis=0),color="b")
    correlation_fig.errorbar(np.arange(0,len(ratio_list_max)*5,5),np.average(np.array(allratio_list), axis=0),yerr=np.sqrt(np.var(np.array(allratio_list), axis=0)), color="b", alpha=0.3)
#     correlation_fig.set_ylim(0,1)
#     np.save("ReFFInhReLU_HPCave.npy",np.mean(np.array(allratio_list),axis=0))
#     np.save("ReFFInhReLU_HPCvar.npy",np.var(np.array(allratio_list),axis=0))
#     np.save("uniHPC_PFCmax_good.npy",np.array(good_points))
#     np.save("uniHPC_PFCmax_bad.npy",np.array(bad_points))

        