In [None]:
!pip install segmentation-models-pytorch==0.3.3 torch torchvision albumentations==1.4.7


In [None]:
"""
POLYP SEGMENTATION - Attention U-Net ResNet34 (Improved + Best IoU Save)
-----------------------------------------------------------------------
Cải tiến: Attention U-Net + Dice Loss + Albumentations + Cosine LR + Save Best IoU.
Dữ liệu: /MyDrive/polyp_data/
"""

# =========================================================
# 0️⃣ Mount Drive
# =========================================================
from google.colab import drive
drive.mount('/content/drive')

DRIVE_PATH = '/content/drive/MyDrive/polyp_data'
SAVE_PATH = '/content/drive/MyDrive/polyp_results'

import os, zipfile, numpy as np
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models
import albumentations as A
from albumentations.pytorch import ToTensorV2

os.makedirs(SAVE_PATH, exist_ok=True)

# =========================================================
# 1️⃣ Dataset Loader
# =========================================================
class PolypDataset(Dataset):
    def __init__(self, image_dir, mask_dir=None, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(".jpg")])
        self.has_mask = mask_dir is not None

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = np.array(Image.open(img_path).convert("RGB"))

        if self.has_mask:
            mask_path = os.path.join(self.mask_dir, img_name)
            mask = np.array(Image.open(mask_path).convert("L"))
            mask = (mask > 127).astype(np.float32)
        else:
            mask = None

        if self.transform:
            if mask is not None:
                augmented = self.transform(image=image, mask=mask)
                image, mask = augmented["image"], augmented["mask"].unsqueeze(0)
            else:
                augmented = self.transform(image=image)
                image = augmented["image"]

        return (image, mask) if self.has_mask else (image, img_name)

# =========================================================
# 2️⃣ Data Augmentation
# =========================================================
train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.5),
    A.Normalize(),
    ToTensorV2(),
])

test_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(),
    ToTensorV2(),
])

# =========================================================
# 3️⃣ Attention U-Net ResNet34
# =========================================================
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, 1, 1, 0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, 1, 1, 0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, 1, 1, 0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return psi


class AttentionUNetResNet34(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        backbone = models.resnet34(weights=models.ResNet34_Weights.DEFAULT if pretrained else None)
        self.encoder0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)
        self.encoder1 = backbone.layer1
        self.encoder2 = backbone.layer2
        self.encoder3 = backbone.layer3
        self.encoder4 = backbone.layer4

        self.center = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.upconv4 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.att4 = AttentionBlock(256, 512, 128)
        self.dec4 = self._decoder_block(256 + 512, 256)

        self.upconv3 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.att3 = AttentionBlock(128, 256, 64)
        self.dec3 = self._decoder_block(128 + 256, 128)

        self.upconv2 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.att2 = AttentionBlock(64, 128, 32)
        self.dec2 = self._decoder_block(64 + 128, 64)

        self.upconv1 = nn.ConvTranspose2d(64, 64, 2, 2)
        self.att1 = AttentionBlock(64, 64, 16)
        self.dec1 = self._decoder_block(64 + 64, 64)

        self.final = nn.Conv2d(64, 1, 1)

    def _decoder_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        e0 = self.encoder0(x)
        e1 = self.encoder1(e0)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        c = self.center(e4)

        u4 = self.upconv4(c)
        u4_resized = nn.functional.interpolate(u4, size=e4.shape[2:], mode='bilinear', align_corners=True)
        att4_map = self.att4(u4_resized, e4)
        e4_att = e4 * att4_map
        e4_att_resized = nn.functional.interpolate(e4_att, size=u4.shape[2:], mode='bilinear', align_corners=True)
        d4 = self.dec4(torch.cat([u4, e4_att_resized], 1))

        u3 = self.upconv3(d4)
        u3_resized = nn.functional.interpolate(u3, size=e3.shape[2:], mode='bilinear', align_corners=True)
        att3_map = self.att3(u3_resized, e3)
        e3_att = e3 * att3_map
        e3_att_resized = nn.functional.interpolate(e3_att, size=u3.shape[2:], mode='bilinear', align_corners=True)
        d3 = self.dec3(torch.cat([u3, e3_att_resized], 1))

        u2 = self.upconv2(d3)
        u2_resized = nn.functional.interpolate(u2, size=e2.shape[2:], mode='bilinear', align_corners=True)
        att2_map = self.att2(u2_resized, e2)
        e2_att = e2 * att2_map
        e2_att_resized = nn.functional.interpolate(e2_att, size=u2.shape[2:], mode='bilinear', align_corners=True)
        d2 = self.dec2(torch.cat([u2, e2_att_resized], 1))

        u1 = self.upconv1(d2)
        u1_resized = nn.functional.interpolate(u1, size=e1.shape[2:], mode='bilinear', align_corners=True)
        att1_map = self.att1(u1_resized, e1)
        e1_att = e1 * att1_map
        e1_att_resized = nn.functional.interpolate(e1_att, size=u1.shape[2:], mode='bilinear', align_corners=True)
        d1 = self.dec1(torch.cat([u1, e1_att_resized], 1))

        return self.final(d1)

# =========================================================
# 4️⃣ Loss & Metrics
# =========================================================
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    def forward(self, logits, targets):
        preds = torch.sigmoid(logits)
        intersection = (preds * targets).sum()
        dice = (2 * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
        return 1 - dice

def compute_iou_dice(preds, masks):
    preds = (torch.sigmoid(preds) > 0.5).float()
    intersection = (preds * masks).sum()
    union = preds.sum() + masks.sum() - intersection
    iou = (intersection + 1e-6) / (union + 1e-6)
    dice = (2 * intersection + 1e-6) / (preds.sum() + masks.sum() + 1e-6)
    return iou.item(), dice.item()

# =========================================================
# 5️⃣ Training (Save best IoU)
# =========================================================
def train_model(model, train_loader, val_loader, device, epochs=15, lr=1e-4):
    bce = nn.BCEWithLogitsLoss()
    dice = DiceLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    model.to(device)

    best_iou = 0.0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            loss = 0.5*bce(preds, masks) + 0.5*dice(preds, masks)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()

        # 🔍 Validation
        model.eval()
        val_iou, val_dice = 0, 0
        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
                imgs, masks = imgs.to(device), masks.to(device)
                preds = model(imgs)
                iou, d = compute_iou_dice(preds, masks)
                val_iou += iou
                val_dice += d
        val_iou /= len(val_loader)
        val_dice /= len(val_loader)

        print(f"✅ Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, IoU={val_iou:.4f}, Dice={val_dice:.4f}")

        if val_iou > best_iou:
            best_iou = val_iou
            torch.save(model.state_dict(), os.path.join(SAVE_PATH, "best_model_iou.pth"))
            print(f"💾 Saved best IoU model (IoU={best_iou:.4f})")

# =========================================================
# 6️⃣ Inference (TTA + ZIP)
# =========================================================
def predict_masks(model, loader, device, zip_path):
    model.eval()
    pred_dir = os.path.join(SAVE_PATH, "pred_masks")
    os.makedirs(pred_dir, exist_ok=True)

    with torch.no_grad():
        for imgs, names in tqdm(loader, desc="Predicting"):
            imgs = imgs.to(device)
            preds1 = torch.sigmoid(model(imgs))
            preds2 = torch.sigmoid(model(torch.flip(imgs, dims=[3])))  # Horizontal flip
            preds = (preds1 + torch.flip(preds2, dims=[3])) / 2
            preds = (preds > 0.5).float().cpu().numpy()

            for i, name in enumerate(names):
                mask = (preds[i, 0] * 255).astype(np.uint8)
                Image.fromarray(mask).save(os.path.join(pred_dir, name.replace(".jpg", ".png")))

    with zipfile.ZipFile(zip_path, "w") as zipf:
        for f in os.listdir(pred_dir):
            zipf.write(os.path.join(pred_dir, f), arcname=f)
    print(f"🎯 Saved masks to: {zip_path}")

# =========================================================
# 7️⃣ Main
# =========================================================
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = PolypDataset(
        os.path.join(DRIVE_PATH, "train/images"),
        os.path.join(DRIVE_PATH, "train/masks"),
        transform=train_transform
    )
    val_size = int(0.2 * len(dataset))
    train_size = len(dataset) - val_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2)

    model = AttentionUNetResNet34(pretrained=True)
    train_model(model, train_loader, val_loader, device, epochs=300, lr=1e-4)

    # Dự đoán
    test_ds = PolypDataset(
        os.path.join(DRIVE_PATH, "images-public/images"),
        mask_dir=None,
        transform=test_transform
    )
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)
    model.load_state_dict(torch.load(os.path.join(SAVE_PATH, "best_model_iou.pth"), map_location=device))

    output_zip = os.path.join(SAVE_PATH, "pred_mask.zip")
    predict_masks(model, test_loader, device, output_zip)
