In [6]:
import pandas as pd
import numpy as np
import os
from plotly.offline import init_notebook_mode, iplot
import plotly.graph_objs as go
import scipy.stats as ss

init_notebook_mode(connected=True) # initiate notebook for offline plot

## Data Loading

In [31]:
data = pd.read_csv("/Users/Dstrip/PycharmProjects/ProjectMetis-FederatedNeuroImaging/projectmetis/resources/config/experiments_configs/brainage/ukbbdata/UKBB_healthy_10k.csv")



In [32]:
data

Unnamed: 0,eid,age_at_scan,9dof_2mm_vol
0,4894384,65.250000,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/489...
1,4969666,73.916667,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/496...
2,4615131,67.083333,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/461...
3,2992424,48.166667,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/299...
4,4889041,71.083333,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/488...
...,...,...,...
10441,3124371,60.833333,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/312...
10442,1772723,51.750000,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/177...
10443,1723950,61.500000,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/172...
10444,4696181,60.916667,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/469...


In [8]:
### PARAM ###
N_BINS = 20

In [33]:
trace = go.Histogram(
    x=data["age_at_scan"],
    opacity=0.5,
)

d = [trace]

layout = dict(title="Age Dist", xaxis=dict(title="Age"), yaxis=dict(title="Freq"), barmode='overlay')

fig = dict(data=d, layout=layout)

iplot(fig, filename='Age Dist')

## Age Binning

In [9]:
data["bin"] = pd.cut(data["age_at_scan"], bins=N_BINS, labels=np.arange(1, N_BINS+1), right=True)

In [11]:
trace = go.Histogram(
    x=data["bin"],
    opacity=0.5,
)

d = [trace]

layout = dict(title="Age Dist", xaxis=dict(title="Age"), yaxis=dict(title="Freq"), barmode='overlay')

fig = dict(data=d, layout=layout)

iplot(fig, filename='Age Dist')

In [10]:
data

Unnamed: 0,eid,age_at_scan,9dof_2mm_vol,bin
0,4894384,65.250000,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/489...,12
1,4969666,73.916667,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/496...,17
2,4615131,67.083333,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/461...,13
3,2992424,48.166667,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/299...,2
4,4889041,71.083333,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/488...,15
...,...,...,...,...
10441,3124371,60.833333,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/312...,10
10442,1772723,51.750000,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/177...,4
10443,1723950,61.500000,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/172...,10
10444,4696181,60.916667,/lfs1/stripeli/neuroimaging_data/ukbb/9DOF/469...,10


## Data Partitioning Functions

In [None]:
# Create sub-table by sampling each bin of data
# Return sampled sub-table and residual table
def homog_sample(data, ratio=0.2, random_state=42):
    sub_data = None

    for b in data["bin"].unique():
        if sub_data is None:
            sub_data = data[data["bin"] == b].sample(frac=ratio, random_state=random_state)
        else:
            sub_data = sub_data.append(data[data["bin"] == b].sample(frac=ratio, random_state=random_state))
    
    # Create (data - test_data)
    resid_data = pd.concat([data, sub_data]).drop_duplicates(keep=False)
    
    resid_data.reset_index(inplace=True, drop=True)
    sub_data.reset_index(inplace=True, drop=True)
    
    return sub_data, resid_data

In [None]:
# Homogeneous data split
# Returns list of data tables with each sub table 
# containing roughly equal ratio subjects from each bucket
def homog_split(data, num_locs=2, random_state=42):
    sites_data = [None for i in range(num_locs)]
    
    for b in data["bin"].unique():
        # Extract & Shuffle
        temp = data[data["bin"] == b].sample(frac=1, random_state=random_state)
        n_samples = len(temp)//num_locs # n_samples from bucket per location
        
        # Homogeneous distribution
        for i in range(len(sites_data)):
            sidx = i*n_samples
            eidx = i*n_samples + n_samples
            if i == len(sites_data) - 1: # Extra samples placed in final site
                eidx = len(temp)
            
            if sites_data[i] is None:
                sites_data[i] = temp[sidx:eidx]
            else:
                sites_data[i] = sites_data[i].append(temp[sidx:eidx])
        
        for site in sites_data:
            site.reset_index(inplace=True, drop=True)
            
    return sites_data

In [None]:
# IID data split
# Returns list of data tables with each sub table 
# containing random subset of data
def iid_split(data, num_locs=2, random_state=42):
    sites_data = [None for i in range(num_locs)]
    
    # Shuffle
    temp = data.sample(frac=1, random_state=random_state)
    n_samples = len(temp)//num_locs
    
    # IID distribution
    for i in range(len(sites_data)):
        sidx = i*n_samples
        eidx = i*n_samples + n_samples
        if i == len(sites_data) - 1:
            eidx = len(temp)
        
        if sites_data[i] is None:
            sites_data[i] = temp[sidx:eidx]
        else:
            sites_data[i] = sites_data[i].append(temp[sidx:eidx])
    
    for site in sites_data:
        site.reset_index(inplace=True, drop=True)
            
    return sites_data

## Create Test and Training (Centralized) Dataset

In [16]:
## PARAM ###
TEST_RATIO  = 0.2 # Percent of data from **each age bin**
RANDOM_SEED = 42

In [17]:
# Create test set
test_data, training_data = homog_sample(data, ratio=TEST_RATIO, random_state=RANDOM_SEED)

In [18]:
trace = go.Histogram(
    x=test_data["age_at_scan"],
    opacity=0.5,
)

d = [trace]

layout = dict(title="Age Dist", xaxis=dict(title="Age"), yaxis=dict(title="Freq"), barmode='overlay')

fig = dict(data=d, layout=layout)

iplot(fig, filename='Age Dist')

In [382]:
# Save test to file
test_data.to_csv("./test_data.csv", index=False)

## Create per-site Train/Validation Datasets

In [19]:
### PARAM ###
NUM_LOCS    = 8    # Num of locations (learners)
VALID_RATIO = 0.05 # Per bucket percentage
RANDOM_SEED = 42
SAVE_DIR    = "./split_gen/" # Output Directory

In [21]:
sites_data = homog_split(training_data, num_locs=NUM_LOCS, random_state=RANDOM_SEED)

In [23]:
plots = []
for i in range(len(sites_data)):
    plots.append(go.Histogram(x=sites_data[i]["bin"], name=i, opacity=0.5))
    
layout = dict(title="Age Dist", xaxis=dict(title="Age"), yaxis=dict(title="Freq"), barmode="overlay")
fig = dict(data=plots, layout=layout)
iplot(fig, filename="Age Dist")

In [24]:
# Create train, valid sets per site and save to .csv
if not os.path.isdir(SAVE_DIR):
    os.mkdir(SAVE_DIR)

for i in range(len(sites_data)):
    site_valid_data, site_train_data = homog_sample(sites_data[i], ratio=VALID_RATIO, random_state=RANDOM_SEED)
    
    train_data.to_csv("{}/train_{}.csv".format(SAVE_DIR, i+1))
    valid_data.to_csv("{}/valid_{}.csv".format(SAVE_DIR, i+1))

In [25]:
### PARAM ###
NUM_LOCS    = 4    # Num of locations (learners/sites)
VALID_RATIO = 0.05 # Per bucket percentage
RANDOM_SEED = 42
SAVE_DIR    = "./iid_4/"

In [27]:
sites_data = iid_split(training_data, num_locs=NUM_LOCS, random_state=RANDOM_SEED)

In [28]:
plots = []
for i in range(len(sites_data)):
    plots.append(go.Histogram(x=sites_data[i]["bin"], name=i, opacity=0.5))
    
layout = dict(title="Age Dist", xaxis=dict(title="Age"), yaxis=dict(title="Freq"), barmode="overlay")
fig = dict(data=plots, layout=layout)
iplot(fig, filename="Age Dist")

In [402]:
# Create train, valid sets per site and save
# to .csv

if not os.path.isdir(SAVE_DIR):
    os.mkdir(SAVE_DIR)

for i in range(len(sites_data)):
    site_valid_data, site_train_data = homog_sample(sites_data[i], ratio=VALID_RATIO, random_state=RANDOM_SEED)
        
    site_valid_data.to_csv("{}/valid_{}.csv".format(SAVE_DIR, i+1))
    site_train_data.to_csv("{}/train_{}.csv".format(SAVE_DIR, i+1))