In [7]:
import itertools
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.io

from collections import defaultdict
from tqdm.auto import tqdm
from joblib import Parallel, delayed
import re
import h5py
import napari
import seaborn as sns

In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
p_dir = (Path().cwd().parents[0]).absolute()

module_path = str(p_dir / "src")

if module_path not in sys.path:
    sys.path.append(module_path)

In [10]:
data_dir = (Path().cwd().parents[0] / 'data').absolute()

In [11]:
# import PPIGraph

# Read per cell data

In [13]:
# Read PPI dfs
PPI_save_path =  Path(r'Y:\coskun-lab\Thomas\15_PLA\data\OCT Cell Culture\Whole\PPI')

dfs = []
for path in os.listdir(PPI_save_path):
    if 'csv' in path:
        df = pd.read_csv(PPI_save_path / path)
        dfs.append(df)

df = pd.concat(dfs)

# Group by location (Cell vs Nuclei)
g = df.groupby(['Condition', 'FOV', 'PPI', 'Cyto']).size()
df_cell = pd.DataFrame({'Count Cyto': g}).reset_index()
df_cell = df_cell[df_cell.Cyto != 0]
df_cell.columns = ['Condition', 'FOV', 'PPI', 'Id', 'Count_cell']

g = df.groupby(['Condition', 'FOV', 'PPI', 'Nuclei']).size()
df_nuclei = pd.DataFrame({'Count Nuclei': g}).reset_index()
df_nuclei= df_nuclei[df_nuclei.Nuclei != 0]
df_nuclei.columns = ['Condition', 'FOV', 'PPI', 'Id', 'Count_nuclei']

# Concat 
df_all = df_cell.merge(df_nuclei, how='left', on=['Condition', 'FOV', 'PPI', 'Id']).fillna(0)
df_all['Count_nuclei'] = df_all['Count_nuclei'].astype(int)
df_all['Count_cyto'] = df_all['Count_cell'] - df_all['Count_nuclei'] # Cyto count

# Ranme columns
df_all.columns = ['Condition', 'FOV', 'PPI', 'Id', 'Cell', 'Nuclei',
       'Cyto']

# Plot descrition
display(df_all.describe())
display(df_all.groupby(['Condition', 'FOV', 'Id']).sum().describe())

Unnamed: 0,Id,Cell,Nuclei,Cyto
count,9655.0,9655.0,9655.0,9655.0
mean,332.755256,18.820404,8.390989,10.429415
std,180.615431,29.040642,10.916339,21.016527
min,2.0,1.0,0.0,-7.0
25%,176.0,4.0,2.0,2.0
50%,326.0,11.0,5.0,5.0
75%,475.0,22.0,11.0,11.0
max,778.0,874.0,202.0,774.0


  display(df_all.groupby(['Condition', 'FOV', 'Id']).sum().describe())


Unnamed: 0,Cell,Nuclei,Cyto
count,2081.0,2081.0,2081.0
mean,87.319077,38.930802,48.388275
std,78.104479,29.733528,57.558508
min,1.0,0.0,-13.0
25%,41.0,21.0,17.0
50%,68.0,33.0,33.0
75%,112.0,51.0,61.0
max,997.0,336.0,842.0


In [14]:
# Filter out by maximum number of counts per cell
min_count = 20
max_count = 70

df_all = df_all.groupby(['Condition', 'FOV', 'Id']).filter(lambda x: x['Cell'].sum() > min_count)
df_all = df_all.groupby(['Condition', 'FOV', 'Id']).filter(lambda x: (x['Cell'] < max_count).all())
# df_all = df_all.groupby(['Condition', 'FOV', 'Id']).filter(lambda x: (x['Cyto'] >= 0).all())

# Plot descrition
display(df_all.describe())
display(df_all.groupby(['Condition', 'FOV', 'Id']).sum().describe())

Unnamed: 0,Id,Cell,Nuclei,Cyto
count,7129.0,7129.0,7129.0,7129.0
mean,339.626876,14.332866,7.021321,7.311544
std,181.049865,13.471075,6.836728,8.711061
min,20.0,1.0,0.0,-5.0
25%,185.0,5.0,2.0,2.0
50%,329.0,10.0,5.0,4.0
75%,481.0,19.0,10.0,10.0
max,752.0,69.0,46.0,64.0


  display(df_all.groupby(['Condition', 'FOV', 'Id']).sum().describe())


Unnamed: 0,Cell,Nuclei,Cyto
count,1510.0,1510.0,1510.0
mean,67.668212,33.149007,34.519205
std,31.996032,16.639663,24.150085
min,21.0,0.0,1.0
25%,43.0,22.0,17.0
50%,63.0,31.0,29.0
75%,87.0,43.0,45.0
max,292.0,143.0,246.0


# ML model

In [15]:
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn import preprocessing, metrics
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import train_test_split, KFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import StandardScaler
import wandb

In [16]:
scaler = StandardScaler()

In [17]:
condition = 'ML'
project_name = f'PLA_5PPI_batch'

In [18]:
models = {
    'Adaboost': AdaBoostClassifier(),
    'DecisionTree': DecisionTreeClassifier(),
    'GradientBoosting' : GradientBoostingClassifier(),
    'NaiveBayes': GaussianNB(),
    'RandomForest': RandomForestClassifier(), 
    'SVM': SVC(probability =True),
    'LogisticRegression':  LogisticRegression(),
    'MLP': MLPClassifier(random_state=1, max_iter=100, hidden_layer_sizes=[16, 16, 16])
}



In [19]:
# Get dataframe seperate into count per cell 
df_count = df_all.pivot(index=['Condition', 'FOV', 'Id'], columns=['PPI']).fillna(0)
df_count_cell = df_count.iloc[:, df_count.columns.get_level_values(0)=='Cell']
df_count_cell = pd.DataFrame(scaler.fit_transform(df_count_cell), index=df_count_cell.index, columns=df_count_cell.columns, )

# Get dataframe seperate into count per nuclei and cyto
df_count_cyto_nuclei = df_count.iloc[:, df_count.columns.get_level_values(0)!='Cell']

# Get Condition into numerical label
le = preprocessing.LabelEncoder()
y = le.fit_transform(df_count.index.get_level_values(0))

In [20]:
df_count_cell

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Cell,Cell,Cell,Cell,Cell
Unnamed: 0_level_1,Unnamed: 1_level_1,PPI,CylinE & CDK2,Mcl-1 & BAK,P-ERK & c-MYC,TEAD1 & YAP1,p-AKT & mTOR
Condition,FOV,Id,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
HCC827Ctrl,FW1,52,-0.984185,-0.372357,-0.034931,0.002731,-0.295713
HCC827Ctrl,FW1,53,0.556125,-0.544249,0.970306,-0.924964,-0.295713
HCC827Ctrl,FW1,54,1.396295,-0.372357,1.206832,0.002731,-0.450595
HCC827Ctrl,FW1,55,-0.144016,-0.544249,-0.921904,-0.821887,-0.605478
HCC827Ctrl,FW1,56,-0.704129,-0.200465,1.265964,0.002731,0.478700
...,...,...,...,...,...,...,...
HCC827Osim,FW2,519,-0.844157,-0.544249,-0.567115,0.002731,0.633582
HCC827Osim,FW2,520,-0.564100,-0.028573,0.556385,-0.409578,-0.915243
HCC827Osim,FW2,523,-1.404270,-0.200465,-1.099299,-1.234196,-0.450595
HCC827Osim,FW2,524,-1.264241,-0.028573,-0.034931,-1.646504,0.323817


In [22]:
## Train on P10 an'P10d evaluate on P21
train = 'FW1'
held = 'FW2'

# Run model on cell count
df_train = df_count_cell[df_count.index.get_level_values('FOV') == train]
df_held = df_count_cell[df_count.index.get_level_values('FOV') == held]
X = df_train.values
y = le.transform(df_train.index.get_level_values(0))
X_held = df_held.values
y_held = le.transform(df_held.index.get_level_values(0))

print(len(X), len(y), len(X_held), len(y_held))
print(np.unique(y, return_counts=True), np.unique(y_held, return_counts=True))

832 832 678 678
(array([0, 1]), array([522, 310], dtype=int64)) (array([0, 1]), array([428, 250], dtype=int64))


In [23]:
## Train on P10 an'P10d evaluate on P21
train = 'FW1'
held = 'FW2'

# Run model on cell count
df_train = df_count_cell[df_count_cell.index.get_level_values('FOV') == train]
df_held = df_count_cell[df_count_cell.index.get_level_values('FOV') == held]

X = df_train.values
# X = scaler.fit_transform(X)
y = le.transform(df_train.index.get_level_values(0))

X_held = df_held.values
# X_held = scaler.fit_transform(X_held)
y_held = le.transform(df_held.index.get_level_values(0))

print(len(X), len(y), len(X_held), len(y_held))

# K fold training
kfold = KFold(n_splits = 3, shuffle = True, random_state = 0)
for model_name, model in models.items():
    for k, (train_index, test_index) in enumerate(kfold.split(X)):
        # Split the dataset
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index] 

        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        y_probas = model.predict_proba(X_test)
        y_pred_held = model.predict(X_held)
        y_probas_held = model.predict_proba(X_held)

        run = wandb.init(project=project_name, group=model_name+'_'+train, name=model_name+f'_{train}_{k}')

        accuracy = metrics.accuracy_score(y_test, y_pred)
        accuracy_held = metrics.accuracy_score(y_held, y_pred_held)
        f1 = metrics.f1_score(y_test, y_pred)
        f1_held = metrics.f1_score(y_held, y_pred_held)
        # auc = metrics.roc_auc_score(y_test, y_probas[:, 1])
        # auc_held = metrics.roc_auc_score(y_held, y_probas_held[:, 1])
        auc = metrics.roc_auc_score(y_test, y_pred)
        auc_held = metrics.roc_auc_score(y_held, y_pred_held)
        wandb.log({"accuracy": accuracy, 'accuracy_held': accuracy_held, 'f1':f1, 'f1_held': f1_held, 'auc': auc, 'auc_held': auc_held})
    run.finish()

832 832 678 678


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mthoomas[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095221…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76259
accuracy_held,0.69764
auc,0.72096
auc_held,0.68649
f1,0.625
f1_held,0.61101


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135868…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.68953
accuracy_held,0.71681
auc,0.669
auc_held,0.7025
f1,0.59434
f1_held,0.62791


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095213…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.72924
accuracy_held,0.71091
auc,0.70778
auc_held,0.6945
f1,0.63415
f1_held,0.61719


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135879…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.69065
accuracy_held,0.66519
auc,0.68367
auc_held,0.66494
f1,0.58654
f1_held,0.59392


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299642…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.68953
accuracy_held,0.64749
auc,0.67726
auc_held,0.64677
f1,0.61607
f1_held,0.57398


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095221…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.61011
accuracy_held,0.60767
auc,0.58588
auc_held,0.5936
f1,0.48571
f1_held,0.50373


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135856…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.72302
accuracy_held,0.70206
auc,0.68864
auc_held,0.69581
f1,0.58378
f1_held,0.62454


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666592937, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135844…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.7148
accuracy_held,0.72124
auc,0.6931
auc_held,0.70518
f1,0.62201
f1_held,0.63014


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299590…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.70036
accuracy_held,0.70649
auc,0.66783
auc_held,0.68684
f1,0.57436
f1_held,0.60594


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095221…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.7518
accuracy_held,0.69174
auc,0.7074
auc_held,0.65436
f1,0.60571
f1_held,0.55054


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095221…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.68592
accuracy_held,0.69764
auc,0.65357
auc_held,0.66153
f1,0.55385
f1_held,0.56103


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.69675
accuracy_held,0.66372
auc,0.67049
auc_held,0.62385
f1,0.58416
f1_held,0.50862


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299616…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.71942
accuracy_held,0.72124
auc,0.69144
auc_held,0.71599
f1,0.58947
f1_held,0.64804


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299590…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.70758
accuracy_held,0.72271
auc,0.68562
auc_held,0.70884
f1,0.61244
f1_held,0.63566


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095221…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.71119
accuracy_held,0.69764
auc,0.67099
auc_held,0.67235
f1,0.56989
f1_held,0.58418


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095213…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77338
accuracy_held,0.75516
auc,0.72627
auc_held,0.73787
f1,0.63158
f1_held,0.66932


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095229…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.71119
accuracy_held,0.73599
auc,0.67767
auc_held,0.71104
f1,0.58333
f1_held,0.63244


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095213…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.75451
accuracy_held,0.72566
auc,0.70587
auc_held,0.6829
f1,0.6092
f1_held,0.58296


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095221…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77698
accuracy_held,0.74041
auc,0.73995
auc_held,0.72868
f1,0.65169
f1_held,0.66023


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095213…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.73285
accuracy_held,0.74484
auc,0.70834
auc_held,0.72969
f1,0.63725
f1_held,0.66012


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095229…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77978
accuracy_held,0.74779
auc,0.74477
auc_held,0.72953
f1,0.6738
f1_held,0.65868




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…



0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76978
accuracy_held,0.74336
auc,0.74007
auc_held,0.74183
f1,0.65217
f1_held,0.67897


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…



0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.71119
accuracy_held,0.72714
auc,0.68867
auc_held,0.71983
f1,0.61538
f1_held,0.6516


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095229…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.77978
accuracy_held,0.72566
auc,0.7559
auc_held,0.7145
f1,0.69347
f1_held,0.64368


In [24]:
## Train on P10 an'P10d evaluate on P21
train = 'FW2'
held = 'FW1'

# Run model on cell count
df_train = df_count_cell[df_count_cell.index.get_level_values('FOV') == train]
df_held = df_count_cell[df_count_cell.index.get_level_values('FOV') == held]

X = df_train.values
# X = scaler.fit_transform(X)
y = le.transform(df_train.index.get_level_values(0))

X_held = df_held.values
# X_held = scaler.fit_transform(X_held)
y_held = le.transform(df_held.index.get_level_values(0))

print(len(X), len(y), len(X_held), len(y_held))

# K fold training
kfold = KFold(n_splits = 3, shuffle = True, random_state = 0)
for model_name, model in models.items():
    for k, (train_index, test_index) in enumerate(kfold.split(X)):
        # Split the dataset
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index] 

        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        y_probas = model.predict_proba(X_test)
        y_pred_held = model.predict(X_held)
        y_probas_held = model.predict_proba(X_held)

        run = wandb.init(project=project_name, group=model_name+'_'+train, name=model_name+f'_{train}_{k}')

        accuracy = metrics.accuracy_score(y_test, y_pred)
        accuracy_held = metrics.accuracy_score(y_held, y_pred_held)
        f1 = metrics.f1_score(y_test, y_pred)
        f1_held = metrics.f1_score(y_held, y_pred_held)
        # auc = metrics.roc_auc_score(y_test, y_probas[:, 1])
        # auc_held = metrics.roc_auc_score(y_held, y_probas_held[:, 1])
        auc = metrics.roc_auc_score(y_test, y_pred)
        auc_held = metrics.roc_auc_score(y_held, y_pred_held)
        wandb.log({"accuracy": accuracy, 'accuracy_held': accuracy_held, 'f1':f1, 'f1_held': f1_held, 'auc': auc, 'auc_held': auc_held})
    run.finish()

678 678 832 832


VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135868…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.69027
accuracy_held,0.73197
auc,0.66533
auc_held,0.67504
f1,0.57831
f1_held,0.55666


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135919…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.75221
accuracy_held,0.73197
auc,0.72988
auc_held,0.67897
f1,0.65854
f1_held,0.56699


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095292…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.74336
accuracy_held,0.69952
auc,0.69374
auc_held,0.63608
f1,0.59155
f1_held,0.4898


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.69469
accuracy_held,0.70433
auc,0.67121
auc_held,0.6648
f1,0.58683
f1_held,0.56228


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666592937, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095308…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.61947
accuracy_held,0.67668
auc,0.60236
auc_held,0.63818
f1,0.51136
f1_held,0.5289


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095292…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.65929
accuracy_held,0.63822
auc,0.62005
auc_held,0.60819
f1,0.50323
f1_held,0.50248


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135844…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.68584
accuracy_held,0.71394
auc,0.6501
auc_held,0.65674
f1,0.54777
f1_held,0.52964


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299677…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76549
accuracy_held,0.71635
auc,0.73558
auc_held,0.65931
f1,0.66242
f1_held,0.5336


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095292…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.74779
accuracy_held,0.70192
auc,0.70264
auc_held,0.64389
f1,0.6069
f1_held,0.50988


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095292…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.69027
accuracy_held,0.74038
auc,0.66533
auc_held,0.69812
f1,0.57831
f1_held,0.6044


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299651…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.73894
accuracy_held,0.74399
auc,0.72904
auc_held,0.70951
f1,0.66286
f1_held,0.62566


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095292…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.71681
accuracy_held,0.7476
auc,0.67305
auc_held,0.71173
f1,0.56757
f1_held,0.62766


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.003 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299677…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.71239
accuracy_held,0.71755
auc,0.68306
auc_held,0.66355
f1,0.59627
f1_held,0.54369


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135919…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.79646
accuracy_held,0.73197
auc,0.76752
auc_held,0.67701
f1,0.70513
f1_held,0.56189


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095300…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.73451
accuracy_held,0.6899
auc,0.68412
auc_held,0.63562
f1,0.57746
f1_held,0.50385


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.74779
accuracy_held,0.73558
auc,0.70442
auc_held,0.67071
f1,0.61224
f1_held,0.53975


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095316…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.79204
accuracy_held,0.73678
auc,0.74698
auc_held,0.66839
f1,0.67133
f1_held,0.53105


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666899498, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.135868…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.75221
accuracy_held,0.73798
auc,0.69519
auc_held,0.67655
f1,0.58824
f1_held,0.55328


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095229…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.75221
accuracy_held,0.73317
auc,0.71965
auc_held,0.67076
f1,0.64103
f1_held,0.54321


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095221…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.80531
accuracy_held,0.73077
auc,0.77456
auc_held,0.66622
f1,0.71429
f1_held,0.53333


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095213…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76991
accuracy_held,0.73077
auc,0.72261
auc_held,0.66753
f1,0.6338
f1_held,0.53719




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…



0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.74779
accuracy_held,0.73558
auc,0.71143
auc_held,0.67791
f1,0.62745
f1_held,0.56


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…



0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.81416
accuracy_held,0.74399
auc,0.78161
auc_held,0.68527
f1,0.72368
f1_held,0.5697


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.095229…

0,1
accuracy,▁
accuracy_held,▁
auc,▁
auc_held,▁
f1,▁
f1_held,▁

0,1
accuracy,0.76549
accuracy_held,0.73197
auc,0.72461
auc_held,0.67439
f1,0.63946
f1_held,0.55489
