In [None]:
from sys import path
path.insert(0, '..')
import matplotlib.pyplot as plt
import os
import itertools
import numpy as np
import pandas as pd
import trainer.data as D
import trainer.plotting as tp

In [None]:
meta = pd.read_csv('../datasets/meta/patient_info.csv')
meta.pid = meta.pid.apply(lambda x: x.lower())

In [None]:
def ssum(X):
    # Sum for lists
    ssum = X[0]
    for x in X[1:]:
        ssum += x
    return ssum

age_group_boundaries = np.sort(list(set(ssum([
    [int(x) for x in x.split('-')]
    for x in D.age_group_bins.values()
]))))

plt.hist(meta.age, bins=15)
for i, a in enumerate(age_group_boundaries):
    plt.axvline(x=a, color='k', ls='--')
    plt.text(a+4, 0.5, 'bin %i'%(i+1), color='w', fontsize=15)

plt.ylabel('Number in bin')
plt.xlabel('age (yrs)')
plt.tight_layout()
plt.show()

In [None]:
ag_conv = {k: i for i, k in enumerate(D.age_group_bins.keys())}
meta['age_group_num'] = meta.age_group.apply(ag_conv.get)

In [None]:
def read_target(filename):
    assert os.path.exists(filename), filename
    record = D.TFRecordFile(filename)
    return np.array(record.target)

In [None]:
dirname = '../datasets/tfrecords'
df = pd.concat({
    entry.pid: pd.DataFrame(read_target(os.path.join(dirname, entry.file)), columns=['stage_number'])
    for _, entry in meta.iterrows()
})
df.index.names = ['pid', 'epoch']
df = pd.DataFrame(df).reset_index()
df['stage'] = df.stage_number.apply(D.decode)
df = pd.merge(df, meta, on='pid')

In [None]:
# Count all events in each age bin
counts = df.groupby('age_group').apply(lambda x: x.stage.value_counts())
counts.index.names = ['age_group', 'stage']
counts = counts.unstack('age_group')[list(D.age_group_bins.keys())].loc[D.events]

# Add joined counts for all stages
counts['all'] = df.stage.value_counts()
counts = counts.transpose()
counts['total'] = counts.transpose().sum()

# normalize each stage per bin, and total by all
for stage in D.events:
    counts[stage] = counts[stage] / counts['total']
counts['total'] = counts['total'] / counts['total']['all']

# convert to percent
counts = 100.0 * counts
print("make sure that's in the right order!")
counts

In [None]:
cm = counts.values.transpose()
cm[-1, -1] = 0.0

In [None]:
ylabels = ['Wake', 'S1', 'S2', 'S3', 'S4', 'REM']
ylabels.append('Total')

plt.figure(figsize=(9, 8))
plt.imshow(cm, interpolation='nearest', cmap=tp.colorscheme['frequency'])
# plt.title(title)
# plt.colorbar()
xtick_marks = np.arange(len(ylabels))
ytick_marks = np.arange(len(D.age_group_bins)+1)
plt.yticks(xtick_marks, ylabels, rotation=45, fontsize=13)
plt.xticks(ytick_marks, list(D.age_group_bins.values())+['all'], fontsize=13)

fmt = 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    if i==cm.shape[0]-1 and j==cm.shape[1]-1:
        continue
    else:
        plt.text(j, i, '{}%'.format(int(cm[i,j])),
            color="white" if cm[i, j] > thresh else "black",
            ha='center', fontsize=12)


plt.axhline(y=len(ylabels)-1.5, color='k', lw=2.)
plt.axvline(x=len(D.age_group_bins)-0.5, color='k', lw=2.)
        
plt.xlabel('Age group  (yrs)', fontsize=13)
plt.tight_layout()
plt.show()