In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
import glob
from model import MultiTaskUNet  # Import model m·ªõi

# --- 1. DATASET TH√îNG MINH (Oversampling) ---
class SmartUnifiedDataset(Dataset):
    def __init__(self, chiu_dir, srinivasan_dir, img_size=256, oversample_factor=5):
        self.img_size = img_size
        self.samples = []
        self.class_map = {'AMD': 0, 'DME': 1, 'NORMAL': 2}
        
        # Load CHIU (Nh√¢n b·∫£n 5 l·∫ßn ƒë·ªÉ c√¢n b·∫±ng)
        if os.path.exists(chiu_dir):
            chiu_img_dir = os.path.join(chiu_dir, 'images')
            chiu_mask_dir = os.path.join(chiu_dir, 'masks')
            chiu_files = sorted(os.listdir(chiu_img_dir))
            for _ in range(oversample_factor):
                for f in chiu_files:
                    self.samples.append({
                        'img_path': os.path.join(chiu_img_dir, f),
                        'mask_path': os.path.join(chiu_mask_dir, f),
                        'label': 1, 'source': 'chiu'
                    })
            print(f"--> Chiu (x{oversample_factor}): {len(chiu_files)*oversample_factor} ·∫£nh.")
            
        # Load DUKE
        if os.path.exists(srinivasan_dir):
            count_duke = 0
            for cls_name, label_idx in self.class_map.items():
                cls_folder = os.path.join(srinivasan_dir, cls_name)
                if not os.path.exists(cls_folder): continue
                files = glob.glob(os.path.join(cls_folder, "*.png"))
                for f_path in files:
                    self.samples.append({
                        'img_path': f_path,
                        'mask_path': None,
                        'label': label_idx, 'source': 'duke'
                    })
                    count_duke += 1
            print(f"--> Duke: {count_duke} ·∫£nh.")

    def __len__(self): return len(self.samples)
    
    def __getitem__(self, idx):
        item = self.samples[idx]
        img = cv2.imread(item['img_path'], cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (self.img_size, self.img_size)) / 255.0
        # EfficientNet c·∫ßn 3 k√™nh m√†u, ta l·∫∑p l·∫°i ·∫£nh x√°m 3 l·∫ßn
        img = np.stack([img, img, img], axis=0) 
        
        if item['mask_path'] is not None:
            mask = cv2.imread(item['mask_path'], cv2.IMREAD_GRAYSCALE)
            mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
            mask = (mask > 0).astype(np.float32)
            has_mask = 1.0
        else:
            mask = np.zeros((self.img_size, self.img_size), dtype=np.float32)
            has_mask = 0.0
            
        return torch.tensor(img, dtype=torch.float32), torch.tensor(mask).unsqueeze(0), torch.tensor(item['label'], dtype=torch.long), torch.tensor(has_mask, dtype=torch.float32)

# --- 2. C·∫§U H√åNH ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8    # Gi·∫£m xu·ªëng 8 v√¨ EfficientNet n·∫∑ng h∆°n
EPOCHS = 50       # 50 Epochs v·ªõi d·ªØ li·ªáu nh√¢n b·∫£n l√† r·∫•t nhi·ªÅu
LR = 1e-4

CHIU_DIR = r"D:\project\processed_data"
SRINIVASAN_DIR = r"D:\project\processed_srinivasan"

# Load Data
print("‚è≥ ƒêang n·∫°p d·ªØ li·ªáu (Ch·∫ø ƒë·ªô kh√¥ m√°u)...")
full_dataset = SmartUnifiedDataset(chiu_dir=CHIU_DIR, srinivasan_dir=SRINIVASAN_DIR, img_size=256, oversample_factor=5)
train_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Kh·ªüi t·∫°o Model EfficientNet (3 k√™nh input)
model = MultiTaskUNet(n_channels=3, n_classes_seg=1, n_classes_cls=3).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Loss
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    def forward(self, pred, target):
        pred = torch.sigmoid(pred).view(-1)
        target = target.view(-1)
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        return 1 - dice

criterion_cls = nn.CrossEntropyLoss()
criterion_dice = DiceLoss()
criterion_bce = nn.BCEWithLogitsLoss()

# --- 3. TRAINING ---
print(f"\nüöÄ START TRAINING EFFICIENTNET-B3 ...")
model.train()

for epoch in range(EPOCHS):
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    epoch_loss = 0
    
    for batch_idx, (images, masks, labels, has_masks) in enumerate(loop):
        images, masks, labels, has_masks = images.to(DEVICE), masks.to(DEVICE), labels.to(DEVICE), has_masks.to(DEVICE)
        
        pred_seg, pred_cls = model(images)
        loss_cls = criterion_cls(pred_cls, labels)
        
        mask_indices = torch.nonzero(has_masks > 0).squeeze()
        if mask_indices.numel() > 0:
            if mask_indices.ndim == 0: mask_indices = mask_indices.unsqueeze(0)
            l_dice = criterion_dice(pred_seg[mask_indices], masks[mask_indices])
            l_bce = criterion_bce(pred_seg[mask_indices], masks[mask_indices])
            loss_seg = 0.7 * l_dice + 0.3 * l_bce 
        else:
            loss_seg = torch.tensor(0.0).to(DEVICE)

        # Tr·ªçng s·ªë c·ª±c m·∫°nh
        loss = 50.0 * loss_seg + 1.0 * loss_cls
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])
        
    scheduler.step(epoch_loss / len(train_loader))
    
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"effnet_epoch_{epoch+1}.pth")

print("‚úÖ ƒê√£ xong! Ki·ªÉm tra k·∫øt qu·∫£...")

# --- 4. T√çNH DICE ---
model.eval()
dice_score = 0
num_batches = 0
with torch.no_grad():
    for x, y, labels, has_masks in train_loader:
        x, y, has_masks = x.to(DEVICE), y.to(DEVICE), has_masks.to(DEVICE)
        valid_idx = torch.nonzero(has_masks > 0).squeeze()
        if valid_idx.numel() > 0:
            if valid_idx.ndim == 0: valid_idx = valid_idx.unsqueeze(0)
            pred = (torch.sigmoid(model(x[valid_idx])[0]) > 0.5).float()
            target = y[valid_idx]
            inter = (pred * target).sum()
            union = pred.sum() + target.sum()
            dice_score += ((2. * inter + 1e-6) / (union + 1e-6)).item()
            num_batches += 1

print(f"\nüèÜ FINAL DICE (EFFICIENTNET): {(dice_score/num_batches):.4f}")

‚è≥ ƒêang n·∫°p d·ªØ li·ªáu g·ªëc...
--> ƒê√£ n·∫°p 610 ·∫£nh t·ª´ b·ªô Chiu (C√≥ Mask).
--> ƒê√£ n·∫°p 3231 ·∫£nh t·ª´ b·ªô Srinivasan (Duke).
‚úÖ T·ªïng c·ªông: 3841 ·∫£nh.

üöÄ B·∫Øt ƒë·∫ßu Training ·ªîn ƒë·ªãnh (100 Epochs)...


Epoch 1/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 241/241 [00:50<00:00,  4.78it/s, cls=3.31, seg=0]       
Epoch 2/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 241/241 [00:50<00:00,  4.81it/s, cls=0.393, seg=0]      
Epoch 3/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 241/241 [00:50<00:00,  4.80it/s, cls=1.29, seg=0]       
Epoch 4/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 241/241 [00:50<00:00,  4.77it/s, cls=1.4, seg=0]        
Epoch 5/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 241/241 [00:50<00:00,  4.78it/s, cls=2.8, seg=0.956]    
Epoch 6/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 241/241 [00:50<00:00,  4.78it/s, cls=1.58, seg=0]        
Epoch 7/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 241/241 [00:50<00:00,  4.79it/s, cls=0.557, seg=0]       
Epoch 8/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 241/241 [00:50<00:00,  4.77it/s, cls=3.07, seg=0]         
Epoch 9/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 241/241 [00:50<00:00,  4.81it/s, cls=0.558, seg=0]         
Epoch 10/100: 100%|‚ñà‚ñà‚ñà‚ñ

‚úÖ Training ho√†n t·∫•t!

üìä ƒêang t√≠nh to√°n k·∫øt qu·∫£...
üèÜ FINAL DICE SCORE: 0.6889
