# Unify the number of samples across all mice

In [2]:
from typing import Tuple

import seaborn as sns
import matplotlib.pyplot as plt

from param import *
from modules.dataloader import UniformSegmentDataset
from modules.utils.util import segment

data_list = ParamDir().data_list

## Find the mouse with the least samples

In [2]:
data_amout_all = []
for i, data_dir in enumerate(data_list):
    segment_len = []
    data_name = str(data_dir).split('/')[-1]
    dataset = UniformSegmentDataset(data_dir, ParamData().mobility, ParamData().shuffle, ParamData().random_state)
    (X_train, y_train), (X_test, y_test) = dataset.load_all_data(ParamData().window_size, ParamData().K, ParamData().train_ratio)

    data_amout_all.append([len(X_train), data_name])
    

In [3]:
data_amout_all

[[124, '091317 OF CaMKII HKO M19-n1'],
 [108, '092217 OF CaMKII HKO M30-n1'],
 [56, 'M45_042718_OF'],
 [92, '091317 OF CaMKII HKO M20-n1'],
 [96, 'M46_042718_OF'],
 [100, 'CK_KO_RN1_OF'],
 [104, 'CK_WT_RN3_OF'],
 [84, '090817 OF CaMKII HKO M22-n1'],
 [88, '092217 OF CaMKII WT M29-n1'],
 [108, 'M44_042718_OF'],
 [84, '092717 OF SERT WT M32-n1'],
 [100, '081117 OF B6J M27-n1']]

In [4]:
base_mouse_name = 'M45_042718_OF'
data_dir = ParamDir().DATA_ROOT / base_mouse_name
base_dataset = UniformSegmentDataset(data_dir, ParamData().mobility, ParamData().shuffle, ParamData().random_state)
(base_X_train, base_y_train), (_, base_y_test) = base_dataset.load_all_data(ParamData().window_size, ParamData().K, ParamData().train_ratio)


In [5]:
print(f"y_train: {np.unique(base_y_train, return_counts=True)}")
print(f"y_test: {np.unique(base_y_test, return_counts=True)}")

y_train: (array(['1', '2', '3', '4'], dtype='<U1'), array([14, 14, 14, 14]))
y_test: (array(['1', '2', '3', '4'], dtype='<U1'), array([2, 2, 2, 2]))


## downsample other mice
only downsample the training set

In [8]:
from sklearn.utils import resample

data_dir = data_list[1]
base_dataset = UniformSegmentDataset(data_dir, ParamData().mobility, ParamData().shuffle, ParamData().random_state)
(X_train, y_train), (_, y_test) = base_dataset.load_all_data(ParamData().window_size, ParamData().K, ParamData().train_ratio)

print(f"y_train before: {np.unique(y_train, return_counts=True)}")

X, y = X_train, y_train
classes = np.unique(y_train)

X_new = []
y_new = []

for c in classes:
    X_tmp, y_tmp = resample(X[y==c], y[y==c], n_samples=14)
    X_new.append(X_tmp)
    y_new.append(y_tmp)

print(f"y_train after: {np.unique(y_new, return_counts=True)}")


y_train before: (array(['1', '2', '3', '4'], dtype='<U1'), array([27, 27, 27, 27]))
y_train after: (array(['1', '2', '3', '4'], dtype='<U1'), array([14, 14, 14, 14]))
