In [48]:
import itertools
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.io

from collections import defaultdict
from tqdm.auto import tqdm
from joblib import Parallel, delayed
import re
import h5py
import napari
import seaborn as sns
from sklearn.preprocessing import StandardScaler


In [49]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [50]:
p_dir = (Path().cwd().parents[0]).absolute()

module_path = str(p_dir / "src")

if module_path not in sys.path:
    sys.path.append(module_path)

In [51]:
data_dir = (Path().cwd().parents[0] / 'data').absolute()

In [52]:
import PPIGraph

# Read per cell data

In [53]:
from torch_geometric.loader import DataLoader

In [54]:
# Define condition mapping
condition_mapping = {'HCC827Ctrl': 0, 'HCC827Osim': 1}

# Load graph dataset and process if neede
graph_path = data_dir / '9PPI' / 'graphs' 

In [55]:
# Filter out by maximum number of counts per cell
min_count = 100
max_count = 400

dataset = PPIGraph.GraphDataset(graph_path, 'raw', 'pt',  condition_mapping=condition_mapping, n_c=2)

# Create Dataloader
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Get Indices
indices = []
for step, data in enumerate(loader):
    if len(data.x) <= min_count:
        continue 
    
    if (data.x.sum(axis=0) >= max_count).any():
        continue
    indices.append(step)
    
# Get subset dataset
dataset_filtered = dataset.index_select(indices)


In [56]:
dataset

GraphDataset(1491)

In [57]:
dataset_filtered

GraphDataset(1368)

# ML model

In [58]:
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn import preprocessing, metrics
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import train_test_split, KFold
from sklearn.utils.class_weight import compute_class_weight
import wandb

In [59]:
condition = 'ML'
project_name = f'PLA_9PPI_121923_{condition}'

In [60]:
models = {
    'Adaboost': AdaBoostClassifier(),
    'DecisionTree': DecisionTreeClassifier(),
    'GradientBoosting' : GradientBoostingClassifier(),
    'NaiveBayes': GaussianNB(),
    'RandomForest': RandomForestClassifier(), 
    'SVM': SVC(probability =True),
    'LogisticRegression':  LogisticRegression(),
    'MLP': MLPClassifier(random_state=1, max_iter=100, hidden_layer_sizes=[16, 16, 16])
}



In [61]:
import torch 

data = torch.empty((0, 9))
label = []

for graph in dataset_filtered:
    new_row =graph.x.sum(axis = 0).unsqueeze(0)
    data = torch.cat((data, new_row), dim=0)
    label.append(graph.condition)


In [62]:
scaler = StandardScaler()

# Run model on cell count
X = data.numpy()
X = scaler.fit_transform(X)
y = np.array(label)

# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
kfold = KFold(n_splits = 5, shuffle = True, random_state = 0)

for model_name, model in models.items():
    for k, (train_index, test_index) in enumerate(kfold.split(X)):
        # Split the dataset
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        y_probas = model.predict_proba(X_test)

        run = wandb.init(project=project_name, group=model_name+'_cell', name=model_name+f'_cell_{k}')

        accuracy = metrics.accuracy_score(y_test, y_pred)
        b_accuracy = metrics.balanced_accuracy_score(y_test, y_pred)
        f1 = metrics.f1_score(y_test, y_pred)
        auc = metrics.roc_auc_score(y_test, y_pred)
        wandb.log({"accuracy": accuracy, 'b_accuracy': b_accuracy, 'f1':f1, 'auc': auc})
    run.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mthoomas[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344432279255491, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.70438
auc,0.69836
b_accuracy,0.69836
f1,0.64629


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.634484091201855, max=1.0)…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.72628
auc,0.71936
b_accuracy,0.71936
f1,0.67811


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344023699124163, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.74088
auc,0.7138
b_accuracy,0.7138
f1,0.63212


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.73993
auc,0.72843
b_accuracy,0.72843
f1,0.6758


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.002 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.11051004636785162, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.76557
auc,0.75748
b_accuracy,0.75748
f1,0.7193


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344023699124163, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.67883
auc,0.67258
b_accuracy,0.67258
f1,0.61739


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.65328
auc,0.64184
b_accuracy,0.64184
f1,0.5815


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.07289587223903664, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.67883
auc,0.66303
b_accuracy,0.66303
f1,0.57692


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.07289587223903664, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.67033
auc,0.66424
b_accuracy,0.66424
f1,0.62185


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.634466769706337, max=1.0)…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.67033
auc,0.66774
b_accuracy,0.66774
f1,0.6281


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.73358
auc,0.72003
b_accuracy,0.72003
f1,0.66359


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.002 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.11056020605280104, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.70438
auc,0.69807
b_accuracy,0.69807
f1,0.65532


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344259128082942, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.77737
auc,0.76211
b_accuracy,0.76211
f1,0.69652


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344259128082942, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.71795
auc,0.70568
b_accuracy,0.70568
f1,0.64516


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011455555556393746, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343850611719253, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.75458
auc,0.74252
b_accuracy,0.74252
f1,0.69683


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.634466769706337, max=1.0)…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.59854
auc,0.63384
b_accuracy,0.63384
f1,0.62329


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344259128082942, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.65328
auc,0.63152
b_accuracy,0.63152
f1,0.54106


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6345076318670703, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.64599
auc,0.6439
b_accuracy,0.6439
f1,0.56502


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6345076318670703, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.63004
auc,0.64386
b_accuracy,0.64386
f1,0.66667


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343850611719253, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.6044
auc,0.63248
b_accuracy,0.63248
f1,0.64238


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.74453
auc,0.73354
b_accuracy,0.73354
f1,0.68182


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.70803
auc,0.69921
b_accuracy,0.69921
f1,0.65217


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.72993
auc,0.70961
b_accuracy,0.70961
f1,0.63


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.71795
auc,0.70754
b_accuracy,0.70754
f1,0.65471


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.002 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.11056020605280104, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.7619
auc,0.75107
b_accuracy,0.75107
f1,0.70852


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.002 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.11051004636785162, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.71533
auc,0.6932
b_accuracy,0.6932
f1,0.62136


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.07283147659218236, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.70438
auc,0.68981
b_accuracy,0.68981
f1,0.63014


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.72263
auc,0.69512
b_accuracy,0.69512
f1,0.60825


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.68864
auc,0.67368
b_accuracy,0.67368
f1,0.5933


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344023699124163, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.73626
auc,0.72329
b_accuracy,0.72329
f1,0.67273


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344432279255491, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.72628
auc,0.71102
b_accuracy,0.71102
f1,0.65116


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.68978
auc,0.67802
b_accuracy,0.67802
f1,0.62222


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344432279255491, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.70803
auc,0.69466
b_accuracy,0.69466
f1,0.61538


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344023699124163, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.69963
auc,0.6863
b_accuracy,0.6863
f1,0.61682


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.72161
auc,0.71688
b_accuracy,0.71688
f1,0.67797




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.73358
auc,0.71716
b_accuracy,0.71716
f1,0.65728


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.69708
auc,0.68753
b_accuracy,0.68753
f1,0.63755


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.73723
auc,0.71752
b_accuracy,0.71752
f1,0.64


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.7033
auc,0.69154
b_accuracy,0.69154
f1,0.63014


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.07284554940100477, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.7619
auc,0.75214
b_accuracy,0.75214
f1,0.71111


In [63]:
# Run model on cell count
X = df_count_cyto_nuclei
features = ['_'.join(col) for col in X.columns.values]
X = X.values
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
kfold = KFold(n_splits = 5, shuffle = True, random_state = 0)

for model_name, model in models.items():
    for k, (train_index, test_index) in enumerate(kfold.split(X)):
        # Split the dataset
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        y_probas = model.predict_proba(X_test)

        run = wandb.init(project=project_name, group=model_name+'_cyto_nuclei', name=model_name+f'_cyto_nuclei_{k}')
        wandb.sklearn.plot_classifier(model, 
                                  X_train, X_test, 
                                  y_train, y_test, 
                                  y_pred, y_probas, 
                                  le.classes_, 
                                  is_binary=True, 
                                  model_name=model_name+'_cyto_nuclei', 
                                  feature_names=features)
        wandb.log({'roc': wandb.plots.ROC(y_test, y_probas, le.classes_)})
        wandb.log({'pr': wandb.plots.precision_recall(y_test, y_probas, le.classes_)})

        accuracy = metrics.accuracy_score(y_test, y_pred)
        b_accuracy = metrics.balanced_accuracy_score(y_test, y_pred)
        f1 = metrics.f1_score(y_test, y_pred)
        auc = metrics.roc_auc_score(y_test, y_probas[:, 1])
        wandb.log({"accuracy": accuracy, 'b_accuracy': b_accuracy, 'f1':f1, 'auc': auc})
    run.finish()


NameError: name 'df_count_cyto_nuclei' is not defined