In [1]:
!pip install yacs

Collecting yacs
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Installing collected packages: yacs
Successfully installed yacs-0.1.8


In [2]:
!pip install warmup_scheduler

Collecting warmup_scheduler
  Downloading warmup_scheduler-0.3.tar.gz (2.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: warmup_scheduler
  Building wheel for warmup_scheduler (setup.py) ... [?25l[?25hdone
  Created wheel for warmup_scheduler: filename=warmup_scheduler-0.3-py3-none-any.whl size=2971 sha256=9a981412c9bdd2fb647fda50ebab7f7cc1b3cec91ba7d1e279e3c9059dfa7476
  Stored in directory: /root/.cache/pip/wheels/cc/5c/3b/6e5033100e0e4191383dad5c4279638a37f9791d1af9e1d85c
Successfully built warmup_scheduler
Installing collected packages: warmup_scheduler
Successfully installed warmup_scheduler-0.3


In [3]:
# ================================
# IMPORTS ET CONFIGURATION
# ================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import os
import time
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import sys

print("‚úÖ Imports termin√©s")

‚úÖ Imports termin√©s


In [4]:
# ================================
# FONCTIONS DE PERTE
# ================================

class CharbonnierLoss(nn.Module):
    """Perte Charbonnier (L1 smooth)"""
    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
        return loss

class EdgeLoss(nn.Module):
    """Perte bas√©e sur les contours (Laplacien)"""
    def __init__(self):
        super(EdgeLoss, self).__init__()
        k = torch.Tensor([[.05, .25, .4, .25, .05]])
        self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
        if torch.cuda.is_available():
            self.kernel = self.kernel.cuda()
        self.loss = CharbonnierLoss()

    def conv_gauss(self, img):
        n_channels, _, kw, kh = self.kernel.shape
        img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
        return F.conv2d(img, self.kernel, groups=n_channels)

    def laplacian_kernel(self, current):
        filtered    = self.conv_gauss(current)
        down        = filtered[:,:,::2,::2]
        new_filter  = torch.zeros_like(filtered)
        new_filter[:,:,::2,::2] = down*4
        filtered    = self.conv_gauss(new_filter)
        diff = current - filtered
        return diff

    def forward(self, x, y):
        loss1 = self.loss(x, y)
        loss2 = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
        return loss1 + 0.1 * loss2

print("‚úÖ Fonctions de perte d√©finies")

‚úÖ Fonctions de perte d√©finies


In [5]:
# ================================
# ARCHITECTURE MPRNet L√âG√àRE
# ================================

class ChannelAttentionBlock(nn.Module):
    """Channel Attention Block simplifi√©"""
    def __init__(self, channels, reduction=4):
        super(ChannelAttentionBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        out = avg_out + max_out
        return x * self.sigmoid(out).view(b, c, 1, 1)

class LightweightImprovedMPRNet(nn.Module):
    """MPRNet ultra-l√©ger pour Kaggle"""
    def __init__(self, n_feat=8, scale_unetfeats=4, scale_orsnetfeats=2, num_cab=1):
        super(LightweightImprovedMPRNet, self).__init__()
        
        # Param√®tres ultra-minimaux
        self.n_feat = n_feat
        self.scale_unetfeats = scale_unetfeats
        self.scale_orsnetfeats = scale_orsnetfeats
        self.num_cab = num_cab
        
        # Shallow feature extraction
        self.shallow_feat = nn.Sequential(
            nn.Conv2d(3, n_feat, kernel_size=3, padding=1),
            nn.Conv2d(n_feat, n_feat, kernel_size=3, padding=1)
        )
        
        # Stage 1 - Encoder-Decoder
        self.stage1_encoder = nn.Sequential(
            nn.Conv2d(n_feat, n_feat*2, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(n_feat*2, n_feat*4, 4, 2, 1),
            nn.LeakyReLU(0.2)
        )
        
        self.stage1_decoder = nn.Sequential(
            nn.ConvTranspose2d(n_feat*4, n_feat*2, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(n_feat*2, n_feat, 4, 2, 1),
            nn.LeakyReLU(0.2)
        )
        
        # Stage 2 - Original Resolution
        self.stage2 = nn.Sequential(
            nn.Conv2d(n_feat, n_feat, 3, 1, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(n_feat, n_feat, 3, 1, 1),
            nn.LeakyReLU(0.2)
        )
        
        # Channel Attention
        self.cab = ChannelAttentionBlock(n_feat)
        
        # Output
        self.output = nn.Conv2d(n_feat, 3, 3, 1, 1)
        
    def forward(self, x):
        # Shallow features
        shallow = self.shallow_feat(x)
        
        # Stage 1
        stage1_enc = self.stage1_encoder(shallow)
        stage1_dec = self.stage1_decoder(stage1_enc)
        
        # Fusion
        fused = shallow + stage1_dec
        
        # Stage 2
        stage2_out = self.stage2(fused)
        
        # Attention
        attended = self.cab(stage2_out)
        
        # Output
        output = self.output(attended)
        
        return output

print("‚úÖ Architecture MPRNet l√©g√®re d√©finie")

‚úÖ Architecture MPRNet l√©g√®re d√©finie


In [6]:
# ================================
# FONCTIONS UTILITAIRES
# ================================

def calculate_psnr_ssim(pred, target):
    """Calculer PSNR et SSIM pour un batch d'images"""
    pred_np = pred.detach().cpu().numpy()
    target_np = target.detach().cpu().numpy()
    
    psnr_values = []
    ssim_values = []
    
    for i in range(pred_np.shape[0]):
        # Convertir de [-1, 1] √† [0, 1]
        pred_img = (pred_np[i].transpose(1, 2, 0) + 1) / 2
        target_img = (target_np[i].transpose(1, 2, 0) + 1) / 2
        
        # Clamper les valeurs
        pred_img = np.clip(pred_img, 0, 1)
        target_img = np.clip(target_img, 0, 1)
        
        # Calculer PSNR
        psnr_val = psnr(target_img, pred_img, data_range=1.0)
        psnr_values.append(psnr_val)
        
        # Calculer SSIM avec param√®tres adapt√©s
        try:
            # D√©terminer la taille de fen√™tre appropri√©e
            min_dim = min(pred_img.shape[0], pred_img.shape[1])
            win_size = min(7, min_dim) if min_dim >= 7 else min_dim
            if win_size % 2 == 0:
                win_size -= 1  # SSIM n√©cessite une taille impaire
            
            # Calculer SSIM avec channel_axis au lieu de multichannel
            ssim_val = ssim(target_img, pred_img, 
                           win_size=win_size, 
                           channel_axis=2, 
                           data_range=1.0)
        except Exception as e:
            print(f"Erreur SSIM pour image {i}: {e}")
            ssim_val = 0.0  # Valeur par d√©faut en cas d'erreur
        
        ssim_values.append(ssim_val)
    
    return np.mean(psnr_values), np.mean(ssim_values)

print("‚úÖ Fonctions utilitaires d√©finies")

‚úÖ Fonctions utilitaires d√©finies


In [7]:
# ================================
# DATASET GOPRO POUR KAGGLE
# ================================

class GoProDataset(Dataset):
    """Dataset GoPro pour Kaggle"""
    def __init__(self, root_dir, patch_size=64, is_training=True):
        self.root_dir = root_dir
        self.patch_size = patch_size
        self.is_training = is_training
        
        # Chemins des dossiers
        self.input_dir = os.path.join(root_dir, 'input')
        self.target_dir = os.path.join(root_dir, 'target')
        
        # V√©rifier que les dossiers existent
        if not os.path.exists(self.input_dir):
            raise ValueError(f"Dossier input non trouv√©: {self.input_dir}")
        if not os.path.exists(self.target_dir):
            raise ValueError(f"Dossier target non trouv√©: {self.target_dir}")
        
        # Lister les fichiers
        self.input_files = sorted([f for f in os.listdir(self.input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
        self.target_files = sorted([f for f in os.listdir(self.target_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
        
        # V√©rifier que les fichiers correspondent
        if len(self.input_files) != len(self.target_files):
            print(f"‚ö†Ô∏è  Nombre de fichiers diff√©rent: input={len(self.input_files)}, target={len(self.target_files)}")
        
        # Prendre le minimum pour √©viter les erreurs
        self.num_files = min(len(self.input_files), len(self.target_files))
        
        # Transformations
        if is_training:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
    
    def __len__(self):
        return self.num_files
    
    def __getitem__(self, idx):
        # Charger les images
        input_path = os.path.join(self.input_dir, self.input_files[idx])
        target_path = os.path.join(self.target_dir, self.target_files[idx])
        
        try:
            input_img = Image.open(input_path).convert('RGB')
            target_img = Image.open(target_path).convert('RGB')
            
            # Redimensionner si n√©cessaire
            if self.patch_size > 0:
                input_img = input_img.resize((self.patch_size, self.patch_size), Image.LANCZOS)
                target_img = target_img.resize((self.patch_size, self.patch_size), Image.LANCZOS)
            
            # Appliquer les transformations
            input_tensor = self.transform(input_img)
            target_tensor = self.transform(target_img)
            
            return input_tensor, target_tensor
            
        except Exception as e:
            print(f"Erreur lors du chargement des images {idx}: {e}")
            # Retourner des images vides en cas d'erreur
            empty_img = torch.zeros(3, self.patch_size, self.patch_size)
            return empty_img, empty_img

print("‚úÖ Dataset GoPro pour Kaggle d√©fini")

‚úÖ Dataset GoPro pour Kaggle d√©fini


In [8]:
# ================================
# FONCTION D'ENTRA√éNEMENT PRINCIPALE
# ================================

def train_improved_model_100_epochs():
    """Entra√Ænement MPRNet am√©lior√© avec 100 √©poques"""
    
    print("üöÄ ENTRA√éNEMENT MPRNet AM√âLIOR√â - 100 √âPOQUES")
    print("=" * 60)
    
    # Configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"üñ•Ô∏è  Device: {device}")
    
    # Configuration optimis√©e pour Kaggle
    batch_size = 16
    num_epochs = 100
    lr_initial = 2e-4
    lr_min = 1e-6
    train_ps = 256
    val_ps = 256
    
    # Mod√®le optimis√© pour grandes images
    model = LightweightImprovedMPRNet(
        n_feat=16,  # L√©g√®rement augment√© pour 256x256
        scale_unetfeats=8,
        scale_orsnetfeats=4,
        num_cab=2
    ).to(device)
    
    # Multi-GPU si disponible
    if torch.cuda.device_count() > 1 and batch_size > 1:
        print(f"üöÄ Utilisation de {torch.cuda.device_count()} GPU")
        model = torch.nn.DataParallel(model)
    
    # Optimizer et scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr_initial, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=lr_min)
    
    # Loss functions
    charb_loss = CharbonnierLoss().to(device)
    edge_loss = EdgeLoss().to(device)
    
    def combined_loss(pred, target):
        charb = charb_loss(pred, target)
        edge = edge_loss(pred, target)
        return charb + 0.05 * edge
    
    # Datasets GoPro
    train_dir = '/kaggle/input/gopro-training'
    val_dir = '/kaggle/input/gopro-training'
    
    print(f"üìÅ Chargement des datasets GoPro...")
    print(f"   Train: {train_dir}")
    print(f"   Val: {val_dir}")
    
    train_dataset = GoProDataset(train_dir, patch_size=train_ps, is_training=True)
    val_dataset = GoProDataset(val_dir, patch_size=val_ps, is_training=False)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    
    print(f"üìà Dataset GoPro charg√©:")
    print(f"   Train: {len(train_dataset)} images")
    print(f"   Val: {len(val_dataset)} images")
    print(f"üîß Batch size: {batch_size}, Epochs: {num_epochs}")
    print(f"üñºÔ∏è  Patch size: {train_ps}x{train_ps}")
    
    # Variables de suivi
    best_psnr = 0.0
    best_ssim = 0.0
    best_epoch = 0
    train_losses = []
    val_psnrs = []
    val_ssims = []
    validation_epochs = []
    
    print(f"\nüèãÔ∏è D√âBUT DE L'ENTRA√éNEMENT - VALIDATION CHAQUE 10 √âPOQUES")
    print("=" * 60)
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        # ===== ENTRA√éNEMENT =====
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_idx, (input_imgs, target_imgs) in enumerate(train_loader):
            input_imgs = input_imgs.to(device, non_blocking=True)
            target_imgs = target_imgs.to(device, non_blocking=True)
            
            # Forward pass
            optimizer.zero_grad()
            output = model(input_imgs)
            pred_imgs = output[0] if isinstance(output, list) else output
            
            # Loss
            loss = combined_loss(pred_imgs, target_imgs)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
            
            # Affichage progressif
            if batch_idx % 50 == 0:
                print(f"Epoch {epoch+1:3d}/{num_epochs} | Batch {batch_idx:4d} | Loss: {loss.item():.6f}")
        
        avg_loss = epoch_loss / num_batches
        train_losses.append(avg_loss)
        
        # ===== VALIDATION (toutes les 10 √©poques) =====
        if (epoch + 1) % 10 == 0 or epoch == 0:
            model.eval()
            val_psnr_sum = 0.0
            val_ssim_sum = 0.0
            val_batches = 0
            
            print(f"\nüîç VALIDATION - √âpoque {epoch+1}")
            print("-" * 50)
            
            with torch.no_grad():
                for input_imgs, target_imgs in val_loader:
                    input_imgs = input_imgs.to(device, non_blocking=True)
                    target_imgs = target_imgs.to(device, non_blocking=True)
                    
                    output = model(input_imgs)
                    pred_imgs = output[0] if isinstance(output, list) else output
                    
                    # Calculer PSNR et SSIM
                    psnr_val, ssim_val = calculate_psnr_ssim(pred_imgs, target_imgs)
                    val_psnr_sum += psnr_val
                    val_ssim_sum += ssim_val
                    val_batches += 1
            
            avg_psnr = val_psnr_sum / val_batches
            avg_ssim = val_ssim_sum / val_batches
            
            val_psnrs.append(avg_psnr)
            val_ssims.append(avg_ssim)
            validation_epochs.append(epoch + 1)
            
            # V√©rifier si c'est le meilleur mod√®le
            is_best = avg_psnr > best_psnr
            if is_best:
                best_psnr = avg_psnr
                best_ssim = avg_ssim
                best_epoch = epoch + 1
                
                # Sauvegarder le meilleur mod√®le
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_psnr': best_psnr,
                    'best_ssim': best_ssim,
                    'best_epoch': best_epoch,
                    'train_losses': train_losses,
                    'val_psnrs': val_psnrs,
                    'val_ssims': val_ssims,
                    'validation_epochs': validation_epochs
                }, 'best_mprnet_model.pth')
                
                print(f"üèÜ NOUVEAU MEILLEUR MOD√àLE!")
                print(f"   PSNR: {best_psnr:.4f} dB (√âpoque {best_epoch})")
                print(f"   SSIM: {best_ssim:.4f}")
                print(f"   ÔøΩÔøΩ Sauvegard√©: best_mprnet_model.pth")
            else:
                print(f"üìä R√©sultats actuels:")
                print(f"   PSNR: {avg_psnr:.4f} dB (Meilleur: {best_psnr:.4f} @ √âpoque {best_epoch})")
                print(f"   SSIM: {avg_ssim:.4f} (Meilleur: {best_ssim:.4f})")
            
            # Affichage d√©taill√©
            current_lr = optimizer.param_groups[0]['lr']
            elapsed_time = time.time() - start_time
            
            print(f"\nüìä D√âTAILS DE L'√âPOQUE {epoch+1:3d}/{num_epochs}")
            print(f"   Loss d'entra√Ænement: {avg_loss:.6f}")
            print(f"   Learning rate: {current_lr:.2e}")
            print(f"   Temps √©coul√©: {elapsed_time/60:.1f} minutes")
            print(f"   Progr√®s: {((epoch+1)/num_epochs)*100:.1f}%")
            print("=" * 60)
        
        # Mettre √† jour le scheduler
        scheduler.step()
        
        # Nettoyer la m√©moire
        torch.cuda.empty_cache()
    
    # ===== SAUVEGARDE FINALE =====
    total_time = time.time() - start_time
    
    # Sauvegarder le mod√®le final
    final_model_path = 'mprnet_final_model.pth'
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_psnr': best_psnr,
        'best_ssim': best_ssim,
        'best_epoch': best_epoch,
        'train_losses': train_losses,
        'val_psnrs': val_psnrs,
        'val_ssims': val_ssims,
        'validation_epochs': validation_epochs,
        'config': {
            'batch_size': batch_size,
            'num_epochs': num_epochs,
            'lr_initial': lr_initial,
            'lr_min': lr_min,
            'train_ps': train_ps,
            'val_ps': val_ps
        }
    }, final_model_path)
    
    # ===== R√âSULTATS FINAUX =====
    print("\nüéâ ENTRA√éNEMENT TERMIN√â!")
    print("=" * 60)
    print(f"‚è±Ô∏è  Temps total: {total_time/60:.1f} minutes")
    print(f"üèÜ MEILLEUR MOD√àLE (√âpoque {best_epoch}):")
    print(f"   PSNR: {best_psnr:.4f} dB")
    print(f"   SSIM: {best_ssim:.4f}")
    print(f"üíæ Fichiers sauvegard√©s:")
    print(f"   - Meilleur mod√®le: best_mprnet_model.pth")
    print(f"   - Mod√®le final: {final_model_path}")
    
    # Statistiques d'entra√Ænement
    print(f"\nüìä STATISTIQUES D'ENTRA√éNEMENT:")
    print(f"   √âpoques totales: {num_epochs}")
    print(f"   Validations effectu√©es: {len(validation_epochs)}")
    print(f"   √âpoques de validation: {validation_epochs}")
    print(f"   Batches par √©poque: {len(train_loader)}")
    print(f"   Total de batches: {num_epochs * len(train_loader)}")
    print(f"   Loss finale: {train_losses[-1]:.6f}")
    
    # √âvolution des performances
    print(f"\nüìà √âVOLUTION DES PERFORMANCES:")
    for i, (ep, psnr, ssim) in enumerate(zip(validation_epochs, val_psnrs, val_ssims)):
        marker = "üèÜ" if ep == best_epoch else "  "
        print(f"   {marker} √âpoque {ep:3d}: PSNR {psnr:.4f} dB, SSIM {ssim:.4f}")
    
    return model, best_psnr, best_ssim, best_epoch

print("‚úÖ Fonction d'entra√Ænement d√©finie")

‚úÖ Fonction d'entra√Ænement d√©finie


In [9]:
# ================================
# LANCEMENT DE L'ENTRA√éNEMENT
# ================================

print("üéØ LANCEMENT DE L'ENTRA√éNEMENT MPRNet AM√âLIOR√â")
print("=" * 60)

# Lancer l'entra√Ænement
trained_model, best_psnr, best_ssim, best_epoch = train_improved_model_100_epochs()

print(f"\n‚úÖ Entra√Ænement termin√© avec succ√®s!")
print(f"üèÜ Meilleures performances: PSNR {best_psnr:.4f} dB, SSIM {best_ssim:.4f}")
print(f"üéØ Meilleur mod√®le obtenu √† l'√©poque {best_epoch}")
print(f"üìÅ Mod√®les sauvegard√©s pour t√©l√©chargement")

üéØ LANCEMENT DE L'ENTRA√éNEMENT MPRNet AM√âLIOR√â
üöÄ ENTRA√éNEMENT MPRNet AM√âLIOR√â - 100 √âPOQUES
üñ•Ô∏è  Device: cuda
üöÄ Utilisation de 2 GPU
üìÅ Chargement des datasets GoPro...
   Train: /kaggle/input/gopro-training
   Val: /kaggle/input/gopro-training
üìà Dataset GoPro charg√©:
   Train: 2103 images
   Val: 2103 images
üîß Batch size: 16, Epochs: 100
üñºÔ∏è  Patch size: 256x256

üèãÔ∏è D√âBUT DE L'ENTRA√éNEMENT - VALIDATION CHAQUE 10 √âPOQUES
Epoch   1/100 | Batch    0 | Loss: 0.428007
Epoch   1/100 | Batch   50 | Loss: 0.159673
Epoch   1/100 | Batch  100 | Loss: 0.133924

üîç VALIDATION - √âpoque 1
--------------------------------------------------
üèÜ NOUVEAU MEILLEUR MOD√àLE!
   PSNR: 20.9043 dB (√âpoque 1)
   SSIM: 0.6524
   ÔøΩÔøΩ Sauvegard√©: best_mprnet_model.pth

üìä D√âTAILS DE L'√âPOQUE   1/100
   Loss d'entra√Ænement: 0.193564
   Learning rate: 2.00e-04
   Temps √©coul√©: 4.3 minutes
   Progr√®s: 1.0%
Epoch   2/100 | Batch    0 | Loss: 0.140871
Epoch   2