In [10]:
import os
import json
import numpy as np
import lightning.pytorch as pl

import torch
from torch import optim, nn, Tensor
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy 
from torchmetrics.classification import MulticlassF1Score

from IPython.display import display

from lightning.pytorch.loggers import TensorBoardLogger

from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.train import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)

### Define Model

In [11]:
class ClassificationModel(pl.LightningModule):
    def __init__(self, input_dim, output_dim, loss_weight, config):
        super().__init__()

        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.loss_weight = Tensor(loss_weight).to(device)
        self.accuracy = Accuracy(task="multiclass", num_classes=output_dim, top_k=1)
        self.f1_score = MulticlassF1Score(num_classes=output_dim, average="macro")

        self.weight_decay = config["weight_decay"]
        self.lr = config["lr"]
        # self.num_layers = config["num_layers"]
        self.hidden_dim = config["hidden_dim"]
        self.last_layer_dim = config["last_layer_dim"]

        input_layer = nn.Linear(input_dim, self.hidden_dim)
        # hidden_layers = []
        # for i in range(self.num_layers):
        #     hidden_layers.append(nn.Linear(self.hidden_dim, self.hidden_dim))
        #     hidden_layers.append(nn.ReLU())
        last_layer = nn.Linear(self.hidden_dim, self.last_layer_dim)
        output_layer = nn.Linear(self.last_layer_dim, output_dim)
        # self.layers = nn.Sequential(input_layer, nn.ReLU(), *hidden_layers, last_layer, nn.ReLU(), output_layer)
        self.layers = nn.Sequential(input_layer, nn.BatchNorm1d(self.hidden_dim), nn.ReLU(), last_layer, nn.BatchNorm1d(self.last_layer_dim), nn.ReLU(), output_layer)

        self.eval_loss = []
        self.eval_accuracy = []

        self.epoch_logits = []
        self.epoch_labels = []

    def forward(self, x):
        return self.layers(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y, weight=self.loss_weight)
        accuracy = self.accuracy(logits, y)
        self.log("train/train_loss", loss)
        self.log("train/train_accuracy", accuracy)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y, weight=self.loss_weight)
        accuracy = self.accuracy(logits, y)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(accuracy)

        self.epoch_logits.append(logits)
        self.epoch_labels.append(y)

        return {"val_loss": loss, "val_accuracy": accuracy}
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y, weight=self.loss_weight)
        accuracy = self.accuracy(logits, y)
        self.log("test/test_loss", loss)
        self.log("test/test_accuracy", accuracy)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(accuracy)

        self.epoch_logits.append(logits)
        self.epoch_labels.append(y)

        return {"test_loss": loss, "test_accuracy": accuracy}
    
    def on_validation_epoch_end(self):
        display(len(self.eval_loss), torch.stack(self.eval_loss).shape)
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("val/val_loss", avg_loss, sync_dist=True)
        self.log("val/val_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()

        epoch_logits = torch.cat(self.epoch_logits, dim=0)
        epoch_labels = torch.cat(self.epoch_labels, dim=0)
        f1_score = self.f1_score(epoch_logits, epoch_labels)
        self.log("val/f1_score", f1_score, sync_dist=True)
        self.epoch_logits.clear()
        self.epoch_labels.clear()

    def on_test_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("test/test_loss", avg_loss, sync_dist=True)
        self.log("test/test_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()

        epoch_logits = torch.cat(self.epoch_logits, dim=0)
        epoch_labels = torch.cat(self.epoch_labels, dim=0)
        f1_score = self.f1_score(epoch_logits, epoch_labels)
        self.log("test/f1_score", f1_score, sync_dist=True)
        self.epoch_logits.clear()
        self.epoch_labels.clear()

    def on_train_epoch_end(self):
        i = 0
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                i += 1
                self.logger.experiment.add_histogram(f"layer_{i}/weight", layer.weight, self.current_epoch)
                self.logger.experiment.add_histogram(f"layer_{i}/bias", layer.bias, self.current_epoch)
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay)
        return optimizer

In [3]:
class helicoid_Dataset(Dataset):
    def __init__(self, patient_folders, files, transform=None, mode='labeled'):
        self.data = []
        self.labels = []
        self.transform = transform
        for patient_folder in patient_folders:
            print(f"loading image {patient_folder}")
            patient_folder = os.path.join('/home/martin_ivan/code/own_labels/npj_database/', patient_folder)
            img_data = []
            img_labels = np.load(os.path.join(patient_folder, 'gtMap.npy')).astype(int)
            for file in files:
                img_data_all = np.load(os.path.join(patient_folder, file))
                if mode == 'labeled':
                    # img_data.append(img_data_all[(img_labels !=0) & (img_labels != 4)])
                    img_data.append(img_data_all[(img_labels !=0)])
                elif mode == 'all':
                    img_data.append(img_data_all.reshape(-1, img_data_all.shape[-1]))
                else:
                    raise ValueError("Unknown mode")
            img_data = np.concatenate(img_data, axis=1)

            self.data.append(img_data)
            if mode == 'labeled':
                # self.labels.append(img_labels[(img_labels !=0) & (img_labels != 4)])
                self.labels.append(img_labels[(img_labels !=0)])
            elif mode == 'all':
                self.labels.append(img_labels.reshape(-1))
            else:
                raise ValueError("Unknown mode")
                    
        self.data = np.concatenate(self.data, axis=0)
        self.labels = np.concatenate(self.labels, axis=0)
        print(f"------------- label counts: {np.unique(self.labels, return_counts=True)} -------------")

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx] - 1
        if self.transform:
            x, y = self.transform((x, y))
        return x, y

In [4]:
class ToTensor(object):
    def __call__(self, sample):
        x, y = sample
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

class HelicoidDataModule(pl.LightningDataModule):
    def __init__(self, files, batch_size=64, fold="fold1"):
        super().__init__()
        self.batch_size = batch_size
        self.fold = fold
        self.files = files
        self.transform = ToTensor()

        self.setup()

    def setup(self, stage=None):
        with open('/home/martin_ivan/code/own_labels/folds.json') as f:
            folds = json.load(f)

        if stage=="fit":
            self.dataset_train = helicoid_Dataset(folds[self.fold]["train"], self.files, transform=self.transform)
            self.dataset_val = helicoid_Dataset(folds[self.fold]["val"], self.files, transform=self.transform)
        if stage=="test":
            self.dataset_test = helicoid_Dataset(folds[self.fold]["test"], self.files, transform=self.transform)
        if stage=="predict":
            self.dataset_predict = helicoid_Dataset(folds[self.fold]["test"], self.files, transform=self.transform, mode='all')

    def train_dataloader(self):
        return DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=16)
    
    def val_dataloader(self):
        return DataLoader(self.dataset_val, batch_size=self.batch_size, shuffle=False, num_workers=16)
    
    def test_dataloader(self): 
        return DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=False, num_workers=16)
    
    def predict_dataloader(self):
        return DataLoader(self.dataset_predict, batch_size=self.batch_size, shuffle=False, num_workers=16)
    
    def sample_size(self):
        return self.dataset_train.data.shape[1]
    
    def class_distribution(self):
        return np.unique(self.dataset_train.labels, return_counts=True)[1]
    
    def num_classes(self):
        return len(np.unique(self.dataset_train.labels))

### Load Data

In [5]:
files = ["preprocessed.npy"]#, "heatmaps_osp.npy", "heatmaps_osp_diff.npy", "heatmaps_osp_diff_mc.npy", "heatmaps_icem.npy", "heatmaps_icem_diff.npy", "heatmaps_icem_diff_mc.npy"]
dm = HelicoidDataModule(files=files, fold="fold1")
dm.setup("fit")

loading image 005-01
loading image 007-01
loading image 012-01
loading image 012-02
loading image 018-01
loading image 018-02
loading image 019-01
loading image 020-01
loading image 013-01
loading image 015-01
loading image 017-01
loading image 034-02
loading image 035-01
loading image 036-01
loading image 036-02
loading image 038-01
loading image 040-01
loading image 040-02
loading image 043-01
loading image 043-02
loading image 043-04
loading image 050-01
loading image 051-01
loading image 053-01
loading image 055-01
loading image 056-01
loading image 056-02
loading image 057-01
loading image 058-02
------------- label counts: (array([1, 2, 3, 4]), array([179536,  29262,  70340, 277811])) -------------
loading image 004-02
loading image 008-01
loading image 008-02
loading image 022-01
loading image 022-02
loading image 039-01
loading image 039-02
loading image 041-01
loading image 041-02
------------- label counts: (array([1, 2, 3, 4]), array([39410,  3459,  9548, 43357])) ----------

### Setup training function for ray tuning

In [13]:
def train_func(config):
    model = ClassificationModel(input_dim=dm.sample_size(), output_dim=dm.num_classes(), loss_weight=dm.class_distribution(), config=config)

    trainer = pl.Trainer(
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
        
    )
    trainer = prepare_trainer(trainer)
    trainer.fit(model, dm.train_dataloader(), dm.val_dataloader())

In [15]:
def tune_hyperparameters():

    search_space = {
    "hidden_dim": tune.choice([9]),
    # "num_layers": tune.choice([0,2,4,8,16]),
    "last_layer_dim": tune.choice([9]),
    "lr": tune.loguniform(1e-5, 1e-2),
    "weight_decay": tune.loguniform(1e-5, 1e-2),
    }   

    ray_num_workers = 1
    num_epochs = 50
    num_samples = 20

    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=5, reduction_factor=4)

    run_config = RunConfig(
        checkpoint_config=CheckpointConfig(
            num_to_keep=1,
            checkpoint_score_attribute="val/val_loss",
            checkpoint_score_order="min",
        ),
    )

    ray_trainer = TorchTrainer(
        train_func,
        scaling_config=ScalingConfig(num_workers=1, use_gpu=True, resources_per_worker={"CPU": 16, "GPU": 1}),
        run_config=run_config,)

    tuner = tune.Tuner(
        ray_trainer,
        param_space={"train_loop_config": search_space},
        tune_config=tune.TuneConfig(
            metric="val/val_loss",
            mode="min",
            num_samples=num_samples,
            scheduler=scheduler,
            max_concurrent_trials=ray_num_workers
        ),
    )

    return tuner.fit()

result = tune_hyperparameters()

0,1
Current time:,2024-03-30 20:54:23
Running for:,00:27:18.53
Memory:,360.9/503.7 GiB

Trial name,status,loc,train_loop_config/hi dden_dim,train_loop_config/la st_layer_dim,train_loop_config/lr,train_loop_config/we ight_decay,iter,total time (s),train/train_loss,train/train_accuracy,val/val_loss
TorchTrainer_7dc19_00000,RUNNING,131.159.10.130:647956,9,9,0.00130116,0.000934396,11,1512.67,0.289957,0.857143,5.97782


[36m(RayTrainWorker pid=648046)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=647956)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=647956)[0m - (ip=131.159.10.130, pid=648046) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=648046)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=648046)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=648046)[0m IPU available: False, using: 0 IPUs
[36m(RayTrainWorker pid=648046)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=648046)[0m /home/martin_ivan/anaconda3/envs/thesis/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
[36m(RayTrainWorker pid=648046)[0m Missing logger folder: /tmp/ray/session_2024-03-30_20-03-52_482813_614347/artifacts/2024-03-30_20-27-05/TorchTrainer_2024-03-30_2

In [12]:
from importlib import reload
import model
from model import ClassificationModel
reload(model)

def train(config):
    model = ClassificationModel(input_dim=dm.sample_size(), output_dim=dm.num_classes(), loss_weight=dm.class_distribution(), config=config)
    logger = TensorBoardLogger(config["log_dir"], name="my_model")
    trainer = pl.Trainer(logger=logger, max_epochs=config["num_epochs"], devices=1)
    trainer.fit(model, dm.train_dataloader(), dm.val_dataloader())
    trainer.save_checkpoint("./classification/model.ckpt")
    return model

config = {
    "hidden_dim": 9,
    "num_layers": 0,
    "last_layer_dim": 9,
    "lr": 1e-5,
    "weight_decay": 0,
    "num_epochs": 20,
    "log_dir": "./classification/tb_logs/"
}

model = train(config)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name     | Type               | Params
------------------------------------------------
0 | accuracy | MulticlassAccuracy | 0     
1 | f1_score | MulticlassF1Score  | 0     
2 | layers   | Sequential         | 3.6 K 
------------------------------------------------
3.6 K     Trainable params
0         Non-trainable params
3.6 K     Total params
0.014     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f980efbf010>
Traceback (most recent call last):
  File "/home/martin_ivan/anaconda3/envs/thesis/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/martin_ivan/anaconda3/envs/thesis/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/martin_ivan/anaconda3/envs/thesis/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/home/martin_ivan/anaconda3/envs/thesis/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/home/martin_ivan/anaconda3/envs/thesis/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/home/martin_ivan/anaconda3/envs/thesis/lib/python3.10/select

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.
