In [1]:
#Set to project root directory
from glob import glob

import pyreadr
import numpy as np
from sklearn.model_selection import LeaveOneGroupOut
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

sample_dirs = ['/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 1',
        '/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 10',
        '/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal M3']

files = []
for d in sample_dirs:
    files.extend(glob(d + "/*.Rds"))

In [12]:
df

Unnamed: 0,frame,CellID,dff,MatchID,Time,Centre position X,Centre position Y,Trial Number,Session,Animal,Reward,event
0,1,1,0.157809,450,0.000000,375.0,352.0,1,9.0,M3,right,In Start
1,1,2,0.112819,451,0.000000,375.0,352.0,1,9.0,M3,right,In Start
2,1,3,-0.038237,625,0.000000,375.0,352.0,1,9.0,M3,right,In Start
3,1,4,-0.335028,10,0.000000,375.0,352.0,1,9.0,M3,right,In Start
4,1,5,-0.269979,626,0.000000,375.0,352.0,1,9.0,M3,right,In Start
...,...,...,...,...,...,...,...,...,...,...,...,...
11393293,24822,455,0.575220,429,1242.541049,309.0,320.0,12,9.0,M3,right,In Start
11393294,24822,456,0.831388,300,1242.541049,309.0,320.0,12,9.0,M3,right,In Start
11393295,24822,457,0.400699,400,1242.541049,309.0,320.0,12,9.0,M3,right,In Start
11393296,24822,458,0.112331,353,1242.541049,309.0,320.0,12,9.0,M3,right,In Start


### TODO

* Load files from the same animal
* Use MatchID
* Train on left reward days, test on right
* Compare this to test performance on left
* How do we know this is statistically significant?

What is our hypothesis again?

That decoder performance for reward prediction decays w time while it doesn't for position, with reversal?
The OFC activity then represents the expectation of reward, not the position, and therefore we can use it to say something about how fast or slow re-learning reward is?


In [11]:
for file in files:
    result = pyreadr.read_r(file)
    df = result[None]
    print(file)
    print(df.Reward.unique())

/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 1/final_010718.Rds
['left']
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 1/final_010719.Rds
['right']
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 1/final_010721.Rds
['right']
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 1/final_010713.Rds
['left']
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 10/final_100210.Rds
['left']
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 10/final_100208.Rds
['left']
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 10/final_100207.Rds
['right']
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 10/final_100202.Rds
['right']
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal M3/final_M30821.Rds
['right']
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal M3/final_M30815.Rds
['left']
/home/blansdel/projects/schwarz/decoder/Retracked/Control A

In [5]:
results = {}

for sample_file in tqdm(files):
    
    result = pyreadr.read_r(sample_file)
    df = result[None]

    #Turn into wide table, by cell type and dff signal
    df_wide = df.pivot(index=['Trial Number', 'Centre position X', 'Centre position Y', 'frame'], columns='CellID', values='dff').reset_index()
    df_wide.columns = list(df_wide.columns[:4]) + [f'Cell_{i}' for i in df_wide.columns[4:]]
    df_wide = df_wide.dropna()
    y = df_wide[['Centre position X', 'Centre position Y']].values
    X = df_wide.iloc[:,4:].values
    groups = df_wide['Trial Number'].values
    logo = LeaveOneGroupOut()
    
    train_accs = []
    test_accs = []
    train_accs_shuffled = []
    test_accs_shuffled = []

    #For each group, do the whole split, training and evaluation 
    for train_index, test_index in logo.split(X, y, groups=groups):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        
        pca = PCA(n_components=20)
        X_train_pca = pca.fit_transform(X_train)
        X_test_pca = pca.transform(X_test)
        
        kmeans = KMeans(n_clusters=10, random_state=0).fit(y_train)
        y_train_kmeans = kmeans.predict(y_train)
        y_test_kmeans = kmeans.predict(y_test)
        
        #try logistic regression 
        clf = LogisticRegression(random_state=42)
        clf.fit(X_train_pca, y_train_kmeans)
        
        y_pred_train = clf.predict(X_train_pca)
        y_pred_test = clf.predict(X_test_pca)
        
        #Evaluate
        train_accuracy = accuracy_score(y_train_kmeans, y_pred_train)
        test_accuracy = accuracy_score(y_test_kmeans, y_pred_test)
        
        test_accs.append(test_accuracy)
        train_accs.append(train_accuracy)
 
        #try logistic regression with shuffled labels
        # clf = LogisticRegression(random_state=42)
        #Shuffle the labels
        # np.random.shuffle(y_train_kmeans)
        # clf.fit(X_train_pca, y_train_kmeans)
        
        np.random.shuffle(y_test_kmeans)
        # y_pred_train = clf.predict(X_train_pca)
        # y_pred_test = clf.predict(X_test_pca)

        #Evaluate
        # train_accuracy = accuracy_score(y_train_kmeans, y_pred_train)
        test_accuracy = accuracy_score(y_test_kmeans, y_pred_test)
        
        test_accs_shuffled.append(test_accuracy)
        # train_accs_shuffled.append(train_accuracy)
                
    print("Train accuracy: ", np.mean(train_accs))
    print("Test accuracy: ", np.mean(test_accs))
    print("Test accuracy shuffled", np.mean(test_accs_shuffled))

    results[sample_file] = {'train_accs': train_accs, 'test_accs': test_accs, 'test_accs_shuffled': test_accs_shuffled}


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.8160289145665686
Test accuracy:  0.6188128169631573
Test accuracy shuffled 0.2492719473302338


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.7631499553396961
Test accuracy:  0.4357689704871988
Test accuracy shuffled 0.17814421462282656


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.6931203739037959
Test accuracy:  0.33467227045105424
Test accuracy shuffled 0.16902796301605771


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.7295558320088181
Test accuracy:  0.455849662445553
Test accuracy shuffled 0.20243306202991015


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.7918731918108116
Test accuracy:  0.43478783491516454
Test accuracy shuffled 0.23730195818700092


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.8298437444781096
Test accuracy:  0.42273919360667983
Test accuracy shuffled 0.2452596384808755


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.8580068779193687
Test accuracy:  0.58031891929317
Test accuracy shuffled 0.2747993030090144


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.816487327562036
Test accuracy:  0.4634974568282209
Test accuracy shuffled 0.31287212389242697


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.9251913908207231
Test accuracy:  0.6517882379080044
Test accuracy shuffled 0.2861553788267299


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.8753911049458596
Test accuracy:  0.5716535825096273
Test accuracy shuffled 0.3140988246470941


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.8471980974542541
Test accuracy:  0.5618267206456635
Test accuracy shuffled 0.24728394862204253


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Train accuracy:  0.8336839728782834
Test accuracy:  0.5329612193195652
Test accuracy shuffled 0.23126781665558913





In [9]:
for file in results:
    print(file)
    print("Train accuracy: ", np.mean(results[file]['train_accs']))
    print("Test accuracy: ", np.mean(results[file]['test_accs']))
    print("Test accuracy shuffled", np.mean(results[file]['test_accs_shuffled']))

/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 1/final_010718.Rds
Train accuracy:  0.8160289145665686
Test accuracy:  0.6188128169631573
Test accuracy shuffled 0.2492719473302338
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 1/final_010719.Rds
Train accuracy:  0.7631499553396961
Test accuracy:  0.4357689704871988
Test accuracy shuffled 0.17814421462282656
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 1/final_010721.Rds
Train accuracy:  0.6931203739037959
Test accuracy:  0.33467227045105424
Test accuracy shuffled 0.16902796301605771
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 1/final_010713.Rds
Train accuracy:  0.7295558320088181
Test accuracy:  0.455849662445553
Test accuracy shuffled 0.20243306202991015
/home/blansdel/projects/schwarz/decoder/Retracked/Control Animal 10/final_100210.Rds
Train accuracy:  0.7918731918108116
Test accuracy:  0.43478783491516454
Test accuracy shuffled 0.23730195818700092
/home/bla

In [None]:
print([np.mean(v['test_accs']) for k,v in results.items()])

In [6]:
print([np.mean(v['test_accs']) for k,v in results.items()])

[0.6188128169631573, 0.4357689704871988, 0.33467227045105424, 0.455849662445553, 0.43478783491516454, 0.42273919360667983, 0.58031891929317, 0.4634974568282209, 0.6517882379080044, 0.5716535825096273, 0.5618267206456635, 0.5329612193195652]


In [7]:
print([np.mean(v['test_accs_shuffled']) for k,v in results.items()])

[0.2492719473302338, 0.17814421462282656, 0.16902796301605771, 0.20243306202991015, 0.23730195818700092, 0.2452596384808755, 0.2747993030090144, 0.31287212389242697, 0.2861553788267299, 0.3140988246470941, 0.24728394862204253, 0.23126781665558913]


In [8]:
train_accs_shuffled

[]

In [8]:
train_accs

[0.7123990055935363,
 0.7450850938967136,
 0.70530045769969,
 0.6769128704113224,
 0.7520205993848795,
 0.7450385071090048,
 0.7625089777352166,
 0.7517063007399957,
 0.6900115696104898,
 0.7523310859274136,
 0.7525283018867924,
 0.7144828855355171]

In [9]:
test_accs

[0.2959844559585492,
 0.7193877551020408,
 0.14942528735632185,
 0.5258823529411765,
 0.5310344827586206,
 0.7609649122807017,
 0.46419098143236076,
 0.5774647887323944,
 0.22398345968297725,
 0.6181434599156118,
 0.38078902229845624,
 0.2077226606538895]