In [1]:
"""
Jason Stranne
"""
import numpy as np
import os
import sys
import gc
from RP_Downstream_Trainer import DownstreamNet, Downstream_Dataset, print_class_counts, num_correct, reduce_dataset_size
from RP_Downstream_Trainer import smallest_class_len, restrict_training_size_per_class
from RP_train_all_at_once import train_end_to_end_RP_combined
sys.path.insert(0, '..')
from Stager_net_pratice import StagerNet
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
from RP_data_loader import Custom_RP_Dataset
from sklearn.model_selection import LeaveOneGroupOut, GroupKFold
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm




In [2]:
torch.cuda.empty_cache()

In [3]:
def train_different_classes(RP_train_gen, RP_val_gen, train_set, test_set, epochs, sample_list):
    outList=[]
    balanced_acc_out=[]
    for i in sample_list:
        # print(i)
        acc, balanced_acc = train_end_to_end_RP_combined(RP_train_gen, RP_val_gen, train_set, test_set, i, epochs, 3)
        outList.append(acc)
        balanced_acc_out.append(balanced_acc)
    return outList, balanced_acc_out

In [4]:
root = os.path.join("Mouse_Training_Data", "Windowed_Data", "")

datasets_list=[]
print('Loading Data')
f=open(os.path.join("training_names.txt"),'r')
lines = f.readlines()
for line in lines:
    recordName=line.strip()
    print('Processing', recordName)
    data_file=root+recordName+os.sep+recordName
    datasets_list.append(Custom_RP_Dataset(path=data_file, total_points=2000, tpos=120, tneg=300, windowSize=3, sfreq=1000))
f.close()


training_set = torch.utils.data.ConcatDataset(datasets_list)

data_len = len(training_set)
print("dataset len is", len(training_set))

train_len = int(data_len*0.8)
val_len = data_len - train_len

training_set, validation_set = torch.utils.data.random_split(training_set, [train_len, val_len])

print("one dataset is", len(datasets_list[0]))

params = {'batch_size': 16,
          'shuffle': True,
          'num_workers': 4}
max_epochs = 40
training_generator = torch.utils.data.DataLoader(training_set, **params)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)

print("len of the dataloader is:",len(training_generator))

# cuda setup if allowed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0


Loading Data
Processing MouseCKA1_030515_HCOFTS
19404000
(588, 3000, 11)
(2000,)
Processing MouseCKA1_030615_HCOFTS
18909000
(573, 3000, 11)
(2000,)
Processing MouseCKL1_062514_HCOFTS
20196000
(612, 3000, 11)
(2000,)
Processing MouseCKB9_022715_HCOFTS
20031000
(607, 3000, 11)
(2000,)
Processing MouseCKB9_022815_HCOFTS
20955000
(635, 3000, 11)
(2000,)
Processing MouseCKL7_063014_HCOFTS
19140000
(580, 3000, 11)
(2000,)
Processing MouseCKL5_063014_HCOFTS
18810000
(570, 3000, 11)
(2000,)
Processing MouseCKL5_070114_HCOFTS
18546000
(562, 3000, 11)
(2000,)
Processing MouseCKL7_070114_HCOFTS
18546000
(562, 3000, 11)
(2000,)
Processing MouseCKN1_063014_HCOFTS
18546000
(562, 3000, 11)
(2000,)
Processing MouseCKN2_070214_HCOFTS
18414000
(558, 3000, 11)
(2000,)
Processing MouseCKN3_070214_HCOFTS
18447000
(559, 3000, 11)
(2000,)
Processing MouseCKN3_070314_HCOFTS
18777000
(569, 3000, 11)
(2000,)
Processing MouseCKO1_070214_HCOFTS
18381000
(557, 3000, 11)
(2000,)
Processing MouseCKN2_070314_HCOFTS


In [None]:
root = os.path.join("Mouse_Training_Data", "Windowed_Data", "")
datasets_list=[]
print('Loading Data')
f=open(os.path.join("training_names.txt"),'r')
lines = f.readlines()
x_vals = []
y_vals = []
groups = []
index = 0
for line in lines:
    recordName=line.strip()
    print('Processing', recordName)
    data_file=root+recordName+os.sep+recordName
    d = Downstream_Dataset(path=data_file)
    x_vals.append(d.data)
    y_vals.append(d.labels)
    groups.append(np.ones(len(d.labels))*index)
    index+=1
    
f.close()


x_vals = np.vstack(x_vals)
y_vals = np.concatenate(y_vals, axis=0)
groups = np.concatenate(groups, axis=0)
print(x_vals.shape)
print(y_vals.shape)
print(groups.shape)


# logo = LeaveOneGroupOut()
kfold = GroupKFold(n_splits=5)
# logo.get_n_splits(x_vals, y_vals, groups)
kfold.get_n_splits(x_vals, y_vals, groups)

result_dict = {}
# dtype=torch.int32
for train_index, test_index in kfold.split(x_vals, y_vals, groups):
    unique = np.unique(groups[test_index])
    # group_num = groups[test_index][0]
    print("Leaving out mouse number:", unique)
    training_set = TensorDataset(torch.tensor(x_vals[train_index], dtype=torch.float), torch.tensor(y_vals[train_index], dtype=torch.long))
    test_set = TensorDataset(torch.tensor(x_vals[test_index], dtype=torch.float), torch.tensor(y_vals[test_index], dtype=torch.long))
    print(len(training_set))
    smallest_class = smallest_class_len(training_set, 3)
    num_samples=[]
    temp=1
    while temp < smallest_class:
        num_samples.append(temp)
        temp*=10
    num_samples.append(None)
    # def train_different_classes(RP_train_gen, RP_val_gen, train_set, test_set, epochs, sample_list):

    accuracy, balanced_accuracy = train_different_classes(training_generator, validation_generator, training_set, test_set, 50, num_samples)
    for num_pos, acc in zip(num_samples, accuracy):
        if num_pos not in result_dict:
            result_dict[num_pos] = []
        result_dict[num_pos].append(acc)
    print(result_dict)
    
for k in result_dict:
    result_dict[k] = np.mean(result_dict[k])
print(result_dict)



Loading Data
Processing MouseCKA1_030515_HCOFTS
removed 0 unknown entries
Processing MouseCKA1_030615_HCOFTS
removed 0 unknown entries
Processing MouseCKL1_062514_HCOFTS
removed 0 unknown entries
Processing MouseCKB9_022715_HCOFTS
removed 0 unknown entries
Processing MouseCKB9_022815_HCOFTS
removed 0 unknown entries
Processing MouseCKL7_063014_HCOFTS
removed 0 unknown entries
Processing MouseCKL5_063014_HCOFTS
removed 0 unknown entries
Processing MouseCKL5_070114_HCOFTS
removed 0 unknown entries
Processing MouseCKL7_070114_HCOFTS
removed 0 unknown entries
Processing MouseCKN1_063014_HCOFTS
removed 0 unknown entries
Processing MouseCKN2_070214_HCOFTS
removed 0 unknown entries
Processing MouseCKN3_070214_HCOFTS
removed 0 unknown entries
Processing MouseCKN3_070314_HCOFTS
removed 0 unknown entries
Processing MouseCKO1_070214_HCOFTS
removed 0 unknown entries
Processing MouseCKN2_070314_HCOFTS
removed 0 unknown entries
Processing MouseCKO1_070314_HCOFTS
removed 0 unknown entries
Processing 

5600it [05:07, 18.19it/s]
5600it [04:58, 18.79it/s]
5600it [04:53, 19.05it/s]
5600it [05:03, 18.47it/s]
5600it [04:58, 18.73it/s]
5600it [05:06, 18.30it/s]
5600it [05:06, 18.28it/s]
5600it [04:57, 18.83it/s]
5600it [04:53, 19.11it/s]
5600it [04:56, 18.89it/s]
5600it [05:00, 18.64it/s]
5600it [05:04, 18.37it/s]
5600it [05:06, 18.28it/s]


EARLY STOPPING




downstream batches: 1
pretext batches: 5600
Start Training Full


5600it [07:44, 12.06it/s]
5600it [07:52, 11.84it/s]
5600it [07:53, 11.83it/s]
5600it [07:49, 11.92it/s]
5600it [07:43, 12.08it/s]
5600it [08:11, 11.39it/s]
5600it [07:57, 11.72it/s]
5600it [07:57, 11.72it/s]
5600it [08:21, 11.17it/s]
5600it [08:17, 11.27it/s]
5600it [08:20, 11.19it/s]


EARLY STOPPING
downstream batches: 3
pretext batches: 5600
Start Training Full


5600it [12:25,  7.51it/s]
5600it [12:42,  7.35it/s]
5600it [12:42,  7.34it/s]
3504it [07:46,  8.38it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

3898it [08:49,  6.38it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

4119it [09:26,  7.56it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub

EARLY STOPPING
downstream batches: 110
pretext batches: 5600
Start Training Full


5600it [12:02,  7.75it/s]
5600it [11:44,  7.95it/s]
5600it [11:47,  7.91it/s]
5600it [11:52,  7.86it/s]
5600it [11:46,  7.93it/s]
5600it [11:40,  8.00it/s]
5600it [12:02,  7.75it/s]
5600it [11:50,  7.88it/s]
5600it [11:52,  7.85it/s]
5600it [11:47,  7.92it/s]
5600it [11:53,  7.85it/s]
5600it [11:50,  7.88it/s]
5600it [11:49,  7.89it/s]
5600it [11:44,  7.95it/s]
5600it [11:54,  7.84it/s]


EARLY STOPPING
{1: [0.7989583333333333], 10: [0.91375], 100: [0.9329166666666666], 1000: [0.9333333333333333], None: [0.9470833333333334]}
Leaving out mouse number: [ 1.  6. 11. 16. 21. 33. 38. 43. 48. 53. 54.]
18000
downstream batches: 1
pretext batches: 5600
Start Training Full


5600it [04:54, 18.99it/s]
5600it [04:27, 20.95it/s]
5600it [04:47, 19.50it/s]
5600it [04:37, 20.22it/s]
5600it [04:45, 19.60it/s]
5600it [04:50, 19.25it/s]
5600it [04:59, 18.67it/s]
5600it [04:51, 19.22it/s]
5600it [04:45, 19.64it/s]
5600it [04:44, 19.67it/s]
5600it [04:53, 19.08it/s]
5600it [04:44, 19.68it/s]
5600it [04:51, 19.20it/s]
5600it [04:48, 19.43it/s]
5600it [04:51, 19.19it/s]


EARLY STOPPING
downstream batches: 1
pretext batches: 5600
Start Training Full


5600it [07:14, 12.90it/s]
5600it [07:05, 13.16it/s]
5600it [07:11, 12.97it/s]
5600it [07:17, 12.80it/s]
5600it [07:12, 12.95it/s]
5600it [07:13, 12.91it/s]
5600it [07:11, 12.99it/s]
5600it [07:08, 13.07it/s]
5600it [07:14, 12.89it/s]
5600it [07:01, 13.27it/s]
5600it [07:03, 13.23it/s]
5600it [07:34, 12.33it/s]
5199it [07:27,  8.74it/s]

 The first holdout group gave {1: [0.7989583333333333], 10: [0.91375], 100: [0.9329166666666666], 1000: [0.9333333333333333], None: [0.9470833333333334]}

In [5]:
smallest_class = smallest_class_len(training_set)
num_samples=[]
temp=1
while temp < smallest_class:
    num_samples.append(temp)
    temp*=10
num_samples.append(None)
print(num_samples)

[1, 10, 100, 1000, None]


In [7]:
num_samples=[1000, None]
RP_2loss_vals = train_different_classes(training_generator, training_generator, validation_set, 100, num_samples)

downstream batches: 20
pretext batches: 250
Start Training Full
[0.6312960554833117]
downstream batches: 68
pretext batches: 250
Start Training Full
[0.6312960554833117, 0.693194625054183]


In [8]:
print(RP_vals)

[0.3046380580840919, 0.4250541829215431, 0.5237104464672735, 0.59913307325531, 0.6680537494581708]
