In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
!pip install segmentation-models-pytorch --no-cache-dir
!pip install safetensors albumentations


In [None]:
!mkdir -p /content/dataset
!cp -r "/content/drive/MyDrive/cc/Uncropped" /content/dataset/
!cp -r "/content/drive/MyDrive/cc/Cropped" /content/dataset/


In [None]:
#Linknet
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import json
import cv2
import segmentation_models_pytorch as smp
from safetensors.torch import load_file
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from sklearn.metrics import jaccard_score

# Load Model Config
config_path = "/content/drive/MyDrive/path_to_model/ln/config.json"
with open(config_path, "r") as f:
    model_config = json.load(f)

# Initialize LinkNet Model
model = smp.Linknet(
    encoder_name="tu-resnet18",  # Using Torchvision ResNet18 (Hugging Face)
    encoder_depth=model_config["encoder_depth"],
    encoder_weights=model_config["encoder_weights"],  # Pretrained weights
    decoder_use_batchnorm=model_config["decoder_use_batchnorm"],
    in_channels=model_config["in_channels"],
    classes=model_config["classes"],
    activation=None
).cuda()

# Apply TorchDynamo (`torch.compile()`)
model = torch.compile(model)  # ⚡ Graph optimization for speedup

# Load Pretrained Weights for LinkNet
weights_path = "/content/drive/MyDrive/path_to_model/ln/model.safetensors"
state_dict = load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
model.train()


# Dataset Class
class ConjunctivaDataset(Dataset):
    def __init__(self, image_filenames, mask_filenames, image_dir, mask_dir, transform=None):
        self.image_filenames = image_filenames
        self.mask_filenames = mask_filenames
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image = cv2.imread(os.path.join(self.image_dir, self.image_filenames[idx]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(os.path.join(self.mask_dir, self.mask_filenames[idx]), cv2.IMREAD_GRAYSCALE)
        mask = (mask / 255.0).astype(np.float32)  # Normalize mask

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image, mask = augmented['image'], augmented['mask']

        return image, mask.unsqueeze(0)  # Add channel dim

# Data Augmentation
transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Affine(scale=(0.95, 1.05), translate_percent=(0.05, 0.05), rotate=(-15, 15), p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
], is_check_shapes=False)

# Data Loading
image_dir = "/content/dataset/Uncropped"
mask_dir = "/content/dataset/Cropped"

image_filenames = sorted(os.listdir(image_dir))
mask_filenames = sorted(os.listdir(mask_dir))

train_images, val_images, train_masks, val_masks = train_test_split(
    image_filenames, mask_filenames, test_size=0.2, random_state=42
)

# 🔹 Set num_workers dynamically
num_workers = min(4, os.cpu_count() // 2)

train_dataset = ConjunctivaDataset(train_images, train_masks, image_dir, mask_dir, transform)
val_dataset = ConjunctivaDataset(val_images, val_masks, image_dir, mask_dir, transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=num_workers, pin_memory=True)

# Loss Function
class DiceLoss(nn.Module):
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        smooth = 1e-6
        intersection = (pred * target).sum()
        return 1 - (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

criterion = lambda pred, target: 0.5 * nn.BCEWithLogitsLoss()(pred, target) + 0.5 * DiceLoss()(pred, target)

# 🔹 Optimized Scheduler: ReduceLROnPlateau
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

# Evaluation Metrics
def dice_score(pred, target, threshold=0.5):
    pred = (pred > threshold).astype(np.uint8)
    target = (target > 0.5).astype(np.uint8)
    intersection = np.sum(pred * target)
    return (2. * intersection) / (np.sum(pred) + np.sum(target) + 1e-8)

def iou_score(pred, target, threshold=0.5):
    pred = (pred > threshold).astype(np.uint8)
    target = (target > 0.5).astype(np.uint8)
    return jaccard_score(target.flatten(), pred.flatten(), average="binary")

def accuracy(pred, target, threshold=0.5):
    return np.mean((pred > threshold) == (target > 0.5))

# 🔹 TTA (Test-Time Augmentation)
def apply_tta(model, images):
    images = images.cuda()
    original_preds = torch.sigmoid(model(images)).cpu().numpy()

    flipped = torch.flip(images, [3])  # Horizontal Flip
    flipped_preds = torch.flip(torch.sigmoid(model(flipped)), [3]).cpu().numpy()

    return (original_preds + flipped_preds) / 2  # Average predictions

# Training Loop
num_epochs = 20
patience = 5
best_dice = 0
counter = 0
scaler = torch.amp.GradScaler('cuda')  # 🔹 Use AMP for mixed precision training

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for images, masks in train_loader:
        images, masks = images.cuda(non_blocking=True), masks.cuda(non_blocking=True)

        optimizer.zero_grad()

        with torch.autocast(device_type='cuda'):  # 🔹 AMP for faster training
            outputs = model(images)
            loss = criterion(outputs, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)

    # Validation
    model.eval()
    total_dice, total_iou, total_acc = 0, 0, 0

    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.cuda(non_blocking=True), masks.cuda(non_blocking=True)

            outputs = apply_tta(model, images)
            masks = masks.cpu().numpy()

            total_dice += sum(dice_score(outputs[i], masks[i]) for i in range(len(outputs)))
            total_iou += sum(iou_score(outputs[i], masks[i]) for i in range(len(outputs)))
            total_acc += sum(accuracy(outputs[i], masks[i]) for i in range(len(outputs)))

    num_samples = len(val_dataset)
    avg_dice, avg_iou, avg_acc = total_dice / num_samples, total_iou / num_samples, total_acc / num_samples

    scheduler.step(avg_dice)

    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f} | Dice: {avg_dice:.4f} | IoU: {avg_iou:.4f} | Acc: {avg_acc:.4f}")

    if avg_dice > best_dice:
        best_dice = avg_dice
        torch.save(model.state_dict(), "best_model.pth")
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping at epoch {epoch+1}.")
            break
