In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# --- CẤU HÌNH ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 10       # Ví dụ với dataset CIFAR-10
BATCH_SIZE = 16        # ConvNext Base khá nặng, giảm batch size nếu OOM (Out of Memory)
LEARNING_RATE = 1e-4
EPOCHS = 5
SAVE_PATH = "best_convnext_base.pth"

def get_model():
    """Khởi tạo ConvNeXt Tiny đơn giản, sửa head."""
    print("-> Đang tải ConvNeXt Tiny pretrained...")
    # Load model với trọng số ImageNet mặc định
    model = models.convnext_tiny(weights='DEFAULT')

    # Sửa lớp classifier cuối cùng (LayerNorm -> Flatten -> Linear)
    in_features = model.classifier[2].in_features
    model.classifier[2] = nn.Linear(in_features, NUM_CLASSES)

    return model.to(DEVICE)

from torch.utils.data import Subset
import numpy as np

def get_dataloaders():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    print("-> Đang tải CIFAR-10...")
    train_full = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    val_full   = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # Lấy subset nhỏ hơn (ví dụ 5000 train, 1000 val)
    train_idx = np.random.choice(len(train_full), 5000, replace=False)
    val_idx   = np.random.choice(len(val_full), 1000, replace=False)

    train_data = Subset(train_full, train_idx)
    val_data   = Subset(val_full, val_idx)

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=6)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=6)
    return train_loader, val_loader
    

def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Dùng tqdm bọc loader để hiện thanh progress bar
    loop = tqdm(loader, desc="Training", leave=False)
    
    for images, labels in loop:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        # Tính toán thống kê
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        loop.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(loader)
    acc = 100. * correct / total
    return avg_loss, acc

def validate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        loop = tqdm(loader, desc="Validating", leave=False)
        for images, labels in loop:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / len(loader)
    acc = 100. * correct / total
    return avg_loss, acc

# --- MAIN RUN ---
if __name__ == "__main__":
    print(f"Sử dụng thiết bị: {DEVICE}")
    
    # 1. Chuẩn bị
    train_loader, val_loader = get_dataloaders()
    model = get_model()
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE) # ConvNeXt thường dùng AdamW
    
    best_val_acc = 0.0

    print(f"\nBắt đầu train {EPOCHS} epochs...")
    print("-" * 50)

    # 2. Vòng lặp Train/Val
    for epoch in range(EPOCHS):
        # Train
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer)
        
        # Validation
        val_loss, val_acc = validate(model, val_loader, criterion)
        
        print(f"Epoch [{epoch+1}/{EPOCHS}]")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")

        # 3. Lưu model tốt nhất
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), SAVE_PATH)
            print(f"--> Đã lưu model tốt nhất (Acc: {best_val_acc:.2f}%)")
        
        print("-" * 50)

    print("Hoàn tất huấn luyện!") 