In [None]:
'''Author: Masoud Rahimi, June 2023'''

In [1]:
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
import random

## late fusion

In [82]:
# load parrot data
dict_parrot = np.load('parrot_data.npy', allow_pickle=True).item()
df_parrot = pd.DataFrame.from_dict(dict_parrot,).T
df_parrot = df_parrot.reset_index()
df_parrot.rename(columns={'index':'id'}, inplace=True)


# load lightwheel data
df_lightwheel = pd.read_csv('lightwheel_data.csv')


# load car data
df_car = pd.read_excel('car_data.xlsx')
df_car = df_car.T.reset_index()
df_car.rename(columns={'index':'id', 7: 'label'}, inplace=True)


# get subjects id
sbjcts_id = list(df_lightwheel['id'])
df_parrot

Unnamed: 0,id,id_vector,main_vector,move_vector,label
0,100,"[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 929.0, 704.0, 10.0, 5.0, 384.0, 414.0, 4...","[0.0, 157.0, 161.0, 8.0, 3.0, 78.0, 100.0, 13....",1
1,102,"[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 245.0, 9.0, 3.0...","[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 91.0, 4.0, 3.0,...",0
2,106,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[1.0, 1177.0, 358.0, 50.0, 16.0, 255.0, 591.0,...","[1.0, 157.0, 115.0, 14.0, 5.0, 111.0, 195.0, 2...",1
3,108,"[0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, ...","[0.0, 2415.0, 211.0, 0.0, 0.0, 47.0, 76.0, 9.0...","[0.0, 156.0, 84.0, 0.0, 0.0, 42.0, 48.0, 9.0, ...",0
4,110,"[0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, ...","[0.0, 51.0, 198.0, 0.0, 0.0, 1685.0, 144.0, 0....","[0.0, 24.0, 46.0, 0.0, 0.0, 112.0, 47.0, 0.0, ...",1
5,112,"[0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, ...","[0.0, 0.0, 596.0, 52.0, 0.0, 169.0, 420.0, 35....","[0.0, 0.0, 121.0, 26.0, 0.0, 93.0, 93.0, 18.0,...",1
6,114,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, ...","[0.0, 33.0, 0.0, 0.0, 0.0, 1816.0, 925.0, 10.0...","[0.0, 15.0, 0.0, 0.0, 0.0, 52.0, 23.0, 5.0, 0....",1
7,116,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, ...","[0.0, 54.0, 0.0, 0.0, 0.0, 1169.0, 106.0, 0.0,...","[0.0, 4.0, 0.0, 0.0, 0.0, 30.0, 15.0, 0.0, 0.0...",1
8,118,"[0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, ...","[0.0, 10.0, 36.0, 0.0, 0.0, 23.0, 46.0, 0.0, 4...","[0.0, 8.0, 26.0, 0.0, 0.0, 6.0, 31.0, 0.0, 55....",1
9,120,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 645.0...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 152.0...",0


In [15]:
class ParrotClassifier:
    def __init__(self, df):
        ''' It classifies ASD and non-ASD subjects
        input:
            df: dataframe containing all data'''
        self.df = df

    def train(self, id_train):
        ''' each subjects has three vector. It calculate the mean of each vector for ASD and non-ASD subjects regarding the training data which is determined by id_train.
        input:
            id_train[list]: it determines the training data. As you know each row of data has an id.
        '''
        assert type(id_train) == list, 'id_train should be a list'

        # get the training data
        self.df_train = self.df[self.df['id'].isin(id_train)]

        # calculate the mean value of each feature in each class for the training data
        self.id_vector_mean_ASD = self.df_train[self.df_train['label'] == '1']['id_vector'].mean()
        self.main_vector_mean_ASD = self.df_train[self.df_train['label'] == '1']['main_vector'].mean()
        self.move_vector_mean_ASD = self.df_train[self.df_train['label'] == '1']['move_vector'].mean()

        self.id_vector_mean_non_ASD = self.df_train[self.df_train['label'] == '0']['id_vector'].mean()
        self.main_vector_mean_non_ASD = self.df_train[self.df_train['label'] == '0']['main_vector'].mean()
        self.move_vector_mean_non_ASD = self.df_train[self.df_train['label'] == '0']['move_vector'].mean()
    
    def predict(self, id_test):
        self.df_test = self.df[self.df['id'].isin(id_test)]

        #TODO: calculate the similarity between 3 matrices of each data point with the mean of ASD and non-ASD
        #TODO: decision based on max pooling


In [None]:
class CarClassifier:
    def __init__(self, ):
        pass

In [None]:
class LightWheelClassifier:
    def __init__(self, ):
        pass

### k-Fold cross-validation

In [84]:
def k_fold(sbjcts_id, k):
    ''' It gets a list of subjects' id and provides k sets of (id_train, id_test) data.
    input
        sbjcts_id[list]: it contains id of the subjects
        k: number of folds in Kfold. 
    output
        k_dataset[tuple]: it contains k tuples like this: (id_train, id_test)'''
    
    # determining the number of data points in the test and train sets
    n_test = int(len(sbjcts_id)//k)
    n_train = int(len(sbjcts_id) - n_test)
    k_dataset = []
    lower = 0
    for _ in range(k):
        id_test = sbjcts_id[lower:lower+n_test]
        id_train = [o for o in sbjcts_id if o not in id_test]
        k_dataset.append((tuple(id_train), tuple(id_test)))
        lower += n_test
    return k_dataset

# test the function output
sbjcts_id = list(df_lightwheel['id'])
for cnt, i in enumerate(k_fold(sbjcts_id, 4)):
    print(f'fold {cnt} of data: (id_train, id_test) -> {i}')

fold 0 of data: (id_train, id_test) -> ((100, 102, 106, 108, 110, 112, 120, 121, 114, 116, 118, 124), (87, 89, 94, 95))
fold 1 of data: (id_train, id_test) -> ((87, 89, 94, 95, 110, 112, 120, 121, 114, 116, 118, 124), (100, 102, 106, 108))
fold 2 of data: (id_train, id_test) -> ((87, 89, 94, 95, 100, 102, 106, 108, 114, 116, 118, 124), (110, 112, 120, 121))
fold 3 of data: (id_train, id_test) -> ((87, 89, 94, 95, 100, 102, 106, 108, 110, 112, 120, 121), (114, 116, 118, 124))


In [120]:
dict_parrot['121']['label'] = '0'

In [122]:
np.save('parrot_data.npy', dict_parrot)

In [121]:
pd.DataFrame.from_dict(dict_parrot).T

Unnamed: 0,id_vector,main_vector,move_vector,label
100,"[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 929.0, 704.0, 10.0, 5.0, 384.0, 414.0, 4...","[0.0, 157.0, 161.0, 8.0, 3.0, 78.0, 100.0, 13....",1
102,"[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 245.0, 9.0, 3.0...","[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 91.0, 4.0, 3.0,...",0
106,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[1.0, 1177.0, 358.0, 50.0, 16.0, 255.0, 591.0,...","[1.0, 157.0, 115.0, 14.0, 5.0, 111.0, 195.0, 2...",1
108,"[0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, ...","[0.0, 2415.0, 211.0, 0.0, 0.0, 47.0, 76.0, 9.0...","[0.0, 156.0, 84.0, 0.0, 0.0, 42.0, 48.0, 9.0, ...",0
110,"[0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, ...","[0.0, 51.0, 198.0, 0.0, 0.0, 1685.0, 144.0, 0....","[0.0, 24.0, 46.0, 0.0, 0.0, 112.0, 47.0, 0.0, ...",1
112,"[0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, ...","[0.0, 0.0, 596.0, 52.0, 0.0, 169.0, 420.0, 35....","[0.0, 0.0, 121.0, 26.0, 0.0, 93.0, 93.0, 18.0,...",1
114,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, ...","[0.0, 33.0, 0.0, 0.0, 0.0, 1816.0, 925.0, 10.0...","[0.0, 15.0, 0.0, 0.0, 0.0, 52.0, 23.0, 5.0, 0....",1
116,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, ...","[0.0, 54.0, 0.0, 0.0, 0.0, 1169.0, 106.0, 0.0,...","[0.0, 4.0, 0.0, 0.0, 0.0, 30.0, 15.0, 0.0, 0.0...",1
118,"[0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, ...","[0.0, 10.0, 36.0, 0.0, 0.0, 23.0, 46.0, 0.0, 4...","[0.0, 8.0, 26.0, 0.0, 0.0, 6.0, 31.0, 0.0, 55....",1
120,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 645.0...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 152.0...",0


### Stratified k-Fold cross-validation

### Leave-p-out cross-validation