# setup directories 

In [1]:
import pandas as pd 
import numpy as np 
from sklearn.model_selection import train_test_split
import nilearn.image as niimage

import gc 
import os 
import shutil
pj = os.path.join
from collections import Counter
import pickle 

# reproducibility

In [2]:
np.random.seed(92)

# paths 

In [3]:
project_path = '/Users/ericchase/OneDrive/NORTHWESTERN/498_cap'
data_path = pj(project_path, 'data')
img_path = pj(data_path, 'images')
download_path = pj(img_path, 'download')
by_site_path = pj(img_path, 'by_site')
clinical_path = pj(data_path, 'clinical')

# load clinical

In [5]:
cdata = pd.read_csv(pj(clinical_path, 'Phenotypic_V1_0b_preprocessed1.csv'))
cdata.columns = cdata.columns.str.lower() 
# cdata.info(verbose=True)

In [6]:
sample_count = len(os.listdir(download_path))
sample_count

884

# fname ls

In [7]:
# create list of all image file names 
fname_ls = [s for s in os.listdir(download_path) if s.endswith('.nii')]
fname_ls[:5]

['OHSU_0050157_reho.nii',
 'OHSU_0050156_reho.nii',
 'UM_1_0050339_reho.nii',
 'UM_1_0050338_reho.nii',
 'UCLA_2_0051317_reho.nii']

# dx_ls

In [8]:
dx_ls = []
site_ls = []

for i,fname in enumerate(fname_ls):
    site_id = ''.join(fname.split('_')[:-2])
    sub_id = fname.split('_')[-2]
    dx_group = cdata[cdata.sub_id.eq(int(sub_id))].dx_group.values[0] - 1
    dx_ls.append(dx_group)
    site_ls.append(site_id)

# class represenation

In [9]:
Counter(dx_ls)

Counter({1: 476, 0: 408})

# populate by_site directories

In [11]:
!find ../images -name '.DS_Store' -delete 

In [12]:
def populate_directories(sites, fnames, labels):
    for site, fname, label in list(zip(sites, fnames, labels)):
        dst_path = pj(by_site_path, site, str(label))
        dst_fname = fname.split('.')[0] + '.pkl'
        if not os.path.exists(dst_path):
            os.makedirs(dst_path) 
        nii_img_obj = niimage.load_img(pj(download_path, fname))
        img_data = nii_img_obj.get_fdata() 
        img_data = np.expand_dims(img_data, axis=-1)    
        # pickle to local dir #
        pickle.dump(img_data, open(pj(dst_path, dst_fname), 'wb'))

In [31]:
populate_directories(site_ls, fname_ls, dx_ls)

In [13]:
class_counts = {'0':0, '1':0}

for site in os.listdir(by_site_path):
    for group in os.listdir(pj(by_site_path, site)):
        for file in os.listdir(pj(by_site_path, site, group)):
#             print(file)
            class_counts[group] += 1
            
class_counts

{'0': 408, '1': 476}