In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pandas as pd
import pygal
from abcd.local.paths import plots_path
from abcd.data.read_data import get_subjects_events
from abcd.data.define_splits import SITES, save_restore_sex_fmri_splits
from abcd.plotting.pygal.colors import CAT_COLORS
from abcd.plotting.pygal.rendering import display_html

In [3]:
k = 5
subjects_df, events_df = get_subjects_events()
splits = save_restore_sex_fmri_splits(k=k)

In [4]:
# Plot number of subjects per split

custom_style = pygal.style.Style(
    colors=tuple([CAT_COLORS['split'][str(split_ix)] for split_ix in range(5)])
    #,background='transparent'
    )

bar_chart = pygal.Bar(x_label_rotation=45, style=custom_style)
bar_chart.title = '# subjects per split'
bar_chart.x_labels = [x.replace("site", "site ") for x in SITES]
for split_ix in range(k):
    bar_chart.add(str(split_ix), [len(splits[site][str(split_ix)]) for site in SITES])
display_html(bar_chart)

In [5]:
# For each split, plot the mean number of events per subject
import numpy as np

bar_chart = pygal.Bar(x_label_rotation=45, style=custom_style, range=(1, 3))
bar_chart.title = 'Mean # events per subject'
bar_chart.x_labels = [x.replace("site", "site ") for x in SITES]
for split_ix in range(k):
    per_site_values = []
    for site_id in SITES:
        per_subject_event_count = [len(events_df.loc[events_df['src_subject_id'] == src_subject_id]) 
                                   for src_subject_id in splits[site_id][str(split_ix)]]
        per_site_values.append(np.mean(per_subject_event_count))
    bar_chart.add(str(split_ix), per_site_values)
display_html(bar_chart)

In [8]:
# Take the interview date to datetime format
events_df['interview_date'] = pd.to_datetime(events_df['interview_date'], format='%m/%d/%Y')

In [11]:
# Plot the interview date (x) and age (y) for each type of event (hue)
y = "interview_age"
hue = "eventname"

eventnames = ["baseline_year_1_arm_1", "2_year_follow_up_y_arm_1", "4_year_follow_up_y_arm_1"]
better_names = {"baseline_year_1_arm_1": "baseline", "2_year_follow_up_y_arm_1": "2 year", "4_year_follow_up_y_arm_1": "4 year"}

plot = pygal.DateTimeLine(
    x_label_rotation=35, truncate_label=-1,
    x_value_formatter=lambda dt: dt.strftime('%d, %b %Y'), stroke=False) # style=custom_style
plot.y_title = "Age (in months)"
plot.title = "Date, type and subject age for all events"

dfs = dict()
for hue_x in eventnames: #set(events_df[hue]):
    dfs[hue_x] = events_df.loc[events_df[hue] == hue_x]
for hue_x, hue_df in dfs.items():
    dates = hue_df['interview_date'].tolist()
    y_values = hue_df[y].tolist()
    plot.add(better_names[hue_x], [(dates[ix], y_values[ix]) for ix in range(len(dates))])
display_html(plot)
