In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
from typing import Tuple, List
import random
import segmentation_models_pytorch as smp
from sklearn.metrics import jaccard_score, f1_score, precision_score, recall_score, accuracy_score


In [None]:
# Set random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Augmentations
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.ElasticTransform(p=0.5),
    A.GridDistortion(p=0.5),
    A.GaussNoise(p=0.2),
    A.CLAHE(p=0.2),
    A.Blur(p=0.2),
    A.RandomResizedCrop(height=128, width=128, p=0.5),
    A.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0),
    ToTensorV2()
])

# Separate transform for validation/testing (no augmentation)
val_transform = A.Compose([
    A.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0),
    ToTensorV2()
])


In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir: str, mask_dir: str, transform=None, size: Tuple[int, int] = (128, 128)):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.size = size
        self.image_paths = sorted(os.listdir(image_dir), key=lambda x: int(x.split('_')[-1].split('.')[0]))
        self.mask_paths = sorted(os.listdir(mask_dir), key=lambda x: int(x.split('_')[-2]))

        assert len(self.image_paths) == len(self.mask_paths), "Number of images and masks should be the same"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_paths[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_paths[idx])

        image = Image.open(image_path).convert("L")
        mask = Image.open(mask_path).convert("L")

        image = image.resize(self.size, resample=Image.BILINEAR)
        mask = mask.resize(self.size, resample=Image.NEAREST)

        image = np.array(image)
        mask = np.array(mask).astype(np.int64)

        background_value = 0
        head_value = 150
        symp_value = 76

        mask = np.where(mask == background_value, 0, mask)
        mask = np.where(mask == head_value, 1, mask)
        mask = np.where(mask == symp_value, 2, mask)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

# Function to calculate metrics
def calculate_metrics(y_true, y_pred):
    jaccard = jaccard_score(y_true, y_pred, average='macro')
    dice = f1_score(y_true, y_pred, average='macro')
    precision = precision_score(y_true, y_pred, average='macro')
    recall = recall_score(y_true, y_pred, average='macro')
    accuracy = accuracy_score(y_true, y_pred)
    sensitivity = recall 
    
    return jaccard, dice, precision, recall, accuracy, sensitivity

# Function to evaluate model on validation or test set
def evaluate_model(model, dataloader, device):
    model.eval()
    y_true_all = []
    y_pred_all = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            if isinstance(outputs, tuple):
                outputs = outputs[0]  

            preds = torch.argmax(outputs, dim=1)
            
            y_true_all.extend(labels.cpu().numpy().flatten())
            y_pred_all.extend(preds.cpu().numpy().flatten())

    y_true_all = np.array(y_true_all)
    y_pred_all = np.array(y_pred_all)

    return calculate_metrics(y_true_all, y_pred_all)



In [None]:
train_image_dir = r'C:\Users\cinth\Documentos\ams\data_science\actual_thesis\codes\MedSAM_Universeg_2024\datasets\data\dataset_complete_1\partitioned_dataset_original\images\train'
train_mask_dir = r'C:\Users\cinth\Documentos\ams\data_science\actual_thesis\codes\MedSAM_Universeg_2024\datasets\data\dataset_complete_1\partitioned_dataset_original\masks\train'
val_image_dir = r'C:\Users\cinth\Documentos\ams\data_science\actual_thesis\codes\MedSAM_Universeg_2024\datasets\data\dataset_complete_1\partitioned_dataset_original\images\val'
val_mask_dir = r'C:\Users\cinth\Documentos\ams\data_science\actual_thesis\codes\MedSAM_Universeg_2024\datasets\data\dataset_complete_1\partitioned_dataset_original\masks\val'
test_image_dir = r'C:\Users\cinth\Documentos\ams\data_science\actual_thesis\codes\MedSAM_Universeg_2024\datasets\data\dataset_complete_1\partitioned_dataset_original\images\test'
test_mask_dir = r'C:\Users\cinth\Documentos\ams\data_science\actual_thesis\codes\MedSAM_Universeg_2024\datasets\data\dataset_complete_1\partitioned_dataset_original\masks\test'

train_dataset = SegmentationDataset(train_image_dir, train_mask_dir, transform=transform)
val_dataset = SegmentationDataset(val_image_dir, val_mask_dir, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_dataset = SegmentationDataset(test_image_dir, test_mask_dir, transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
# Function to train and save multiple models
def train_and_save_models(num_models, train_loader, val_loader, device, num_epochs=50):
    model_paths = []
    for i in range(num_models):
        print(f"Training model {i+1}/{num_models}")
        model = smp.Unet(encoder_name="efficientnet-b0", encoder_weights="imagenet", in_channels=1, classes=3)
        model = model.to(device)
        
        optimizer = optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-6)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
        weights = torch.tensor([0.5270, 0.8379, 0.98], device=device)
        criterion = nn.CrossEntropyLoss(weight=weights)
        
        best_loss = float('inf')
        patience = 5
        epochs_without_improvement = 0

        for epoch in range(num_epochs):
            print(f"Starting epoch {epoch+1}/{num_epochs}")
            model.train()
           


            running_loss = 0.0
            train_y_true_all = []
            train_y_pred_all = []

            for inputs, labels in train_loader:
                inputs = inputs.to(device)
                labels = labels.to(device).long()
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                train_y_true_all.extend(labels.cpu().numpy().flatten())
                train_y_pred_all.extend(preds.cpu().numpy().flatten())

            epoch_loss = running_loss / len(train_loader)
            train_y_true_all = np.array(train_y_true_all)
            train_y_pred_all = np.array(train_y_pred_all)
            train_metrics = calculate_metrics(train_y_true_all, train_y_pred_all)
            train_jaccard, train_dice, train_precision, train_recall, train_accuracy, train_sensitivity = train_metrics
            print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss}, Jaccard Score: {train_jaccard}, Dice Coefficient: {train_dice}, Precision: {train_precision}, Recall: {train_recall}, Accuracy: {train_accuracy}, Sensitivity: {train_sensitivity}")

            model.eval()
            val_loss = 0.0
            val_y_true_all = []
            val_y_pred_all = []

            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs = inputs.to(device)
                    labels = labels.to(device).long()
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()

                    preds = torch.argmax(outputs, dim=1)
                    val_y_true_all.extend(labels.cpu().numpy().flatten())
                    val_y_pred_all.extend(preds.cpu().numpy().flatten())

            val_loss /= len(val_loader)
            val_y_true_all = np.array(val_y_true_all)
            val_y_pred_all = np.array(val_y_pred_all)
            val_metrics = calculate_metrics(val_y_true_all, val_y_pred_all)
            val_jaccard, val_dice, val_precision, val_recall, val_accuracy, val_sensitivity = val_metrics
            print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss}, Jaccard Score: {val_jaccard}, Dice Coefficient: {val_dice}, Precision: {val_precision}, Recall: {val_recall}, Accuracy: {val_accuracy}, Sensitivity: {val_sensitivity}")
            scheduler.step(val_loss)

            if val_loss < best_loss:
                best_loss = val_loss
                model_path = f'best_model_{i}.pth'
                torch.save(model.state_dict(), model_path)
                print(f"Epoch {epoch+1}: New best model saved with validation loss {val_loss}")
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1
                print(f"Epoch {epoch+1}: No improvement in validation loss")
                if epochs_without_improvement >= patience:
                    print("Early stopping due to no improvement in validation loss")
                    break

        model_paths.append(model_path)
    return model_paths

# Train and save 2 models
num_models = 2
model_paths = train_and_save_models(num_models, train_loader, val_loader, device)

# Load the models
models = []
for model_path in model_paths:
    model = smp.Unet(encoder_name="efficientnet-b0", encoder_weights=None, in_channels=1, classes=3)
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    model.eval()
    models.append(model)

# Function to make ensemble predictions
def ensemble_predict(models, dataloader, device):
    all_preds = []

    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            preds = []

            for model in models:
                outputs = model(inputs)
                softmax_outputs = torch.softmax(outputs, dim=1)
                preds.append(softmax_outputs)

            # Average the softmax outputs from all models
            avg_preds = torch.mean(torch.stack(preds), dim=0)
            final_preds = torch.argmax(avg_preds, dim=1)

            all_preds.extend(final_preds.cpu().numpy())

    return np.array(all_preds)

# Evaluate the ensemble model
def evaluate_ensemble(ensemble_preds, dataloader):
    y_true_all = []
    for _, labels in dataloader:
        y_true_all.extend(labels.numpy().flatten())
    y_true_all = np.array(y_true_all)
    
    return calculate_metrics(y_true_all, ensemble_preds.flatten())

# Get ensemble predictions for the test set
ensemble_preds = ensemble_predict(models, test_loader, device)

# Get test metrics for the ensemble model
test_metrics = evaluate_ensemble(ensemble_preds, test_loader)
test_jaccard, test_dice, test_precision, test_recall, test_accuracy, test_sensitivity = test_metrics
print(f"Ensemble Test Metrics - Jaccard Score: {test_jaccard}, Dice Coefficient: {test_dice}, Precision: {test_precision}, Recall: {test_recall}, Accuracy: {test_accuracy}, Sensitivity: {test_sensitivity}")

# Save the ensemble test metrics to a file
with open('ensemble_test_metrics.txt', 'w') as f:
    f.write(f"Ensemble Test Metrics:\n")
    f.write(f"Jaccard Score: {test_jaccard}\n")
    f.write(f"Dice Coefficient: {test_dice}\n")
    f.write(f"Precision: {test_precision}\n")
    f.write(f"Recall: {test_recall}\n")
    f.write(f"Accuracy: {test_accuracy}\n")
    f.write(f"Sensitivity: {test_sensitivity}\n")

# Save the final ensemble model weights
# Averaging the weights of the models 
def average_weights(models):
    avg_model = models[0]
    for key in avg_model.state_dict().keys():
        for model in models[1:]:
            avg_model.state_dict()[key] += model.state_dict()[key]
        avg_model.state_dict()[key] = avg_model.state_dict()[key] / len(models)
    return avg_model

ensemble_model = average_weights(models)
torch.save(ensemble_model.state_dict(), 'ensemble_model_final.pth')
print("Final ensemble model weights saved as 'ensemble_model_final.pth'")

