In [None]:
from collections import defaultdict
from functools import reduce
from pathlib2 import Path

import numpy as np
from numpy import interp
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.cm import get_cmap
import seaborn as sns

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelBinarizer, label_binarize
from sklearn.pipeline import Pipeline
from sklearn.cluster import FeatureAgglomeration
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.model_selection import KFold, StratifiedKFold, LeaveOneGroupOut
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import roc_curve, precision_recall_curve
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import auc, average_precision_score
from sklearn.multiclass import OneVsRestClassifier
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

# 1) Load Features from skimage GLCM or CellProfiler output

## 1a) From skimage pipeline output

In [None]:
# csv_file = '/scratch/hoerl/auto_sir_dna_comp/20220829_glcm_good95_imagenorm.csv'
# csv_file = '/scratch/hoerl/auto_sir_dna_comp/20220829_glcm_good95_imagenorm_confocalblur.csv'
csv_file = '/scratch/hoerl/auto_sir_dna_comp/20220829_glcm_good95_replicatenorm.csv'
# csv_file = '/scratch/hoerl/auto_sir_dna_comp/20220829_glcm_good95_replicatenorm_confocalblur.csv'

save_plots = False

# use the date of the csv file for saving plots
date_str = Path(csv_file).name.split('_')[0]

In [None]:
df_skimage = pd.read_csv(csv_file)

### Optional 1: filter by foreground brightness
# df_skimage = df_skimage[df_skimage.fg_mean > 100]

### Optional 2: keep only replicates abov certain mean (mean) intensity
# keep only replicates above a certain mean fg intensity
only_bright = False
if only_bright:
    df_skimage = reduce(pd.DataFrame.append, [dfi for _,dfi in df_skimage.groupby(['cell_class', 'replicate']) if dfi.fg_mean.mean() > 100])

### Optional 3: only use technical replicates with a minimum amount of images
min_replicate_size = 10
df_skimage = reduce(pd.DataFrame.append, [dfi for _, dfi in df_skimage.groupby(['cell_class', 'replicate']) if len(dfi)>min_replicate_size])

### Optional 4: Group all ICM treated cells into one class
# df_skimage.cell_class.replace(['IMR90_3d_ICM_young', 'IMR90_6d_ICM_young', 'IMR90_9d_ICM_young'], 'IMR90_ICM_treated', inplace=True)

### 5: select only some cell_classes
# selected_cell_classes = ['IMR90_untreated_old', 'IMR90_9d_ICM_young', 'IMR90_3d_ICM_young', 'IMR90_6d_ICM_young', 'IMR90_young_untreated']
selected_cell_classes = ['IMR90_untreated_old', 'IMR90_6d_ICM_young', 'IMR90_young_untreated']
# selected_cell_classes = ['IMR90_untreated_old', 'IMR90_young_untreated']

# NOTE: only works with grouping of the treated cells
# selected_cell_classes = ['IMR90_untreated_old', 'IMR90_young_untreated', 'IMR90_ICM_treated']

df_skimage = df_skimage[df_skimage.cell_class.isin(selected_cell_classes)]

# Get data 'batch' from preparation date
# NOTE: scaling per batch (below) was not really helpfull
dates = df_skimage.replicate.str.split('_', expand=True)[0]

date_to_batch = {
    '20200622' : 0,
    '20200625' : 0,
    '20200629' : 0,
    '20200702' : 0,
    '20200705' : 0,
    '20201208' : 1,
    '20201214' : 1,
    '20210326' : 2,
    '20210402' : 2,
    '20210826' : 3,
    '20211006' : 4,
    '20220107' : 5,
    '20220111' : 5
}

batches = dates.replace(date_to_batch)

# columns to drop from features
# filepaths, classes, good/bad cls & auxillariy features
columns_to_drop = ['dataset_name', 'filename', 'classification_manual', 'classification_auto', 'replicate',
                   'cell_class', 'condition',
                   'img_height', 'img_width', 'mask_area',
                   'num_blank_rows', 'num_blank_cols',
#                    'intensity_mu', 'intensity_sigma', 
                   'perc_high', 'perc_low', 'fg_mean',
                   'perc_high_image', 'perc_low_image'
                  ] 

# drop columns that are not features
df_feats = df_skimage.drop(columns=columns_to_drop, errors='ignore')

conditions = df_skimage.cell_class
replicates = df_skimage.replicate
bio_replicates = df_skimage.condition

tex_values = df_feats.values
feat_names = df_feats.columns

In [None]:
r = df_skimage.groupby(['cell_class', 'condition']).fg_mean.describe()[['count', 'mean', '50%']]
r['bigger_than_100'] = r['mean'] > 100
r['count'] = r['count'].astype(int)
r

## 1b) Old version from CellProfiler

NOTE: Not tested with recent changes

In [None]:
# load CellProfiler output
obj_file = '/Volumes/davidh-ssd/examples_tiff_n_200_optimal_intensity_8bit_repl/MyExpt_IdentifyPrimaryObjects.csv'
img_file = '/Volumes/davidh-ssd/examples_tiff_n_200_optimal_intensity_8bit_repl/MyExpt_Image.csv'

obj_df = pd.read_csv(obj_file)
img_df = pd.read_csv(img_file)

# fix win pathnames
img_df.PathName_DNA = img_df.PathName_DNA.str.replace('\\', '/')

In [None]:
paths = img_df.PathName_DNA.values[obj_df.ImageNumber.values - 1]
folders = [p.split('/')[-1] for p in paths]

# get condition/replicate from folder name
conditions = ['_'.join(p.split('_')[1:-1]) for p in folders]
replicates = ['_'.join([p.split('_')[0], p.split('_')[-1]]) for p in folders]

# fix two different namings for "old_untreated"
conditions = ['IMR90_untreated_old' if c == 'IMR90_old' else c for c in conditions]

In [None]:
np.unique([conditions, replicates], axis=1, return_counts=True)

In [None]:
# get features from object df
# we comment out RadialDistribution & AreaShape, as we are not really interested
# nucleus shape change or changes to globel chromatin localization
feat_names = [c for c in obj_df.columns if c.startswith('Texture') 
#               or c.startswith('RadialDistribution')
              or c.startswith('Granularity')
               or c.startswith('Intensity')
#               or c.startswith('AreaShape')
              and not 'NormalizedMoment' in c and not 'EulerNumber' in c]

# drop unwanted features
tex_values = obj_df[feat_names].values

## Preprocessing

In [None]:
# for verification
feat_names, tex_values.shape

### Numeric labels, scale features

In [None]:
# encode labels
le = LabelEncoder()
ys = le.fit_transform(conditions)

# make one-hot encoded ys for OvR classification
binarizer = LabelBinarizer()
onehot_y = binarizer.fit_transform(ys)

In [None]:
# we have NaNs -> impute
tex_values = SimpleImputer().fit_transform(tex_values)

# normalize features
scaler = StandardScaler()
tex_values = scaler.fit_transform(tex_values)

# normalize per 'batch'
# for b in np.unique(batches):
#     mask = [bi == b for bi in batches]
#     tex_values[mask] = StandardScaler().fit_transform(tex_values[mask])


# Feature visualization

## 2a) Feature heatmap

In [None]:
# repl_for_plot = replicates
repl_for_plot = bio_replicates

idxs = [idx for idx, _ in sorted(enumerate(zip(conditions, repl_for_plot)), key=lambda x: x[1])]

ylabs = {}
for comb in sorted(set(zip(conditions, repl_for_plot))):
    idx = sorted(zip(conditions, repl_for_plot)).index(comb)
    ylabs[', '.join(comb) + ' (↓)'] = idx

In [None]:
figsize = (15, 15)

# get an aspect ration (cols/rows) so that we reach roughly the desired figure size
aspect_to_fit_figsize = np.divide(*(np.array(figsize[::-1]) / tex_values.shape))

# optional: manual row stretch
row_stretch = 1.0
aspect = aspect_to_fit_figsize * row_stretch


plt.figure(figsize = figsize)
plt.imshow(tex_values[idxs].clip(*np.quantile(tex_values, [0.01, 0.99])),
           aspect=aspect, cmap='coolwarm', interpolation='nearest')

plt.yticks(list(ylabs.values()), list(ylabs.keys()));
plt.xticks(np.arange(len(feat_names)), feat_names, rotation='vertical');
plt.title('Features');

plt.rcParams['pdf.fonttype'] = 42

plt.colorbar(shrink=.5)
plt.tight_layout()

if save_plots:
    plt.savefig(f'/scratch/hoerl/auto_sir_dna_comp/{date_str}_heatmap_{"bright" if only_bright else "all"}replicates.pdf', transparent=True)

### use Feature Aggregation to produce a heatmap of aggregate features

NOTE: not used in classification below

In [None]:
figsize = (15, 15)
n_aggregated_features = 20

# get feature aggregator and aggregate texture feats
agg = FeatureAgglomeration(n_clusters=n_aggregated_features)
tex_values_agg = agg.fit_transform(tex_values)

# get an aspect ration (cols/rows) so that we reach roughly the desired figure size
aspect_to_fit_figsize = np.divide(*(np.array(figsize[::-1]) / tex_values_agg.shape))

# optional: manual row stretch
row_stretch = 1.0
aspect = aspect_to_fit_figsize * row_stretch

plt.figure(figsize = figsize)
plt.imshow(tex_values_agg[idxs].clip(*np.quantile(tex_values_agg, [0.01, 0.99])),
           aspect=aspect, cmap='coolwarm', interpolation='nearest')

plt.yticks(list(ylabs.values()), list(ylabs.keys()));
plt.title('Aggregated Features');

## 2b) tSNE embedding

In [None]:
# calculate tsne
ts = TSNE(perplexity=100, init='pca', learning_rate='auto').fit_transform(tex_values)

# alternative: just do PCA
# ts = PCA().fit_transform(tex_values)

In [None]:
sns.set_palette('bright')
plt.figure(figsize=(10,10))

sns.scatterplot(x=ts.T[0], y=ts.T[1], hue=conditions, alpha=0.35, s=25)
# sns.scatterplot(x=ts.T[0], y=ts.T[1], hue=conditions, alpha=0.65, size=df_skimage.fg_mean)
plt.xlabel('tSNE comp. 1'); plt.ylabel('tSNE comp. 2');
plt.legend(fontsize='large')
plt.rcParams['pdf.fonttype'] = 42

if save_plots:
    plt.savefig(f'/scratch/hoerl/auto_sir_dna_comp/{date_str}_tsne_all_{"bright" if only_bright else "all"}replicates.pdf', transparent=True)

In [None]:
repl_for_plot = replicates
# repl_for_plot = bio_replicates

for condition in np.unique(conditions):
    
    # label: replicate for current class, "others" for all other points
    repl_label = [r if c==condition else 'others' for r, c in zip(repl_for_plot, conditions)]

    # color others gray, replicates in different colors
    cm_ = get_cmap('rainbow', len(np.unique(repl_label))-1)
    sns.set_palette(sns.color_palette(['lightgray'] + [cm_(i) for i in range(len(np.unique(repl_label))-1)]))
    
    # sort so we plot "others" first
    ts_ = ts[np.argsort(repl_label)[::-1]]
    repl_label.sort(reverse=True)
      
    plt.figure(figsize=(10, 10))
    sns.scatterplot(x=ts_.T[0], y=ts_.T[1], hue=repl_label,
                    alpha=0.65, s=25)
    
    plt.title(condition);
    plt.xlabel('tSNE comp. 1'); plt.ylabel('tSNE comp. 2');
    plt.legend(fontsize='large')
    
    plt.rcParams['pdf.fonttype'] = 42

    if save_plots:
        plt.savefig(f'/scratch/hoerl/auto_sir_dna_comp/{date_str}_tsne_{condition}_{"bright" if only_bright else "all"}replicates.pdf', transparent=True)

## Plot example images from a range of tSNE

In [None]:
import h5py as h5
from skimage.exposure import rescale_intensity

def load_single_image_from_h5(filename, dataset_name, norm_quantiles):
    with h5.File(filename, 'r') as fd:
        img = fd[f'/experiment/{dataset_name}/0/0'][...].squeeze()
    return rescale_intensity(img.astype(np.float32), in_range=norm_quantiles)

# min_t = np.array([-40, 0])
# max_t = np.array([-20, 20])

# min_t = np.array([-20, -20])
# max_t = np.array([0, -10])

# min_t = np.array([30, -10])
# max_t = np.array([40, 0])

min_t = np.array([45, 0])
max_t = np.array([55, 20])

n_images_per_class = 4

selection = np.all(ts > min_t, axis=1) & np.all(ts < max_t, axis=1)

df_tsne_range = df_skimage[selection]

imgs = defaultdict(list)
for cell_cls, dfi in df_tsne_range.groupby('cell_class'):
    sample = dfi[['filename', 'dataset_name', 'perc_low', 'perc_high']].sample(min(n_images_per_class, len(dfi)))
    for _, (filename, dataset_name, perc_low, perc_high) in sample.iterrows():
        imgs[cell_cls].append(load_single_image_from_h5(filename, dataset_name, (perc_low, perc_high)))
    
    
fig, axs = plt.subplots(ncols=n_images_per_class, nrows=len(imgs), figsize=(16,12))
for (cl, imgs_cls), axs_cls in zip(imgs.items(), axs):
    for img, ax in zip(imgs_cls, axs_cls):
        ax.imshow(img, cmap='gray')
        ax.set_title(cl)
        ax.axis('off')
fig.tight_layout()

## 3) cross-val and PR curves

In [None]:
figsize = (8, 8)
fig, ax = plt.subplots(figsize=figsize)

# cls = SVC(C=100.0, probability=True, class_weight='balanced')
cls = SVC(C=0.01, probability=True, class_weight='balanced')
# cls = RandomForestClassifier(300, class_weight='balanced')

cv_strategies = (
    (StratifiedKFold(10, shuffle=True), None, '10-Fold Stratified CV'),
    (LeaveOneGroupOut(), [c + '_' + r for c,r in zip(conditions, replicates)], 'Leave-One-Replicate-Out (Technical)'),
    (LeaveOneGroupOut(), [c + '_' + r for c,r in zip(conditions, bio_replicates)], 'Leave-One-Replicate-Out (Biological)')
)

for cv, groups, desc in cv_strategies:
    cv.split(tex_values, ys, groups)

    prob_pred = cross_val_predict(cls, tex_values, ys, cv=cv, groups=groups, n_jobs=-1, method='predict_proba')
    pred = np.argmax(prob_pred, axis=1)

    # we have only 2 classes -> use probabilities for class 1 for further analysis
    if prob_pred.shape[1] == 2:
        prob_pred = prob_pred[:,1:2] # 1-column selection

    # get accuracy, AP & PR curve
    overall_acc = (pred == ys).sum() / len(ys)
    pr, re, _ = precision_recall_curve(onehot_y.ravel(), prob_pred.ravel())
    ap = average_precision_score(onehot_y.ravel(), prob_pred.ravel())

    ax.plot(re, pr, label=f'{desc}\naccuracy={overall_acc:.3f}\nAP={ap:.3f}', lw=2)

# add "random guess" baseline to all plots
_, cts = np.unique(ys, return_counts=True)
baseline_pr = np.mean(cts/len(ys))

ax.plot([0, 1], [baseline_pr, baseline_pr], linestyle='dashed', color='firebrick', lw=2, label='random guess')
    
ax.set_xlim(0, 1)
ax.set_xlabel('Recall')
ax.set_ylim(0, 1)
ax.set_ylabel('Precision')
ax.legend()

### Confusion matrix

In [None]:
# get cross-val prediction and string labels

cv, groups, _ = cv_strategies[2]

pred = cross_val_predict(cls, tex_values, ys, cv=cv, n_jobs=-1, groups=groups)
labs_pred = le.inverse_transform(pred)

In [None]:
conf_mat = defaultdict(lambda : np.zeros(np.max(ys) + 1))

# replicates_ = replicates
replicates_ = bio_replicates

# go through all predictions, increment corresponding row
for cond, repl, lab_pred in zip(conditions, replicates_, labs_pred):
    conf_mat[(cond, repl)][le.transform([lab_pred])[0]] += 1

# get sorted label + number of samples
input_cls = [s[0] + (f'N: {int(s[1].sum())}' ,) for s in sorted(conf_mat.items())]

# make matrix from dict, normalize per-row
mat = np.array([s[1] for s in sorted(conf_mat.items())])
mat = mat / np.sum(mat, axis=1).reshape((-1,1))

# plot as heatmap
plt.figure(figsize=(12,8))
plt.imshow(mat, cmap='Blues', aspect=0.2)
plt.yticks(ticks=np.arange(len(input_cls)), labels=[', '.join(c) for c in input_cls]);
plt.xticks(ticks=np.arange(np.max(ys) + 1), labels=le.inverse_transform(np.arange(np.max(ys) + 1)), rotation='vertical');


plt.rcParams['pdf.fonttype'] = 42
plt.tight_layout()

plt.colorbar(shrink=.8)
if save_plots:
    plt.savefig(f'/scratch/hoerl/auto_sir_dna_comp/{date_str}_confusionmatrix_{"bright" if only_bright else "all"}replicates.pdf', transparent=True)

# Old Cross-Validation Code and plots per fold

In [None]:
# build classifier with optional feature agglomeration
# cls = Pipeline([('cluster_feat', FeatureAgglomeration(n_clusters=50)), ('cls', RandomForestClassifier(n_estimators=400, class_weight='balanced'))])
# cls = Pipeline([('cls', LogisticRegression(max_iter=1000))])
# cls = Pipeline([('cls', SVC(probability=True))])
# cls = Pipeline([('cls', RandomForestClassifier(n_estimators=300))])

# set up cross-val split strategy
cv = StratifiedKFold(5, shuffle=True)
groups = None

# cv = LeaveOneGroupOut()
# groups = LabelEncoder().fit_transform(bio_replicates)

# cls = RandomForestClassifier(n_estimators=300)
cls = SVC(C=100.0, probability=True)

# NB: for cross_val_score, use the non-onehot ys
# cross_val_score(ovr_cls, tex_values, ys, cv=cv, n_jobs=-1)

In [None]:
res_perclass = defaultdict(list)
res_avg = defaultdict(list)

prob_pred = cross_val_predict(cls, tex_values, ys, cv=cv, groups=groups, n_jobs=-1, method='predict_proba')
pred = np.argmax(prob_pred, axis=1)

# we have only 2 classes -> use probabilities for class 1 for further analysis
if prob_pred.shape[1] == 2:
    prob_pred = prob_pred[:,1:2] # 1-column selection

for i, (train, test) in enumerate(cv.split(tex_values, ys, groups=groups)):
       
    # get overall multiclass accuracy    
    overall_acc = (pred[test] == ys[test]).sum() / len(test)
    
    # get PR / ROC metrics for each class
    for j in range(prob_pred[test].shape[1]):
        
        # NB: in 2-class case, onehot_y will be just y
        fpr, tpr, _ = roc_curve(onehot_y[test, j], prob_pred[test, j])
        auc_roc = auc(fpr, tpr)
        
        pr, re, _ = precision_recall_curve(onehot_y[test, j], prob_pred[test, j])
        ap = average_precision_score(onehot_y[test, j], prob_pred[test, j])
        
        res_perclass[(j, 'pr')].append(pr)
        res_perclass[(j, 're')].append(re)
        res_perclass[(j, 'fpr')].append(fpr)
        res_perclass[(j, 'tpr')].append(tpr)
        res_perclass[(j, 'auc')].append(auc_roc)
        res_perclass[(j, 'ap')].append(ap)
        

    # get PR / ROC micro-average
    fpr, tpr, _ = roc_curve(onehot_y[test].ravel(), prob_pred[test].ravel())
    auc_roc = auc(fpr, tpr)
    
    pr, re, _ = precision_recall_curve(onehot_y[test].ravel(), prob_pred[test].ravel())
    ap = average_precision_score(onehot_y[test].ravel(), prob_pred[test].ravel())
        
    res_avg['acc'].append(overall_acc)
    res_avg['pr'].append(pr)
    res_avg['re'].append(re)
    res_avg['fpr'].append(fpr)
    res_avg['tpr'].append(tpr)
    res_avg['auc'].append(auc_roc)
    res_avg['ap'].append(ap)

## Plot ROC + PR curves per class / CV-fold

In [None]:
# set up plots, cmap
cmap = get_cmap('Set3', np.max(ys)+1)
fig_pr, axs_pr = plt.subplots(ncols=2, nrows=1, figsize=(20,8))
fig_roc, axs_roc = plt.subplots(ncols=2, nrows=1, figsize=(20,8))


# add "random guess" baseline to all plots
_, cts = np.unique(ys, return_counts=True)
baseline_pr = np.mean(cts/len(ys))

axs_pr[0].plot([0, 1], [baseline_pr, baseline_pr], linestyle='dashed', color='gray', lw=3, label='random guess')
axs_pr[1].plot([0, 1], [baseline_pr, baseline_pr], linestyle='dashed', color='gray', lw=3, label='random guess')
axs_pr[0].plot([0, 1, 1], [1, 1, 0], linestyle='dashed', color='darkgreen', lw=3, label='perfect classifier')
axs_pr[1].plot([0, 1, 1], [1, 1, 0], linestyle='dashed', color='darkgreen', lw=3, label='perfect classifier')

axs_roc[0].plot([0,1], [0,1], linestyle='dashed', color='gray', lw=3, label='random guess')
axs_roc[1].plot([0,1], [0,1], linestyle='dashed', color='gray', lw=3, label='random guess')
axs_roc[0].plot([0, 0, 1], [0, 1, 1], linestyle='dashed', color='darkgreen', lw=3, label='perfect classifier')
axs_roc[1].plot([0, 0, 1], [0, 1, 1], linestyle='dashed', color='darkgreen', lw=3, label='perfect classifier')


# curves for micro-averages
x_range = np.linspace(0, 1, 200)
tprs_interp = []
prs_interp = []
for tpr, fpr, pr, re in zip(res_avg['tpr'], res_avg['fpr'], res_avg['pr'], res_avg['re']):
    
    # plot raw curves for split
    axs_pr[1].plot(re, pr, color='steelblue', alpha=0.5)   
    axs_roc[1].plot(fpr, tpr, color='steelblue', alpha=0.5)
    
    # interpolate curves at defined x locations
    tprs_interp.append(interp(x_range, fpr, tpr))
    # NB: invert re, pr as they start with recall 1, but we go from 0 to 1
    prs_interp.append(interp(x_range, re[::-1], pr[::-1]))

# average interpolated curves
tpr_mean = np.mean(np.array(tprs_interp), axis=0)
pr_mean = np.mean(np.array(prs_interp), axis=0)

# plot interpolated curves + some info
label_roc = f'''micro-average AUC: {np.round(np.mean(res_avg["auc"]), 2)} +- {np.round(np.std(res_avg["auc"]), 2)}
accuracy: {np.round(np.mean(res_avg["acc"]), 2)} +- {np.round(np.std(res_avg["auc"]), 2)}'''
label_pr = f'''micro-average AP: {np.round(np.mean(res_avg["ap"]), 2)} +- {np.round(np.std(res_avg["ap"]), 2)}
accuracy: {np.round(np.mean(res_avg["acc"]), 2)} +- {np.round(np.std(res_avg["auc"]), 2)}'''
axs_roc[1].plot(x_range, tpr_mean, color='steelblue', lw=3,
                label=label_roc)
axs_pr[1].plot(x_range, pr_mean, color='steelblue', lw=3,
                label=label_pr)


# curves per class
# similar to average above
for j in range(np.max(ys)+1 if np.max(ys)>1 else 1):
    x_range = np.linspace(0, 1, 200)
    tprs_interp = []
    prs_interp = []
    for tpr, fpr, pr, re in zip(res_perclass[(j, 'tpr')], res_perclass[(j, 'fpr')],
                                res_perclass[(j, 'pr')], res_perclass[(j, 're')]):
        tprs_interp.append(interp(x_range, fpr, tpr))
        prs_interp.append(interp(x_range, re[::-1], pr[::-1]))
        
        axs_roc[0].plot(fpr, tpr, color=cmap(j), alpha=0.5)
        axs_pr[0].plot(re, pr, color=cmap(j), alpha=0.5)
        
    tpr_mean = np.mean(np.array(tprs_interp), axis=0)
    pr_mean = np.mean(np.array(prs_interp), axis=0)

    axs_roc[0].plot(x_range, tpr_mean, lw=3, color=cmap(j),
                    label=f'{le.inverse_transform([j])[0]} AUC: {np.round(np.mean(res_perclass[(j, "auc")]), 2)} +- {np.round(np.std(res_perclass[(j,"auc")]), 2)}')
    axs_pr[0].plot(x_range, pr_mean, lw=3, color=cmap(j),
                    label=f'{le.inverse_transform([j])[0]} AP: {np.round(np.mean(res_perclass[(j,"ap")]), 2)} +- {np.round(np.std(res_perclass[(j,"ap")]), 2)}')

    
# labels, legends, etc.
axs_pr[0].set_title('PR curve (per class)')
axs_pr[0].set_xlabel('Recall')
axs_pr[0].set_ylabel('Precision')
axs_pr[1].set_title('PR curve (micro-average)')
axs_pr[1].set_xlabel('Recall')
axs_pr[1].set_ylabel('Precision')

axs_roc[0].set_title('ROC curve (per class)')
axs_roc[0].set_xlabel('False Positive Rate')
axs_roc[0].set_ylabel('True Positive Rate')
axs_roc[1].set_title('ROC curve (micro-average)')
axs_roc[1].set_xlabel('False Positive Rate')
axs_roc[1].set_ylabel('True Positive Rate')
    
axs_roc[0].legend(fontsize='large')
axs_pr[0].legend(fontsize='large')
axs_roc[1].legend(fontsize='large')
axs_pr[1].legend(fontsize='large')

plt.rcParams['pdf.fonttype'] = 42
if save_plots:
    fig_pr.savefig(f'/scratch/hoerl/auto_sir_dna_comp/{date_str}_prcurves_{"bright" if only_bright else "all"}replicates.pdf', transparent=True)
    fig_roc.savefig(f'/scratch/hoerl/auto_sir_dna_comp/{date_str}_roccurves_{"bright" if only_bright else "all"}replicates.pdf', transparent=True)

# OLD: Grid search hyperparameters

Requires cls to be a pipeline of Feature aggregation followed by classifier

In [None]:
n_cluster_grid = np.arange(10, len(feat_names), 10)
n_estimators_grid = np.arange(100, 401, 50)
n_cluster_grid

In [None]:
gs = GridSearchCV(cls, {
    'cluster_feat__n_clusters' : n_cluster_grid,
    'cls__n_estimators': n_estimators_grid},
                  cv=cv, n_jobs=-1)
gs.fit(tex_values, ys)

In [None]:
gs.best_estimator_

In [None]:
cv_scores = cross_val_score(gs.best_estimator_, tex_values, ys, cv=cv)
cv_scores.mean(), cv_scores

In [None]:
# set classifier to use to best result from CV
cls = gs.best_estimator_