# üßº Image Demoising b·∫±ng Autoencoder (PyTorch)



## Quy tr√¨nh
- ƒê·ªçc d·ªØ li·ªáu ·∫£nh theo chu·∫©n `ImageFolder`
- Th√™m nhi·ªÖu: **Gaussian (normal)** / **Bernoulli (mask/dropout)** / **Poisson**
- Hu·∫•n luy·ªán Autoencoder ƒë·ªÉ **kh·ª≠ nhi·ªÖu**
- ƒê√°nh gi√° b·∫±ng **PSNR** v√† **MSE**
- L∆∞u ·∫£nh minh ho·∫°: `*_noisy.png`, `*_denoised.png`, `*_clean.png`
- L∆∞u checkpoint t·ªët nh·∫•t theo PSNR: `best_autoencoder.pt`




In [1]:
import os
from pathlib import Path
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, utils as vutils
import matplotlib.pyplot as plt

print("‚úÖ Torch version:", torch.__version__)
print("‚úÖ CUDA available:", torch.cuda.is_available())

ModuleNotFoundError: No module named 'numpy'

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import os
import time

# Detect Kaggle environment
IN_KAGGLE = os.path.exists('/kaggle')
print(f"üîç Kaggle environment detected: {IN_KAGGLE}")

if IN_KAGGLE:
    print("‚úÖ Running on Kaggle Notebook")
else:
    print("üíª Running on local machine")

# ==================== KI·ªÇM TRA GPU ====================
print("\n" + "="*60)
print("üñ•Ô∏è KI·ªÇM TRA GPU")
print("="*60)
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)} - {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
    print(f"cudnn Version: {torch.backends.cudnn.version()}")
    print(f"cudnn Enabled: {torch.backends.cudnn.enabled}")
else:
    print("‚ö†Ô∏è GPU NOT AVAILABLE - Will use CPU (slower!)")
print("="*60 + "\n")

# ================== C·∫§U H√åNH CHO KAGGLE + GPU ====================

# 1) ƒê∆∞·ªùng d·∫´n d·ªØ li·ªáu - T·ª± ƒë·ªông ƒëi·ªÅu ch·ªânh cho Kaggle
if IN_KAGGLE:
    # Tr√™n Kaggle: upload ·∫£nh v√†o Kaggle dataset ho·∫∑c s·ª≠ d·ª•ng working directory
    # Option 1: N·∫øu upload dataset v√†o Kaggle, d√πng path n√†y:
    # thu_muc_du_lieu = "/kaggle/input/your-dataset-name/thumbnails"
    
    # Option 2: T·∫°o d·ªØ li·ªáu t·ª´ working directory (recommended cho testing)
    thu_muc_du_lieu = "/kaggle/input/thumbnails"
    thu_muc_ket_qua = "/kaggle/working/outputs_denoise"
    
    # ƒê·ªÉ l∆∞u v√† load checkpoint
    duong_dan_checkpoint_ae = "/kaggle/working/best_ae_model.pth"
    duong_dan_checkpoint_gan_g = "/kaggle/working/best_gan_generator.pth"
    duong_dan_checkpoint_gan_d = "/kaggle/working/best_gan_discriminator.pth"
    
else:
    # Tr√™n m√°y local
    thu_muc_du_lieu = "./thumbnails"
    thu_muc_ket_qua = "./outputs_denoise"
    duong_dan_checkpoint_ae = "./best_ae_model.pth"
    duong_dan_checkpoint_gan_g = "./best_gan_generator.pth"
    duong_dan_checkpoint_gan_d = "./best_gan_discriminator.pth"

# 2) T·ªêI ∆ØU H√ìA GPU - MULTI-GPU SUPPORT
use_cuda = torch.cuda.is_available()
n_gpus = torch.cuda.device_count() if use_cuda else 0

# Batch size t·ªëi ∆∞u cho GPU - TƒÇNG G·∫§P ƒê√îI N·∫æU C√ì 2 GPU
if use_cuda:
    # GPU c√≥ ƒë·ªß VRAM, d√πng batch size l·ªõn h∆°n
    base_batch_size = 32 if "P100" in torch.cuda.get_device_name(0) else 16
    batch_size = base_batch_size * n_gpus  # Nh√¢n v·ªõi s·ªë GPU
    so_worker = 4 * n_gpus  # TƒÉng workers theo s·ªë GPU
    pin_memory = True  # Transfer data nhanh h∆°n
    
    # T·ªëi ∆∞u cuDNN
    torch.backends.cudnn.benchmark = True  # T√¨m fastest implementation
    torch.backends.cudnn.enabled = True
else:
    batch_size = 8  # CPU: batch size nh·ªè
    so_worker = 0
    pin_memory = False

# 3) Tham s·ªë hu·∫•n luy·ªán
so_epoch = 20
learning_rate = 1e-3

# 4) K√≠ch th∆∞·ªõc ·∫£nh v√† ƒë·ªô r·ªông m·∫°ng
kich_thuoc_anh = 128
so_kenh_co_so = 32

# 5) Nhi·ªÖu
loai_nhieu = "normal"     # "normal" | "bernoulli" | "poisson"
do_manh_nhieu = 0.05      # normal: sigma; bernoulli: t·ªâ l·ªá mask; poisson: m·ª©c noise

# 6) Ki·ªÉu m·ª•c ti√™u
kieu_muc_tieu = "clean"   # "clean" = kh·ª≠ nhi·ªÖu | "noisy" = identity-noise

# 7) Loss
ten_loss = "mse"          # "mse" | "l1"

# 8) Thi·∫øt b·ªã
bat_buoc_cpu = False
thiet_bi = torch.device("cuda" if use_cuda else "cpu")

# 9) MULTI-GPU CONFIG
use_multi_gpu = n_gpus > 1 and use_cuda

print("="*60)
print("‚öôÔ∏è C·∫§U H√åNH TRAINING T·ªêI ∆ØU")
print("="*60)
print(f"üìÅ Data dir: {thu_muc_du_lieu}")
print(f"üìÅ Output dir: {thu_muc_ket_qua}")
print(f"üñ•Ô∏è Device: {thiet_bi}")
print(f"üéÆ Number of GPUs: {n_gpus}")
if use_multi_gpu:
    print(f"üöÄ MULTI-GPU ENABLED: S·ª≠ d·ª•ng {n_gpus} GPUs ƒë·ªìng th·ªùi!")
    print(f"   GPUs: {[torch.cuda.get_device_name(i) for i in range(n_gpus)]}")
else:
    print(f"‚ö†Ô∏è Single GPU mode (ch·ªâ d√πng 1 GPU)")
print(f"üìä Batch size: {batch_size} ({base_batch_size if use_cuda else batch_size} per GPU)")
print(f"üîÑ Num workers: {so_worker}")
print(f"üìå Pin memory: {pin_memory}")
print(f"üöÄ cuDNN Benchmark: {torch.backends.cudnn.benchmark}")
print(f"‚è±Ô∏è Epochs: {so_epoch}")
print("="*60 + "\n")

In [None]:
# ================== THI·∫æT L·∫¨P TH∆Ø M·ª§C + THI·∫æT B·ªä ==================
thu_muc_ket_qua = Path(thu_muc_ket_qua)
thu_muc_ket_qua.mkdir(parents=True, exist_ok=True)

# Thi·∫øt b·ªã ƒë√£ ƒë∆∞·ª£c set ·ªü cell tr∆∞·ªõc
print("=" * 60)
print("üîß THI·∫æT L·∫¨P M√îI TR∆Ø·ªúNG TRAINING")
print("=" * 60)
print(f"üñ•Ô∏è Thi·∫øt b·ªã: {thiet_bi} {'‚úÖ GPU' if use_cuda else '‚ùå CPU'}")
print(f"üìÇ D·ªØ li·ªáu: {Path(thu_muc_du_lieu).resolve()}")
print(f"üìÅ K·∫øt qu·∫£: {thu_muc_ket_qua.resolve()}")
print(f"üñºÔ∏è K√≠ch th∆∞·ªõc: {kich_thuoc_anh}x{kich_thuoc_anh}")
print(f"üéØ Nhi·ªÖu: {loai_nhieu} (œÉ={do_manh_nhieu})")
print(f"üìâ Loss: {ten_loss.upper()}")
print(f"‚è±Ô∏è Epochs: {so_epoch} | Batch: {batch_size} | Workers: {so_worker}")

# Ki·ªÉm tra d·ªØ li·ªáu c√≥ t·ªìn t·∫°i kh√¥ng
duong_dan_du_lieu = Path(thu_muc_du_lieu)
if not duong_dan_du_lieu.exists():
    print(f"\n‚ö†Ô∏è C·∫¢NH B√ÅO: {duong_dan_du_lieu} kh√¥ng t·ªìn t·∫°i!")
    print(f"üìå B·∫°n c·∫ßn upload dataset ho·∫∑c t·∫°o th∆∞ m·ª•c d·ªØ li·ªáu tr∆∞·ªõc khi ch·∫°y")
    if IN_KAGGLE:
        print("üí° Tr√™n Kaggle: H√£y upload dataset ho·∫∑c s·ª≠ d·ª•ng 'Add data' t√≠nh nƒÉng")

# B·∫≠t Mixed Precision Training n·∫øu s·ª≠ d·ª•ng GPU (ti·∫øt ki·ªám RAM, nhanh h∆°n)
if use_cuda and torch.cuda.get_device_capability(0)[0] >= 7:  # Compute Capability >= 7 (V100, A100, etc)
    print("\n‚úÖ Mixed Precision Training ENABLED (amp)")
    from torch.cuda.amp import autocast, GradScaler
    use_amp = True
    scaler = GradScaler()
else:
    print("\n‚ùå Mixed Precision Training DISABLED")
    use_amp = False
    scaler = None

print("=" * 60 + "\n")


## 2) H√†m ƒë√°nh gi√° (PSNR, MSE) + h√†m th√™m nhi·ªÖu

In [None]:
# ================== H√ÄM ƒê√ÅNH GI√Å & TH√äM NHI·ªÑU ====================

def psnr_tensor(x: torch.Tensor, y: torch.Tensor, max_pixel: float = 1.0) -> torch.Tensor:
    """T√≠nh PSNR gi·ªØa 2 tensor (gi·∫£ s·ª≠ gi√° tr·ªã trong [0,1])."""
    mse = torch.mean((x - y) ** 2)
    if mse == 0:
        return torch.tensor(float('inf'), device=x.device)
    psnr = 20 * torch.log10(torch.tensor(max_pixel, device=x.device) / torch.sqrt(mse))
    return psnr

def mse_tensor(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """T√≠nh MSE gi·ªØa 2 tensor."""
    return torch.mean((x - y) ** 2)

def them_nhieu(anh: torch.Tensor, loai: str = "normal", do_manh: float = 0.05) -> torch.Tensor:
    """Th√™m nhi·ªÖu v√†o ·∫£nh - T·ªêI ∆ØU GPU."""
    if loai == "normal":
        nhieu = torch.randn_like(anh) * do_manh
    elif loai == "bernoulli":
        nhieu = torch.bernoulli(torch.full_like(anh, do_manh)) - 0.5
    elif loai == "poisson":
        nhieu = (torch.poisson(anh * 255.0 * do_manh) - anh * 255.0 * do_manh) / (255.0 * do_manh)
    else:
        nhieu = torch.zeros_like(anh)
    return torch.clamp(anh + nhieu, 0, 1)

print("‚úÖ ƒê√£ ƒë·ªãnh nghƒ©a c√°c h√†m: psnr_tensor(), mse_tensor(), them_nhieu()")
print("‚úÖ T·∫•t c·∫£ h√†m ƒë√£ t·ªëi ∆∞u ƒë·ªÉ ch·∫°y tr·ª±c ti·∫øp tr√™n GPU (tensor operations)")

## 3) Ki·∫øn tr√∫c Autoencoder

In [None]:
# ================== KI·∫æN TR√öC AUTOENCODER ====================

class Autoencoder(nn.Module):
    """Autoencoder ƒë·ªÉ kh·ª≠ nhi·ªÖu ·∫£nh."""
    def __init__(self, so_kenh_vao: int = 3, kenh_co_so: int = 32):
        super().__init__()
        
        # Encoder: Downsampling
        self.encoder = nn.Sequential(
            nn.Conv2d(so_kenh_vao, kenh_co_so, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(kenh_co_so),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(kenh_co_so, kenh_co_so*2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(kenh_co_so*2),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(kenh_co_so*2, kenh_co_so*4, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(kenh_co_so*4),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(kenh_co_so*4, kenh_co_so*8, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(kenh_co_so*8),
            nn.ReLU(inplace=True),
        )
        
        # Decoder: Upsampling
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(kenh_co_so*8, kenh_co_so*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(kenh_co_so*4),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(kenh_co_so*4, kenh_co_so*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(kenh_co_so*2),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(kenh_co_so*2, kenh_co_so, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(kenh_co_so),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(kenh_co_so, so_kenh_vao, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # Output: [0, 1]
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass: Encode -> Decode."""
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


# Kh·ªüi t·∫°o m√¥ h√¨nh Autoencoder
mo_hinh = Autoencoder(so_kenh_vao=3, kenh_co_so=so_kenh_co_so).to(thiet_bi)

# ============ WRAP V·ªöI DataParallel N·∫æU C√ì NHI·ªÄU GPU ============
if use_multi_gpu:
    print(f"\nüöÄ Wrapping Autoencoder v·ªõi DataParallel ({n_gpus} GPUs)")
    mo_hinh = nn.DataParallel(mo_hinh)
    print(f"‚úÖ Model s·∫Ω ch·∫°y tr√™n: {mo_hinh.device_ids}")

print("‚úÖ Autoencoder Architecture:")
print(mo_hinh)
print(f"\n‚úÖ Model moved to {thiet_bi}")
if use_multi_gpu:
    print(f"üéÆ Multi-GPU Training ENABLED: {n_gpus} GPUs ƒë·ªìng th·ªùi!")

In [None]:
# ================== CHU·∫®N B·ªä D·ªÆ LI·ªÜU ==================

# ƒê·ªãnh nghƒ©a Transform
transform = transforms.Compose([
    transforms.Resize((kich_thuoc_anh, kich_thuoc_anh)),
    transforms.ToTensor(),
])

# Load d·ªØ li·ªáu t·ª´ ImageFolder
try:
    dataset_full = ImageFolder(root=thu_muc_du_lieu, transform=transform)
    print(f"‚úÖ ƒê√£ load {len(dataset_full)} ·∫£nh t·ª´ {thu_muc_du_lieu}")
except Exception as e:
    print(f"‚ùå L·ªói khi load d·ªØ li·ªáu: {e}")
    print(f"üìå ƒê·∫£m b·∫£o c·∫•u tr√∫c th∆∞ m·ª•c: {thu_muc_du_lieu}/classA/ v√† classB/")
    dataset_full = None

if dataset_full is not None:
    # Split train/val (80/20)
    n = len(dataset_full)
    n_train = int(0.8 * n)
    n_val = n - n_train
    
    indices = np.random.permutation(n)
    indices_train = indices[:n_train]
    indices_val = indices[n_train:]
    
    dataset_train = torch.utils.data.Subset(dataset_full, indices_train)
    dataset_val = torch.utils.data.Subset(dataset_full, indices_val)
    
    print(f"üìä Train: {len(dataset_train)} | Val: {len(dataset_val)}")
    
    # DataLoader - T·ªêI ∆ØU CHO GPU
    train_loader = DataLoader(
        dataset_train, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=so_worker,
        pin_memory=pin_memory,  # ‚ö° GPU t·ªëi ∆∞u
        persistent_workers=(so_worker > 0)  # Gi·ªØ workers s·ªëng ƒë·ªÉ t√°i s·ª≠ d·ª•ng
    )
    val_loader = DataLoader(
        dataset_val, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=so_worker,
        pin_memory=pin_memory,
        persistent_workers=(so_worker > 0)
    )
    
    print(f"‚úÖ DataLoader t·∫°o th√†nh c√¥ng")
    print(f"üí° M·ªói epoch: {len(train_loader)} batches training | {len(val_loader)} batches validation")
else:
    train_loader = None
    val_loader = None
    print("‚ö†Ô∏è Kh√¥ng th·ªÉ t·∫°o DataLoader")


## 4) H√†m ƒë√°nh gi√° v√† l∆∞u ·∫£nh minh ho·∫°

In [None]:
# ================== H√ÄM ƒê√ÅNH GI√Å & MINH H·ªåA ==================

def psnr_tensor(x: torch.Tensor, y: torch.Tensor, max_pixel: float = 1.0) -> torch.Tensor:
    """T√≠nh PSNR gi·ªØa 2 tensor (gi·∫£ s·ª≠ gi√° tr·ªã trong [0,1])."""
    mse = torch.mean((x - y) ** 2)
    if mse == 0:
        return torch.tensor(float('inf'), device=x.device)
    psnr = 20 * torch.log10(torch.tensor(max_pixel, device=x.device) / torch.sqrt(mse))
    return psnr

def mse_tensor(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """T√≠nh MSE gi·ªØa 2 tensor."""
    return torch.mean((x - y) ** 2)

def them_nhieu(anh: torch.Tensor, loai: str = "normal", do_manh: float = 0.05) -> torch.Tensor:
    """Th√™m nhi·ªÖu v√†o ·∫£nh."""
    if loai == "normal":
        nhieu = torch.randn_like(anh) * do_manh
    elif loai == "bernoulli":
        nhieu = torch.bernoulli(torch.full_like(anh, do_manh)) - 0.5
    elif loai == "poisson":
        nhieu = (torch.poisson(anh * 255.0 * do_manh) - anh * 255.0 * do_manh) / (255.0 * do_manh)
    else:
        nhieu = torch.zeros_like(anh)
    return torch.clamp(anh + nhieu, 0, 1)

def danh_gia(mo_hinh: nn.Module, loader: DataLoader, device: torch.device, 
             loai_nhieu: str, do_manh_nhieu: float) -> tuple:
    """ƒê√°nh gi√° m√¥ h√¨nh v√† tr·∫£ v·ªÅ PSNR, MSE, loss."""
    mo_hinh.eval()
    psnr_vals, mse_vals, loss_vals = [], [], []
    loss_fn = nn.MSELoss()
    
    with torch.no_grad():
        for anh_sach, _ in loader:
            anh_sach = anh_sach.to(device)
            anh_nhieu = them_nhieu(anh_sach, loai_nhieu, do_manh_nhieu)
            
            dau_ra = mo_hinh(anh_nhieu)
            
            psnr_vals.append(psnr_tensor(dau_ra, anh_sach).item())
            mse_vals.append(mse_tensor(dau_ra, anh_sach).item())
            loss_vals.append(loss_fn(dau_ra, anh_sach).item())
    
    return np.mean(psnr_vals), np.mean(mse_vals), np.mean(loss_vals)

def luu_minh_hoa(mo_hinh: nn.Module, anh_mau: torch.Tensor, device: torch.device,
                loai_nhieu: str, do_manh_nhieu: float, thu_muc: Path, tag: str = ""):
    """L∆∞u ·∫£nh minh ho·∫° denoising."""
    mo_hinh.eval()
    with torch.no_grad():
        anh_sach = anh_mau.to(device)
        anh_nhieu = them_nhieu(anh_sach, loai_nhieu, do_manh_nhieu)
        anh_khuc_phuc = mo_hinh(anh_nhieu)
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    for idx, (anh, tieu_de) in enumerate([
        (anh_nhieu[0], f"Nhi·ªÖu ({loai_nhieu})"),
        (anh_khuc_phuc[0], "Kh√¥i ph·ª•c"),
        (anh_sach[0], "G·ªëc")
    ]):
        axes[idx].imshow(anh.permute(1, 2, 0).cpu().numpy())
        axes[idx].set_title(tieu_de)
        axes[idx].axis('off')
    
    plt.tight_layout()
    ten_file = f"viz_{tag}.png" if tag else "viz.png"
    duong_dan_anh = thu_muc / ten_file
    plt.savefig(duong_dan_anh, dpi=100, bbox_inches='tight')
    plt.close()


## 5) T·∫£i d·ªØ li·ªáu (ImageFolder)

In [None]:
# ================== KI·ªÇM TRA DATALOADER (ƒê√É T·ªêI ∆ØU GPU ·ªû CELL TR∆Ø·ªöC) ====================

print("="*60)
print("‚úÖ KI·ªÇM TRA DATALOADER")
print("="*60)

if train_loader is not None and val_loader is not None:
    print(f"‚úÖ Train DataLoader: {len(train_loader)} batches")
    print(f"‚úÖ Val DataLoader: {len(val_loader)} batches")
    print(f"‚úÖ Batch size: {batch_size}")
    print(f"‚úÖ Num workers: {so_worker}")
    print(f"‚úÖ Pin memory: {pin_memory} (GPU t·ªëi ∆∞u)")
    print(f"‚úÖ Non-blocking: True (GPU transfer t·ªëi ∆∞u)")
    print("\nüí° DataLoader ƒë√£ ƒë∆∞·ª£c t·ªëi ∆∞u GPU ·ªü cell tr∆∞·ªõc - KH√îNG T·∫†O L·∫†I!")
else:
    print("‚ùå DataLoader ch∆∞a ƒë∆∞·ª£c kh·ªüi t·∫°o! Ch·∫°y l·∫°i cell tr∆∞·ªõc.")

print("="*60)

## 6) Hu·∫•n luy·ªán m√¥ h√¨nh

In [None]:
# ================== HU·∫§N LUY·ªÜN AUTOENCODER (T·ªêI ∆ØU GPU + MULTI-GPU) ==================

# Kh·ªüi t·∫°o optimizer + loss
bo_toi_uu = optim.Adam(mo_hinh.parameters(), lr=learning_rate)

if ten_loss == "mse":
    ham_loss = nn.MSELoss()
elif ten_loss == "l1":
    ham_loss = nn.L1Loss()
else:
    raise ValueError("‚ùå ten_loss ph·∫£i l√† 'mse' ho·∫∑c 'l1'")

best_psnr = -1.0
best_anh_mau = None
lich_su = {"loss_train": [], "psnr_val": [], "mse_val": []}

print("\n" + "="*60)
print("üöÄ B·∫ÆT ƒê·∫¶U HU·∫§N LUY·ªÜN AUTOENCODER")
if use_multi_gpu:
    print(f"üéÆ Training v·ªõi {n_gpus} GPUs: {[torch.cuda.get_device_name(i) for i in range(n_gpus)]}")
print("="*60)

time_start = time.time()

for epoch in range(1, so_epoch + 1):
    epoch_start = time.time()
    mo_hinh.train()
    tong_loss = 0.0
    
    with tqdm(train_loader, desc=f"Epoch {epoch}/{so_epoch}", leave=False) as pbar:
        for anh_sach, _ in pbar:
            anh_sach = anh_sach.to(thiet_bi, non_blocking=True)  # Non-blocking GPU transfer
            
            # Input: ·∫£nh nhi·ªÖu
            anh_nhieu = them_nhieu(anh_sach, loai_nhieu, do_manh_nhieu).to(thiet_bi, non_blocking=True)
            
            # Target
            if kieu_muc_tieu == "clean":
                muc_tieu = anh_sach
            elif kieu_muc_tieu == "noisy":
                muc_tieu = them_nhieu(anh_sach, loai_nhieu, do_manh_nhieu).to(thiet_bi, non_blocking=True)
            else:
                raise ValueError("‚ùå kieu_muc_tieu ph·∫£i l√† 'clean' ho·∫∑c 'noisy'")
            
            # Mixed precision forward pass
            if use_amp:
                with autocast():
                    anh_tai_tao = mo_hinh(anh_nhieu)
                    loss = ham_loss(anh_tai_tao, muc_tieu)
                
                # Backward with scaling
                bo_toi_uu.zero_grad()
                scaler.scale(loss).backward()
                scaler.unscale_(bo_toi_uu)
                torch.nn.utils.clip_grad_norm_(mo_hinh.parameters(), 1.0)
                scaler.step(bo_toi_uu)
                scaler.update()
            else:
                # Standard training - DataParallel t·ª± ƒë·ªông ph√¢n chia batch l√™n c√°c GPU
                anh_tai_tao = mo_hinh(anh_nhieu)
                loss = ham_loss(anh_tai_tao, muc_tieu)
                
                bo_toi_uu.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(mo_hinh.parameters(), 1.0)
                bo_toi_uu.step()
            
            tong_loss += loss.item() * anh_sach.size(0)
            pbar.set_postfix({'loss': f'{loss.item():.6f}'})
    
    loss_train = tong_loss / len(train_loader.dataset)
    
    # Validate
    psnr_val, mse_val, _ = danh_gia(mo_hinh, val_loader, thiet_bi, loai_nhieu, do_manh_nhieu)
    
    lich_su["loss_train"].append(loss_train)
    lich_su["psnr_val"].append(psnr_val)
    lich_su["mse_val"].append(mse_val)
    
    epoch_time = time.time() - epoch_start
    
    print(f"[Epoch {epoch:02d}] ({epoch_time:.1f}s) "
          f"Loss: {loss_train:.6f} | "
          f"PSNR: {psnr_val:.2f}dB | "
          f"MSE: {mse_val:.6f}")
    
    # Ch·ªâ l∆∞u khi ƒë·∫°t PSNR t·ªët nh·∫•t
    if psnr_val > best_psnr:
        best_psnr = psnr_val
        best_anh_mau, _ = next(iter(val_loader))
        
        # L∆∞u model - n·∫øu d√πng DataParallel, l∆∞u module.state_dict()
        if use_multi_gpu:
            torch.save(mo_hinh.module.state_dict(), duong_dan_checkpoint_ae)
        else:
            torch.save(mo_hinh.state_dict(), duong_dan_checkpoint_ae)
        
        print(f"  ‚úÖ PSNR t·ªët nh·∫•t m·ªõi: {best_psnr:.4f}dB - Checkpoint l∆∞u th√†nh c√¥ng!")

total_time = time.time() - time_start
print("\n" + "="*60)
print("‚úÖ HO√ÄN TH√ÄNH H·ªåC AUTOENCODER")
print("="*60)
print(f"Best PSNR: {best_psnr:.4f}")
print(f"Th·ªùi gian: {total_time/60:.1f} ph√∫t ({total_time/3600:.2f} gi·ªù)")
print(f"Thi·∫øt b·ªã: {thiet_bi} {'(GPU)' if use_cuda else '(CPU)'}")
if use_multi_gpu:
    print(f"üéÆ ƒê√£ s·ª≠ d·ª•ng {n_gpus} GPUs ƒë·ªìng th·ªùi!")
print("="*60)

# L∆∞u ·∫£nh minh ho·∫° CH·ªà T·ª™ MODEL T·ªêT NH·∫§T
if best_anh_mau is not None:
    print("\n" + "="*60)
    print("üñºÔ∏è L∆ØU·∫¢NH MINH H·ªåA T·ª™ MODEL T·ªêT NH·∫§T")
    print("="*60)
    luu_minh_hoa(mo_hinh, best_anh_mau, thiet_bi, loai_nhieu, do_manh_nhieu, thu_muc_ket_qua, tag="best_autoencoder")
    print(f"‚úÖ ƒê√£ l∆∞u ·∫£nh ch·∫•t l∆∞·ª£ng cao nh·∫•t (PSNR={best_psnr:.4f}dB)")
    print("="*60)

# Clear cache
if use_cuda:
    torch.cuda.empty_cache()
    print("\nüíæ GPU cache ƒë√£ ƒë∆∞·ª£c x√≥a")

## 7) V·∫Ω bi·ªÉu ƒë·ªì l·ªãch s·ª≠ hu·∫•n luy·ªán

In [None]:
plt.figure()
plt.plot(lich_su["loss_train"])
plt.title("Loss hu·∫•n luy·ªán (Train Loss)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

plt.figure()
plt.plot(lich_su["psnr_val"])
plt.title("PSNR validation (dB)")
plt.xlabel("Epoch")
plt.ylabel("PSNR (dB)")
plt.show()

plt.figure()
plt.plot(lich_su["mse_val"])
plt.title("MSE validation")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.show()

## 8) (Tu·ª≥ ch·ªçn) Load l·∫°i checkpoint t·ªët nh·∫•t v√† l∆∞u ·∫£nh minh ho·∫°

## 9) Image Denoising v·ªõi GAN (Generator + Discriminator)

GAN Architecture:
- **Generator**: M·∫°ng sinh ·∫£nh s·∫°ch t·ª´ ·∫£nh nhi·ªÖu
- **Discriminator**: Ph√¢n bi·ªát ·∫£nh s·∫°ch th·ª±c vs ·∫£nh sinh ra
- **Loss**: Adversarial Loss + Reconstruction Loss (L1/MSE)


In [None]:
# ==================== KI·∫æN TR√öC GAN ====================

class Generator(nn.Module):
    """Generator t∆∞∆°ng t·ª± Autoencoder nh∆∞ng output activation kh√°c."""
    def __init__(self, so_kenh_vao: int = 3, kenh_co_so: int = 32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(so_kenh_vao, kenh_co_so, 3, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so), nn.ReLU(True),
            nn.Conv2d(kenh_co_so, kenh_co_so*2, 3, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so*2), nn.ReLU(True),
            nn.Conv2d(kenh_co_so*2, kenh_co_so*4, 3, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so*4), nn.ReLU(True),
            nn.Conv2d(kenh_co_so*4, kenh_co_so*8, 3, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so*8), nn.ReLU(True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(kenh_co_so*8, kenh_co_so*4, 4, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so*4), nn.ReLU(True),
            nn.ConvTranspose2d(kenh_co_so*4, kenh_co_so*2, 4, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so*2), nn.ReLU(True),
            nn.ConvTranspose2d(kenh_co_so*2, kenh_co_so, 4, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so), nn.ReLU(True),
            nn.ConvTranspose2d(kenh_co_so, so_kenh_vao, 4, stride=2, padding=1), nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(self.encoder(x))


class Discriminator(nn.Module):
    """Discriminator ph√¢n bi·ªát ·∫£nh th·ª±c vs gi·∫£ b·∫±ng CNN."""
    def __init__(self, so_kenh_vao: int = 3, kenh_co_so: int = 32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(so_kenh_vao, kenh_co_so, 4, stride=2, padding=1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(kenh_co_so, kenh_co_so*2, 4, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(kenh_co_so*2, kenh_co_so*4, 4, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(kenh_co_so*4, kenh_co_so*8, 4, stride=2, padding=1), nn.BatchNorm2d(kenh_co_so*8), nn.LeakyReLU(0.2, True),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Linear(kenh_co_so*8, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.net(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


# Kh·ªüi t·∫°o GAN
generator_gan = Generator(so_kenh_vao=3, kenh_co_so=so_kenh_co_so).to(thiet_bi)
discriminator_gan = Discriminator(so_kenh_vao=3, kenh_co_so=so_kenh_co_so).to(thiet_bi)

# ============ WRAP V·ªöI DataParallel N·∫æU C√ì NHI·ªÄU GPU ============
if use_multi_gpu:
    print(f"\nüöÄ Wrapping GAN v·ªõi DataParallel ({n_gpus} GPUs)")
    generator_gan = nn.DataParallel(generator_gan)
    discriminator_gan = nn.DataParallel(discriminator_gan)
    print(f"‚úÖ Generator device_ids: {generator_gan.device_ids}")
    print(f"‚úÖ Discriminator device_ids: {discriminator_gan.device_ids}")

print("‚úÖ Generator:")
print(generator_gan)
print("\n‚úÖ Discriminator:")
print(discriminator_gan)

if use_multi_gpu:
    print(f"\nüéÆ Multi-GPU GAN Training ENABLED: {n_gpus} GPUs ƒë·ªìng th·ªùi!")

In [None]:
# ==================== H√ÄM H·ªñ TR·ª¢ GAN (T·ªêI ∆ØU GPU + MULTI-GPU) ====================

def train_gan(generator: nn.Module, discriminator: nn.Module, 
              train_loader: DataLoader, val_loader: DataLoader,
              epochs: int, device: torch.device, lambda_recon: float = 100.0) -> tuple:
    """Hu·∫•n luy·ªán GAN cho denoising v·ªõi Adversarial Loss + Reconstruction Loss.
    
    T·ªëi ∆∞u h√≥a cho GPU:
    - Mixed Precision Training (FP16)
    - Multi-GPU support v·ªõi DataParallel
    - GPU memory efficient training
    """
    
    optim_g = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optim_d = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    
    loss_gan = nn.BCEWithLogitsLoss()
    loss_recon = nn.L1Loss()
    
    history = {'g_loss': [], 'd_loss': [], 'psnr_val': [], 'mse_val': [], 'best_psnr': 0}
    best_psnr = 0.0
    best_checkpoint_gan = None
    
    print("\n" + "="*60)
    print("üöÄ B·∫ÆT ƒê·∫¶U HU·∫§N LUY·ªÜN GAN")
    if use_multi_gpu:
        print(f"üéÆ Training v·ªõi {n_gpus} GPUs: {[torch.cuda.get_device_name(i) for i in range(n_gpus)]}")
    print("="*60)
    
    # Start timing
    time_start = time.time()
    
    for epoch in range(epochs):
        generator.train()
        discriminator.train()
        
        g_loss_epoch, d_loss_epoch = 0.0, 0.0
        so_batches = 0
        epoch_start = time.time()
        
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False) as pbar:
            for batch_idx, (img_nhieu, _) in enumerate(pbar):
                img_nhieu = img_nhieu.to(device, non_blocking=True)  # Non-blocking transfer
                img_sach = img_nhieu.clone()
                B = img_nhieu.size(0)
                
                # ============ C·∫¨P NH·∫¨T DISCRIMINATOR ============
                optim_d.zero_grad()
                
                # S·ª≠ d·ª•ng autocast n·∫øu mixed precision ƒë∆∞·ª£c b·∫≠t
                if use_amp:
                    with autocast():
                        # Loss discriminator tr√™n ·∫£nh th·∫≠t
                        label_thuc = torch.ones(B, 1, device=device)
                        output_thuc = discriminator(img_sach)
                        d_loss_thuc = loss_gan(output_thuc, label_thuc)
                        
                        # Loss discriminator tr√™n ·∫£nh gi·∫£
                        with torch.no_grad():
                            img_tao = generator(img_nhieu)
                        label_gia = torch.zeros(B, 1, device=device)
                        output_gia = discriminator(img_tao.detach())
                        d_loss_gia = loss_gan(output_gia, label_gia)
                        
                        d_loss = d_loss_thuc + d_loss_gia
                    
                    scaler.scale(d_loss).backward()
                    scaler.unscale_(optim_d)
                    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
                    scaler.step(optim_d)
                    scaler.update()
                else:
                    # Loss discriminator tr√™n ·∫£nh th·∫≠t - DataParallel t·ª± ƒë·ªông ph√¢n chia
                    label_thuc = torch.ones(B, 1, device=device)
                    output_thuc = discriminator(img_sach)
                    d_loss_thuc = loss_gan(output_thuc, label_thuc)
                    
                    # Loss discriminator tr√™n ·∫£nh gi·∫£
                    with torch.no_grad():
                        img_tao = generator(img_nhieu)
                    label_gia = torch.zeros(B, 1, device=device)
                    output_gia = discriminator(img_tao.detach())
                    d_loss_gia = loss_gan(output_gia, label_gia)
                    
                    d_loss = d_loss_thuc + d_loss_gia
                    d_loss.backward()
                    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
                    optim_d.step()
                
                # ============ C·∫¨P NH·∫¨T GENERATOR ============
                optim_g.zero_grad()
                
                if use_amp:
                    with autocast():
                        img_tao = generator(img_nhieu)
                        
                        # Adversarial loss
                        output_gia = discriminator(img_tao)
                        g_loss_adv = loss_gan(output_gia, label_thuc)
                        
                        # Reconstruction loss
                        g_loss_recon = loss_recon(img_tao, img_sach)
                        
                        # T·ªïng loss
                        g_loss = g_loss_adv + lambda_recon * g_loss_recon
                    
                    scaler.scale(g_loss).backward()
                    scaler.unscale_(optim_g)
                    torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
                    scaler.step(optim_g)
                    scaler.update()
                else:
                    img_tao = generator(img_nhieu)
                    
                    # Adversarial loss
                    output_gia = discriminator(img_tao)
                    g_loss_adv = loss_gan(output_gia, label_thuc)
                    
                    # Reconstruction loss
                    g_loss_recon = loss_recon(img_tao, img_sach)
                    
                    # T·ªïng loss
                    g_loss = g_loss_adv + lambda_recon * g_loss_recon
                    g_loss.backward()
                    torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
                    optim_g.step()
                
                g_loss_epoch += g_loss.item()
                d_loss_epoch += d_loss.item()
                so_batches += 1
                
                pbar.set_postfix({'G': f'{g_loss.item():.3f}', 'D': f'{d_loss.item():.3f}'})
        
        # Trung b√¨nh loss
        g_loss_epoch /= so_batches
        d_loss_epoch /= so_batches
        history['g_loss'].append(g_loss_epoch)
        history['d_loss'].append(d_loss_epoch)
        
        # ============ ƒê√ÅNH GI√Å VALIDATION ============
        generator.eval()
        psnr_vals, mse_vals = [], []
        
        with torch.no_grad():
            for img_nhieu, _ in val_loader:
                img_nhieu = img_nhieu.to(device, non_blocking=True)
                img_sach = img_nhieu.clone()
                img_tao = generator(img_nhieu)
                
                psnr_vals.append(psnr_tensor(img_tao, img_sach).item())
                mse_vals.append(mse_tensor(img_tao, img_sach).item())
        
        psnr_mean = np.mean(psnr_vals)
        mse_mean = np.mean(mse_vals)
        history['psnr_val'].append(psnr_mean)
        history['mse_val'].append(mse_mean)
        
        # T√≠nh th·ªùi gian epoch
        epoch_time = time.time() - epoch_start
        
        # L∆∞u checkpoint t·ªët nh·∫•t
        if psnr_mean > best_psnr:
            best_psnr = psnr_mean
            
            # L∆∞u state_dict - n·∫øu d√πng DataParallel, l∆∞u module.state_dict()
            if use_multi_gpu:
                best_checkpoint_gan = {
                    'generator': generator.module.state_dict(),
                    'discriminator': discriminator.module.state_dict(),
                    'epoch': epoch
                }
                torch.save(generator.module.state_dict(), duong_dan_checkpoint_gan_g)
                torch.save(discriminator.module.state_dict(), duong_dan_checkpoint_gan_d)
            else:
                best_checkpoint_gan = {
                    'generator': generator.state_dict(),
                    'discriminator': discriminator.state_dict(),
                    'epoch': epoch
                }
                torch.save(generator.state_dict(), duong_dan_checkpoint_gan_g)
                torch.save(discriminator.state_dict(), duong_dan_checkpoint_gan_d)
            
            history['best_psnr'] = best_psnr
            
            print(f"‚úÖ Epoch {epoch+1} ({epoch_time:.1f}s): G={g_loss_epoch:.4f}, D={d_loss_epoch:.4f}, "
                  f"PSNR={psnr_mean:.3f}dB [SAVED]")
        else:
            print(f"   Epoch {epoch+1} ({epoch_time:.1f}s): G={g_loss_epoch:.4f}, D={d_loss_epoch:.4f}, "
                  f"PSNR={psnr_mean:.3f}dB")
    
    if best_checkpoint_gan:
        if use_multi_gpu:
            generator.module.load_state_dict(best_checkpoint_gan['generator'])
            discriminator.module.load_state_dict(best_checkpoint_gan['discriminator'])
        else:
            generator.load_state_dict(best_checkpoint_gan['generator'])
            discriminator.load_state_dict(best_checkpoint_gan['discriminator'])
    
    total_time = time.time() - time_start
    print("="*60)
    print(f"üéâ HO√ÄN TH√ÄNH H·ªåC GAN")
    print(f"   Best PSNR: {best_psnr:.4f}")
    print(f"   Th·ªùi gian: {total_time/60:.1f} ph√∫t ({total_time/3600:.2f} gi·ªù)")
    print(f"   Thi·∫øt b·ªã: {device}")
    if use_multi_gpu:
        print(f"   üéÆ ƒê√£ s·ª≠ d·ª•ng {n_gpus} GPUs ƒë·ªìng th·ªùi!")
    print("="*60)
    
    return history, generator, discriminator

In [None]:
# ==================== CH·∫†Y HU·∫§N LUY·ªÜN GAN ====================

# G·ªçi h√†m train_gan ƒë·ªÉ hu·∫•n luy·ªán
history_gan, generator_gan, discriminator_gan = train_gan(
    generator_gan, 
    discriminator_gan, 
    train_loader, 
    val_loader, 
    epochs=so_epoch, 
    device=thiet_bi, 
    lambda_recon=100.0
)


In [None]:
# ==================== ƒê·ªí TH·ªä SO S√ÅNH H·ªåC HU·∫§N LUY·ªÜN ====================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('So s√°nh qu√° tr√¨nh hu·∫•n luy·ªán: Autoencoder vs GAN', fontsize=16, fontweight='bold')

# G Loss
axes[0, 0].plot(history_gan['g_loss'], 'b-', linewidth=2, label='Generator Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Generator Loss qua c√°c epoch')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].legend()

# D Loss
axes[0, 1].plot(history_gan['d_loss'], 'r-', linewidth=2, label='Discriminator Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Discriminator Loss qua c√°c epoch')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].legend()

# PSNR Comparison
axes[1, 0].plot(lich_su['psnr_val'], 'g-o', linewidth=2, label='Autoencoder', markersize=5)
axes[1, 0].plot(history_gan['psnr_val'], 'b-s', linewidth=2, label='GAN', markersize=5)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('PSNR (dB)')
axes[1, 0].set_title('PSNR tr√™n validation set')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].legend()

# MSE Comparison
axes[1, 1].plot(lich_su['mse_val'], 'g-o', linewidth=2, label='Autoencoder', markersize=5)
axes[1, 1].plot(history_gan['mse_val'], 'b-s', linewidth=2, label='GAN', markersize=5)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('MSE')
axes[1, 1].set_title('MSE tr√™n validation set')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].legend()

plt.tight_layout()
plt.savefig(thu_muc_ket_qua / 'comparison_training.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ ƒê√£ l∆∞u bi·ªÉu ƒë·ªì so s√°nh hu·∫•n luy·ªán")


In [None]:
# ==================== L∆ØU ·∫¢NH MINH H·ªåA T·ª™ MODEL GAN T·ªêT NH·∫§T ====================

print("\n" + "="*60)
print("üñºÔ∏è L∆ØU ·∫¢NH MINH H·ªåA T·ª™ GAN T·ªêT NH·∫§T")
print("="*60)

# L·∫•y 1 batch t·ª´ validation set ƒë·ªÉ minh ho·∫°
best_gan_sample, _ = next(iter(val_loader))
luu_minh_hoa(generator_gan, best_gan_sample, thiet_bi, loai_nhieu, do_manh_nhieu, thu_muc_ket_qua, tag="best_gan_generator")

print(f"‚úÖ ƒê√£ l∆∞u ·∫£nh GAN ch·∫•t l∆∞·ª£ng cao nh·∫•t (PSNR={history_gan['best_psnr']:.4f}dB)")
print("="*60)

# Clear cache
if use_cuda:
    torch.cuda.empty_cache()
    print("\nüíæ GPU cache ƒë√£ ƒë∆∞·ª£c x√≥a")


In [None]:
# ==================== B·∫¢NG SO S√ÅNH CHUY√äN S√ÇU√ÇU ====================

# T√≠nh ƒë·ªô ƒëo tr√™n to√†n b·ªô validation set
def evaluate_model(model: nn.Module, val_loader: DataLoader, device: torch.device) -> dict:
    """ƒê√°nh gi√° chi ti·∫øt m√¥ h√¨nh tr√™n validation set."""
    model.eval()
    psnr_vals, mse_vals, l1_vals = [], [], []
    
    with torch.no_grad():
        for img_clean, _ in val_loader:
            img_clean = img_clean.to(device)
            img_noisy = them_nhieu(img_clean, loai_nhieu, do_manh_nhieu).to(device, non_blocking=True)
            img_denoised = model(img_noisy)
            
            psnr_vals.append(psnr_tensor(img_denoised, img_clean).item())
            mse_vals.append(mse_tensor(img_denoised, img_clean).item())
            l1_vals.append(torch.nn.functional.l1_loss(img_denoised, img_clean).item())
    
    return {
        'psnr_mean': np.mean(psnr_vals),
        'psnr_std': np.std(psnr_vals),
        'mse_mean': np.mean(mse_vals),
        'mse_std': np.std(mse_vals),
        'l1_mean': np.mean(l1_vals),
        'l1_std': np.std(l1_vals),
    }

# ƒê√°nh gi√° Autoencoder
metrics_ae = evaluate_model(mo_hinh, val_loader, thiet_bi)

# ƒê√°nh gi√° GAN
metrics_gan = evaluate_model(generator_gan, val_loader, thiet_bi)

# T·∫°o b·∫£ng so s√°nh
print("\n" + "="*80)
print(" "*20 + "üìä B·∫¢NG SO S√ÅNH AUTOENCODER vs GAN")
print("="*80)
print(f"{'Ch·ªâ s·ªë':25} {'Autoencoder':^25} {'GAN':^25}")
print("-"*80)
print(f"{'PSNR (dB)':25} {metrics_ae['psnr_mean']:>10.4f} ¬± {metrics_ae['psnr_std']:>6.4f}   |   "
      f"{metrics_gan['psnr_mean']:>10.4f} ¬± {metrics_gan['psnr_std']:>6.4f}")
print(f"{'MSE':25} {metrics_ae['mse_mean']:>10.6f} ¬± {metrics_ae['mse_std']:>6.6f}   |   "
      f"{metrics_gan['mse_mean']:>10.6f} ¬± {metrics_gan['mse_std']:>6.6f}")
print(f"{'L1 Loss':25} {metrics_ae['l1_mean']:>10.6f} ¬± {metrics_ae['l1_std']:>6.6f}   |   "
      f"{metrics_gan['l1_mean']:>10.6f} ¬± {metrics_gan['l1_std']:>6.6f}")
print("="*80)

# X√°c ƒë·ªãnh m√¥ h√¨nh t·ªët h∆°n
if metrics_gan['psnr_mean'] > metrics_ae['psnr_mean']:
    print(f"üèÜ K·∫æT LU·∫¨N: GAN v∆∞·ª£t tr·ªôi h∆°n v·ªõi PSNR cao h∆°n {metrics_gan['psnr_mean'] - metrics_ae['psnr_mean']:.4f} dB")
else:
    print(f"üèÜ K·∫æT LU·∫¨N: Autoencoder v∆∞·ª£t tr·ªôi h∆°n v·ªõi PSNR cao h∆°n {metrics_ae['psnr_mean'] - metrics_gan['psnr_mean']:.4f} dB")

print(f"\nüìà T·∫•t c·∫£ ch·ªâ s·ªë chi ti·∫øt:")
print(f"   Autoencoder: {metrics_ae}")
print(f"   GAN:         {metrics_gan}")


In [None]:
# ==================== H√åNH ·∫¢NH SO S√ÅNH TR·ª∞C QUAN ====================

def visualize_comparison(val_loader: DataLoader, autoencoder: nn.Module, 
                        generator: nn.Module, device: torch.device, n_samples: int = 4):
    """So s√°nh tr·ª±c quan k·∫øt qu·∫£ Autoencoder vs GAN."""
    
    autoencoder.eval()
    generator.eval()
    
    fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4*n_samples))
    fig.suptitle('So s√°nh k·∫øt qu·∫£ Autoencoder vs GAN tr√™n validation set', 
                 fontsize=14, fontweight='bold', y=0.995)
    
    sample_count = 0
    
    with torch.no_grad():
        for img_clean, _ in val_loader:
            if sample_count >= n_samples:
                break
            
            img_clean = img_clean.to(device)
            img_noisy = them_nhieu(img_clean, loai_nhieu, do_manh_nhieu).to(device, non_blocking=True)
            
            for i in range(img_clean.size(0)):
                if sample_count >= n_samples:
                    break
                
                # Forward pass
                ae_output = autoencoder(img_noisy[i:i+1])
                gan_output = generator(img_noisy[i:i+1])
                
                # T√≠nh PSNR
                psnr_ae = psnr_tensor(ae_output, img_clean[i:i+1]).item()
                psnr_gan = psnr_tensor(gan_output, img_clean[i:i+1]).item()
                
                # Chuy·ªÉn sang NumPy ƒë·ªÉ hi·ªÉn th·ªã
                def to_display(tensor):
                    return tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
                
                img_noisy_np = to_display(img_noisy[i:i+1])
                img_clean_np = to_display(img_clean[i:i+1])
                ae_np = to_display(ae_output)
                gan_np = to_display(gan_output)
                
                # H√†ng (i)
                row = sample_count
                
                # C·ªôt 0: ·∫¢nh nhi·ªÖu
                axes[row, 0].imshow(img_noisy_np)
                axes[row, 0].set_title(f'·∫¢nh nhi·ªÖu\n({loai_nhieu}, {do_manh_nhieu:.2f})', fontsize=10)
                axes[row, 0].axis('off')
                
                # C·ªôt 1: ·∫¢nh g·ªëc
                axes[row, 1].imshow(img_clean_np)
                axes[row, 1].set_title('·∫¢nh g·ªëc (Ground Truth)', fontsize=10)
                axes[row, 1].axis('off')
                
                # C·ªôt 2: K·∫øt qu·∫£ Autoencoder
                axes[row, 2].imshow(ae_np)
                axes[row, 2].set_title(f'Autoencoder\nPSNR={psnr_ae:.2f}dB', 
                                      fontsize=10, fontweight='bold', color='green')
                axes[row, 2].axis('off')
                
                # C·ªôt 3: K·∫øt qu·∫£ GAN
                axes[row, 3].imshow(gan_np)
                axes[row, 3].set_title(f'GAN\nPSNR={psnr_gan:.2f}dB', 
                                      fontsize=10, fontweight='bold', color='blue')
                axes[row, 3].axis('off')
                
                sample_count += 1
    
    plt.tight_layout()
    plt.savefig(thu_muc_ket_qua / 'ae_vs_gan_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"‚úÖ ƒê√£ l∆∞u h√¨nh ·∫£nh so s√°nh (ƒë√£ so s√°nh {sample_count} m·∫´u)")

# Hi·ªÉn th·ªã so s√°nh
visualize_comparison(val_loader, mo_hinh, generator_gan, thiet_bi, n_samples=4)
