# Benchmark Construction

This notebook organizes the standard benchmark of our `CAUEEG` dataset using the previously generated signal, annotation, and event files.

-----

## Configurations

In [1]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

C:\Users\Minjae\Desktop\EEG_Project


In [2]:
# Load some packages
import os
import glob
import json
import pprint

import numpy as np
import random

from tqdm.auto import tqdm

# custom package
from datasets.caueeg_dataset import *
from datasets.pipeline import *

In [None]:
# Data file path
data_path = r'local/dataset/02_Curated_Data_220720_seg_10s/'

In [None]:
anno_path = os.path.join(data_path, 'annotation.json')
with open(anno_path, 'r') as json_file:
    annotation = json.load(json_file)

pprint.pprint({k: (v if k != 'data' else v[:5]) for (k, v) in annotation.items()}, width=250)

## Helper functions

In [None]:
def shuffle_splitted_metadata(splitted_metadata, class_label_to_name, ratios, seed=None, verbose=False):
    # random seed
    if seed is not None:
        random.seed(seed)
    else:
        random.seed()

    metadata_train = []
    metadata_val = []
    metadata_test = []

    for split in splitted_metadata:
        random.shuffle(split)

        n1 = round(len(split) * ratios[0])
        n2 = n1 + round(len(split) * ratios[1])

        metadata_train.extend(split[:n1])
        metadata_val.extend(split[n1:n2])
        metadata_test.extend(split[n2:])

    random.shuffle(metadata_train)
    random.shuffle(metadata_val)
    random.shuffle(metadata_test)

    if verbose:
        train_class_dist = [np.sum([1 for m in metadata_train if m['class_label'] == i])
                            for i in range(len(class_label_to_name))]

        val_class_dist = [np.sum([1 for m in metadata_val if m['class_label'] == i])
                          for i in range(len(class_label_to_name))]

        test_class_dist = [np.sum([1 for m in metadata_test if m['class_label'] == i])
                           for i in range(len(class_label_to_name))]

        print(f'<{"Train":^15}> data label distribution\t:', train_class_dist, '=', np.sum(train_class_dist))
        print(f'<{"Validation":^15}> data label distribution\t:', val_class_dist, '=', np.sum(val_class_dist))
        print(f'<{"Test":^15}> data label distribution\t:', test_class_dist, '=', np.sum(test_class_dist))

    # restore random seed (stochastic)
    random.seed()

    return metadata_train, metadata_val, metadata_test

In [None]:
def shuffle_splitted_metadata_with_initial_train(splitted_metadata, class_label_to_name, ratios, initial_train, seed=None, verbose=False):
    # random seed
    if seed is not None:
        random.seed(seed)
    else:
        random.seed()

    metadata_train = []
    metadata_val = []
    metadata_test = []

    for i, split in enumerate(splitted_metadata):
        metadata_train.extend([s for s in split if s['serial'] in [m['serial'] for m in initial_train[i]]])
        split_rest = [s for s in split if s['serial'] not in [m['serial'] for m in metadata_train]]
        random.shuffle(split_rest)
        
        n1 = round(len(split) * ratios[0]) - len(initial_train[i])
        n2 = n1 + round(len(split) * ratios[1])
        
        metadata_train.extend(split_rest[:n1])
        metadata_val.extend(split_rest[n1:n2])
        metadata_test.extend(split_rest[n2:])

    random.shuffle(metadata_train)
    random.shuffle(metadata_val)
    random.shuffle(metadata_test)

    if verbose:
        train_class_dist = [np.sum([1 for m in metadata_train if m['class_label'] == i])
                            for i in range(len(class_label_to_name))]

        val_class_dist = [np.sum([1 for m in metadata_val if m['class_label'] == i])
                          for i in range(len(class_label_to_name))]

        test_class_dist = [np.sum([1 for m in metadata_test if m['class_label'] == i])
                           for i in range(len(class_label_to_name))]

        print(f'<{"Train":^15}> data label distribution\t:', train_class_dist, '=', np.sum(train_class_dist))
        print(f'<{"Validation":^15}> data label distribution\t:', val_class_dist, '=', np.sum(val_class_dist))
        print(f'<{"Test":^15}> data label distribution\t:', test_class_dist, '=', np.sum(test_class_dist))

    # restore random seed (stochastic)
    random.seed()

    return metadata_train, metadata_val, metadata_test

-----

## Main Task 2: Classification of Three Symptoms (Normal, MCI, Dementia).

#### Define the target diagnoses and split them by their symptoms

In [None]:
diagnosis_filter = [
    # Normal
    {'name': 'Normal',
     'include': ['normal'], 
     'exclude': []},
    # Non-vascular MCI
    {'name': 'MCI',
     'include': ['mci'], 
     'exclude': []},
    # Non-vascular dementia
    {'name': 'Dementia',
     'include': ['dementia'], 
     'exclude': []},
]

class_label_to_name = [d_f['name'] for d_f in diagnosis_filter]
print('class_label_to_name:', class_label_to_name)

class_name_to_label = {d_f['name']: i for i, d_f in enumerate(diagnosis_filter)}
print('class_name_to_label:', class_name_to_label)

In [None]:
# Split the filtered dataset
splitted_metadata = [[] for _ in diagnosis_filter]

for m in annotation['data']:
    symptom = m['symptom']
    for c, f in enumerate(diagnosis_filter):
        inc = set(f['include']) & set(symptom) == set(f['include'])
        # inc = len(set(f['include']) & set(label)) > 0
        exc = len(set(f['exclude']) & set(symptom)) == 0
        if inc and exc:
            m['class_name'] = f['name']
            m['class_label'] = c
            splitted_metadata[c].append(m)
            break

for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        raise ValueError(f'(Warning) Split group {i} has no data.')
    print(f'- There are {len(split):} data belonging to {split[0]["class_name"]}')

#### Shuffle the divided data

In [None]:
ratios = np.array([8, 1, 1])
ratios = ratios / ratios.sum()
print('Train, validation, test sets ratios:', ratios)

In [None]:
metadata_train, metadata_val, metadata_test = shuffle_splitted_metadata(splitted_metadata, 
                                                                        class_label_to_name, 
                                                                        ratios, 
                                                                        seed=None, 
                                                                        verbose=True)

#### Save the dataset as JSON file

In [None]:
task_dict = dict()

task_dict['task_name'] = 'CAUEEG-task2 benchmark'
task_dict['task_description'] = 'Classification of [Normal], [MCI], and [Dementia] symptoms.'
task_dict['class_label_to_name'] = class_label_to_name
task_dict['class_name_to_label'] = class_name_to_label

task_dict['train_split'] = metadata_train
task_dict['validation_split'] = metadata_val
task_dict['test_split'] = metadata_test

print('{')
for k, v in task_dict.items():
    print(f'\t{k}:')
    if isinstance(v, list) and len(v) > 3:
        print(f'\t\t{v[0]}')
        print(f'\t\t{v[1]}')
        print(f'\t\t{v[2]}')
        print(f'\t\t.')
        print(f'\t\t.')
        print(f'\t\t.')
        print(f'\t\t{v[-1]}')
    else:
        print(f'\t\t{v}')
    print()
print('}')

with open(os.path.join(data_path, 'task2.json'), 'w') as json_file:
    json.dump(task_dict, json_file, indent=4)
    print('task2.json file is saved.')

---

## Task 1: Classification of Normal and Abnormal Symptoms

#### Define the target diagnoses and split them by their symptoms

In [None]:
diagnosis_filter = [
    # Normal
    {'name': 'Normal',
     'include': ['normal'], 
     'exclude': []},
    # Abnormal
    {'name': 'Abnormal',
     'include': [], 
     'exclude': ['normal']},
]

class_label_to_name = [d_f['name'] for d_f in diagnosis_filter]
print('class_label_to_name:', class_label_to_name)

class_name_to_label = {d_f['name']: i for i, d_f in enumerate(diagnosis_filter)}
print('class_name_to_label:', class_name_to_label)

In [None]:
# Split the filtered dataset
splitted_metadata = [[] for _ in diagnosis_filter]

for m in annotation['data']:
    symptom = m['symptom']
    
    # ignore data with the unknown label 
    if len(symptom) == 0:
        continue
    
    for c, f in enumerate(diagnosis_filter):
        inc = set(f['include']) & set(symptom) == set(f['include'])
        # inc = len(set(f['include']) & set(label)) > 0
        exc = len(set(f['exclude']) & set(symptom)) == 0
        if inc and exc:
            m['class_name'] = f['name']
            m['class_label'] = c
            splitted_metadata[c].append(m)
            break

for i, split in enumerate(splitted_metadata):
    if len(split) == 0:
        raise ValueError(f'(Warning) Split group {i} has no data.')
    print(f'- There are {len(split):} data belonging to {split[0]["class_name"]}')

#### Shuffle the divided data

In [None]:
ratios = np.array([8, 1, 1])
ratios = ratios / ratios.sum()
print('Train, validation, test sets ratios:', ratios)

#### Consider `Task 2` training split to be also `Task 1` training split preferentially.

In [None]:
# metadata_train, metadata_val, metadata_test = shuffle_splitted_metadata(splitted_metadata, 
#                                                                         class_label_to_name, 
#                                                                         ratios, 
#                                                                         seed=None, 
#                                                                         verbose=True)

with open(os.path.join(data_path, 'task2.json')) as json_file:
    task2_dict = json.load(json_file)
    
task2_normals = [m for m in task2_dict['train_split'] if m['class_label'] == 0]
task2_abnormals = [m for m in task2_dict['train_split'] if m['class_label'] > 0]
    
print('Task2  -  Normal:', len(task2_normals), ' / Abnormal:', len(task2_abnormals))
print()
initial_train = [task2_normals, task2_abnormals]

metadata_train, metadata_val, metadata_test = shuffle_splitted_metadata_with_initial_train(splitted_metadata, 
                                                                                           class_label_to_name, 
                                                                                           ratios, 
                                                                                           initial_train,
                                                                                           seed=None, 
                                                                                           verbose=True)

#### Save the dataset as JSON file

In [None]:
task_dict = dict()

task_dict['task_name'] = 'CAUEEG-task1 benchmark'
task_dict['task_description'] = 'Classification of [Normal] and [Abnormal] symptoms'
task_dict['class_label_to_name'] = class_label_to_name
task_dict['class_name_to_label'] = class_name_to_label

task_dict['train_split'] = metadata_train
task_dict['validation_split'] = metadata_val
task_dict['test_split'] = metadata_test

print('{')
for k, v in task_dict.items():
    print(f'\t{k}:')
    if isinstance(v, list) and len(v) > 3:
        print(f'\t\t{v[0]}')
        print(f'\t\t{v[1]}')
        print(f'\t\t{v[2]}')
        print(f'\t\t.')
        print(f'\t\t.')
        print(f'\t\t.')
        print(f'\t\t{v[-1]}')
    else:
        print(f'\t\t{v}')
    print()
print('}')

with open(os.path.join(data_path, 'task1.json'), 'w') as json_file:
    json.dump(task_dict, json_file, indent=4)
    print('task1.json file is saved.')

In [None]:
with open(os.path.join(data_path, 'task1.json')) as json_file:
    task1_dict = json.load(json_file)
    
with open(os.path.join(data_path, 'task2.json')) as json_file:
    task2_dict = json.load(json_file)

# sanity check 1
task1_train_serials = [m1['serial'] for m1 in task1_dict['train_split']]
task2_train_serials = [m2['serial'] for m2 in task2_dict['train_split']]
print(len(task1_train_serials), len(task2_train_serials))

for serial2 in task2_train_serials:
    if serial2 not in task1_train_serials:
        print('NO' * 5)
        
# sanity check 2
for split in ['train_split', 'validation_split', 'test_split']:
    temp_dict = {'set': {}, 'counter': {}}
    for m1 in task1_dict[split]:
        cl = m1['class_label']
        temp_dict['counter'][cl] = temp_dict['counter'].get(cl, 0) + 1
        temp_dict['set'][cl] = temp_dict['set'].get(cl, set(m1['symptom']))
        temp_dict['set'][cl].update(m1['symptom'])
        
    print(split, ':', )
    pprint.pprint(temp_dict)
    print()
    print()