In [1]:
import torch
import mlflow
import os
import uuid
import yaml
from tqdm import tqdm
import mlflow.pytorch
import numpy as np
import pandas as pd
import shutil
import argparse
from sklearn.metrics import accuracy_score, f1_score, classification_report

from histocartography.ml import CellGraphModel, TissueGraphModel, HACTModel

from dataloader import make_data_loader

# cuda support
IS_CUDA = torch.cuda.is_available()
DEVICE = 'cuda:0' if IS_CUDA else 'cpu'
NODE_DIM = 514

Using backend: pytorch


In [2]:
config_fpath = '/nadeem_lab/Eliram/repos/hact-net/core/config/bracs_hact_7_classes_pna.yml'
with open(config_fpath, 'r') as f:
  config = yaml.safe_load(f)

In [21]:
config_fpath = '/nadeem_lab/Eliram/repos/hact-net/core/config/bracs_cggnn_7_classes_pna.yml'
with open(config_fpath, 'r') as f:
  config = yaml.safe_load(f)

In [3]:
model_path = ''
model_path = os.path.join(model_path, str(uuid.uuid4()))
os.makedirs(model_path, exist_ok=True)

In [4]:
#args
cg_path = '/nadeem_lab/Eliram/repos/hact-net/data/hact-net-data/cell_graphs'
tg_path = '/nadeem_lab/Eliram/repos/hact-net/data/hact-net-data/tissue_graphs'
assign_mat_path = '/nadeem_lab/Eliram/repos/hact-net/data/hact-net-data/assignment_matrices'
batch_size = 8
learning_rate = 0.0005
epochs = 60

In [23]:
#args
cg_path = '/nadeem_lab/Eliram/repos/hact-net/data/hact-net-data/cell_graphs'
tg_path = ''
assign_mat_path = ''
batch_size = 8
learning_rate = 0.0005
epochs = 60

In [5]:
train_dataloader = make_data_loader(
    cg_path=os.path.join(cg_path, 'train') if cg_path is not None else None,
    tg_path=os.path.join(tg_path, 'train') if tg_path is not None else None,
    assign_mat_path=os.path.join(assign_mat_path, 'train') if assign_mat_path is not None else None,
    batch_size=batch_size,
    load_in_ram='in_ram',
)

In [6]:
# declare model
model = HACTModel(
    cg_gnn_params=config['cg_gnn_params'],
    tg_gnn_params=config['tg_gnn_params'],
    classification_params=config['classification_params'],
    cg_node_dim=NODE_DIM,
    tg_node_dim=NODE_DIM,
    num_classes=7
).to(DEVICE)

In [7]:
# build optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
    weight_decay=5e-4
)

In [8]:
loss_fn = torch.nn.CrossEntropyLoss()

In [9]:
logger='none'

# training loop
step = 0
best_val_loss = 10e5
best_val_accuracy = 0.
best_val_weighted_f1_score = 0.

for epoch in range(epochs):
    # A.) train for 1 epoch
    model = model.to(DEVICE)
    model.train()
    for batch in tqdm(train_dataloader, desc='Epoch training {}'.format(epoch), unit='batch'):

        # 1. forward pass
        labels = batch[-1]
        data = batch[:-1]
        logits = model(*data)

        # 2. backward pass
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 3. log training loss 
        if logger == 'mlflow':
            mlflow.log_metric('train_loss', loss.item(), step=step)

        # 4. increment step
        step += 1

    # B.) validate
    model.eval()
    all_val_logits = []
    all_val_labels = []
    for batch in tqdm(val_dataloader, desc='Epoch validation {}'.format(epoch), unit='batch'):
        labels = batch[-1]
        data = batch[:-1]
        with torch.no_grad():
            logits = model(*data)
        all_val_logits.append(logits)
        all_val_labels.append(labels)

    all_val_logits = torch.cat(all_val_logits).cpu()
    all_val_preds = torch.argmax(all_val_logits, dim=1)
    all_val_labels = torch.cat(all_val_labels).cpu()

    # compute & store loss + model
    with torch.no_grad():
        loss = loss_fn(all_val_logits, all_val_labels).item()
    if logger == 'mlflow':
        mlflow.log_metric('val_loss', loss, step=step)
    if loss < best_val_loss:
        best_val_loss = loss
        torch.save(model.state_dict(), os.path.join(model_path, 'model_best_val_loss.pt'))

    # compute & store accuracy + model
    all_val_preds = all_val_preds.detach().numpy()
    all_val_labels = all_val_labels.detach().numpy()
    accuracy = accuracy_score(all_val_labels, all_val_preds)
    if logger == 'mlflow':
        mlflow.log_metric('val_accuracy', accuracy, step=step)
    if accuracy > best_val_accuracy:
        best_val_accuracy = accuracy
        torch.save(model.state_dict(), os.path.join(model_path, 'model_best_val_accuracy.pt'))

    # compute & store weighted f1-score + model
    weighted_f1_score = f1_score(all_val_labels, all_val_preds, average='weighted')
    if logger == 'mlflow':
        mlflow.log_metric('val_weighted_f1_score', weighted_f1_score, step=step)
    if weighted_f1_score > best_val_weighted_f1_score:
        best_val_weighted_f1_score = weighted_f1_score
        torch.save(model.state_dict(), os.path.join(model_path, 'model_best_val_weighted_f1_score.pt'))

    print('Val loss {}'.format(loss))
    print('Val weighted F1 score {}'.format(weighted_f1_score))
    print('Val accuracy {}'.format(accuracy))

Epoch training 0:   0%|          | 0/119 [00:42<?, ?batch/s]


AssertionError: The CG and TG are not the same. There was an issue while creating HACT.

In [None]:
print(train_dataloader.__doc__)


    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.

    The :class:`~torch.utils.data.DataLoader` supports both map-style and
    iterable-style datasets with single- or multi-process loading, customizing
    loading order and optional automatic batching (collation) and memory pinning.

    See :py:mod:`torch.utils.data` documentation page for more details.

    Args:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: ``1``).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).
        sampler (Sampler or Iterable, optional): defines the strategy to draw
            samples from the dataset. Can be any ``Iterable`` with ``__len__``
            implemented. If specified, :attr:`shuffle` must not be specified.
        batch_sampler (Sampler or Iterable, opti

In [None]:
dir(train_dataloader)

['_DataLoader__initialized',
 '_DataLoader__multiprocessing_context',
 '_IterableDataset_len_called',
 '__annotations__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_dataset_kind',
 '_get_iterator',
 '_index_sampler',
 '_iterator',
 'batch_sampler',
 'batch_size',
 'check_worker_number_rationality',
 'collate_fn',
 'dataset',
 'drop_last',
 'generator',
 'multiprocessing_context',
 'num_workers',
 'persistent_workers',
 'pin_memory',
 'prefetch_factor',
 'sampler',
 'timeout',
 'worker_init_fn']