In [None]:
import os
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets, transforms, models

# =====================
# CONFIGURATION
# =====================
DATA_DIR = "./data"  # Thư mục chứa dataset (ví dụ: data/train/)
NUM_CLASSES = 10
BATCH_SIZE = 32
EPOCHS = 15
LR = 1e-3
KFOLD_SPLITS = 5
CHECKPOINT_DIR = "./checkpoints"

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =====================
# DATA AUGMENTATION
# =====================
class AddGaussianNoise(object):
    """Thêm nhiễu Gaussian vào ảnh"""
    def __init__(self, mean=0., std=0.05):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        noisy = tensor + noise
        return torch.clamp(noisy, 0., 1.)


transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    AddGaussianNoise(0., 0.03),
    transforms.Normalize((0.5,), (0.5,))
])

# Load dataset (ví dụ: ảnh chia theo folder per class)
dataset = datasets.ImageFolder(DATA_DIR, transform=transform)

# Chia train/test 80/20
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

print(f"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}")


# =====================
# MODEL
# =====================
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 32 * 32, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


# =====================
# TRAINING + VALIDATION
# =====================
def train_one_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    for imgs, labels in tqdm(dataloader, desc="Train", leave=False):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    f1 = f1_score(all_labels, all_preds, average='macro')
    return total_loss / len(dataloader), f1


def validate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in tqdm(dataloader, desc="Valid", leave=False):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    f1 = f1_score(all_labels, all_preds, average='macro')
    return total_loss / len(dataloader), f1


# =====================
# CROSS VALIDATION
# =====================
def run_cross_validation(dataset, k_folds=KFOLD_SPLITS):
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(kfold.split(range(len(dataset)))):
        print(f"\n========== Fold {fold + 1}/{k_folds} ==========")

        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)

        model = SimpleCNN(num_classes=NUM_CLASSES).to(DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=LR)

        best_val_f1 = 0.0
        best_model_path = os.path.join(CHECKPOINT_DIR, f"best_model_fold{fold+1}.pt")

        for epoch in range(EPOCHS):
            print(f"Epoch {epoch + 1}/{EPOCHS}")
            train_loss, train_f1 = train_one_epoch(model, train_loader, criterion, optimizer)
            val_loss, val_f1 = validate(model, val_loader, criterion)

            print(f"Train Loss: {train_loss:.4f}, F1: {train_f1:.4f} | Val Loss: {val_loss:.4f}, F1: {val_f1:.4f}")

            # Save checkpoint nếu tốt hơn
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save(model.state_dict(), best_model_path)
                print(f"✅ Saved new best model at epoch {epoch + 1} (F1={best_val_f1:.4f})")

        fold_results.append(best_val_f1)

    print("\n===== K-Fold Results =====")
    for i, f1_val in enumerate(fold_results):
        print(f"Fold {i+1}: Best F1 = {f1_val:.4f}")
    print(f"Average F1: {np.mean(fold_results):.4f}")


# =====================
# MAIN
# =====================
if __name__ == "__main__":
    print("🚀 Starting Training with K-Fold Cross Validation...")
    run_cross_validation(train_dataset, k_folds=KFOLD_SPLITS)
    print("🎯 Training complete. Checkpoints saved in", CHECKPOINT_DIR)
