# Imports

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from torch.optim.optimizer import Optimizer
from torchvision.models import DenseNet121_Weights
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import datetime
from sklearn.metrics import f1_score
from PIL import Image
import nibabel as nib

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Visualisation

In [None]:
def view_nii_pic(nii_data: np.ndarray) -> None:
    for i in range(nii_data.shape[2]):
        cv2.imshow('slice', nii_data[:, :, i])
        cv2.waitKey(0)
    cv2.destroyAllWindows()
    
def visualize_photo(img: np.ndarray, photo_title: str, *slices: int) -> None:
    print(f"Visualizing {photo_title}")
    plt.figure(figsize=(5 * len(slices), 5)) 
    
    for i, slice_num in enumerate(slices):
        plt.subplot(1, len(slices), i + 1)
        plt.title(f"photo Slice {slice_num}")
        plt.imshow(img[:, :, slice_num], cmap="gray")
        
    plt.tight_layout() 
    plt.show()
  
    
def visualize_photos(original: np.ndarray, segmented: np.ndarray, reference: np.ndarray, *slices: int) -> None:
    num_slices = len(slices)
    plt.figure(figsize=(15, 5 * num_slices))  # Adjust figure size based on the number of slices

    for i, slice_num in enumerate(slices):
        # Original slice
        plt.subplot(num_slices, 3, 3 * i + 1)
        plt.title(f"Original Slice {slice_num}")
        plt.imshow(original[:, :, slice_num], cmap="gray")
        
        # Segmented slice
        plt.subplot(num_slices, 3, 3 * i + 2)
        plt.title(f"Segmented Slice {slice_num}")
        plt.imshow(segmented[:, :, slice_num], cmap="gray")
        
        # Reference slice
        plt.subplot(num_slices, 3, 3 * i + 3)
        plt.title(f"Reference Slice {slice_num}")
        plt.imshow(reference[:, :, slice_num], cmap="gray")

    plt.tight_layout() 
    plt.show()
    
def plot_histogram(img: np.ndarray) -> None:
    plt.hist(img.ravel(), bins=256, range=(img.min()+1, img.max()-1), fc='k', ec='k')
    plt.axvline(x=-320, color='red', linestyle='--', linewidth=1.5)
    plt.show()

# Dataset

In [None]:
class AbdomenDataset(Dataset):
    def __init__(self, filepath: str, label_filepath: str, transform=None):
        self.transform = transform
        self.data: list = self._load_nii_gz_files(filepath)
        self.labels: list = self._load_labels_file(label_filepath)
        
        self.frames_list: list = []
        self.labels_list: list = []
        
        for (end, begin), img in zip(self.labels, self.data):
            for i in range(0, img.shape[2]):
                self.frames_list.append(img[:, :, i])
                self.labels_list.append(1 if begin <= i <= end else 0)
        
    def __len__(self) -> int:
        return len(self.frames_list)
    
    def __getitem__(self, idx: int) -> dict:
        img = self.frames_list[idx]
        label = float(self.labels_list[idx])
        
        img = torch.tensor(img).float().unsqueeze(0)
        img = self.transform(img)            
        return img.repeat(3, 1, 1), label
        
        
    def _load_labels_file(self, label_filepath: str) -> list:
        labels: list = []
        with open(label_filepath, 'r') as file:
            for line in file:
                _, end, begin = line.split()
                labels.append((int(end), int(begin)))
        return labels
    
    def _load_nii_gz_files(self, filepath) -> list:
        data: list = []
        for file_name in os.listdir(filepath):
            if file_name.endswith('.nii.gz'):
                file_path = os.path.join(filepath, file_name)
                nii_img = nib.load(file_path)
                nii_data = nii_img.get_fdata()  # img as numpy array
                data.append(nii_data)
        return data

# Model

In [None]:
class AbdomenModel(nn.Module):
    def __init__(self) -> None:
        super(AbdomenModel, self).__init__()
        self.densenet = models.densenet121(weights=DenseNet121_Weights.DEFAULT)
        self.freeze_densenet()
        self.densenet.classifier = nn.Sequential(
            nn.Dropout(0.25),
            nn.Linear(self.densenet.classifier.in_features, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.densenet(x)
    
    def unfreeze_densenet(self) -> None:
        for param in self.densenet.parameters():
            param.requires_grad = True
    
    def freeze_densenet(self) -> None:
        for param in self.densenet.parameters():
            param.requires_grad = False

# Training loop

In [None]:
def training_loop(model: AbdomenModel, criterion: torch.nn.Module, optimizer: Optimizer, dataloader: dict, EPOCHS: int = 10):
    accuracy_history: list = []
    loss_history: list = []
    val_accuracy_history: list = []
    val_loss_history: list = []
    
    for epoch in range(EPOCHS):    
        # Training
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for data in tqdm(dataloader['train']):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device).float()
            
            optimizer.zero_grad()
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            preds = (outputs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        
        train_loss = running_loss / len(dataloader['train'])
        train_accuracy = correct / total
        loss_history.append(train_loss)
        accuracy_history.append(train_accuracy)
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for data in tqdm(dataloader['val']):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device).float()
                
                outputs = model(inputs).squeeze()
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                preds = (outputs > 0.5).float()
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        
        val_loss = val_loss / len(dataloader['val'])
        val_accuracy = val_correct / val_total
        val_loss_history.append(val_loss)
        val_accuracy_history.append(val_accuracy)
        
        print(f"Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        
    return model, [accuracy_history, loss_history, val_accuracy_history, val_loss_history]

# Algorithm to select window of abdomen

In [None]:
def select_window_of_abdomen(list_of_probabilities: list) -> tuple[int, int]:
    best_begin = 0
    best_end = 0
    best_value = 0
    
    list_of_proba = np.array(list_of_probabilities)
    list_of_proba_rev = 1 - list_of_proba
    
    for start in range(len(list_of_probabilities)):
        for end in range(start, len(list_of_probabilities)):
            product = np.prod(list_of_probabilities[start:end])
            product_rev_rest = np.prod(list_of_proba_rev[end + 1:]) * np.prod(list_of_probabilities[:start])
            value = product * product_rev_rest
            
            if value > best_value:
                best_value = value
                best_begin = start
                best_end = end
                
    return best_begin, best_end

# Train

In [None]:
transform_val = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
])

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.RandomRotation(40),
    torchvision.transforms.RandomAffine(0, translate=(0.1, 0.1)),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
])

datasets = {
    "train": AbdomenDataset('data/train', 'data/oznaczenia.txt', transform=transform),
    "val": AbdomenDataset('data/val', 'data/oznaczenia.txt', transform=transform_val),
}

dataloaders = {
    "train": DataLoader(datasets["train"], batch_size=8, shuffle=True),
    "val": DataLoader(datasets["val"], batch_size=8, shuffle=True)
}

In [None]:
model = AbdomenModel().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
model, history = training_loop(model, criterion, optimizer, dataloaders, EPOCHS=5)
model.unfreeze_densenet()
model, history = training_loop(model, criterion, optimizer, dataloaders, EPOCHS=70)
torch.save(model.state_dict(), f"results/model_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{max(history[2])}.pth")

# Test

In [None]:
# model_path: str = "model_2024-12-08_20-00-00_0.9.pth"
# model.load_state_dict(torch.load(model_path))