# Deep Learning Part

## TODO
* losses should be weighted for different classes, there are 8 to 9 times of healthy images than there is for bleeding, so model will be inclined to predict healthy
* hyperparameters for deep learning runs should be saved

## Imports

In [1]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
from torch.optim import Adam, lr_scheduler

## Dataset

In [2]:
class BleedDataset(Dataset):
    def __init__(self, root_dir, mode="RGB"):
        self.root_dir = root_dir
        self.bleeding_dir = os.path.join(root_dir, "bleeding")
        self.healthy_dir = os.path.join(root_dir, "healthy")

        # Combine images and labels into tuples
        self.data = [
            (os.path.join(self.bleeding_dir, p), 1)
            for p in os.listdir(self.bleeding_dir)
        ] + [
            (os.path.join(self.healthy_dir, p), 0)
            for p in os.listdir(self.healthy_dir)
        ]

        self.mode = mode.lower()
        if self.mode not in {"rgb", "gray"}:
            raise ValueError("Invalid mode. Use 'RGB' or 'gray'.")

    def __len__(self):
        return len(self.data)
    
    @staticmethod
    def _preprocess_image(image):
        image = image[32:544, 32:544]  # Crop black borders
        image[:48, :48] = 0  # Remove artifacts
        image[:31, 452:] = 0
        if image.ndim == 3:
            image = np.transpose(image, (2, 0, 1))  # Convert to PyTorch format
        else:
            image = image[np.newaxis, ...]
        return image

    def __getitem__(self, idx):
        image_path, label = self.data[idx]
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE if self.mode == "gray" else cv2.IMREAD_COLOR)
        image = self._preprocess_image(image)
        return image, label

## Model

In [3]:
class DummyModel(nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 128, kernel_size=3)
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 64, kernel_size=3)
        self.drop5 = nn.Dropout2d(p=0.2)
        self.conv6 = nn.Conv2d(64, 16, kernel_size=3)
        self.drop6 = nn.Dropout2d(p=0.2)
        self.relu = F.relu
        self.max_pool = nn.MaxPool2d(kernel_size=(2,2), stride=2)
        self.sigmoid = nn.Sigmoid()

        self.fully_conv = nn.Conv2d(16, 1, kernel_size=10)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.max_pool(self.relu(self.bn2(self.conv2(x))))
        x = self.max_pool(self.relu(self.bn3(self.conv3(x))))
        
        x = self.max_pool(self.relu(self.bn4(self.conv4(x))))

        x = self.drop5(self.relu(self.conv5(x)))
        x = self.drop6(self.relu(self.conv6(x)))
        x = self.sigmoid(self.fully_conv(x))
        
        return x[:,0,0]

## Training Space

### Hyperparameters

In [4]:
SAVE_PATH = "./"

TRAIN_TEST_SPLIT = (0.8, 0.1) # remaining parts will be test
DIRECTORY_PATH = "../project_capsule_dataset"
BATCH_SIZE = 8
LR = 0.001 # learning rate

NUM_OF_EPOCHS = 1
EARLY_STOP_LIMIT = 3

THRESHOLD = 0.5 # predictions bigger than threshold will be counted as bleeding prediction, and lower ones will be healthy prediction

### Dataset, Model etc. Inıtıalization

In [5]:
### ---|---|---|---|---|---|---|---|---|---|--- MODEL & DATASET ---|---|---|---|---|---|---|---|---|---|--- ###
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

# Model Initialization
def initialize_model(model_class, save_path):
    model = model_class().to(device)
    model_serial_number = f"training_with_{model.__class__.__name__}_{datetime.now().strftime('on_%m.%d._at_%H:%M:%S')}"
    model_serial_path = os.path.join(save_path, model_serial_number)
    os.makedirs(model_serial_path, exist_ok=True)
    return model, model_serial_path

model, model_serial_path = initialize_model(DummyModel, SAVE_PATH)

# Dataset Preparation
def prepare_datasets(dataset_class, directory, split_ratios, batch_size, image_mode="RGB", seed=0):
    dataset = dataset_class(directory, mode=image_mode)
    total_size = len(dataset)
    train_size = int(split_ratios[0] * total_size)
    test_size = int(split_ratios[1] * total_size)
    validation_size = total_size - train_size - test_size

    torch.manual_seed(seed)
    train_dataset, validation_dataset, test_dataset = random_split(dataset, [train_size, validation_size, test_size])
    return {
        "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True),
        "validation": DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, pin_memory=True),
        "test": DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True),
    }

data_loaders = prepare_datasets(BleedDataset, DIRECTORY_PATH, TRAIN_TEST_SPLIT, BATCH_SIZE, image_mode="RGB", seed=0)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

### Training Loop

In [6]:
### ---|---|---|---|---|---|---|---|---|---|--- TRAINING ---|---|---|---|---|---|---|---|---|---|--- ###
train_losses, validation_losses = [], []
min_validation_loss = None # minimum achieved loss on validation dataset, used for early stopping
min_validation_path = None # path to model checkpoint file
early_stop_step = 0

for epoch in range(NUM_OF_EPOCHS):
    averaged_training_loss = 0
    for batch_idx, (images, labels) in tqdm(enumerate(data_loaders['train']), leave=False):
        images, labels = images.to(device), labels.to(device)

        model = model.train()
        outputs = model(images.type(torch.float))

        float_outputs = outputs[:,0].type(torch.float)
        float_labels = labels.type(torch.float)
        train_loss = criterion(float_outputs, float_labels)
        averaged_training_loss = averaged_training_loss + train_loss

        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

    # calculating the average loss in this epoch's training loop
    averaged_training_loss = averaged_training_loss / len(data_loaders['train'])

    with torch.no_grad():
        model = model.eval()
        
        # calculate validation loss
        validation_loss = 0.0
        for validation_images, validation_labels in data_loaders['validation']:
            validation_images, validation_labels = validation_images.to(device), validation_labels.to(device)

            validation_outputs = model(validation_images.type(torch.float))

            float_validation_outputs = validation_outputs[:,0].type(torch.float)
            float_validation_labels = validation_labels.type(torch.float)
            validation_loss += criterion(float_validation_outputs, float_validation_labels).item()

        # and average the loss over dataset length
        validation_loss /= len(data_loaders['validation'])
    
    # if this is first validation or a new minimum is achieved
    if min_validation_loss is None or validation_loss < min_validation_loss:
        early_stop_step = 0
        min_validation_loss = validation_loss
        
        # if there is a checkpoint file, remove it
        if min_validation_path is not None:
            os.remove(min_validation_path)
        
        # save the new checkpoint file
        min_validation_path = os.path.join(model_serial_path, "min_validation_loss:"+str(min_validation_loss) + "_epoch:" + str(epoch) + ".pth")
        torch.save(model.state_dict(), min_validation_path)

    # log the losses, and append to the lists
    print(f"Epoch: {epoch+1} | training loss: {averaged_training_loss.item()} | min validation loss: {min_validation_loss}", flush=True)
    train_losses.append(averaged_training_loss.item())
    validation_losses.append(validation_loss)
    scheduler.step()
    
    # check for early stopping
    early_stop_step = early_stop_step + 1
    if early_stop_step == EARLY_STOP_LIMIT:
        print("early stopping...")
        break

                        

KeyboardInterrupt: 

In [None]:
plt.plot(train_losses, color='blue', label='Train Loss')
plt.plot(validation_losses, color='orange', label='Validation Loss')
plt.legend()
plt.savefig(os.path.join(model_serial_path, "losses.png"))

## Testing Space

In [None]:
### ---|---|---|---|---|---|---|---|---|---|--- TESTING ---|---|---|---|---|---|---|---|---|---|--- ###
loaded_model = DummyModel().to(device)
#loaded_model.load_state_dict(torch.load(min_validation_path))
loaded_model = loaded_model.eval()

# class_correct counts how many correct predictions for that label [corrects_for_label_0, corrects_for_label_1]
# class_total counts how many predictions are there for that label [predictions_for_label_0, predictions_for_label_1]
class_correct, class_total = [0,0], [0,0]
with torch.no_grad():
    for test_images, test_labels in tqdm(data_loaders['test']):
        test_images, test_labels = test_images.to(device), test_labels.to(device)

        test_outputs = loaded_model(test_images.type(torch.float))
        
        test_outputs = test_outputs[:,0].type(torch.float)
        test_outputs[test_outputs >= THRESHOLD] = 1
        test_outputs[test_outputs < THRESHOLD] = 0

        # calculate indices for correct predictions
        correct = (test_outputs == test_labels).squeeze()
        for e, label in enumerate(test_labels):
            # increase the correct prediction count for that label
            class_correct[label] += correct[e].item()
            class_total[label] += 1

In [None]:
# Total: accuracy for whole dataset
print(f"Total accuracy: {sum(class_correct)/sum(class_total)} on threshold: {THRESHOLD}")
print(f"Healthy detection: {class_correct[0]}/{class_total[0]} | accuracy: {class_correct[0]/class_total[0]}")
print(f"Bleeding detection: {class_correct[1]}/{class_total[1]} | accuracy: {class_correct[1]/class_total[1]}")

with open(os.path.join(model_serial_path, "accuracy.txt"), 'w') as txt:
    txt.write(f"Total accuracy: {sum(class_correct)/sum(class_total)} on threshold: {THRESHOLD}\n")
    txt.write(f"Healthy detection: {class_correct[0]}/{class_total[0]} | accuracy: {class_correct[0]/class_total[0]}\n")
    txt.write(f"Bleeding detection: {class_correct[1]}/{class_total[1]} | accuracy: {class_correct[1]/class_total[1]}\n")

torch.cuda.empty_cache()