In [56]:
import itertools
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.io

from collections import defaultdict
from tqdm.auto import tqdm
from joblib import Parallel, delayed
import re
import h5py
import napari
import seaborn as sns

In [57]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [58]:
p_dir = (Path().cwd().parents[0]).absolute()

module_path = str(p_dir / "src")

if module_path not in sys.path:
    sys.path.append(module_path)

In [59]:
data_dir = (Path().cwd().parents[0] / 'data'/'13cyc').absolute() / 'graphs'


In [60]:
import PPIGraph

In [61]:
condition_mapping = {'control': 0, '100nM': 1}

# Read per cell data

In [62]:
from torch_geometric.loader import DataLoader

In [63]:
# Filter out by maximum number of counts per cell
min_count = 20
max_count = 70

dataset = PPIGraph.GraphDataset(data_dir, 'raw', 'pt',  condition_mapping=condition_mapping, n_c=2)

# Create Dataloader
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Get Indices
indices = []
for step, data in enumerate(loader):
    if len(data.x) <= min_count:
        continue 
    
    if (data.x.sum(axis=0) >= max_count).any():
        continue
    indices.append(step)
    
# Get subset dataset
dataset_filtered = dataset.index_select(indices)


In [64]:
dataset

GraphDataset(2230)

In [65]:
dataset_filtered

GraphDataset(2117)

# ML model

In [66]:
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn import preprocessing, metrics
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import train_test_split, KFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import StandardScaler
import wandb

In [67]:
condition = 'ML'
project_name = f'PLA_13PPI_121923_{condition}'

In [68]:
models = {
    'Adaboost': AdaBoostClassifier(),
    'DecisionTree': DecisionTreeClassifier(),
    'GradientBoosting' : GradientBoostingClassifier(),
    'NaiveBayes': GaussianNB(),
    'RandomForest': RandomForestClassifier(), 
    'SVM': SVC(probability =True),
    'LogisticRegression':  LogisticRegression(),
    'MLP': MLPClassifier(random_state=1, max_iter=100, hidden_layer_sizes=[16, 16, 16])
}



In [69]:
import torch 

data = torch.empty((0, 13))
label = []

for graph in dataset_filtered:
    new_row =graph.x.sum(axis = 0).unsqueeze(0)
    data = torch.cat((data, new_row), dim=0)
    label.append(graph.condition)


In [70]:
from sklearn.naive_bayes import GaussianNB
from sklearn import tree
from sklearn.ensemble import HistGradientBoostingClassifier


In [71]:
scaler = StandardScaler()

# Run model on cell count
X = data.numpy()
X = scaler.fit_transform(X)
y = np.array(label)
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
kfold = KFold(n_splits = 5, shuffle = True, random_state = 0)

for model_name, model in models.items():
    for k, (train_index, test_index) in enumerate(kfold.split(X)):
        # Split the dataset
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        y_probas = model.predict_proba(X_test)

        run = wandb.init(project=project_name, group=model_name+'_cell', name=model_name+f'_cell_{k}')

        accuracy = metrics.accuracy_score(y_test, y_pred)
        b_accuracy = metrics.balanced_accuracy_score(y_test, y_pred)
        f1 = metrics.f1_score(y_test, y_pred)
        auc = metrics.roc_auc_score(y_test, y_pred)
        wandb.log({"accuracy": accuracy, 'b_accuracy': b_accuracy, 'f1':f1, 'auc': auc})
    run.finish()

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6345076318670703, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.83726
auc,0.83714
b_accuracy,0.83714
f1,0.83373


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.002 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.11056732564878614, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.75708
auc,0.75551
b_accuracy,0.75551
f1,0.73522


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.002 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.11050293000193187, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.80851
auc,0.80667
b_accuracy,0.80667
f1,0.82276


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344259128082942, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.85106
auc,0.8511
b_accuracy,0.8511
f1,0.85315


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343850611719253, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.8227
auc,0.82275
b_accuracy,0.82275
f1,0.82185


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.002 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.11049581455247907, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.75708
auc,0.75717
b_accuracy,0.75717
f1,0.75765


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.73349
auc,0.73176
b_accuracy,0.73176
f1,0.70951


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.634524959742351, max=1.0)…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.76123
auc,0.76034
b_accuracy,0.76034
f1,0.77303


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.76359
auc,0.7635
b_accuracy,0.7635
f1,0.75369


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.07283616692426584, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.70449
auc,0.70481
b_accuracy,0.70481
f1,0.70998


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.86085
auc,0.86095
b_accuracy,0.86095
f1,0.86118


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.77123
auc,0.77098
b_accuracy,0.77098
f1,0.75443


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.82033
auc,0.81843
b_accuracy,0.81843
f1,0.83406


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.85106
auc,0.8511
b_accuracy,0.8511
f1,0.85315


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.634484091201855, max=1.0)…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.83452
auc,0.835
b_accuracy,0.835
f1,0.83945


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344023699124163, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.6934
auc,0.6947
b_accuracy,0.6947
f1,0.72917


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.66509
auc,0.67436
b_accuracy,0.67436
f1,0.68161


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344432279255491, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.6643
auc,0.65784
b_accuracy,0.65784
f1,0.72157


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.73286
auc,0.73321
b_accuracy,0.73321
f1,0.76701


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.002 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.11056020605280104, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.71631
auc,0.71745
b_accuracy,0.71745
f1,0.73913


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.07284085786050107, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.83255
auc,0.83269
b_accuracy,0.83269
f1,0.83372


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.010 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.686288400850132, max=1.0)…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.75708
auc,0.75713
b_accuracy,0.75713
f1,0.74055


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344432279255491, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.81797
auc,0.81681
b_accuracy,0.81681
f1,0.82851


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.86761
auc,0.86765
b_accuracy,0.86765
f1,0.86916


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343615171614398, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.84161
auc,0.84178
b_accuracy,0.84178
f1,0.84235


VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.83255
auc,0.83224
b_accuracy,0.83224
f1,0.82555


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.07283147659218236, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.8066
auc,0.80923
b_accuracy,0.80923
f1,0.79902


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011466666666739104, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.07289587223903664, max=1.…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.83215
auc,0.83186
b_accuracy,0.83186
f1,0.83827


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.83924
auc,0.83926
b_accuracy,0.83926
f1,0.84038


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6343850611719253, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.80851
auc,0.8084
b_accuracy,0.8084
f1,0.80482


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6345076318670703, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.79245
auc,0.79226
b_accuracy,0.79226
f1,0.78641


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.0728911783644559, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.7783
auc,0.78032
b_accuracy,0.78032
f1,0.76847


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888278356, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.0728911783644559, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.80378
auc,0.80328
b_accuracy,0.80328
f1,0.81179


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.8156
auc,0.8156
b_accuracy,0.8156
f1,0.81517


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.634466769706337, max=1.0)…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.7896
auc,0.78965
b_accuracy,0.78965
f1,0.7886






VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.82783
auc,0.82784
b_accuracy,0.82784
f1,0.8266


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.79245
auc,0.79538
b_accuracy,0.79538
f1,0.78537


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.81797
auc,0.81681
b_accuracy,0.81681
f1,0.82851


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.86525
auc,0.86528
b_accuracy,0.86528
f1,0.86651


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

VBox(children=(Label(value='0.009 MB of 0.015 MB uploaded\r'), FloatProgress(value=0.6344023699124163, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.8227
auc,0.82281
b_accuracy,0.82281
f1,0.8227
