In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pandas as pd
from collections import OrderedDict
import torch
torch.manual_seed(0)
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from tqdm import tqdm
from abcd.local.paths import output_path
from abcd.data.read_data import get_subjects_events_sf, subject_cols_to_events, add_event_vars
import abcd.data.VARS as VARS
from abcd.data.define_splits import SITES, save_restore_sex_fmri_splits
from abcd.data.divide_with_splits import divide_events_by_splits
from abcd.data.var_tailoring.normalization import normalize_var
from abcd.data.pytorch.get_dataset import PandasDataset
from abcd.training.ClassifierTrainer import ClassifierTrainer
from abcd.local.paths import core_path, output_path
import abcd.data.VARS as VARS
from abcd.exp.Experiment import Experiment
import abcd.utils.io as io
import importlib
import sys
from matplotlib import pyplot as plt
sys.stdout = sys.__stdout__

In [3]:
config = {'target_col': 'nihtbx_fluidcomp_fc',
          'features': ['fmri', 'smri'],
          'model': ['abcd.models.classification.FullyConnected', 'FullyConnected3'],
          'lr': 1e-3,
          'batch_size': 64,
          'nr_epochs': 150}

#todo: make new experiment when all data works
exp = Experiment(name='nc_classification_experiment', config=config)

In [4]:
# Fetch subjects and events
subjects_df, events_df = get_subjects_events_sf()
print("There are {} subjects and {} visits with imaging".format(len(subjects_df), len(events_df)))

In [5]:
events_df

Unnamed: 0,src_subject_id,interview_date,eventname,interview_age,rsfmri_c_ngd_ad_ngd_ad,rsfmri_c_ngd_ad_ngd_cgc,rsfmri_c_ngd_ad_ngd_ca,rsfmri_c_ngd_ad_ngd_dt,rsfmri_c_ngd_ad_ngd_dla,rsfmri_c_ngd_ad_ngd_fo,...,smri_vol_cdk_suplrh,smri_vol_cdk_sutmrh,smri_vol_cdk_smrh,smri_vol_cdk_frpolerh,smri_vol_cdk_tmpolerh,smri_vol_cdk_trvtmrh,smri_vol_cdk_insularh,smri_vol_cdk_meanlh,smri_vol_cdk_meanrh,smri_vol_cdk_mean
0,NDAR_INV003RTV85,10/01/2018,baseline_year_1_arm_1,131.0,0.471330,0.256267,-0.076960,-0.116451,0.022202,-0.036302,...,14700.0,15562.0,12023.0,1281.0,2371.0,1318.0,7404.0,280168.0,279541.0,559709.0
1,NDAR_INV005V6D2C,04/22/2018,baseline_year_1_arm_1,121.0,0.279435,0.116256,0.063664,-0.024781,-0.000840,-0.023421,...,14581.0,11454.0,11180.0,1400.0,2202.0,1072.0,6572.0,256557.0,258378.0,514935.0
2,NDAR_INV00CY2MDM,08/22/2017,baseline_year_1_arm_1,130.0,0.395300,0.180940,-0.142808,-0.073424,-0.049414,-0.144481,...,16278.0,14056.0,12663.0,1486.0,2850.0,1217.0,7373.0,274090.0,278845.0,552935.0
3,NDAR_INV00CY2MDM,06/15/2019,2_year_follow_up_y_arm_1,152.0,0.343300,0.192000,-0.083796,-0.082260,0.007287,-0.082282,...,16495.0,14395.0,12959.0,1398.0,2820.0,1218.0,7647.0,277827.0,280937.0,558764.0
4,NDAR_INV00CY2MDM,11/21/2021,4_year_follow_up_y_arm_1,181.0,0.491863,0.231676,-0.133458,-0.175459,-0.017222,-0.189228,...,15917.0,14025.0,12351.0,1405.0,2707.0,1202.0,7451.0,270956.0,274105.0,545061.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19322,NDAR_INVZZZ2ALR6,06/29/2019,2_year_follow_up_y_arm_1,145.0,0.332423,0.207600,0.050303,-0.051760,-0.061305,-0.086048,...,15256.0,15773.0,14387.0,1613.0,2624.0,1244.0,6635.0,311508.0,312506.0,624014.0
19323,NDAR_INVZZZNB0XC,01/03/2017,baseline_year_1_arm_1,108.0,0.270290,0.147113,-0.111270,-0.093982,-0.005676,-0.083171,...,16489.0,15759.0,12061.0,1358.0,2809.0,1080.0,7303.0,290799.0,292564.0,583363.0
19324,NDAR_INVZZZNB0XC,02/01/2021,4_year_follow_up_y_arm_1,157.0,0.363400,0.090317,-0.231313,-0.150431,-0.010488,-0.140455,...,15751.0,15053.0,11213.0,1208.0,2782.0,1027.0,6965.0,280478.0,282328.0,562806.0
19325,NDAR_INVZZZP87KR,08/02/2019,2_year_follow_up_y_arm_1,150.0,0.411610,0.180035,-0.008353,-0.053541,-0.019124,-0.126567,...,13723.0,16167.0,11140.0,1641.0,2536.0,1364.0,7166.0,283042.0,283994.0,567036.0


In [6]:
subjects_df

Unnamed: 0,src_subject_id,site_id_l,rel_family_id,kbi_sex_assigned_at_birth,race_ethnicity
0,NDAR_INV003RTV85,site06,8781.0,2.0,1.0
1,NDAR_INV005V6D2C,site10,10210.0,2.0,3.0
2,NDAR_INV00CY2MDM,site20,5355.0,1.0,1.0
3,NDAR_INV00HEV6HB,site12,2257.0,1.0,2.0
4,NDAR_INV00J52GPG,site17,4151.0,1.0,5.0
...,...,...,...,...,...
9744,NDAR_INVZZNX6W2P,site14,3797.0,1.0,1.0
9745,NDAR_INVZZPKBDAC,site12,2445.0,2.0,1.0
9746,NDAR_INVZZZ2ALR6,site08,7032.0,2.0,5.0
9747,NDAR_INVZZZNB0XC,site03,6676.0,2.0,3.0


In [7]:
# Add the target to the events df, if not there
target_col = config['target_col']
if target_col not in events_df.columns:
    events_df = add_event_vars(events_df, "/Users/brentju/Desktop/abcd/nc_y_nihtb.csv", [target_col])
    
    
# Add scanner data to table if not exists
scanner_cols = ['mri_info_manufacturer', 'mri_info_manufacturersmn']
for col in scanner_cols:
    if col not in events_df.columns:
        mri_df = io.load_df("/Users/brentju/Desktop/abcd/mri_y_adm_info.csv", sep =',', cols=["src_subject_id", "eventname"]+scanner_cols)
        events_df = pd.merge(events_df, mri_df, on=["src_subject_id", "eventname"])  
        break


# Add data from family, environment
family_income_col = ['demo_comb_income_v2']
# Conflict Subscale from the Family Environment Scale Sum of Youth Report
environment_cols = ['fes_y_ss_fc']

income_df = io.load_df(os.path.join(core_path, 'abcd-general', 'abcd_p_demo.csv'), sep =',', cols=["src_subject_id", "eventname"]+family_income_col)
env_df = io.load_df(os.path.join(core_path, 'culture-environment', 'ce_y_fes.csv'), sep =',', cols=["src_subject_id", "eventname"]+environment_cols)

events_df = pd.merge(events_df, env_df, on=["src_subject_id", "eventname"])  
events_df = pd.merge(events_df, income_df, on=["src_subject_id", "eventname"]) 

events_df = events_df.dropna()
events_df

Unnamed: 0,src_subject_id,interview_date,eventname,interview_age,rsfmri_c_ngd_ad_ngd_ad,rsfmri_c_ngd_ad_ngd_cgc,rsfmri_c_ngd_ad_ngd_ca,rsfmri_c_ngd_ad_ngd_dt,rsfmri_c_ngd_ad_ngd_dla,rsfmri_c_ngd_ad_ngd_fo,...,smri_vol_cdk_trvtmrh,smri_vol_cdk_insularh,smri_vol_cdk_meanlh,smri_vol_cdk_meanrh,smri_vol_cdk_mean,nihtbx_fluidcomp_fc,mri_info_manufacturer,mri_info_manufacturersmn,fes_y_ss_fc,demo_comb_income_v2
0,NDAR_INV003RTV85,10/01/2018,baseline_year_1_arm_1,131.0,0.471330,0.256267,-0.076960,-0.116451,0.022202,-0.036302,...,1318.0,7404.0,280168.0,279541.0,559709.0,47.0,SIEMENS,Prisma_fit,3.0,8.0
1,NDAR_INV005V6D2C,04/22/2018,baseline_year_1_arm_1,121.0,0.279435,0.116256,0.063664,-0.024781,-0.000840,-0.023421,...,1072.0,6572.0,256557.0,258378.0,514935.0,48.0,GE MEDICAL SYSTEMS,DISCOVERY MR750,0.0,999.0
2,NDAR_INV00CY2MDM,08/22/2017,baseline_year_1_arm_1,130.0,0.395300,0.180940,-0.142808,-0.073424,-0.049414,-0.144481,...,1217.0,7373.0,274090.0,278845.0,552935.0,37.0,SIEMENS,Prisma,1.0,6.0
5,NDAR_INV00HEV6HB,07/08/2017,baseline_year_1_arm_1,124.0,0.219848,0.185615,-0.079140,-0.041995,0.058579,0.027804,...,1040.0,6783.0,263307.0,266505.0,529812.0,47.0,SIEMENS,Prisma_fit,1.0,999.0
8,NDAR_INV00J52GPG,09/05/2018,baseline_year_1_arm_1,110.0,0.284362,0.172945,-0.159164,-0.078633,-0.072509,-0.071459,...,1191.0,9093.0,325216.0,325876.0,651092.0,41.0,Philips Medical Systems,Achieva dStream,0.0,6.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19089,NDAR_INVZZFG6J5U,04/21/2018,baseline_year_1_arm_1,129.0,0.263748,0.162639,-0.172385,-0.064067,0.019411,-0.075842,...,959.0,7613.0,302993.0,299070.0,602063.0,33.0,GE MEDICAL SYSTEMS,DISCOVERY MR750,3.0,7.0
19090,NDAR_INVZZJ3A7BK,10/09/2017,baseline_year_1_arm_1,122.0,0.567350,0.233855,-0.275115,-0.113414,-0.033699,-0.121024,...,1114.0,7471.0,294806.0,295557.0,590363.0,66.0,SIEMENS,Prisma_fit,3.0,3.0
19092,NDAR_INVZZLZCKAY,08/26/2017,baseline_year_1_arm_1,110.0,0.315367,0.176642,-0.010926,-0.019208,-0.030765,-0.062264,...,1324.0,7711.0,284024.0,280617.0,564641.0,35.0,SIEMENS,Prisma_fit,2.0,9.0
19097,NDAR_INVZZPKBDAC,01/20/2018,baseline_year_1_arm_1,113.0,0.404070,0.300714,-0.167735,-0.093816,-0.044867,0.019305,...,1158.0,6582.0,284249.0,285551.0,569800.0,40.0,SIEMENS,Prisma_fit,0.0,10.0


In [8]:
def discretize_column_inplace(events_df, column_name):
    _, bins = pd.qcut(events_df[column_name], q=[0, 0.25, 0.5, 0.75, 1], labels=False, retbins=True, duplicates='drop')
    events_df[column_name] = pd.cut(events_df[column_name], bins=bins, labels=False, include_lowest=True)
    return events_df

events_df = discretize_column_inplace(events_df, target_col)
labels = list(set(events_df[target_col]))
labels

[0, 1, 2, 3]

In [9]:
siemens_table = events_df[events_df['mri_info_manufacturer'] == "SIEMENS"]
ge_table = events_df[events_df['mri_info_manufacturer'] == "GE MEDICAL SYSTEMS"]
philips_table = events_df[events_df['mri_info_manufacturer'] == "Philips Medical Systems"]
msg = f'{len(siemens_table)} entries for Siemens, {len(ge_table)} entries for GE Medical, and {len(philips_table)} for Philips'

In [10]:
# Define features
features_fmri = list(VARS.NAMED_CONNECTIONS.keys())
features_smri = [var_name + '_' + parcel for var_name in VARS.DESIKAN_STRUCT_FEATURES.keys() for parcel in VARS.DESIKAN_PARCELS[var_name] + VARS.DESIKAN_MEANS]
feature_cols = ['demo_comb_income_v2','fes_y_ss_fc']
if 'fmri' in config['features']:
    feature_cols += features_fmri
if 'smri' in config['features']:
    feature_cols += features_smri

In [11]:
# Normalize features
for var_id in feature_cols:
    events_df = normalize_var(events_df, var_id, var_id)

In [12]:
siemens_train = siemens_table[:int(0.7*len(siemens_table))]
ge_train = ge_table[:int(0.7*len(ge_table))]
philips_train = ge_table[:int(0.7*len(philips_table))]



siemens_test = siemens_table[int(0.7*len(siemens_table)):]
ge_test = ge_table[int(0.7*len(ge_table)):]
philips_test = ge_table[int(0.7*len(philips_table)):]

In [13]:
# Define PyTorch datasets and dataloaders
s_datasets = OrderedDict([('SiemensTrain', PandasDataset(siemens_train, feature_cols, target_col)),
            ('SiemensTest', PandasDataset(siemens_test, feature_cols, target_col))])

g_datasets = OrderedDict([('GETrain', PandasDataset(ge_train, feature_cols, target_col)),
                                ('GETest', PandasDataset(siemens_test, feature_cols, target_col))])

p_datasets = OrderedDict([('PhilipsTrain', PandasDataset(philips_train, feature_cols, target_col)),
                                ('PhilipsTest', PandasDataset(siemens_test, feature_cols, target_col))])

In [14]:
# Create dataloaders
batch_size = config['batch_size']
s_dataloaders = OrderedDict([(dataset_name, DataLoader(dataset, batch_size=batch_size, shuffle=True))
    for dataset_name, dataset in s_datasets.items()])

g_dataloaders = OrderedDict([(dataset_name, DataLoader(dataset, batch_size=batch_size, shuffle=True))
    for dataset_name, dataset in g_datasets.items()])

p_dataloaders = OrderedDict([(dataset_name, DataLoader(dataset, batch_size=batch_size, shuffle=True))
    for dataset_name, dataset in p_datasets.items()])

train_dataloaders = OrderedDict([
    ('SiemensTrain', s_dataloaders['SiemensTrain']),
    ('GETrain', g_dataloaders['GETrain']),
    ('PhilipsTrain', p_dataloaders['PhilipsTrain']),
]
)

train_dataloaders

OrderedDict([('SiemensTrain',
              <torch.utils.data.dataloader.DataLoader at 0x7fa7b1eda070>),
             ('GETrain',
              <torch.utils.data.dataloader.DataLoader at 0x7fa7b1eda0d0>),
             ('PhilipsTrain',
              <torch.utils.data.dataloader.DataLoader at 0x7fa7b5d6cc40>)])

In [15]:
device = "cpu"
device

'cpu'

In [16]:
# Define model
models_path = os.path.join(exp.path, 'models')
module = importlib.import_module(config['model'][0])
model = getattr(module, config['model'][1])(save_path=models_path, labels=labels, input_size=len(feature_cols))
#model = FullyConnected5(save_path=models_path, labels=labels, input_size=len(feature_cols))
model = model.to(device)
labels

[0, 1, 2, 3]

In [17]:
# Define optimizer and trainer
learning_rate = config['lr']
loss_f = nn.CrossEntropyLoss()
trainer_path = os.path.join(exp.path, 'trainer')
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
trainer = ClassifierTrainer(trainer_path, device, optimizer, loss_f, labels=labels)

# Methodology:
* Train the model on the Siemens train set. Evaluate on test set of all 3 models
* Continue training of the model on the GE and Philips datasets as well, evaluating on all 3 datasets as well.

In [18]:
test_datasets = OrderedDict([('SiemensTest', PandasDataset(siemens_test, feature_cols, target_col)),
                                ('GETest', PandasDataset(siemens_test, feature_cols, target_col)),
                                ('PhilipsTest', PandasDataset(siemens_test, feature_cols, target_col))])
test_dataloaders = OrderedDict([(dataset_name, DataLoader(dataset, batch_size=batch_size, shuffle=True))
    for dataset_name, dataset in test_datasets.items()])

all_dataloaders = OrderedDict(list(train_dataloaders.items()) + list(test_dataloaders.items()))
all_dataloaders

OrderedDict([('SiemensTrain',
              <torch.utils.data.dataloader.DataLoader at 0x7fa7b1eda070>),
             ('GETrain',
              <torch.utils.data.dataloader.DataLoader at 0x7fa7b1eda0d0>),
             ('PhilipsTrain',
              <torch.utils.data.dataloader.DataLoader at 0x7fa7b5d6cc40>),
             ('SiemensTest',
              <torch.utils.data.dataloader.DataLoader at 0x7fa7b5d6cdc0>),
             ('GETest',
              <torch.utils.data.dataloader.DataLoader at 0x7fa7b5d6c070>),
             ('PhilipsTest',
              <torch.utils.data.dataloader.DataLoader at 0x7fa7b1edaf40>)])

In [20]:
nr_epochs = config['nr_epochs']

trainer.train(model, train_dataloaders, test_dataloaders,
              nr_epochs=nr_epochs, starting_from_epoch=0,
              print_loss_every=int(nr_epochs/10), eval_every=int(nr_epochs/3), export_every=int(nr_epochs/5), verbose=True)

100%|█████████████████████████████████████████| 150/150 [00:37<00:00,  3.99it/s]
100%|█████████████████████████████████████████| 150/150 [00:15<00:00,  9.56it/s]
100%|█████████████████████████████████████████| 150/150 [00:08<00:00, 18.06it/s]
