In [1]:
import os
import sys
import tarfile

# Patch the path to include local libs
sys.path.insert(0, os.path.abspath("./libs"))

In [2]:
# Import
import pytorch_lightning as pl
from data.data_loader import MyDataModule
from training.PL_train import Main_Loop
import torch.optim as optim
import torchio as tio
import torch
import matplotlib.pyplot as plt
from data.brats_nii_data_utils import nni_utils
import numpy as np

In [3]:
PROCESSED_DATA_PATH = os.path.abspath("./data/processed")
RAW_DATA_PATH = os.path.abspath("./data/raw")

In [4]:
size = (48, 64, 48)
model = "custom"
criterion = "Focal"
batch_size = 4
type_list = ["t1"]
epochs = 50
weight = torch.from_numpy(np.array([0.1, 1, 1, 1, 1])).float().cuda()
model_args = {}

# Data TRansforms
train_transformer = tio.Compose(
    [
        tio.RandomMotion(p=0.2),
        tio.RandomBiasField(p=0.3),
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        tio.RandomNoise(p=0.5),
        tio.RandomFlip(),
        tio.OneOf(
            {
                tio.RandomAffine(): 0.8,
                tio.RandomElasticDeformation(): 0.2,
            }
        ),
    ]
)

val_transformer = tio.Compose(
    [
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
    ]
)

# Dataloading
data_module = MyDataModule(
    data_dir=RAW_DATA_PATH,
    out_dir=PROCESSED_DATA_PATH,
    train_transformer=train_transformer,
    val_transformer=val_transformer,
    size=size,
    type_list=type_list,
    sample_list=type_list,
)

In [5]:
if len(os.listdir(RAW_DATA_PATH)) <= 1:
    tarball_path = input("Path to BRATS 2021 training tarball")
    tarball_path = os.path.abspath(tarball_path)

    if tarfile.is_tarfile(tarball_path):
        # open file
        file = tarfile.open(tarball_path)

        # extracting file
        file.extractall(RAW_DATA_PATH)

        file.close()
    else:
        raise Exception("Valid tarball path not passed")

if len(os.listdir(PROCESSED_DATA_PATH)) <= 1:
    print("Processing images")
    data_module.preprocessing()

In [6]:
trainer = pl.Trainer(gpus=1, max_epochs=epochs)
main = Main_Loop(
    model=model,
    loss=criterion,
    type_list=type_list,
    scheduler=optim.lr_scheduler.ExponentialLR,
    scheduler_args={"gamma": 0.95},
    model_args=model_args,
    loss_args={"weight": weight},
    batch_size=batch_size,
    optimizer=optim.AdamW,
    optimizer_args={"amsgrad": True},
)
trainer.fit(main, data_module)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type        | Params
--------------------------------------
0 | model | CustomModel | 1.4 M 
1 | loss  | FocalLoss   | 0     
--------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.573     Total estimated model params size (MB)


/home/josh/Repos/tumor-segmentation/data/processed


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
trainer.test(main, data_module)