In [None]:
from functools import reduce
from pathlib import Path

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from imblearn.under_sampling import RandomUnderSampler

# 1) load and prepare data

In [None]:
csv_file = '/scratch/hoerl/auto_sir_dna_comp/20220829_glcm_good95_imagenorm.csv'
# csv_file = '/scratch/hoerl/auto_sir_dna_comp/20220829_glcm_good95_imagenorm_confocalblur.csv'
# csv_file = '/scratch/hoerl/auto_sir_dna_comp/20220829_glcm_good95_replicatenorm.csv'
# csv_file = '/scratch/hoerl/auto_sir_dna_comp/20220829_glcm_good95_replicatenorm_confocalblur.csv'

## CSV with experimental conditions
exp_overview_csv = '/scratch/hoerl/auto_sir_experiment_overview.csv'

In [None]:
df = pd.read_csv(csv_file)

# ignore very small technical replicates (acquisition crashed/stopped bc. sample was bad)
min_replicate_size = 10
df = reduce(pd.DataFrame.append, [dfi for _, dfi in df.groupby(['cell_class', 'replicate']) if len(dfi)>min_replicate_size])

# remove one outlying replicate (looks blurry -> air bubble/diry objective?)
df = df[~df[['cell_class', 'replicate']].isin([('IMR90_young_untreated', '20200705_rep2')]).all(axis=1)]

# get file stem for merge with standardized replicate names
df['file_stem'] = df.filename.apply(lambda f: Path(f).stem)

# read and add standardized replicate/condition info
df_exp_overview = pd.read_csv(exp_overview_csv, sep=';')[['file', 'treatment', 'replicate_technical', 'replicate_biological', 'overlapping_tiles']]
df_exp_overview['replicate_technical'] = df_exp_overview['replicate_technical'].apply(str)
df_exp_overview['replicate_biological'] = df_exp_overview['replicate_biological'].apply(str)

df = df.merge(df_exp_overview, left_on='file_stem', right_on='file', suffixes=(None, '_duplicate') )

# auxillary columns for grouping
df['treatment_icm_grouped'] = df.treatment.str.split('_').str[-1]
df['replicate_biological_with_treat'] = df['treatment'] + '_' + df['replicate_biological']

# cell_class for verficication
df.cell_class.unique()

In [None]:
# optional, select date range
df = df[df.replicate.str[:6].astype(int) < 202008]

In [None]:
df

## split old & young / ICM treated

In [None]:
oldyoung_classes = ['IMR90_young_untreated', 'IMR90_untreated_old']

df_oldyoung = df[df.cell_class.isin(oldyoung_classes)]
df_treated = df[~df.cell_class.isin(oldyoung_classes)]

## feature preprocessing / numerical labels

In [None]:
# columns to drop from features
# filepaths, classes, good/bad cls & auxillariy features
columns_to_drop = [
                  'file', 'file_stem', 'treatment', 'replicate_technical', 'replicate_biological', 'overlapping_tiles',
                  'dataset_name', 'filename', 'classification_manual', 'classification_auto', 'replicate',
                  'treatment_icm_grouped', 'replicate_biological_with_treat',
                   'cell_class', 'condition',
                   'img_height', 'img_width', 
                #    'mask_area',
                   'num_blank_rows', 'num_blank_cols',
#                    'intensity_mu', 'intensity_sigma', 
                   'perc_high', 'perc_low', 'fg_mean',
                   'perc_high_image', 'perc_low_image'
                   
                  ] 

feats_oldyoung = df_oldyoung.drop(columns=columns_to_drop).values
feats_treated = df_treated.drop(columns=columns_to_drop).values

In [None]:
le_oldyoung = LabelEncoder()
y_oldyoung = le_oldyoung.fit_transform(df_oldyoung.cell_class)

X_oldyoung = SimpleImputer().fit_transform(feats_oldyoung)
X_treated = SimpleImputer().fit_transform(feats_treated)

# fit scaler on old/young samples and apply the same scaling to treated samples
scaler = StandardScaler()
X_oldyoung = scaler.fit_transform(X_oldyoung)
X_treated = scaler.transform(X_treated)

# the label encoder might switch old/young order, get indices
young_old_indices = le_oldyoung.transform(oldyoung_classes)

## randomly undersample the young images to match size of old
## did not change classifier preference for 'young' much
# X_oldyoung, y_oldyoung = RandomUnderSampler().fit_resample(X_oldyoung, y_oldyoung)
# np.unique(y_oldyoung, return_counts=True)

# 2) fit classifier on old & young

In [None]:
# SVC with higher regularization than default (C parameter<1)
cls = SVC(C=0.1, probability=True, class_weight='balanced')

# CV scoring to assess classifier performance on old/young
cv = StratifiedKFold(5, shuffle=True)
y_oldyoung_pred = cross_val_predict(cls, X_oldyoung, y_oldyoung, n_jobs=-1, cv=cv)
np.mean(y_oldyoung_pred == y_oldyoung)

In [None]:
# fit again on whole old/young dataset
cls.fit(X_oldyoung, y_oldyoung);

# 3) apply to ICM treated

In [None]:
y_treated_pred = cls.predict_proba(X_treated)

df_pred = pd.DataFrame()

df_pred['cell_class'] = df_treated.treatment
df_pred['replicate_bio'] = df_treated.replicate_biological
df_pred['replicate_tech'] = df_treated.replicate_technical

df_pred['prob_young'] = y_treated_pred.T[young_old_indices[0]]
df_pred['prob_old'] = y_treated_pred.T[young_old_indices[1]]

In [None]:
import seaborn as sns

## group by technical or biological replicates
# df_grouped = df_pred.groupby(['cell_class', 'replicate_tech'])[['prob_young', 'prob_old']]
df_grouped = df_pred.groupby(['cell_class', 'replicate_bio'])[['prob_young', 'prob_old']]

# average prediction per group
df_confmat = df_grouped.mean()

# add replicate size to index -> for labelling in plot
df_confmat['count'] = list(map(lambda c: f'N={c}', df_grouped.count().iloc[:,0].values))
df_confmat = df_confmat.set_index('count', append=True)

plt.figure(figsize=(6,8))
sns.heatmap(df_confmat, cmap='Blues', annot=True, vmin=0, vmax=1)
plt.yticks(ticks=range(len(df_confmat)), labels=map(', '.join, df_confmat.index))
plt.xticks(ticks=range(len(oldyoung_classes)), labels=['young', 'old'], rotation='vertical');

# save
plt.rc('pdf', fonttype='42')
plt.savefig('/home/hoerl/ageing_dna_texture_figure_parts/confusionmatrix_oldyoung-classification_sted.pdf')