## initial implimentation-for backup 

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import csv
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
import random
import torchvision.transforms.functional as TF

# Import the UNet++ model
from unetplus import NestedUNet

class CataractDataset(Dataset):
    def __init__(self, csv_file, image_transform=None, mask_transform=None, augment=False):
        self.data = []
        self.questions = set()
        with open(csv_file, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.data.append(row)
                self.questions.add(row['Questions'])
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = Image.open(item['Image_Paths']).convert('RGB')
        mask = Image.open(item['Mask_Paths']).convert('L')
        question = item['Questions']
        label = item['Labels']

        if self.augment:
            image, mask = self.apply_augmentation(image, mask)

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask, question, label

    def apply_augmentation(self, image, mask):
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        angle = random.uniform(-10, 10)
        image = TF.rotate(image, angle)
        mask = TF.rotate(mask, angle)
        brightness_factor = random.uniform(0.8, 1.2)
        image = TF.adjust_brightness(image, brightness_factor)
        return image, mask

class ImageEncoder(nn.Module):
    def __init__(self, num_classes=1, input_channels=3, deep_supervision=False):
        super().__init__()
        self.unetpp = NestedUNet(num_classes=num_classes, input_channels=input_channels, deep_supervision=deep_supervision)
        self.deep_supervision = deep_supervision

    def forward(self, x):
        return self.unetpp(x)

class LLMPromptEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        self.model = AutoModel.from_pretrained("distilbert-base-uncased")
        self.fc = nn.Linear(768, 1024)

    def forward(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        outputs = self.model(**inputs)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        return self.fc(pooled_output)

class MaskDecoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.decoder = nn.Sequential(
            self.upconv_block(in_channels, 256),
            self.upconv_block(256, 128),
            self.upconv_block(128, 64),
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.decoder(x)

class LLMSupervisedSAM(nn.Module):
    def __init__(self, num_classes=1, input_channels=3, deep_supervision=False):
        super().__init__()
        self.image_encoder = ImageEncoder(num_classes=num_classes, input_channels=input_channels, deep_supervision=deep_supervision)
        self.prompt_encoder = LLMPromptEncoder()
        self.mask_decoder = MaskDecoder(in_channels=num_classes + 1024)
        self.deep_supervision = deep_supervision

    def forward(self, image, questions):
        image_features = self.image_encoder(image)
        
        batch_size = image.size(0)
        prompt_features = self.prompt_encoder(questions)
        
        if self.deep_supervision:
            outputs = []
            for feature in image_features:
                resized_prompt = F.interpolate(prompt_features.unsqueeze(2).unsqueeze(3), 
                                               size=feature.shape[2:], 
                                               mode='bilinear', 
                                               align_corners=False)
                combined_features = torch.cat([feature, resized_prompt], dim=1)
                mask = self.mask_decoder(combined_features)
                outputs.append(mask)
            return outputs
        else:
            prompt_features = prompt_features.view(batch_size, -1, 1, 1).expand(-1, -1, image_features.shape[2], image_features.shape[3])
            combined_features = torch.cat([image_features, prompt_features], dim=1)
            mask = self.mask_decoder(combined_features)
            return mask

def dice_coefficient(pred, target):
    smooth = 1.0
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)

def visualize_results(image, mask, prediction, question, output_path):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(image.permute(1, 2, 0))  # Change from (C, H, W) to (H, W, C)
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(image.permute(1, 2, 0))
    plt.imshow(mask.squeeze(), alpha=0.5, cmap='jet')
    plt.title("Ground Truth")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(image.permute(1, 2, 0))
    plt.imshow(prediction.squeeze(), alpha=0.5, cmap='jet')
    plt.title("Predicted Mask")
    plt.axis('off')

    plt.suptitle(f"Question: {question}")
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def train_and_evaluate(model, train_loader, val_loader, test_loader, criterion, optimizer, num_epochs, device, output_folder):
    train_folder = os.path.join(output_folder, "train")
    val_folder = os.path.join(output_folder, "validation")
    test_folder = os.path.join(output_folder, "test")
    os.makedirs(train_folder, exist_ok=True)
    os.makedirs(val_folder, exist_ok=True)
    os.makedirs(test_folder, exist_ok=True)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5, factor=0.5, verbose=True)
    best_val_dice = 0
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_dice = 0
        for i, (images, masks, questions, _) in enumerate(train_loader):
            images, masks = images.to(device), masks.to(device)
            batch_size = images.size(0)
            
            #print(f"Batch {i} - Image shape: {images.shape}, Mask shape: {masks.shape}")
            
            # Ensure questions match the batch size
            questions = questions[:batch_size]
            
            optimizer.zero_grad()
            outputs = model(images, questions)
            
            if isinstance(outputs, list):  # deep supervision
                loss = sum([criterion(output, masks) for output in outputs]) / len(outputs)
                dice = sum([dice_coefficient(output, masks) for output in outputs]) / len(outputs)
                output_for_vis = outputs[-1]  # Use the last output for visualization
            else:
                loss = criterion(outputs, masks)
                dice = dice_coefficient(outputs, masks)
                output_for_vis = outputs

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_dice += dice.item()

            if i % 10 == 0:
                visualize_results(
                    images[0].cpu(),  # Single image
                    masks[0].cpu(),   # Single mask
                    output_for_vis[0].detach().cpu(),  # Single prediction
                    questions[0], 
                    os.path.join(train_folder, f"epoch_{epoch+1}_batch_{i}.png")
                )

        avg_loss = total_loss / len(train_loader)
        avg_dice = total_dice / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Dice: {avg_dice:.4f}")

        model.eval()
        val_loss = 0
        val_dice = 0
        with torch.no_grad():
            for i, (images, masks, questions, _) in enumerate(val_loader):
                images, masks = images.to(device), masks.to(device)
                batch_size = images.size(0)
                questions = questions[:batch_size]
                outputs = model(images, questions)
                
                if isinstance(outputs, list):
                    val_loss += sum([criterion(output, masks).item() for output in outputs]) / len(outputs)
                    val_dice += sum([dice_coefficient(output, masks).item() for output in outputs]) / len(outputs)
                    output_for_vis = outputs[-1]
                else:
                    val_loss += criterion(outputs, masks).item()
                    val_dice += dice_coefficient(outputs, masks).item()
                    output_for_vis = outputs

                if i % 5 == 0:
                    visualize_results(
                        images[0].cpu(),
                        masks[0].cpu(),
                        output_for_vis[0].detach().cpu(),
                        questions[0],
                        os.path.join(val_folder, f"epoch_{epoch+1}_batch_{i}.png")
                    )

        avg_val_loss = val_loss / len(val_loader)
        avg_val_dice = val_dice / len(val_loader)
        print(f"Validation Loss: {avg_val_loss:.4f}, Validation Dice: {avg_val_dice:.4f}")

        scheduler.step(avg_val_dice)

        if avg_val_dice > best_val_dice:
            best_val_dice = avg_val_dice
            torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth"))

    # Load best model for testing
    model.load_state_dict(torch.load(os.path.join(output_folder, "best_model.pth")))

    # Test set evaluation
    model.eval()
    test_loss = 0
    test_dice = 0
    with torch.no_grad():
        for i, (images, masks, questions, _) in enumerate(test_loader):
            images, masks = images.to(device), masks.to(device)
            batch_size = images.size(0)
            questions = questions[:batch_size]
            outputs = model(images, questions)
            
            if isinstance(outputs, list):
                test_loss += sum([criterion(output, masks).item() for output in outputs]) / len(outputs)
                test_dice += sum([dice_coefficient(output, masks).item() for output in outputs]) / len(outputs)
                output_for_vis = outputs[-1]
            else:
                test_loss += criterion(outputs, masks).item()
                test_dice += dice_coefficient(outputs, masks).item()
                output_for_vis = outputs

            visualize_results(
                images[0].cpu(),
                masks[0].cpu(),
                output_for_vis[0].detach().cpu(),
                questions[0],
                os.path.join(test_folder, f"test_sample_{i}.png")
            )

    avg_test_loss = test_loss / len(test_loader)
    avg_test_dice = test_dice / len(test_loader)
    print(f"Test Loss: {avg_test_loss:.4f}, Test Dice: {avg_test_dice:.4f}")

def main(csv_file):
    output_folder = "segmentation_results"
    os.makedirs(output_folder, exist_ok=True)

    # Separate transforms for images and masks
    image_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    mask_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    train_dataset = CataractDataset(csv_file, image_transform=image_transform, mask_transform=mask_transform, augment=True)
    val_dataset = CataractDataset(csv_file, image_transform=image_transform, mask_transform=mask_transform, augment=False)
    test_dataset = CataractDataset(csv_file, image_transform=image_transform, mask_transform=mask_transform, augment=False)

    train_indices, test_indices = train_test_split(range(len(train_dataset)), test_size=0.2, random_state=42)
    train_indices, val_indices = train_test_split(train_indices, test_size=0.2, random_state=42)

    train_data = Subset(train_dataset, train_indices)
    val_data = Subset(val_dataset, val_indices)
    test_data = Subset(test_dataset, test_indices)

    batch_size = 16
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=8)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=8)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = LLMSupervisedSAM(num_classes=1, input_channels=3, deep_supervision=False)
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    model = model.to(device)

    criterion = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)

    num_epochs = 100
    train_and_evaluate(model, train_loader, val_loader, test_loader, criterion, optimizer, num_epochs, device, output_folder)

    print("Training and evaluation completed. Results saved in the 'segmentation_results' folder.")

if __name__ == "__main__":
    csv_file = "../retinal_segmentation/segmentation/final_data_for_segmentation/final_dataset.csv"
    main(csv_file)

## Tried increasing the model complexity here

In [1]:
import os
import warnings
import sys
from contextlib import contextmanager
from tqdm import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import csv
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModel
import random
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Import the UNet++ model
from unetplus import NestedUNet

# Suppress all warnings
warnings.filterwarnings("ignore")

@contextmanager
def suppress_output():
    """Suppress all output except tqdm progress bars."""
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

class CataractDataset(Dataset):
    def __init__(self, csv_file, image_transform=None, mask_transform=None, augment=False):
        self.data = []
        self.questions = set()
        with open(csv_file, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.data.append(row)
                self.questions.add(row['Questions'])
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = Image.open(item['Image_Paths']).convert('RGB')
        mask = Image.open(item['Mask_Paths']).convert('L')
        question = item['Questions']
        label = item['Labels']

        if self.augment:
            image, mask = self.apply_augmentation(image, mask)

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask, question, label

    def apply_augmentation(self, image, mask):
        # Simpler augmentations
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        
        angle = random.uniform(-10, 10)
        image = TF.rotate(image, angle)
        mask = TF.rotate(mask, angle)
        
        return image, mask

class ImageEncoder(nn.Module):
    def __init__(self, num_classes=1, input_channels=3, deep_supervision=False):
        super().__init__()
        # Reduce model complexity by reducing depth or the number of filters
        self.unetpp = NestedUNet(num_classes=num_classes, input_channels=input_channels, deep_supervision=deep_supervision)
        self.deep_supervision = deep_supervision

    def forward(self, x):
        return self.unetpp(x)

class LLMPromptEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        self.model = AutoModel.from_pretrained("distilbert-base-uncased")
        self.fc = nn.Linear(768, 1024)

    def forward(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        outputs = self.model(**inputs)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        return self.fc(pooled_output)

class MaskDecoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.decoder = nn.Sequential(
            self.upconv_block(in_channels, 256),
            self.upconv_block(256, 128),
            self.upconv_block(128, 64),
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.decoder(x)

class LLMSupervisedSAM(nn.Module):
    def __init__(self, num_classes=1, input_channels=3, deep_supervision=False):
        super().__init__()
        self.image_encoder = ImageEncoder(num_classes=num_classes, input_channels=input_channels, deep_supervision=deep_supervision)
        self.prompt_encoder = LLMPromptEncoder()
        self.mask_decoder = MaskDecoder(in_channels=num_classes + 1024)
        self.deep_supervision = deep_supervision

    def forward(self, image, questions):
        image_features = self.image_encoder(image)
        
        batch_size = image.size(0)
        device = image.device
        
        sliced_questions = questions[:batch_size]
        prompt_features = self.prompt_encoder(sliced_questions)
        
        if self.deep_supervision:
            outputs = []
            for feature in image_features:
                resized_prompt = F.interpolate(prompt_features.view(batch_size, 1024, 1, 1),
                                               size=feature.shape[2:], 
                                               mode='bilinear', 
                                               align_corners=False)
                combined_features = torch.cat([feature, resized_prompt], dim=1)
                mask = self.mask_decoder(combined_features)
                outputs.append(mask)
            return outputs
        else:
            prompt_features = prompt_features.view(batch_size, 1024, 1, 1).expand(-1, -1, image_features.shape[2], image_features.shape[3])
            combined_features = torch.cat([image_features, prompt_features], dim=1)
            mask = self.mask_decoder(combined_features)
            return mask

def dice_coefficient(pred, target):
    smooth = 1.0
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)

def denormalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

def visualize_results(image, mask, prediction, question, output_path):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    image = denormalize(image.clone(), mean, std)
    
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(image.permute(1, 2, 0).clip(0, 1))
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(image.permute(1, 2, 0).clip(0, 1))
    plt.imshow(mask.squeeze(), alpha=0.5, cmap='jet')
    plt.title("Ground Truth")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(image.permute(1, 2, 0).clip(0, 1))
    plt.imshow(prediction.squeeze(), alpha=0.5, cmap='jet')
    plt.title("Predicted Mask")
    plt.axis('off')

    plt.suptitle(f"Question: {question}")
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()

    def forward(self, logits, targets):
        bce = self.bce_loss(logits, targets)
        dice = self.dice_loss(torch.sigmoid(logits), targets)
        return self.alpha * bce + self.beta * dice

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        logits = logits.view(-1)
        targets = targets.view(-1)
        intersection = (logits * targets).sum()
        return 1 - ((2. * intersection + self.smooth) / (logits.sum() + targets.sum() + self.smooth))

def train_and_evaluate(model, train_loader, val_loader, test_loader, criterion, optimizer, scheduler, num_epochs, device, output_folder):
    train_folder = os.path.join(output_folder, "train")
    val_folder = os.path.join(output_folder, "validation")
    test_folder = os.path.join(output_folder, "test")
    os.makedirs(train_folder, exist_ok=True)
    os.makedirs(val_folder, exist_ok=True)
    os.makedirs(test_folder, exist_ok=True)

    scaler = GradScaler()
    best_val_dice = 0
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_dice = 0
        for i, (images, masks, questions, _) in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", unit="batch")):
            images, masks = images.to(device), masks.to(device)
            batch_size = images.size(0)
            questions = questions[:batch_size]

            optimizer.zero_grad()
            
            with autocast():
                outputs = model(images, questions)
                if isinstance(outputs, list):
                    loss = sum([criterion(output, masks) for output in outputs]) / len(outputs)
                    dice = sum([dice_coefficient(torch.sigmoid(output), masks) for output in outputs]) / len(outputs)
                    output_for_vis = outputs[-1]
                else:
                    loss = criterion(outputs, masks)
                    dice = dice_coefficient(torch.sigmoid(outputs), masks)
                    output_for_vis = outputs

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            total_dice += dice.item()

            if i in [0, len(train_loader) // 2, len(train_loader) - 1]:
                visualize_results(
                    images[0].cpu(),
                    masks[0].cpu(),
                    torch.sigmoid(output_for_vis[0]).detach().cpu(),
                    questions[0], 
                    os.path.join(train_folder, f"epoch_{epoch+1}_batch_{i}.png")
                )

        avg_loss = total_loss / len(train_loader)
        avg_dice = total_dice / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Dice: {avg_dice:.4f}")

        model.eval()
        val_loss = 0
        val_dice = 0
        for i, (images, masks, questions, _) in enumerate(tqdm(val_loader, desc="Validation", unit="batch")):
            images, masks = images.to(device), masks.to(device)
            batch_size = images.size(0)
            questions = questions[:batch_size]
            
            with torch.no_grad():
                outputs = model(images, questions)

                if isinstance(outputs, list):
                    val_loss += sum([criterion(output, masks).item() for output in outputs]) / len(outputs)
                    val_dice += sum([dice_coefficient(torch.sigmoid(output), masks).item() for output in outputs]) / len(outputs)
                    output_for_vis = outputs[-1]
                else:
                    val_loss += criterion(outputs, masks).item()
                    val_dice += dice_coefficient(torch.sigmoid(outputs), masks).item()
                    output_for_vis = outputs

            if i in [0, len(val_loader) // 2, len(val_loader) - 1]:
                visualize_results(
                    images[0].cpu(),
                    masks[0].cpu(),
                    torch.sigmoid(output_for_vis[0]).detach().cpu(),
                    questions[0],
                    os.path.join(val_folder, f"epoch_{epoch+1}_batch_{i}.png")
                )

        avg_val_loss = val_loss / len(val_loader)
        avg_val_dice = val_dice / len(val_loader)
        print(f"Validation Loss: {avg_val_loss:.4f}, Validation Dice: {avg_val_dice:.4f}")

        scheduler.step()

        if avg_val_dice > best_val_dice:
            best_val_dice = avg_val_dice
            torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth"))

    # Load best model for testing
    model.load_state_dict(torch.load(os.path.join(output_folder, "best_model.pth")))

    # Test set evaluation
    model.eval()
    test_loss = 0
    test_dice = 0
    for i, (images, masks, questions, _) in enumerate(tqdm(test_loader, desc="Testing", unit="batch")):
        images, masks = images.to(device), masks.to(device)
        batch_size = images.size(0)
        questions = questions[:batch_size]
        outputs = model(images, questions)

        if isinstance(outputs, list):
            test_loss += sum([criterion(output, masks).item() for output in outputs]) / len(outputs)
            test_dice += sum([dice_coefficient(output, masks).item() for output in outputs]) / len(outputs)
            output_for_vis = outputs[-1]
        else:
            test_loss += criterion(outputs, masks).item()
            test_dice += dice_coefficient(outputs, masks).item()
            output_for_vis = outputs

        visualize_results(
            images[0].cpu(),
            masks[0].cpu(),
            output_for_vis[0].detach().cpu(),
            questions[0],
            os.path.join(test_folder, f"test_sample_{i}.png")
        )

    avg_test_loss = test_loss / len(test_loader)
    avg_test_dice = test_dice / len(test_loader)
    print(f"Test Loss: {avg_test_loss:.4f}, Test Dice: {avg_test_dice:.4f}")

def main(csv_file):
    output_folder = "segmentation_results"
    os.makedirs(output_folder, exist_ok=True)

    # Separate transforms for images and masks
    image_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    mask_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    train_dataset = CataractDataset(csv_file, image_transform=image_transform, mask_transform=mask_transform, augment=True)
    val_dataset = CataractDataset(csv_file, image_transform=image_transform, mask_transform=mask_transform, augment=False)
    test_dataset = CataractDataset(csv_file, image_transform=image_transform, mask_transform=mask_transform, augment=False)

    train_indices, test_indices = train_test_split(range(len(train_dataset)), test_size=0.2, random_state=42)
    train_indices, val_indices = train_test_split(train_indices, test_size=0.2, random_state=42)

    train_data = Subset(train_dataset, train_indices)
    val_data = Subset(val_dataset, val_indices)
    test_data = Subset(test_dataset, test_indices)

    batch_size = 16
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, prefetch_factor=2)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, prefetch_factor=2)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, prefetch_factor=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = LLMSupervisedSAM(num_classes=1, input_channels=3, deep_supervision=True)
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    model = model.to(device)

    criterion = CombinedLoss(alpha=0.7, beta=0.3)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

    num_epochs = 50  # Reduced number of epochs
    train_and_evaluate(model, train_loader, val_loader, test_loader, criterion, optimizer, scheduler, num_epochs, device, output_folder)

    print("Training and evaluation completed. Results saved in the 'segmentation_results' folder.")

if __name__ == "__main__":
    csv_file = "../retinal_segmentation/segmentation/final_data_for_segmentation/final_dataset.csv"
    main(csv_file)

Using 6 GPUs!


Training Epoch 1/50:  44%|████▍     | 133/304 [13:05<16:49,  5.91s/batch]


KeyboardInterrupt: 

## Added early stopping, loss fuctions, data agumentation and learning rate scheduler

In [2]:
import os
import warnings
import sys
from contextlib import contextmanager
from tqdm import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import csv
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModel
import random
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision.transforms import RandomAffine, ColorJitter

# Import the NestedUNet model
from unetplus import NestedUNet

# Suppress all warnings
warnings.filterwarnings("ignore")

@contextmanager
def suppress_output():
    """Suppress all output except tqdm progress bars."""
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

class CataractDataset(Dataset):
    def __init__(self, csv_file, image_transform=None, mask_transform=None, augment=False):
        self.data = []
        self.questions = set()
        with open(csv_file, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.data.append(row)
                self.questions.add(row['Questions'])
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = Image.open(item['Image_Paths']).convert('RGB')
        mask = Image.open(item['Mask_Paths']).convert('L')
        question = item['Questions']
        label = item['Labels']

        if self.augment:
            image, mask = self.apply_augmentation(image, mask)

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask, question, label

    def apply_augmentation(self, image, mask):
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        
        angle = random.uniform(-15, 15)
        image = TF.rotate(image, angle)
        mask = TF.rotate(mask, angle)
        
        affine_params = RandomAffine.get_params(degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=image.size)
        image = TF.affine(image, *affine_params)
        mask = TF.affine(mask, *affine_params)
        
        color_jitter = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
        image = color_jitter(image)
        
        return image, mask

class LLMSupervisedSAM(nn.Module):
    def __init__(self, num_classes=1, input_channels=3, deep_supervision=False):
        super().__init__()
        self.image_encoder = NestedUNet(num_classes=256, input_channels=input_channels, deep_supervision=deep_supervision)
        self.prompt_encoder = nn.Linear(768, 256)
        self.final_conv = nn.Conv2d(256, num_classes, kernel_size=1)
        self.deep_supervision = deep_supervision
        
        # Add batch normalization and dropout
        self.bn = nn.BatchNorm2d(256)
        self.dropout = nn.Dropout(0.3)  # Increased dropout rate

    def forward(self, image, prompt_embedding):
        image_features = self.image_encoder(image)
        prompt_features = self.prompt_encoder(prompt_embedding)
        
        if self.deep_supervision:
            outputs = []
            for feature in image_features:
                combined_features = feature + prompt_features.unsqueeze(2).unsqueeze(3).expand_as(feature)
                combined_features = self.bn(combined_features)
                combined_features = self.dropout(combined_features)
                output = self.final_conv(combined_features)
                outputs.append(output)
            return outputs
        else:
            combined_features = image_features + prompt_features.unsqueeze(2).unsqueeze(3).expand_as(image_features)
            combined_features = self.bn(combined_features)
            combined_features = self.dropout(combined_features)
            output = self.final_conv(combined_features)
            return output
        
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.8, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.focal_loss = FocalLoss()
        self.dice_loss = DiceLoss()

    def forward(self, logits, targets):
        focal = self.focal_loss(logits, targets)
        dice = self.dice_loss(torch.sigmoid(logits), targets)
        return self.alpha * focal + self.beta * dice

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        logits = logits.view(-1)
        targets = targets.view(-1)
        intersection = (logits * targets).sum()
        return 1 - ((2. * intersection + self.smooth) / (logits.sum() + targets.sum() + self.smooth))

def train_and_evaluate(model, train_loader, val_loader, test_loader, criterion, optimizer, scheduler, num_epochs, device, output_folder):
    train_folder = os.path.join(output_folder, "train")
    val_folder = os.path.join(output_folder, "validation")
    test_folder = os.path.join(output_folder, "test")
    os.makedirs(train_folder, exist_ok=True)
    os.makedirs(val_folder, exist_ok=True)
    os.makedirs(test_folder, exist_ok=True)

    scaler = GradScaler()
    best_val_dice = 0
    patience = 20  # Increased patience
    no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_dice = 0
        for i, (images, masks, questions, _) in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", unit="batch")):
            images, masks = images.to(device), masks.to(device)
            batch_size = images.size(0)
            
            prompt_embeddings = torch.randn(batch_size, 768).to(device)

            optimizer.zero_grad()
            
            with autocast():
                outputs = model(images, prompt_embeddings)
                if isinstance(outputs, list):
                    loss = sum([criterion(output, masks) for output in outputs]) / len(outputs)
                    dice = sum([dice_coefficient(torch.sigmoid(output), masks) for output in outputs]) / len(outputs)
                    output_for_vis = outputs[-1]
                else:
                    loss = criterion(outputs, masks)
                    dice = dice_coefficient(torch.sigmoid(outputs), masks)
                    output_for_vis = outputs

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            total_dice += dice.item()

            if i % 50 == 0:
                visualize_results(
                    images[0].cpu(),
                    masks[0].cpu(),
                    torch.sigmoid(output_for_vis[0]).detach().cpu(),
                    questions[0], 
                    os.path.join(train_folder, f"epoch_{epoch+1}_batch_{i}.png")
                )

        avg_loss = total_loss / len(train_loader)
        avg_dice = total_dice / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Dice: {avg_dice:.4f}")

        val_loss, val_dice = evaluate(model, val_loader, criterion, device, val_folder, epoch)
        print(f"Validation Loss: {val_loss:.4f}, Validation Dice: {val_dice:.4f}")

        scheduler.step(val_loss)

        if val_dice > best_val_dice:
            best_val_dice = val_dice
            torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth"))
            no_improve = 0
        else:
            no_improve += 1

        if no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

    # Load best model and evaluate on test set
    model.load_state_dict(torch.load(os.path.join(output_folder, "best_model.pth")))
    test_loss, test_dice = evaluate(model, test_loader, criterion, device, test_folder, "test")
    print(f"Test Loss: {test_loss:.4f}, Test Dice: {test_dice:.4f}")

# Make sure to define the evaluate function if it's not already defined
def evaluate(model, data_loader, criterion, device, output_folder, epoch):
    model.eval()
    total_loss = 0
    total_dice = 0
    with torch.no_grad():
        for i, (images, masks, questions, _) in enumerate(tqdm(data_loader, desc="Evaluation", unit="batch")):
            images, masks = images.to(device), masks.to(device)
            batch_size = images.size(0)
            
            prompt_embeddings = torch.randn(batch_size, 768).to(device)
            
            outputs = model(images, prompt_embeddings)

            if isinstance(outputs, list):
                loss = sum([criterion(output, masks) for output in outputs]) / len(outputs)
                dice = sum([dice_coefficient(torch.sigmoid(output), masks) for output in outputs]) / len(outputs)
                output_for_vis = outputs[-1]
            else:
                loss = criterion(outputs, masks)
                dice = dice_coefficient(torch.sigmoid(outputs), masks)
                output_for_vis = outputs

            total_loss += loss.item()
            total_dice += dice.item()

            if i % 10 == 0:
                visualize_results(
                    images[0].cpu(),
                    masks[0].cpu(),
                    torch.sigmoid(output_for_vis[0]).detach().cpu(),
                    questions[0],
                    os.path.join(output_folder, f"epoch_{epoch}_batch_{i}.png")
                )

    return total_loss / len(data_loader), total_dice / len(data_loader)

def dice_coefficient(pred, target):
    smooth = 1.0
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return (2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)

def visualize_results(image, mask, prediction, question, output_path):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    image = denormalize(image.clone(), mean, std)
    
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(image.permute(1, 2, 0).clip(0, 1))
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(image.permute(1, 2, 0).clip(0, 1))
    plt.imshow(mask.squeeze(), alpha=0.5, cmap='jet')
    plt.title("Ground Truth")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(image.permute(1, 2, 0).clip(0, 1))
    plt.imshow(prediction.squeeze(), alpha=0.5, cmap='jet')
    plt.title("Predicted Mask")
    plt.axis('off')

    plt.suptitle(f"Question: {question}")
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def denormalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

def main(csv_file):
    output_folder = "segmentation_results"
    os.makedirs(output_folder, exist_ok=True)

    image_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    mask_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    train_dataset = CataractDataset(csv_file, image_transform=image_transform, mask_transform=mask_transform, augment=True)
    val_dataset = CataractDataset(csv_file, image_transform=image_transform, mask_transform=mask_transform, augment=False)
    test_dataset = CataractDataset(csv_file, image_transform=image_transform, mask_transform=mask_transform, augment=False)

    train_indices, test_indices = train_test_split(range(len(train_dataset)), test_size=0.2, random_state=42)
    train_indices, val_indices = train_test_split(train_indices, test_size=0.2, random_state=42)

    train_data = Subset(train_dataset, train_indices)
    val_data = Subset(val_dataset, val_indices)
    test_data = Subset(test_dataset, test_indices)

    batch_size = 32
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 

    model = LLMSupervisedSAM(num_classes=1, input_channels=3, deep_supervision=True)
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    model = model.to(device)

    criterion = CombinedLoss(alpha=0.3, beta=0.7)  # Adjusted weights
    optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)  # Reduced learning rate
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)  # Changed scheduler

    num_epochs = 500  # Increased number of epochs
    train_and_evaluate(model, train_loader, val_loader, test_loader, criterion, optimizer, scheduler, num_epochs, device, output_folder)

if __name__ == "__main__":
    csv_file = "../retinal_segmentation/segmentation/final_data_for_segmentation/final_dataset.csv"
    main(csv_file)

Using 6 GPUs!


Training Epoch 1/500: 100%|██████████| 152/152 [00:39<00:00,  3.84batch/s]


Epoch [1/500], Loss: 0.5975, Dice: 0.2375


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  6.26batch/s]


Validation Loss: 0.5864, Validation Dice: 0.2809


Training Epoch 2/500: 100%|██████████| 152/152 [00:39<00:00,  3.85batch/s]


Epoch [2/500], Loss: 0.5655, Dice: 0.2899


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.45batch/s]


Validation Loss: 0.5634, Validation Dice: 0.3291


Training Epoch 3/500: 100%|██████████| 152/152 [00:39<00:00,  3.84batch/s]


Epoch [3/500], Loss: 0.5344, Dice: 0.3536


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  6.16batch/s]


Validation Loss: 0.4884, Validation Dice: 0.3684


Training Epoch 4/500: 100%|██████████| 152/152 [00:39<00:00,  3.81batch/s]


Epoch [4/500], Loss: 0.4967, Dice: 0.4234


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  6.10batch/s]


Validation Loss: 0.5076, Validation Dice: 0.4400


Training Epoch 5/500: 100%|██████████| 152/152 [00:39<00:00,  3.82batch/s]


Epoch [5/500], Loss: 0.4620, Dice: 0.4671


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  6.18batch/s]


Validation Loss: 0.4464, Validation Dice: 0.4655


Training Epoch 6/500: 100%|██████████| 152/152 [00:39<00:00,  3.85batch/s]


Epoch [6/500], Loss: 0.4251, Dice: 0.5038


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.63batch/s]


Validation Loss: 0.4694, Validation Dice: 0.5002


Training Epoch 7/500: 100%|██████████| 152/152 [00:40<00:00,  3.79batch/s]


Epoch [7/500], Loss: 0.3952, Dice: 0.5307


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  6.31batch/s]


Validation Loss: 0.3933, Validation Dice: 0.5203


Training Epoch 8/500: 100%|██████████| 152/152 [00:40<00:00,  3.80batch/s]


Epoch [8/500], Loss: 0.3691, Dice: 0.5558


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.54batch/s]


Validation Loss: 0.3693, Validation Dice: 0.5350


Training Epoch 9/500: 100%|██████████| 152/152 [00:40<00:00,  3.76batch/s]


Epoch [9/500], Loss: 0.3511, Dice: 0.5738


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.66batch/s]


Validation Loss: 0.3676, Validation Dice: 0.5728


Training Epoch 10/500: 100%|██████████| 152/152 [00:39<00:00,  3.81batch/s]


Epoch [10/500], Loss: 0.3411, Dice: 0.5860


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.54batch/s]


Validation Loss: 0.3619, Validation Dice: 0.5722


Training Epoch 11/500: 100%|██████████| 152/152 [00:39<00:00,  3.82batch/s]


Epoch [11/500], Loss: 0.3337, Dice: 0.5936


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.48batch/s]


Validation Loss: 0.3457, Validation Dice: 0.5928


Training Epoch 12/500: 100%|██████████| 152/152 [00:39<00:00,  3.81batch/s]


Epoch [12/500], Loss: 0.3268, Dice: 0.6009


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  6.17batch/s]


Validation Loss: 0.3331, Validation Dice: 0.5839


Training Epoch 13/500: 100%|██████████| 152/152 [00:39<00:00,  3.81batch/s]


Epoch [13/500], Loss: 0.3228, Dice: 0.6055


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.53batch/s]


Validation Loss: 0.3341, Validation Dice: 0.5891


Training Epoch 14/500: 100%|██████████| 152/152 [00:41<00:00,  3.69batch/s]


Epoch [14/500], Loss: 0.3196, Dice: 0.6089


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.74batch/s]


Validation Loss: 0.3270, Validation Dice: 0.6070


Training Epoch 15/500: 100%|██████████| 152/152 [00:40<00:00,  3.78batch/s]


Epoch [15/500], Loss: 0.3155, Dice: 0.6126


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.55batch/s]


Validation Loss: 0.3213, Validation Dice: 0.5950


Training Epoch 16/500: 100%|██████████| 152/152 [00:40<00:00,  3.77batch/s]


Epoch [16/500], Loss: 0.3149, Dice: 0.6133


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  5.68batch/s]


Validation Loss: 0.3235, Validation Dice: 0.6006


Training Epoch 17/500: 100%|██████████| 152/152 [00:40<00:00,  3.76batch/s]


Epoch [17/500], Loss: 0.3128, Dice: 0.6153


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.49batch/s]


Validation Loss: 0.3248, Validation Dice: 0.6135


Training Epoch 18/500: 100%|██████████| 152/152 [00:40<00:00,  3.79batch/s]


Epoch [18/500], Loss: 0.3105, Dice: 0.6175


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.47batch/s]


Validation Loss: 0.3213, Validation Dice: 0.6006


Training Epoch 19/500: 100%|██████████| 152/152 [00:40<00:00,  3.78batch/s]


Epoch [19/500], Loss: 0.3095, Dice: 0.6185


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  5.64batch/s]


Validation Loss: 0.3215, Validation Dice: 0.5967


Training Epoch 20/500: 100%|██████████| 152/152 [00:40<00:00,  3.78batch/s]


Epoch [20/500], Loss: 0.3075, Dice: 0.6216


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.36batch/s]


Validation Loss: 0.3194, Validation Dice: 0.5935


Training Epoch 21/500: 100%|██████████| 152/152 [00:40<00:00,  3.78batch/s]


Epoch [21/500], Loss: 0.3055, Dice: 0.6221


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  6.27batch/s]


Validation Loss: 0.3307, Validation Dice: 0.5834


Training Epoch 22/500: 100%|██████████| 152/152 [00:40<00:00,  3.78batch/s]


Epoch [22/500], Loss: 0.3069, Dice: 0.6213


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.58batch/s]


Validation Loss: 0.3299, Validation Dice: 0.6006


Training Epoch 23/500: 100%|██████████| 152/152 [00:39<00:00,  3.81batch/s]


Epoch [23/500], Loss: 0.3037, Dice: 0.6244


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.38batch/s]


Validation Loss: 0.3248, Validation Dice: 0.6108


Training Epoch 24/500: 100%|██████████| 152/152 [00:40<00:00,  3.77batch/s]


Epoch [24/500], Loss: 0.3046, Dice: 0.6237


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.73batch/s]


Validation Loss: 0.3277, Validation Dice: 0.6023


Training Epoch 25/500: 100%|██████████| 152/152 [00:39<00:00,  3.82batch/s]


Epoch [25/500], Loss: 0.3041, Dice: 0.6237


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.75batch/s]


Validation Loss: 0.3272, Validation Dice: 0.5929


Training Epoch 26/500: 100%|██████████| 152/152 [00:39<00:00,  3.81batch/s]


Epoch [26/500], Loss: 0.3019, Dice: 0.6258


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.70batch/s]


Validation Loss: 0.3338, Validation Dice: 0.5776


Training Epoch 27/500: 100%|██████████| 152/152 [00:39<00:00,  3.81batch/s]


Epoch [27/500], Loss: 0.2997, Dice: 0.6275


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.57batch/s]


Validation Loss: 0.3240, Validation Dice: 0.5989


Training Epoch 28/500: 100%|██████████| 152/152 [00:40<00:00,  3.77batch/s]


Epoch [28/500], Loss: 0.2985, Dice: 0.6275


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.68batch/s]


Validation Loss: 0.3174, Validation Dice: 0.6049


Training Epoch 29/500: 100%|██████████| 152/152 [00:39<00:00,  3.81batch/s]


Epoch [29/500], Loss: 0.2979, Dice: 0.6288


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.40batch/s]


Validation Loss: 0.3212, Validation Dice: 0.6041


Training Epoch 30/500: 100%|██████████| 152/152 [00:40<00:00,  3.74batch/s]


Epoch [30/500], Loss: 0.2974, Dice: 0.6280


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.62batch/s]


Validation Loss: 0.3189, Validation Dice: 0.6000


Training Epoch 31/500: 100%|██████████| 152/152 [00:40<00:00,  3.80batch/s]


Epoch [31/500], Loss: 0.2960, Dice: 0.6303


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  5.46batch/s]


Validation Loss: 0.3203, Validation Dice: 0.6013


Training Epoch 32/500: 100%|██████████| 152/152 [00:40<00:00,  3.79batch/s]


Epoch [32/500], Loss: 0.2947, Dice: 0.6312


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.36batch/s]


Validation Loss: 0.3194, Validation Dice: 0.6034


Training Epoch 33/500: 100%|██████████| 152/152 [00:40<00:00,  3.80batch/s]


Epoch [33/500], Loss: 0.2958, Dice: 0.6305


Evaluation: 100%|██████████| 38/38 [00:06<00:00,  6.26batch/s]


Validation Loss: 0.3201, Validation Dice: 0.6026


Training Epoch 34/500: 100%|██████████| 152/152 [00:40<00:00,  3.75batch/s]


Epoch [34/500], Loss: 0.2966, Dice: 0.6289


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.61batch/s]


Validation Loss: 0.3282, Validation Dice: 0.5858


Training Epoch 35/500: 100%|██████████| 152/152 [00:40<00:00,  3.77batch/s]


Epoch [35/500], Loss: 0.2950, Dice: 0.6308


Evaluation: 100%|██████████| 38/38 [00:07<00:00,  5.16batch/s]


Validation Loss: 0.3231, Validation Dice: 0.5986


Training Epoch 36/500: 100%|██████████| 152/152 [00:39<00:00,  3.80batch/s]


Epoch [36/500], Loss: 0.2932, Dice: 0.6323


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.34batch/s]


Validation Loss: 0.3246, Validation Dice: 0.5926


Training Epoch 37/500: 100%|██████████| 152/152 [00:40<00:00,  3.77batch/s]


Epoch [37/500], Loss: 0.2939, Dice: 0.6314


Evaluation: 100%|██████████| 38/38 [00:05<00:00,  6.37batch/s]


Validation Loss: 0.3185, Validation Dice: 0.6029
Early stopping triggered after 37 epochs


Evaluation: 100%|██████████| 48/48 [00:07<00:00,  6.82batch/s]

Test Loss: 0.3275, Test Dice: 0.6103





## channels modified 

Learning Rate Scheduling:
Implement a learning rate scheduler that reduces the learning rate when performance plateaus.
Data Augmentation:
Implement more aggressive data augmentation techniques to increase the diversity of your training data.
Initially, image was 3 channel and mask was 1 channel(since it was not rgb form)- modify it to 3 channels 
try using layer normalization instead of batch normalization. 

In [None]:
import os
import warnings
from tqdm import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import csv
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Import the NestedUNet model
from unetplus import NestedUNet

class CataractDataset(Dataset):
    def __init__(self, csv_file, transform=None, augment=False):
        self.data = []
        self.questions = set()
        with open(csv_file, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.data.append(row)
                self.questions.add(row['Questions'])
        self.transform = transform
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Check if the image and mask paths exist
        if not os.path.exists(item['Image_Paths']) or not os.path.exists(item['Mask_Paths']):
            raise FileNotFoundError(f"Image or mask not found: {item['Image_Paths']} or {item['Mask_Paths']}")

        try:
            image = Image.open(item['Image_Paths']).convert('RGB')
            mask = Image.open(item['Mask_Paths']).convert('RGB')
        except Exception as e:
            raise IOError(f"Error opening image or mask: {e}")

        question = item['Questions']
        label = item['Labels']

        if self.transform:
            transformed = self.transform(image=np.array(image), mask=np.array(mask))
            image = transformed['image']
            mask = transformed['mask']

        # Convert mask to float and normalize to [0, 1]
        mask = mask.float() / 255.0

        # Ensure mask is in the correct format (B, C, H, W)
        mask = mask.permute(2, 0, 1)  # Change from (H, W, C) to (C, H, W)

        return image, mask, question, label

def get_transforms(train=True):
    if train:
        return A.Compose([
            A.RandomRotate90(p=0.5),
            A.Flip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=0.5),
            A.Resize(256, 256),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(256, 256),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

class LLMSupervisedSAM(nn.Module):
    def __init__(self, num_classes=3, input_channels=3, deep_supervision=True):
        super().__init__()
        self.image_encoder = NestedUNet(num_classes=256, input_channels=input_channels, deep_supervision=deep_supervision)
        self.prompt_encoder = nn.Linear(768, 256)
        self.final_convs = nn.ModuleList([nn.Conv2d(256, num_classes, kernel_size=1) for _ in range(5)])
        self.deep_supervision = deep_supervision
        
        self.ln = nn.LayerNorm([256, 256, 256])
        self.dropout = nn.Dropout(0.3)

    def forward(self, image, prompt_embedding):
        image_features = self.image_encoder(image)
        prompt_features = self.prompt_encoder(prompt_embedding)
        
        if self.deep_supervision:
            outputs = []
            for i, feature in enumerate(image_features):
                combined_features = feature + prompt_features.unsqueeze(2).unsqueeze(3).expand_as(feature)
                combined_features = self.ln(combined_features)
                combined_features = self.dropout(combined_features)
                output = self.final_convs[i](combined_features)
                output = torch.sigmoid(output)  # Apply sigmoid to ensure output is in [0, 1] range
                outputs.append(output)
            return outputs
        else:
            combined_features = image_features[-1] + prompt_features.unsqueeze(2).unsqueeze(3).expand_as(image_features[-1])
            combined_features = self.ln(combined_features)
            combined_features = self.dropout(combined_features)
            output = self.final_convs[-1](combined_features)
            output = torch.sigmoid(output)  # Apply sigmoid to ensure output is in [0, 1] range
            return output

    def get_attention_map(self, image, prompt_embedding):
        image_features = self.image_encoder(image)
        prompt_features = self.prompt_encoder(prompt_embedding)
        
        if self.deep_supervision:
            attention_maps = []
            for i, feature in enumerate(image_features):
                combined_features = feature + prompt_features.unsqueeze(2).unsqueeze(3).expand_as(feature)
                attention_map = F.softmax(combined_features.sum(dim=1), dim=-1)
                attention_maps.append(attention_map)
            return attention_maps
        else:
            combined_features = image_features[-1] + prompt_features.unsqueeze(2).unsqueeze(3).expand_as(image_features[-1])
            attention_map = F.softmax(combined_features.sum(dim=1), dim=-1)
            return attention_map

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.8, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.focal_loss = FocalLoss()
        self.dice_loss = DiceLoss()

    def forward(self, logits, targets):
        focal = self.focal_loss(logits, targets)
        dice = self.dice_loss(torch.sigmoid(logits), targets)
        return self.alpha * focal + self.beta * dice

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
        return 1 - dice

def dice_coefficient(pred, target):
    smooth = 1.0
    pred = torch.sigmoid(pred)
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

def train_and_evaluate(model, train_loader, val_loader, test_loader, criterion, optimizer, scheduler, num_epochs, device, output_folder):
    scaler = GradScaler()
    best_val_dice = 0
    patience = 20
    no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_dice = 0
        for i, (images, masks, questions, _) in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", unit="batch")):
            images, masks = images.to(device), masks.to(device)
            batch_size = images.size(0)
            
            prompt_embeddings = torch.randn(batch_size, 768).to(device)

            optimizer.zero_grad()
            
            with autocast():
                outputs = model(images, prompt_embeddings)
                
                if isinstance(outputs, list):
                    loss = sum([criterion(output, masks) for output in outputs]) / len(outputs)
                    dice = sum([dice_coefficient(output, masks) for output in outputs]) / len(outputs)
                else:
                    loss = criterion(outputs, masks)
                    dice = dice_coefficient(outputs, masks)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            total_dice += dice.item()

        avg_loss = total_loss / len(train_loader)
        avg_dice = total_dice / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Dice: {avg_dice:.4f}")

        # Validation step
        val_loss, val_dice = evaluate(model, val_loader, criterion, device)
        print(f"Validation Loss: {val_loss:.4f}, Validation Dice: {val_dice:.4f}")

        scheduler.step(val_loss)

        if val_dice > best_val_dice:
            best_val_dice = val_dice
            torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth"))
            no_improve = 0
        else:
            no_improve += 1

        if no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

    # Load the best model for final evaluation
    model.load_state_dict(torch.load(os.path.join(output_folder, "best_model.pth")))
    test_loss, test_dice = evaluate(model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f}, Test Dice: {test_dice:.4f}")

def evaluate(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    total_dice = 0
    with torch.no_grad():
        for i, (images, masks, _, _) in enumerate(tqdm(data_loader, desc="Evaluation", unit="batch")):
            images, masks = images.to(device), masks.to(device)
            batch_size = images.size(0)
            
            prompt_embeddings = torch.randn(batch_size, 768).to(device)
            
            outputs = model(images, prompt_embeddings)

            if isinstance(outputs, list):
                loss = sum([criterion(output, masks) for output in outputs]) / len(outputs)
                dice = sum([dice_coefficient(torch.sigmoid(output), masks) for output in outputs]) / len(outputs)
            else:
                loss = criterion(outputs, masks)
                dice = dice_coefficient(torch.sigmoid(outputs), masks)

            total_loss += loss.item()
            total_dice += dice.item()

    return total_loss / len(data_loader), total_dice / len(data_loader)

def main(csv_file):
    output_folder = "segmentation_results"
    os.makedirs(output_folder, exist_ok=True)

    train_transform = get_transforms(train=True)
    val_transform = get_transforms(train=False)

    train_dataset = CataractDataset(csv_file, transform=train_transform, augment=True)
    val_dataset = CataractDataset(csv_file, transform=val_transform, augment=False)
    test_dataset = CataractDataset(csv_file, transform=val_transform, augment=False)

    train_indices, test_indices = train_test_split(range(len(train_dataset)), test_size=0.2, random_state=42)
    train_indices, val_indices = train_test_split(train_indices, test_size=0.2, random_state=42)

    train_data = Subset(train_dataset, train_indices)
    val_data = Subset(val_dataset, val_indices)
    test_data = Subset(test_dataset, test_indices)

    batch_size = 32
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = LLMSupervisedSAM(num_classes=3, input_channels=3, deep_supervision=True)
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    model = model.to(device)

    criterion = CombinedLoss(alpha=0.5, beta=0.5)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

    num_epochs = 500
    train_and_evaluate(model, train_loader, val_loader, test_loader, criterion, optimizer, scheduler, num_epochs, device, output_folder)

if __name__ == "__main__":
    csv_file = "../retinal_segmentation/segmentation/final_data_for_segmentation/final_dataset.csv"
    main(csv_file)
