In [3]:
'''
Training Data is a N*T*n tensor, N is the number of samples, T is the interval,
n is number of neurons.
Training Data is a N*1*1 tensor, N is the number of samples, 1*1 represents the
output dimension, which is the position of last time point.
Testing Data is a N'*T*n tensor.
Testing label is a N*1*1 tensor.
E.g.,
TrainingData = create_subsequences(np.transpose(X, TimeInterval))
TrainingLabel = Y[TimeInterval-1:].reshape(-1,1)
X is dFF, Y is corresponding position.
'''

"\nTraining Data is a N*T*n tensor, N is the number of samples, T is the interval,\nn is number of neurons.\nTraining Data is a N*1*1 tensor, N is the number of samples, 1*1 represents the\noutput dimension, which is the position of last time point.\nTesting Data is a N'*T*n tensor.\nTesting label is a N*1*1 tensor.\nE.g.,\nTrainingData = create_subsequences(np.transpose(X, TimeInterval))\nTrainingLabel = Y[TimeInterval-1:].reshape(-1,1)\nX is dFF, Y is corresponding position.\n"

In [1]:
import torch, os, glob
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle
import numpy as np
import matplotlib.pyplot as plt
import sys
np.set_printoptions(threshold=sys.maxsize)
import scipy, pandas as pd, random
from sklearn.preprocessing import MinMaxScaler

class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim = 256, 
        output_dim = 1, 
        num_layers = 2):
        # hidden_dim = height of network
        # layers = wifth of network
        super(LSTMModel, self).__init__()
        # Initialize the LSTM, Hidden Layer, and Output Layer
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, 
                dropout = 0.0, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Initialize hidden state and cell state
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)

        # Forward propagate the LSTM
        out, _ = self.lstm(x, (h0, c0))

        # Pass the output of the last time step to the classifier
        out = self.fc(out[:, -1, :])
        
        return out

def create_subsequences(time_series, subsequence_length=20):
    num_subsequences = len(time_series) - subsequence_length + 1
    subsequences = [time_series[i:i+subsequence_length] for i in range(num_subsequences)]
    return np.array(subsequences)

class CreateTimeSeriesData(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, i):
        return self.x[i], self.y[i]

In [2]:
# import raw data
conddf = pd.read_csv(r"Z:\condition_df\conddf_neural_modeling.csv", index_col=None)

In [7]:
# extract data
def extract_data(params_pth, TimeInterval, batch_size = 256, maxtrials = 8, maxtrialsheldout = 12):
    """
    maxtrialsheldout = for inference
    """
    fall = scipy.io.loadmat(params_pth, variable_names=['dFF', 'forwardvel', 'ybinned', 'iscell',
                                'trialnum', 'bordercells', 'changeRewLoc', 'licks', 'VR'])
    VR = fall['VR'][0][0]
    try:
        gainf = VR[14][0][0]
        rewsize = VR[16][0][0][4][0][0]/gainf
    except:
        gainf = VR[15][0][0] # opto days have additional variables in vr
        rewsize = VR[17][0][0][4][0][0]/gainf
    changeRewLoc = np.hstack(fall['changeRewLoc']) 
    eptest = conddf.optoep.values[dd]
    eps = np.where(changeRewLoc>0)[0]
    rewlocs = changeRewLoc[eps]
    eps = np.append(eps, len(changeRewLoc)) 
    if conddf.optoep.values[dd]<2: 
        eptest = random.randint(2,3)   
        if len(eps)<4: eptest = 2 # if no 3 epochs
    trialnum = np.hstack(fall['trialnum'])
    comp = [eptest-2,eptest-1] # eps to compare  
    other_ep = [xx for xx in range(len(eps)-1) if xx not in comp]
    # filter iscell        
    dff = fall['dFF'][:,(fall['iscell'][:,0].astype(bool)) & (~fall['bordercells'][0].astype(bool))]
    # position mask 

    # filter iscell
    dff = fall['dFF'][:,(fall['iscell'][:,0].astype(bool)) & (~fall['bordercells'][0].astype(bool))]        
    # remove nans
    dff[:, sum(np.isnan(dff))>0] = 0
    print(dff.shape)
    # train on first 6 trials
    dff_per_ep = [dff[eps[xx]:eps[xx+1],:] for xx in range(len(eps)-1)]
    trialnum_per_ep = [trialnum[eps[xx]:eps[xx+1]] for xx in range(len(eps)-1)]
    # get a subset of trials
    dff_per_ep_trials = [dff_per_ep[ii][((trialnum_per_ep[ii]>2) & (trialnum_per_ep[ii]<=maxtrials))] for ii in range(len(eps)-1)]
    dff_per_ep_trials_test = [dff_per_ep[ii][(trialnum_per_ep[ii]>maxtrials)&(trialnum_per_ep[ii]<=maxtrialsheldout)] for ii in range(len(eps)-1)]
    position = fall['ybinned'][0]
    position_per_ep = [position[eps[xx]:eps[xx+1]] for xx in range(len(eps)-1)]
    # get a subset of trials
    position_per_ep_trials = [position_per_ep[ii][((trialnum_per_ep[ii]>2) & (trialnum_per_ep[ii]<=maxtrials))] for ii in range(len(eps)-1)]
    position_per_ep_trials_test = [position_per_ep[ii][(trialnum_per_ep[ii]>maxtrials)&(trialnum_per_ep[ii]<=maxtrialsheldout)] for ii in range(len(eps)-1)]
    # licks
    licks = fall['licks'][0]
    licks_per_ep = [licks[eps[xx]:eps[xx+1]] for xx in range(len(eps)-1)]
    # get a subset of trials
    licks_per_ep_trials = [licks_per_ep[ii][((trialnum_per_ep[ii]>2) & (trialnum_per_ep[ii]<=maxtrials))] for ii in range(len(eps)-1)]
    licks_per_ep_trials_test = [licks_per_ep[ii][(trialnum_per_ep[ii]>maxtrials)&(trialnum_per_ep[ii]<=maxtrialsheldout)] for ii in range(len(eps)-1)]
    #prepare variables 
    # TRAIN ON OPTO EP INSTEAD
    train = dff_per_ep_trials[comp[1]]
    print(train.shape)
    if train.shape[1]>0:
        test = dff_per_ep_trials_test[comp[1]]
        train_pos = position_per_ep_trials[comp[1]]
        TrainingLabel = train_pos[TimeInterval-1:].reshape(-1,1)
        test_pos = position_per_ep_trials_test[comp[1]]
        TestLabel = test_pos[TimeInterval-1:].reshape(-1,1)

        TrainingData = create_subsequences(train,TimeInterval)
        input_size = TrainingData.shape[-1] # number of cells
        output_size = 1
        Train_dataset = CreateTimeSeriesData(TrainingData, TrainingLabel)
        Train_loader = DataLoader(dataset=Train_dataset, batch_size=batch_size, 
                    shuffle=True, drop_last = True)
        TestData = create_subsequences(test,TimeInterval)
        Test_dataset = CreateTimeSeriesData(TestData, TestLabel)
        Test_loader = DataLoader(dataset=Test_dataset, batch_size=batch_size,
                    shuffle=False, drop_last = True)
    
    return input_size, output_size, Train_loader, Test_loader, comp

In [8]:
# make models for all animals/days
num_epochs = 3000

for dd in range(len(conddf)): 
    animal = conddf.animals.values[dd]
    day = conddf.days.values[dd]
    savepth = r'Z:\lstm\models_lstm_all_cells_sameep'   
    if not os.path.exists(savepth): os.mkdir(savepth)
    testpth = glob.glob(os.path.join(savepth, f'model_dd{dd:03d}*'), recursive=True)
    if len(testpth)==0:             
        params_pth = rf"Y:\analysis\fmats\{animal}\days\{animal}_day{day:03d}_plane0_Fall.mat"
        print(params_pth)
        TimeInterval = 20 # frames
        input_size, output_size, Train_loader, Test_loader, comp = extract_data(params_pth, TimeInterval)            
        # define the model & train
        # TODO: add transfer learning
        model = LSTMModel(input_size, output_dim = output_size)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(device)
        model = model.to(device)
        criterion = nn.MSELoss()  # For regression tasks
        # criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(),
                lr=1e-5, weight_decay = 1e-9)
        # Example training loop
        l = []
        val_l = []
        try:
            for epoch in range(num_epochs):
                train_loss = 0.0
                for inputs, targets in Train_loader:
                    # Forward pass
                    inputs, targets = inputs.to(device).float(), targets.to(device).float()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    # Backward pass and optimization
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    train_loss += loss.item()
                l.append(train_loss/len(Train_loader))
                if epoch % 20 == 0:
                    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, train_loss/len(Train_loader)))
                    val_loss = 0.0
                    for inputs, targets in Test_loader:
                        # Forward pass
                        inputs, targets = inputs.to(device).float(), targets.to(device).float()
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)
                        val_loss += loss.item()
                        val_l.append(val_loss/len(Test_loader))
                    print('Validation Loss: {:.4f}'.format(val_loss/len(Test_loader)))
            # save        
            torch.save(model.state_dict(), os.path.join(savepth, f'model_dd{dd:03d}_epcompare{comp[0]}-{comp[1]}_{animal}_day{day}.pt'))
        except:
            print(f'did not run for {dd}') # in case of low # of frames, etc.
                

Y:\analysis\fmats\e218\days\e218_day048_plane0_Fall.mat
(50000, 855)
(2621, 855)
cuda:0
Epoch [1/3000], Loss: 8039.0740
Validation Loss: 7507.9399
Epoch [21/3000], Loss: 7278.5495
Validation Loss: 6801.7275
Epoch [41/3000], Loss: 6872.2941
Validation Loss: 6436.0574
Epoch [61/3000], Loss: 6728.8284
Validation Loss: 6292.3928
Epoch [81/3000], Loss: 6626.9902
Validation Loss: 6195.0124
Epoch [101/3000], Loss: 6549.9349
Validation Loss: 6113.3231
Epoch [121/3000], Loss: 6464.9042
Validation Loss: 6043.9010
Epoch [141/3000], Loss: 6421.8572
Validation Loss: 5979.6162
Epoch [161/3000], Loss: 6358.9578
Validation Loss: 5918.4672
Epoch [181/3000], Loss: 6281.4061
Validation Loss: 5859.5742
Epoch [201/3000], Loss: 6248.1535
Validation Loss: 5802.3128
Epoch [221/3000], Loss: 6195.3025
Validation Loss: 5746.4724
Epoch [241/3000], Loss: 6118.6146
Validation Loss: 5691.7929
Epoch [261/3000], Loss: 6042.3455
Validation Loss: 5638.1377
Epoch [281/3000], Loss: 6044.0350
Validation Loss: 5585.4731
Epo