In [18]:
import torch
from torch.utils.data import DataLoader
import sys
from torch.utils.data import Dataset
import numpy as np
sys.path.append('/home3/ebrahim/isr/isr_model_review/BP06/')
from utils import cosine_sim
from datasets import OneHotLetters, OneHotLetters_test
from run_test_trials import run_test_trials
from simulation_one import simulation_one
from RNNcell import RNN_one_layer
import wandb
device = torch.device("cpu")
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import statsmodels
import pandas as pd
from statsmodels.stats.anova import AnovaRM
import scipy.stats as stats
from scipy.stats import ttest_rel
import pickle
%matplotlib inline
from basic_model import load_model_basic

In [14]:
class creating_y_hat(load_model_basic):

    def __init__(self, base):

        super().__init__(base)

    def y_hat_transpositions(self):

        transpositions_test = self.sim_one.transpositions_test
        ll = transpositions_test.shape[0]

        for i in range(ll):
            hist, bins = np.histogram(transpositions_test[i, :], bins=np.arange(1,ll+2,1), density=True)

            print(np.round(hist,2))

In [15]:
# information for loading model
base = '/home3/ebrahim/isr/isr_model_review/BP06/'
modelPATH_arr = ['stateful_True_noise_0.0_opt_Adam_no_grad_num_letters_26_nl_test_26_True_']
wandB_arr = ["ehrdds6f", "ed8798k6", "4ko6rcqt", "9s6t7l9o", "iwxesev8"]

In [16]:
create_yhat = creating_y_hat(base)
create_yhat.load_model('/ebrahimfeghhi/BP06/' + wandB_arr[0], modelPATH_arr[0] + '1')
create_yhat.run_simulation_one(6)

In [17]:
create_yhat.y_hat_transpositions()

[0.94 0.06 0.   0.   0.   0.  ]
[0.06 0.86 0.08 0.   0.   0.  ]
[0.   0.07 0.84 0.08 0.   0.  ]
[0.   0.01 0.07 0.82 0.09 0.  ]
[0.   0.   0.01 0.07 0.77 0.13]
[0.   0.   0.01 0.02 0.13 0.84]


In [28]:
class OneHotLetters_EB(Dataset):
  
    def __init__(self, max_length, num_cycles, test_path, num_classes, batch_size=1, num_letters=26, 
                delay_start=0, delay_middle=0, double_trial=False):

        """ Initialize class to generate letters, represented as one hot vectors in 26 dimensional space. 
        :param int max_length: maximum number of letters 
        :param int num_cycles: number of cycles (1 cycle = set of lists of length 1,...,max_length)
        :param int num_classes: number of classes 
        :param int batch_size: size of each batch
        :param int num_letters: number of letters in vocabulary 
        :param str test_path: path to test_list (type should be set for quick look up)
        :param float repeat_prob: fraction of trials to sample with repetition
        :param int delay_start: how much delay before trial starts 
        :param int delay_middle: how much delay between retrieval and recall 
        :param bool double_trial: if true, one list contains two trials 
        """ 

        self.max_length = max_length
        self.num_letters = num_letters
        self.num_cycles = num_cycles
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.storage = []
        self.delay_start = delay_start
        self.delay_middle = delay_middle
        self.list_length = 6 
        self.double_trial = double_trial

        with open(test_path, 'rb') as f:
            self.test_data = pickle.load(f)
    
    def __len__(self):

        return self.num_cycles * self.max_length

    def construct_trial(self): 

        '''
        Generates a training example
        '''

        rng = np.random.default_rng()

        delay_start = np.ones(self.delay_start) * (self.num_letters+1)
        delay_middle = np.ones(self.delay_middle) * (self.num_letters+1)
     
        letters = rng.choice(self.num_letters, self.list_length, replace=False) 
            
        # ensure train and test are not overlapping
        if self.list_length > 1: 
            while tuple(letters) in self.test_data[str(self.list_length)]:
                rng = np.random.default_rng()
                letters = rng.choice(self.num_letters, self.list_length, replace=False) 
                    
        recall_cue = np.ones(self.list_length+1) * self.num_letters 

        X = torch.nn.functional.one_hot(torch.from_numpy(
            np.hstack((delay_start, letters, delay_middle, recall_cue))).to(torch.long),
        num_classes=self.num_classes)

        # output is letters during letter presentation
        # letters again after recall cue
        # and finally end of list cue 
        y = torch.from_numpy(np.hstack((delay_start, letters, delay_middle,
        letters, self.num_letters))).to(torch.long)
        self.letters_to_prob_dist(letters)
        y = torch.nn.functional.one_hot(y, num_classes=self.num_classes)

        return X, y

    def letters_to_prob_dist(self, letters):

        '''
        Converts letters to a probability distribution, specified by human error patterns. 
        '''

        print(letters)
    def __getitem__(self, idx):

        # every new batch, increment the list_length 
        # once the list length exceeds the max length, return back to 1 
        if idx % self.batch_size == 0: 
            self.list_length += 1
            if self.list_length > self.max_length:
                self.list_length = 1

        X, y = self.construct_trial()

        if self.double_trial: 

            rng = np.random.default_rng()
            uniform_0_1 = rng.random()

            if uniform_0_1 < self.storage_frac:
                # the second list is a fixed delay period followed by recalling the previous list 
                # these are termed storage trials
                list_recall = self.delay_start + self.list_length + self.delay_middle
                X2, y2 = self.recall_list_from_storage(X[list_recall:], y[list_recall:])
            else:
                # the second list is a normal list, i.e. presented letters followed by recalling those letters
                X2, y2 = self.construct_trial()

            X = torch.cat((X, X2),axis=0)
            y = torch.cat((y, y2),axis=0)

        return X.to(torch.float32), y.to(torch.float32) 


In [29]:
eb = OneHotLetters_EB(9, 1, '/home3/ebrahim/isr/isr_model_review/BP06/test_set/test_lists_cleaned_26.pkl', 
                                num_classes=27)
X, y = eb.construct_trial()

[22  3  5 11  2 15]
