In [1]:
import itertools
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.io

from collections import defaultdict
from tqdm.auto import tqdm
from joblib import Parallel, delayed
import re
import h5py
import napari
import seaborn as sns

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
p_dir = (Path().cwd().parents[0]).absolute()

module_path = str(p_dir / "src")

if module_path not in sys.path:
    sys.path.append(module_path)

In [8]:
data_dir = (Path().cwd().parents[0] / 'data'/ 'pixelgen').absolute() / 'Tcell' / 'graphs'


In [9]:
import PPIGraph

In [10]:
condition_mapping = {'Control': 0, 'Stimulated': 1}

# Read per cell data

In [11]:
from torch_geometric.loader import DataLoader

In [12]:
# Dataset
dataset = PPIGraph.GraphDatasetPixelgen(data_dir, 'raw', 'pt',  condition_mapping=condition_mapping, n_c=2)

# Create Dataloader
loader = DataLoader(dataset, batch_size=1, shuffle=False)



In [13]:
dataset

GraphDatasetPixelgen(932)

In [14]:
dataset[0]

Data(edge_index=[2, 7188], pos=[1200, 3], labels=[1200, 80], weight=[7188], condition=0, id=0, train_mask=[1200], test_mask=[1200], x=[1200, 80], y=[1], edge_weight=[7188], name='0_0.gpt')

# ML model

In [15]:
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn import preprocessing, metrics
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import train_test_split, KFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import StandardScaler
import wandb

In [16]:
condition = 'ML'
project_name = f'Pixelgen_{condition}'

In [17]:
models = {
    'Adaboost': AdaBoostClassifier(),
    'DecisionTree': DecisionTreeClassifier(),
    'GradientBoosting' : GradientBoostingClassifier(),
    'NaiveBayes': GaussianNB(),
    'RandomForest': RandomForestClassifier(), 
    'SVM': SVC(probability =True),
    'LogisticRegression':  LogisticRegression(),
    'MLP': MLPClassifier(random_state=1, max_iter=100, hidden_layer_sizes=[16, 16, 16])
}



In [18]:
import torch 

data = torch.empty((0, 80))
label = []

for graph in dataset:
    new_row =graph.x.sum(axis = 0).unsqueeze(0)
    data = torch.cat((data, new_row), dim=0)
    label.append(graph.condition)


In [22]:
df = pd.DataFrame(data)
df['Condition'] = label

In [18]:
from sklearn.naive_bayes import GaussianNB
from sklearn import tree
from sklearn.ensemble import HistGradientBoostingClassifier


In [19]:
scaler = StandardScaler()

# Run model on cell count
X = data.numpy()
X = scaler.fit_transform(X)
y = np.array(label)
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
kfold = KFold(n_splits = 5, shuffle = True, random_state = 0)

for model_name, model in models.items():
    for k, (train_index, test_index) in enumerate(kfold.split(X)):
        # Split the dataset
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        y_probas = model.predict_proba(X_test)

        run = wandb.init(project=project_name, group=model_name+'_cell', name=model_name+f'_cell_{k}')

        accuracy = metrics.accuracy_score(y_test, y_pred)
        b_accuracy = metrics.balanced_accuracy_score(y_test, y_pred)
        f1 = metrics.f1_score(y_test, y_pred)
        auc = metrics.roc_auc_score(y_test, y_pred)
        wandb.log({"accuracy": accuracy, 'b_accuracy': b_accuracy, 'f1':f1, 'auc': auc})
    run.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mthoomas[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97861
auc,0.97776
b_accuracy,0.97776
f1,0.97674


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.98396
auc,0.98611
b_accuracy,0.98611
f1,0.98137


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.96774
auc,0.96534
b_accuracy,0.96534
f1,0.96341


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666883975, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.96774
auc,0.96736
b_accuracy,0.96736
f1,0.96907


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0169333333382383, max=1.0))…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.98387
auc,0.98214
b_accuracy,0.98214
f1,0.98182


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0169333333382383, max=1.0))…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.94652
auc,0.94851
b_accuracy,0.94851
f1,0.94444


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0169333333382383, max=1.0))…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.96791
auc,0.96712
b_accuracy,0.96712
f1,0.96203


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97312
auc,0.97234
b_accuracy,0.97234
f1,0.97006


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.96774
auc,0.96736
b_accuracy,0.96736
f1,0.96907


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.95161
auc,0.94853
b_accuracy,0.94853
f1,0.94479


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.98396
auc,0.98425
b_accuracy,0.98425
f1,0.98286


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666107873, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.9893
auc,0.99074
b_accuracy,0.99074
f1,0.9875


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97312
auc,0.97339
b_accuracy,0.97339
f1,0.97041


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97849
auc,0.97813
b_accuracy,0.97813
f1,0.97938


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666107873, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.98387
auc,0.98214
b_accuracy,0.98214
f1,0.98182


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0169333333382383, max=1.0))…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.64706
auc,0.62443
b_accuracy,0.62443
f1,0.44068


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.000 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.69519
auc,0.64434
b_accuracy,0.64434
f1,0.46729


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.71505
auc,0.68768
b_accuracy,0.68768
f1,0.56198


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0169333333382383, max=1.0))…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.65591
auc,0.66632
b_accuracy,0.66632
f1,0.50769


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666883975, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.68817
auc,0.65791
b_accuracy,0.65791
f1,0.5


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016933333330477276, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.9893
auc,0.99
b_accuracy,0.99
f1,0.98864


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016933333330477276, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97326
auc,0.97515
b_accuracy,0.97515
f1,0.96894


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97312
auc,0.97444
b_accuracy,0.97444
f1,0.97076


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.98387
auc,0.98403
b_accuracy,0.98403
f1,0.98429


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.98925
auc,0.9881
b_accuracy,0.9881
f1,0.98795


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016933333330477276, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97326
auc,0.975
b_accuracy,0.975
f1,0.97207


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016933333330477276, max=1.0…

VBox(children=(Label(value='0.000 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.95722
auc,0.96296
b_accuracy,0.96296
f1,0.95181


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666883975, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.9086
auc,0.91246
b_accuracy,0.91246
f1,0.90395


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.98387
auc,0.98403
b_accuracy,0.98403
f1,0.98429


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0169333333382383, max=1.0))…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.95699
auc,0.95553
b_accuracy,0.95553
f1,0.95181


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016933333330477276, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97861
auc,0.97925
b_accuracy,0.97925
f1,0.97727


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016933333330477276, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97861
auc,0.97978
b_accuracy,0.97978
f1,0.975


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.96774
auc,0.96744
b_accuracy,0.96744
f1,0.96429


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.98387
auc,0.98403
b_accuracy,0.98403
f1,0.98429


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666107873, max=1.0)…

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.95699
auc,0.95343
b_accuracy,0.95343
f1,0.95062




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666883975, max=1.0)…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97861
auc,0.97851
b_accuracy,0.97851
f1,0.97701


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01691666666883975, max=1.0)…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.99465
auc,0.99537
b_accuracy,0.99537
f1,0.99371


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.96774
auc,0.96849
b_accuracy,0.96849
f1,0.96471


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97849
auc,0.97882
b_accuracy,0.97882
f1,0.97895


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

0,1
accuracy,▁
auc,▁
b_accuracy,▁
f1,▁

0,1
accuracy,0.97849
auc,0.97724
b_accuracy,0.97724
f1,0.9759
