In [1]:
import itertools
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.io

from collections import defaultdict
from tqdm.auto import tqdm
from joblib import Parallel, delayed
import re
import h5py
import napari
import seaborn as sns

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
p_dir = (Path().cwd().parents[0]).absolute()

module_path = str(p_dir / "src")

if module_path not in sys.path:
    sys.path.append(module_path)

In [4]:
data_dir = (Path().cwd().parents[0] / 'data').absolute()

In [5]:
# import PPIGraph

# Read per cell data

In [23]:
# Read PPI dfs
PPI_save_path =  data_dir / '9PPI' / 'PPI'

dfs = []
for path in os.listdir(PPI_save_path):
    if 'csv' in path:
        df = pd.read_csv(PPI_save_path / path)
        dfs.append(df)

df = pd.concat(dfs)

# Group by location (Cell vs Nuclei)
g = df.groupby(['Condition', 'FOV', 'PPI', 'Cyto']).size()
df_cell = pd.DataFrame({'Count Cyto': g}).reset_index()
df_cell = df_cell[df_cell.Cyto != 0]
df_cell.columns = ['Condition', 'FOV', 'PPI', 'Id', 'Count_cell']

g = df.groupby(['Condition', 'FOV', 'PPI', 'Nuclei']).size()
df_nuclei = pd.DataFrame({'Count Nuclei': g}).reset_index()
df_nuclei= df_nuclei[df_nuclei.Nuclei != 0]
df_nuclei.columns = ['Condition', 'FOV', 'PPI', 'Id', 'Count_nuclei']

# Concat 
df_all = df_cell.merge(df_nuclei, how='left', on=['Condition', 'FOV', 'PPI', 'Id']).fillna(0)
df_all['Count_nuclei'] = df_all['Count_nuclei'].astype(int)
df_all['Count_cyto'] = df_all['Count_cell'] - df_all['Count_nuclei'] # Cyto count

# Ranme columns
df_all.columns = ['Condition', 'FOV', 'PPI', 'Id', 'Cell', 'Nuclei',
       'Cyto']

# Plot descrition
display(df_all.describe())
display(df_all.groupby(['Condition', 'FOV', 'Id']).sum().describe())

Unnamed: 0,Id,Cell,Nuclei,Cyto
count,10947.0,10947.0,10947.0,10947.0
mean,205.209829,20.432995,6.469078,13.963917
std,124.17242,31.579238,9.234348,26.709337
min,5.0,1.0,0.0,-47.0
25%,99.0,3.0,0.0,1.0
50%,198.0,10.0,3.0,5.0
75%,298.0,23.0,9.0,15.0
max,519.0,706.0,96.0,691.0


  display(df_all.groupby(['Condition', 'FOV', 'Id']).sum().describe())


Unnamed: 0,Cell,Nuclei,Cyto
count,1491.0,1491.0,1491.0
mean,150.020121,47.496311,102.52381
std,85.268889,26.813884,85.771626
min,15.0,0.0,-131.0
25%,91.0,32.0,45.5
50%,134.0,47.0,87.0
75%,187.0,63.0,138.0
max,1132.0,234.0,1112.0


In [24]:
# Filter out by maximum number of counts per cell
min_count = 100
max_count = 400

df_all = df_all.groupby(['Condition', 'FOV', 'Id']).filter(lambda x: x['Cell'].sum() > min_count)
df_all = df_all.groupby(['Condition', 'FOV', 'Id']).filter(lambda x: (x['Cell'] < max_count).all())
# df_all = df_all.groupby(['Condition', 'FOV', 'Id']).filter(lambda x: (x['Cyto'] >= 0).all())

# Plot descrition
display(df_all.describe())
display(df_all.groupby(['Condition', 'FOV', 'Id']).sum().describe())

Unnamed: 0,Id,Cell,Nuclei,Cyto
count,7943.0,7943.0,7943.0,7943.0
mean,205.982752,23.991313,6.577112,17.414201
std,127.745461,34.790571,9.539739,29.076617
min,8.0,1.0,0.0,-40.0
25%,93.0,3.0,0.0,2.0
50%,195.0,13.0,3.0,8.0
75%,300.0,26.0,9.0,19.0
max,519.0,395.0,96.0,380.0


  display(df_all.groupby(['Condition', 'FOV', 'Id']).sum().describe())


Unnamed: 0,Cell,Nuclei,Cyto
count,1047.0,1047.0,1047.0
mean,182.008596,49.896848,132.111748
std,75.522975,27.789169,77.818659
min,101.0,0.0,-130.0
25%,128.0,34.0,80.5
50%,161.0,49.0,117.0
75%,211.0,66.0,163.0
max,703.0,234.0,703.0


# ML model

In [25]:
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn import preprocessing, metrics
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import train_test_split, KFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import StandardScaler
import wandb

In [39]:
scaler = StandardScaler()

In [40]:
condition = 'ML'
project_name = f'PLA_9PPI_batch'

In [41]:
models = {
    'Adaboost': AdaBoostClassifier(),
    'DecisionTree': DecisionTreeClassifier(),
    'GradientBoosting' : GradientBoostingClassifier(),
    'NaiveBayes': GaussianNB(),
    'RandomForest': RandomForestClassifier(), 
    'SVM': SVC(probability =True),
    'LogisticRegression':  LogisticRegression(),
    'MLP': MLPClassifier(random_state=1, max_iter=100, hidden_layer_sizes=[16, 16, 16])
}



In [42]:
# Get dataframe seperate into count per cell 
df_count = df_all.pivot(index=['Condition', 'FOV', 'Id'], columns=['PPI']).fillna(0)
df_count_cell = df_count.iloc[:, df_count.columns.get_level_values(0)=='Cell']
df_count_cell = pd.DataFrame(scaler.fit_transform(df_count_cell), index=df_count_cell.index, columns=df_count_cell.columns, )

# Get dataframe seperate into count per nuclei and cyto
df_count_cyto_nuclei = df_count.iloc[:, df_count.columns.get_level_values(0)!='Cell']

# Get Condition into numerical label
le = preprocessing.LabelEncoder()
y = le.fit_transform(df_count.index.get_level_values(0))

In [43]:
df_count_cell

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Cell,Cell,Cell,Cell,Cell,Cell,Cell,Cell,Cell,Cell,Cell
Unnamed: 0_level_1,Unnamed: 1_level_1,PPI,Bim & Tom20,Cyclin D1 & CDK2,Cyclin D1 & CDK2 - re,Cyclin E & CDK4,Mcl-1 & BAK,NF-Kb & p-P90rsk,NF-Kb & p-P90rsk - re,P-AKT & mTOR,Sox2 & Oct4,TEAD1 & YAP,p-ERK & c-MYC
Condition,FOV,Id,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
HCC827Ctrl,FW1,13,1.824241,1.951270,1.144715,-0.327841,2.181185,2.520916,-0.521190,-0.270094,2.182217,-0.338654,2.785298
HCC827Ctrl,FW1,14,3.096289,3.672417,0.040091,-0.327841,0.029457,3.384905,-0.521190,-0.270094,3.569752,0.759087,0.481068
HCC827Ctrl,FW1,15,-0.338241,-0.586319,-0.180833,0.075034,0.029457,-0.762241,-0.521190,0.117287,-0.592854,0.210217,1.294325
HCC827Ctrl,FW1,17,-0.465446,-1.049705,-0.180833,-0.327841,-0.687786,-1.107836,-0.521190,-0.270094,-1.082572,0.210217,0.481068
HCC827Ctrl,FW1,18,0.806602,0.296321,-0.622683,-0.327841,-0.687786,0.447343,-0.521190,-0.270094,-0.021516,-0.338654,-0.738819
...,...,...,...,...,...,...,...,...,...,...,...,...,...
HCC827Osim,FW2,336,-0.592650,0.252189,0.040091,-0.327841,-0.687786,-0.416646,-0.521190,0.117287,-0.266375,0.210217,-0.874362
HCC827Osim,FW2,339,1.315422,-0.056735,1.365640,0.477909,0.746700,0.447343,1.029055,-0.270094,-0.021516,-0.338654,-0.061104
HCC827Osim,FW2,343,-0.465446,-0.299461,-0.180833,-0.327841,-0.687786,-0.330247,1.029055,-0.270094,0.957921,-0.338654,-1.009905
HCC827Osim,FW2,344,-0.465446,-0.652517,2.470263,0.075034,-0.687786,-0.330247,-0.521190,-0.270094,-0.756093,-0.338654,-0.467733


In [44]:
## Train on P10 an'P10d evaluate on P21
train = 'FW1'
held = 'FW2'

# Run model on cell count
df_train = df_count_cell[df_count.index.get_level_values('FOV') == train]
df_held = df_count_cell[df_count.index.get_level_values('FOV') == held]
X = df_train.values
y = le.transform(df_train.index.get_level_values(0))
X_held = df_held.values
y_held = le.transform(df_held.index.get_level_values(0))

print(len(X), len(y), len(X_held), len(y_held))
print(np.unique(y, return_counts=True), np.unique(y_held, return_counts=True))

460 460 587 587
(array([0, 1]), array([275, 185], dtype=int64)) (array([0, 1]), array([405, 182], dtype=int64))


In [46]:
## Train on P10 an'P10d evaluate on P21
train = 'FW1'
held = 'FW2'

# Run model on cell count
df_train = df_count_cell[df_count_cell.index.get_level_values('FOV') == train]
df_held = df_count_cell[df_count_cell.index.get_level_values('FOV') == held]

X = df_train.values
# X = scaler.fit_transform(X)
y = le.transform(df_train.index.get_level_values(0))

X_held = df_held.values
# X_held = scaler.fit_transform(X_held)
y_held = le.transform(df_held.index.get_level_values(0))

print(len(X), len(y), len(X_held), len(y_held))

# K fold training
kfold = KFold(n_splits = 3, shuffle = True, random_state = 0)
for model_name, model in models.items():
    for k, (train_index, test_index) in enumerate(kfold.split(X)):
        # Split the dataset
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index] 

        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        y_probas = model.predict_proba(X_test)
        y_pred_held = model.predict(X_held)
        y_probas_held = model.predict_proba(X_held)

        run = wandb.init(project=project_name, group=model_name+'_'+train, name=model_name+f'_{train}_{k}')

        accuracy = metrics.accuracy_score(y_test, y_pred)
        accuracy_held = metrics.accuracy_score(y_held, y_pred_held)
        f1 = metrics.f1_score(y_test, y_pred)
        f1_held = metrics.f1_score(y_held, y_pred_held)
        # auc = metrics.roc_auc_score(y_test, y_probas[:, 1])
        # auc_held = metrics.roc_auc_score(y_held, y_probas_held[:, 1])
        auc = metrics.roc_auc_score(y_test, y_pred)
        auc_held = metrics.roc_auc_score(y_held, y_pred_held)
        wandb.log({"accuracy": accuracy, 'accuracy_held': accuracy_held, 'f1':f1, 'f1_held': f1_held, 'auc': auc, 'auc_held': auc_held})
    run.finish()

460 460 587 587


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mthoomas[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135943…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.74675
accuracy_held,0.74617
auc,0.73142
auc_held,0.72378
f1,0.67769
f1_held,0.61893


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135919…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.7451
accuracy_held,0.70187
auc,0.73766
auc_held,0.68109
f1,0.67769
f1_held,0.56576


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095308…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.72549
accuracy_held,0.74617
auc,0.70952
auc_held,0.74042
f1,0.65
f1_held,0.63923


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.67532
accuracy_held,0.6678
auc,0.66354
auc_held,0.63824
f1,0.60317
f1_held,0.51128


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095308…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.65359
accuracy_held,0.67973
auc,0.65054
auc_held,0.66958
f1,0.58268
f1_held,0.5545


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095383…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.71242
accuracy_held,0.63884
auc,0.70556
auc_held,0.6339
f1,0.65625
f1_held,0.51598


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135931…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.75325
accuracy_held,0.75468
auc,0.72569
auc_held,0.72541
f1,0.65455
f1_held,0.62105


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299651…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76471
accuracy_held,0.74276
auc,0.75009
auc_held,0.72887
f1,0.68966
f1_held,0.62531


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095316…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.79085
accuracy_held,0.73595
auc,0.78889
auc_held,0.71939
f1,0.75385
f1_held,0.61347


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135919…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.7013
accuracy_held,0.73254
auc,0.6474
auc_held,0.63373
f1,0.47727
f1_held,0.46416


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.69281
accuracy_held,0.73765
auc,0.63512
auc_held,0.65861
f1,0.49462
f1_held,0.51572


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.69281
accuracy_held,0.73595
auc,0.66508
auc_held,0.67553
f1,0.57658
f1_held,0.5481


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299651…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77273
accuracy_held,0.77342
auc,0.74913
auc_held,0.74353
f1,0.69027
f1_held,0.64533


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299677…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77778
accuracy_held,0.74617
auc,0.76397
auc_held,0.72075
f1,0.7069
f1_held,0.61499


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095408…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77124
accuracy_held,0.76491
auc,0.76032
auc_held,0.72828
f1,0.71545
f1_held,0.625


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666592937, max=1.0)…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.81818
accuracy_held,0.78365
auc,0.79028
auc_held,0.75245
f1,0.74074
f1_held,0.65768


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095308…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.78431
accuracy_held,0.7615
auc,0.76252
auc_held,0.74094
f1,0.7027
f1_held,0.64103


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095292…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76471
accuracy_held,0.75298
auc,0.74048
auc_held,0.72418
f1,0.67857
f1_held,0.61942


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.78571
accuracy_held,0.7632
auc,0.76476
auc_held,0.74066
f1,0.71304
f1_held,0.64083


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76471
accuracy_held,0.71891
auc,0.74673
auc_held,0.71461
f1,0.68421
f1_held,0.60808


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.75163
accuracy_held,0.74617
auc,0.74127
auc_held,0.73739
f1,0.69355
f1_held,0.6357




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…



0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.83117
accuracy_held,0.75809
auc,0.81042
auc_held,0.73242
f1,0.77193
f1_held,0.63021


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…



0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77124
accuracy_held,0.73083
auc,0.75535
auc_held,0.7293
f1,0.69565
f1_held,0.62559


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095221…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.73203
accuracy_held,0.74787
auc,0.70794
auc_held,0.73257
f1,0.63717
f1_held,0.63


In [47]:
## Train on P10 an'P10d evaluate on P21
train = 'FW2'
held = 'FW1'

# Run model on cell count
df_train = df_count_cell[df_count_cell.index.get_level_values('FOV') == train]
df_held = df_count_cell[df_count_cell.index.get_level_values('FOV') == held]

X = df_train.values
# X = scaler.fit_transform(X)
y = le.transform(df_train.index.get_level_values(0))

X_held = df_held.values
# X_held = scaler.fit_transform(X_held)
y_held = le.transform(df_held.index.get_level_values(0))

print(len(X), len(y), len(X_held), len(y_held))

# K fold training
kfold = KFold(n_splits = 3, shuffle = True, random_state = 0)
for model_name, model in models.items():
    for k, (train_index, test_index) in enumerate(kfold.split(X)):
        # Split the dataset
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index] 

        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        y_probas = model.predict_proba(X_test)
        y_pred_held = model.predict(X_held)
        y_probas_held = model.predict_proba(X_held)

        run = wandb.init(project=project_name, group=model_name+'_'+train, name=model_name+f'_{train}_{k}')

        accuracy = metrics.accuracy_score(y_test, y_pred)
        accuracy_held = metrics.accuracy_score(y_held, y_pred_held)
        f1 = metrics.f1_score(y_test, y_pred)
        f1_held = metrics.f1_score(y_held, y_pred_held)
        # auc = metrics.roc_auc_score(y_test, y_probas[:, 1])
        # auc_held = metrics.roc_auc_score(y_held, y_probas_held[:, 1])
        auc = metrics.roc_auc_score(y_test, y_pred)
        auc_held = metrics.roc_auc_score(y_held, y_pred_held)
        wandb.log({"accuracy": accuracy, 'accuracy_held': accuracy_held, 'f1':f1, 'f1_held': f1_held, 'auc': auc, 'auc_held': auc_held})
    run.finish()

587 587 460 460


VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095229…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.7602
accuracy_held,0.7413
auc,0.6903
auc_held,0.69784
f1,0.56881
f1_held,0.59661


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135919…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77041
accuracy_held,0.73478
auc,0.72157
auc_held,0.69061
f1,0.6281
f1_held,0.58503


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095325…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.7641
accuracy_held,0.7413
auc,0.67838
auc_held,0.70138
f1,0.53061
f1_held,0.60726


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.70408
accuracy_held,0.68478
auc,0.66659
auc_held,0.65499
f1,0.54688
f1_held,0.56193


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.69388
accuracy_held,0.70217
auc,0.64191
auc_held,0.66069
f1,0.51613
f1_held,0.54785


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.73333
accuracy_held,0.7
auc,0.68682
auc_held,0.67125
f1,0.54386
f1_held,0.58434


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135844…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77041
accuracy_held,0.73696
auc,0.70643
auc_held,0.69774
f1,0.59459
f1_held,0.60328


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666592937, max=1.0)…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299703…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77551
accuracy_held,0.73043
auc,0.73262
auc_held,0.68875
f1,0.64516
f1_held,0.58667


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095308…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.80513
accuracy_held,0.7413
auc,0.7302
auc_held,0.69872
f1,0.61224
f1_held,0.59933


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095292…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.71939
accuracy_held,0.67826
auc,0.72545
auc_held,0.67961
f1,0.62585
f1_held,0.63184


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666592937, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095308…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.68878
accuracy_held,0.74348
auc,0.70259
auc_held,0.73504
f1,0.62112
f1_held,0.68449


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.7641
accuracy_held,0.75217
auc,0.70795
auc_held,0.71931
f1,0.57407
f1_held,0.64151


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299677…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76531
accuracy_held,0.72826
auc,0.66803
auc_held,0.67631
f1,0.52083
f1_held,0.54874


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299677…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76531
accuracy_held,0.71087
auc,0.72128
auc_held,0.66088
f1,0.62903
f1_held,0.53004


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666592937, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095292…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.8359
accuracy_held,0.73913
auc,0.78089
auc_held,0.69248
f1,0.68627
f1_held,0.58333


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095466…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77041
accuracy_held,0.75
auc,0.6761
auc_held,0.69803
f1,0.53608
f1_held,0.58182


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095308…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.80102
accuracy_held,0.74348
auc,0.73765
auc_held,0.6917
f1,0.64865
f1_held,0.57246


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.81538
accuracy_held,0.73696
auc,0.70768
auc_held,0.68624
f1,0.5814
f1_held,0.56631


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.79082
accuracy_held,0.75652
auc,0.72135
auc_held,0.71145
f1,0.61682
f1_held,0.61379


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.79592
accuracy_held,0.75
auc,0.75888
auc_held,0.70688
f1,0.68254
f1_held,0.61017


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095416…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.81538
accuracy_held,0.74565
auc,0.74907
auc_held,0.70059
f1,0.64
f1_held,0.59794




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…



0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.79592
accuracy_held,0.7587
auc,0.73375
auc_held,0.71592
f1,0.63636
f1_held,0.62373


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…



0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.81122
accuracy_held,0.73913
auc,0.75975
auc_held,0.69514
f1,0.68376
f1_held,0.59184


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666592937, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.86667
accuracy_held,0.76087
auc,0.80793
auc_held,0.71862
f1,0.73469
f1_held,0.62838
