In [58]:
import os
import re
import sys
from pathlib import Path

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import skimage
from skimage import io
from sklearn import preprocessing
from tqdm.notebook import tqdm, trange
import anndata as ad
import cv2
import scanorama
from sklearn.model_selection import train_test_split
import seaborn as sns

In [59]:
# Import spatial omics library
import athena as ath
from spatialOmics import SpatialOmics

# import default graph builder parameters
from athena.graph_builder.constants import GRAPH_BUILDER_DEFAULT_PARAMS

In [60]:
d_dir = (Path().cwd().parents[0].parents[0]).absolute()
data_dir = d_dir / "09_datasets"

p_dir = (Path().cwd().parents[0]).absolute()

In [61]:
%load_ext autoreload
%autoreload 2

module_path = str(p_dir / "src")

if module_path not in sys.path:
    sys.path.append(module_path)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [62]:
import graph
import torch
import torch_geometric.utils
import networkx as nx
import lightning.pytorch as pl
import torch.utils.data as data

spatial_omics_folder = (Path().cwd().parents[0]).absolute() / 'data' / 'spatial_omics_graph'
process_path = (Path().cwd().parents[0]).absolute() / 'data' / 'torch_graph_data'

# Create data loader

In [63]:
from torch_geometric.loader import DataLoader
seed = torch.Generator().manual_seed(42)

name = 'All_roi'

# Crate dataset
dataset = graph.GraphDatasetPos(process_path / name, process_path / name / 'info.csv', 6, 3)

train_set, val_set, test_set = graph.train_test_val_split(dataset)

# Create Dataloader
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)


In [64]:
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Dataset: GraphDatasetPos(932):
Number of graphs: 932
Number of features: 6
Number of classes: 3


In [65]:
print(f'Train set: {len(train_set)}, val set: {len(test_set)}, val set: {len(val_set)}')

Train set: 448, val set: 372, val set: 112


In [66]:
for step, data in enumerate(test_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()
    data.label
    break

Step 1:
Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 1510773], num_nodes=218275, x=[218275, 6], pos=[218275, 2], node_types=[218275], label=[32], train_mask=[218275], test_mask=[218275], batch=[218275], ptr=[33])



# Graph learning

In [67]:
from lightning.pytorch.accelerators import find_usable_cuda_devices
import wandb

In [68]:
find_usable_cuda_devices()

[0, 1]

In [69]:
AVAIL_GPUS = [1]
BATCH_SIZE = 64 if AVAIL_GPUS else 32
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = (Path().cwd().parents[0]).absolute() / 'data' / "saved_models" / f"Pos_GNNs_{name}"
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
NUM_WORKERS = int(os.cpu_count() / 2)

# Setting the seed
pl.seed_everything(42)

INFO: Global seed set to 42
INFO:lightning.fabric.utilities.seed:Global seed set to 42


42

In [None]:
models = ['MLP', 'GCN', 'GraphConv', 'GAT', 'GINConv', 'SAGEConv']
for model_name in models:
    run = wandb.init(project='snowflake_pos_032823', name=model_name, )
    model, result, trainer = graph.train_node_classifier(model_name, train_set, val_set, test_set, 
                                                             dataset, CHECKPOINT_PATH, AVAIL_GPUS, 
                                                             hidden_channels=16, num_layers=3)
    run.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

INFO: Global seed set to 42
INFO:lightning.fabric.utilities.seed:Global seed set to 42
  rank_zero_warn(
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: Global seed set to 42
INFO:lightning.fabric.utilities.seed:Global seed set to 42
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
INFO: 
  | Name        | Type             | Params
-------------------------------------------------
0 | model       | MLPModel         | 435   
1 | loss_module | CrossEntro

In [None]:
# metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
# del metrics["step"]
# metrics.set_index("epoch", inplace=True)
# display(metrics.dropna(axis=1, how="all").head())
# sns.relplot(data=metrics, kind="line")