In [21]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
import torch as t
import torch.nn as nn
from pathlib import Path
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
from DiagnosisAI.models.AutoEncoder_UNET.AutoEncoder_BRAIN_UNET import UNet
import torchmetrics
from torchvision.transforms import Compose, ToTensor
import matplotlib.pyplot as plt
import json
from tqdm import tqdm
# import cv2 as cv

In [3]:
class BrainSlicesDataset(t.utils.data.Dataset):
    def __init__(self, file_names: list, transform: bool = True, binary_mask: bool = True):
        self.root_path = Path("../datasets/brain/train_images_max_area/")
        self.file_names = file_names
        self.image_size = (240, 240)
        self.transform = transform
        self.binary_mask = binary_mask
        self.transforms = Compose([ToTensor()])

    def __getitem__(self, idx):
        slice_pickle_path = self.root_path / self.file_names[idx]
        with open(slice_pickle_path, 'rb') as handle:
            slice_data = pickle.load(handle)
        img = slice_data['flair'].astype(np.float32)
        label = slice_data['seg']
        # convert to binary mask
        
        if self.binary_mask:
            label = np.where(label >= 1, 1, 0)

        if self.transform:
            img = self.transforms(img)
            label = self.transforms(label)
            if self.binary_mask:
                label = label.type(t.float32)

        return img, label

    def __len__(self):
        return len(self.file_names)

In [4]:
filenames_pickle = []
for dir in Path("../datasets/brain/train_images_max_area/").iterdir():
    for filename in dir.iterdir():
        rel_path = Path(*filename.parts[-2:])
        filenames_pickle.append(rel_path)

In [5]:
train_names, val_names = train_test_split(filenames_pickle, test_size=0.2, random_state=42)
val_names, test_names = train_test_split(val_names, test_size=0.5, random_state=42)

In [6]:
train_dataset = BrainSlicesDataset(train_names)
val_dataset = BrainSlicesDataset(val_names)
test_dataset = BrainSlicesDataset(test_names)

In [7]:
batch_size = 8

In [8]:
train_loader = t.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=4)
val_loader = t.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=4)
test_loader = t.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=4)

In [9]:
device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')

In [10]:
class Segmenter(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.network = UNet(in_channels=1, out_channels=1, n_filters=16)
        # self.loss_function = nn.CrossEntropyLoss() for semantic
        self.loss_function = nn.BCEWithLogitsLoss() # for binary 
        metrics = torchmetrics.MetricCollection([torchmetrics.Precision(num_classes=1, average='macro', mdmc_average='samplewise'),
                                                 torchmetrics.Recall(num_classes=1, average='macro', mdmc_average='samplewise'),
                                                 torchmetrics.F1Score(num_classes=1, average='macro', mdmc_average='samplewise'),
                                                 torchmetrics.Accuracy(num_classes=1, average='macro', mdmc_average='samplewise')
                                                ])
        self.train_metrics = metrics.clone('train_')
        self.val_metrics = metrics.clone('val_')

    def forward(self, x):
        return self.network(x)
    
    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        
        loss = self.loss_function(outputs, labels)
        self.log('train_loss', loss)
        self.log_dict(self.train_metrics(outputs.view(-1), labels.type(t.int32).view(-1)))
        
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)

        loss = self.loss_function(outputs, labels)
        self.log('val_loss', loss, prog_bar=True)
        self.log_dict(self.val_metrics(outputs.view(-1), labels.type(t.int32).view(-1)))

    def configure_optimizers(self):
        return t.optim.Adam(self.parameters(), lr=1e-3)

In [11]:
f = open('../config/secret.json')
api_key = json.load(f)

In [12]:
segmenter = Segmenter()

model_checkpoint = pl.callbacks.ModelCheckpoint(dirpath='../src/DiagnosisAI/models/AutoEncoder_UNET/checkpoints')
early_stopping = pl.callbacks.EarlyStopping(monitor='val_loss', patience=10)
logger = pl.loggers.NeptuneLogger(
    api_key=api_key['api_neptune'],
    project="jumpincrane/Binary-slices"
)

In [13]:
trainer = pl.Trainer(logger=logger, callbacks=[model_checkpoint, early_stopping], gpus=0, max_epochs=100)
trainer.fit(segmenter, train_dataloaders=train_loader, val_dataloaders=val_loader)

logger.run.stop()
t.save(segmenter.state_dict(), "../src/DiagnosisAI/AutoEncoder_UNET/binary_brats_weights")

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(

  | Name          | Type              | Params
----------------------------------------------------
0 | network       | UNet              | 2.2 M 
1 | loss_function | BCEWithLogitsLoss | 0     
2 | train_metrics | MetricCollection  | 0     
3 | val_metrics   | MetricCollection  | 0     
----------------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params
8.640     Total estimated model params size (MB)


Epoch 0:   9%|▊         | 12/141 [01:04<11:36,  5.40s/it, loss=1.42]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


https://app.neptune.ai/jumpincrane/Binary-slices/e/BIN-5
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.
Shutting down background jobs, please wait a moment...
Done!
Waiting for the remaining 6 operations to synchronize with Neptune. Do not kill this process.
All 6 operations synced, thanks for waiting!
Explore the metadata in the Neptune app:
https://app.neptune.ai/jumpincrane/Binary-slices/e/BIN-5


In [None]:
# load best weights
segmenter = Segmenter.load_from_checkpoint(model_checkpoint.best_model_path).to(device)  # wczytanie najlepszych wag z treningu
segmenter = segmenter.eval()