# Imports and initialization of general parameters
***

In [17]:
from config.info import AGES, RACES, GENDERS, COMBS_BASELINE
from visualization.subgroup_distribution import plot_dist
from dataprocess.dataloader import load_data
from dataprocess.dataclass import Data
from config.get_args import get_args
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import argparse
import pandas as pd


# Auto reload part
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load the desired data set
***

In [18]:
ref = pd.read_pickle('data/info.pkl')
ref = ref[~ref.cancer.isin(['coad_read_FS', 'coad_read_PM'])]
ref[:100]

Unnamed: 0,task,cancer,label,age_,race_,gender_,subj,siteID,age_str,race_str,gender_str
1024,cancer_classification,kich_kirc_FS,0,3,1,1,TCGA-KL-8340,KL,3,WHITE,MALE
1025,cancer_classification,kich_kirc_FS,0,3,1,1,TCGA-KL-8340,KL,3,WHITE,MALE
1026,cancer_classification,kich_kirc_FS,0,5,1,1,TCGA-KN-8426,KN,5,WHITE,MALE
1027,cancer_classification,kich_kirc_FS,0,5,1,1,TCGA-KN-8426,KN,5,WHITE,MALE
1028,cancer_classification,kich_kirc_FS,0,7,1,1,TCGA-KL-8343,KL,7,WHITE,MALE
...,...,...,...,...,...,...,...,...,...,...,...
1119,cancer_classification,kich_kirc_FS,0,4,1,1,TCGA-NP-A5H6,NP,4,WHITE,MALE
1120,cancer_classification,kich_kirc_FS,0,6,1,0,TCGA-KN-8434,KN,6,WHITE,FEMALE
1121,cancer_classification,kich_kirc_FS,0,6,1,0,TCGA-KN-8434,KN,6,WHITE,FEMALE
1122,cancer_classification,kich_kirc_FS,0,7,1,1,TCGA-KO-8409,KO,7,WHITE,MALE


In [19]:
# Modifications of the refs 
ref['age_str'] = ref.age_.astype(str)
ref['race_str'] = ref.apply(lambda row : RACES['str'][row['race_']], axis = 1)
ref['gender_str'] = ref.apply(lambda row : GENDERS['str'][row['gender_']], axis = 1)
ref.reset_index(inplace=True, drop = True)

In [20]:
fig = px.histogram(ref, x = "cancer", color = 'task', 
                   text_auto = True, 
                   color_discrete_sequence = ['rgb(50, 100, 170)', 'rgb(21, 21, 45)'],
                   template = 'none',
                   width = 1000, 
                   height = 400)
fig.update_xaxes(tickangle = 40)
fig.update_layout(legend = dict(yanchor = "top", 
                                xanchor = "right"))
fig.update_layout(barmode = 'relative', bargap = 0.3)
fig.write_image("images/eda_1.eps")
fig.show()

In [21]:
fig.write_image("images/eda_1.eps")

# Protected attributes
***

In [22]:
# Reduce the refs
c1 = 'kirc_kirp_PM'
c2 = 'lusc'
c1 = 'luad_lusc_PM'
c2 = 'lusc'
ref1 = ref[ref.cancer == c1].sort_values(by = ['gender_str'])
ref2 = ref[ref.cancer == c2]#.sort_values(by = ['gender_str'])

In [23]:
# Initialization of the plot
fig = make_subplots(rows = 2, cols = 3,
                    subplot_titles=('',  f'Cancer classification for {c1}', '', '', f'Tumor detection for {c2}', ''))

# Traces
trace1_1 = px.histogram(ref1, x = 'age_', color = 'gender_',
                        color_discrete_sequence = ['rgb(180, 40, 40)', 'rgb(40, 40, 140)']).update_traces(showlegend = False).update_xaxes()
trace1_2 = px.histogram(ref1, x = 'race_str', color = 'gender_str',
                        color_discrete_sequence = ['rgb(180, 40, 40)', 'rgb(40, 40, 140)']).update_traces(showlegend = False)
trace1_3 = px.histogram(ref1, x='gender_str', color='gender_str',
                        color_discrete_sequence = ['rgb(180, 40, 40)', 'rgb(40, 40, 140)'])
trace2_1 = px.histogram(ref2, x = 'age_str', color = 'gender_str',
                        color_discrete_sequence = ['rgb(180, 40, 40)', 'rgb(40, 40, 140)']).update_traces(showlegend = False)
trace2_2 = px.histogram(ref2, x = 'race_str', color = 'gender_str',
                        color_discrete_sequence = ['rgb(180, 40, 40)', 'rgb(40, 40, 140)']).update_traces(showlegend = False)
trace2_3 = px.histogram(ref2, x='gender_str', color='gender_str',
                        color_discrete_sequence = ['rgb(180, 40, 40)', 'rgb(40, 40, 140)']).update_traces(showlegend = False)

# Update the figure
for t in range(len(trace1_1['data'])): fig.add_trace(trace1_1['data'][t], row = 1, col = 1)
for t in range(len(trace1_2['data'])): fig.add_trace(trace1_2['data'][t], row = 1, col = 2)
for t in range(len(trace1_3['data'])): fig.add_trace(trace1_3['data'][t], row = 1, col = 3)
for t in range(len(trace2_1['data'])): fig.add_trace(trace2_1['data'][t], row = 2, col = 1)
for t in range(len(trace2_2['data'])): fig.add_trace(trace2_2['data'][t], row = 2, col = 2)
for t in range(len(trace2_3['data'])): fig.add_trace(trace2_3['data'][t], row = 2, col = 3)
fig['layout']['xaxis']['title']='age'
fig['layout']['yaxis']['title']='count'
fig['layout']['xaxis2']['title']='race'
fig['layout']['xaxis3']['title']='gender'
fig['layout']['xaxis4']['title']='age'
fig['layout']['yaxis4']['title']='count'
fig['layout']['xaxis5']['title']='race'
fig['layout']['xaxis6']['title']='gender'
fig.update_layout(height = 700, width = 1100, template = 'none',
                  xaxis1 = dict(tickvals=[2, 3, 4, 5, 6, 7, 8, 9]),
                  xaxis4 = dict(tickvals=[4, 5, 6, 7, 8]))
fig.update_layout(barmode = 'group', bargap = 0.3, bargroupgap = 0.1)
fig.update_coloraxes(showscale=False)
fig.write_image("images/eda_2.eps")
fig.show()

***

In [24]:
# Build the pivot matrices
# Init
p1 = ref1.copy()
p1['a'] = 1
p2 = ref2.copy()
p2['a'] = 1

# Pivot for female
p1_f = pd.pivot_table(p1[p1.gender_str == 'FEMALE'][['age_str', 'race_str', 'a']], index= ['age_str'], columns = ['race_str'], aggfunc = 'count' )
p1_f = p1_f['a']
p2_f = pd.pivot_table(p2[p2.gender_str == 'FEMALE'][['age_str', 'race_str', 'a']], index= ['age_str'], columns = ['race_str'], aggfunc = 'count' )
p2_f = p2_f['a']

# Pivot for male
p1_m = pd.pivot_table(p1[p1.gender_str == 'MALE'][['age_str', 'race_str', 'a']], index= ['age_str'], columns = ['race_str'], aggfunc = 'count' )
p1_m = p1_m['a']
p2_m = pd.pivot_table(p2[p2.gender_str == 'MALE'][['age_str', 'race_str', 'a']], index= ['age_str'], columns = ['race_str'], aggfunc = 'count' )
p2_m = p2_m['a']

In [25]:
# Build the subplot
fig = make_subplots(2, 2, horizontal_spacing = 0.14,
                    subplot_titles=(f'Female patients ({c1})', f'Male patients ({c1})', f'Female patients ({c2})', f'Male patients ({c2})'))

# Add the heatmaps
fig.add_trace(go.Heatmap(x = p1_f.columns, y = p1_f.index, z = p1_f, coloraxis = 'coloraxis', texttemplate = '%{z}'), 1, 1)
fig.add_trace(go.Heatmap(x = p1_m.columns, y = p1_m.index, z = p1_m, coloraxis = 'coloraxis2', texttemplate = '%{z}'), 1, 2)
fig.add_trace(go.Heatmap(x = p2_f.columns, y = p2_f.index, z = p2_f, coloraxis = 'coloraxis3', texttemplate = '%{z}'), 2, 1)
fig.add_trace(go.Heatmap(x = p2_m.columns, y = p2_m.index, z = p2_m, coloraxis = 'coloraxis4', texttemplate = '%{z}'), 2, 2)

# Update the figures
fig.update_layout(height = 600, width = 1000,
                  template = 'none',
                  yaxis = {"title": 'age'},
                  yaxis2 = {"title": 'age'},
                  yaxis3 = {"title": 'age'},
                  yaxis4 = {"title": 'age'},
                  xaxis = {"title": 'race'},
                  xaxis2 = {"title": 'race'},
                  xaxis3 = {"title": 'race'},
                  xaxis4 = {"title": 'race'},
                  coloraxis = dict(colorscale='matter_r', colorbar_x=0.43, colorbar_y=0.8, colorbar_len = 0.5, colorbar_thickness=23),
                  coloraxis2 = dict(colorscale='ice', colorbar_x=1, colorbar_y=0.8, colorbar_len = 0.5, colorbar_thickness=23),
                  coloraxis3 = dict(colorscale='matter_r', colorbar_x=0.43, colorbar_y=0.2, colorbar_len = 0.5, colorbar_thickness=23),
                  coloraxis4 = dict(colorscale='ice', colorbar_x=1, colorbar_y=0.2, colorbar_len = 0.5, colorbar_thickness=23)) #deep_r
fig.update_layout(yaxis=dict(tickvals=[2, 3, 4, 5, 6, 7, 8, 9]),
                  yaxis2=dict(tickvals=[2, 3, 4, 5, 6, 7, 8, 9]),
                  yaxis3=dict(tickvals=[2, 3, 4, 5, 6, 7, 8, 9]),
                  yaxis4=dict(tickvals=[2, 3, 4, 5, 6, 7, 8, 9]))
fig.write_image("images/eda_3.eps")
fig.show()

In [26]:
# Build the pivot matrices
# Init
p1 = ref1.copy()
p1 = ref1[ref1.gender_ == 0].copy()
p1['a'] = 1
p2 = ref2.copy()
p2 = ref1[ref1.gender_ == 1].copy()
p2['a'] = 1

# Pivot for female
p1_f = pd.pivot_table(p1[p1.label == 0][['age_str', 'race_str', 'a']], index= ['age_str'], columns = ['race_str'], aggfunc = 'count' )
p1_f = p1_f['a']
p2_f = pd.pivot_table(p2[p2.label == 0][['age_str', 'race_str', 'a']], index= ['age_str'], columns = ['race_str'], aggfunc = 'count' )
p2_f = p2_f['a']

# Pivot for male
p1_m = pd.pivot_table(p1[p1.label == 1][['age_str', 'race_str', 'a']], index= ['age_str'], columns = ['race_str'], aggfunc = 'count' )
p1_m = p1_m['a']
p2_m = pd.pivot_table(p2[p2.label == 1][['age_str', 'race_str', 'a']], index= ['age_str'], columns = ['race_str'], aggfunc = 'count' )
p2_m = p2_m['a']

In [27]:
# Build the subplot
fig = make_subplots(2, 2, horizontal_spacing = 0.14,
                    subplot_titles=(f'Female patients ({c1})', f'Male patients ({c1})', f'Female patients ({c2})', f'Male patients ({c2})'))

# Add the heatmaps
fig.add_trace(go.Heatmap(x = p1_f.columns, y = p1_f.index, z = p1_f, coloraxis = 'coloraxis', texttemplate = '%{z}'), 1, 1)
fig.add_trace(go.Heatmap(x = p1_m.columns, y = p1_m.index, z = p1_m, coloraxis = 'coloraxis2', texttemplate = '%{z}'), 1, 2)
fig.add_trace(go.Heatmap(x = p2_f.columns, y = p2_f.index, z = p2_f, coloraxis = 'coloraxis3', texttemplate = '%{z}'), 2, 1)
fig.add_trace(go.Heatmap(x = p2_m.columns, y = p2_m.index, z = p2_m, coloraxis = 'coloraxis4', texttemplate = '%{z}'), 2, 2)

# Update the figures
fig.update_layout(height = 600, width = 1000,
                  template = 'none',
                  yaxis = {"title": 'age'},
                  yaxis2 = {"title": 'age'},
                  yaxis3 = {"title": 'age'},
                  yaxis4 = {"title": 'age'},
                  xaxis = {"title": 'race'},
                  xaxis2 = {"title": 'race'},
                  xaxis3 = {"title": 'race'},
                  xaxis4 = {"title": 'race'},
                  coloraxis = dict(colorscale='matter_r', colorbar_x=0.43, colorbar_y=0.8, colorbar_len = 0.5, colorbar_thickness=23),
                  coloraxis2 = dict(colorscale='deep_r', colorbar_x=1, colorbar_y=0.8, colorbar_len = 0.5, colorbar_thickness=23),
                  coloraxis3 = dict(colorscale='matter_r', colorbar_x=0.43, colorbar_y=0.2, colorbar_len = 0.5, colorbar_thickness=23),
                  coloraxis4 = dict(colorscale='deep_r', colorbar_x=1, colorbar_y=0.2, colorbar_len = 0.5, colorbar_thickness=23))
fig.update_layout(yaxis=dict(tickvals=[2, 3, 4, 5, 6, 7, 8, 9]),
                  yaxis2=dict(tickvals=[2, 3, 4, 5, 6, 7, 8, 9]),
                  yaxis3=dict(tickvals=[2, 3, 4, 5, 6, 7, 8, 9]),
                  yaxis4=dict(tickvals=[2, 3, 4, 5, 6, 7, 8, 9]))
fig.write_image("images/eda_3.eps")
fig.show()

In [29]:
ref[ref.cancer == 'luad_lusc_PM'][['age_', 'race_', 'gender_']].value_counts()

age_  race_  gender_
7     1      1          164
6     1      1          160
7     1      0          130
6     1      0          114
5     1      1           66
             0           46
8     1      1           41
             0           22
4     1      1           20
6     3      1           18
7     3      0           17
4     1      0           17
6     3      0           15
5     3      0           14
7     3      1           11
5     3      1            8
7     2      1            3
4     3      0            3
6     2      1            3
3     1      0            3
7     2      0            2
6     2      0            2
4     3      1            2
      2      1            2
3     1      1            1
6     0      0            1
5     2      1            1
4     2      0            1
8     2      0            1
             1            1
      3      0            1
             1            1
Name: count, dtype: int64