<a href="https://colab.research.google.com/github/moetezwiw/projet-PNEUMONIA/blob/main/ganlast.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
# Sauvegarder votre script
%%writefile cyclegan.py

"""
CycleGAN pour Navigation C√©leste : Real (Classe A) ‚Üî Stellarium (Classe B)
Version optimis√©e pour Colab
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
import os
from itertools import chain
from PIL import Image
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from datetime import datetime

# ============================================================================
# CBAM: Convolutional Block Attention Module
# ============================================================================

class ChannelAttention(nn.Module):
    """Attention sur les canaux (QUELS features sont importants)"""
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return x * self.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
    """Attention spatiale (O√ô sont les features - localisation des √©toiles)"""
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        return x * self.sigmoid(self.conv(out))


class CBAM(nn.Module):
    """Module CBAM complet"""
    def __init__(self, channels, reduction=16, kernel_size=7):
        super().__init__()
        self.channel_att = ChannelAttention(channels, reduction)
        self.spatial_att = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_att(x)
        x = self.spatial_att(x)
        return x


# ============================================================================
# ResNet Block avec CBAM
# ============================================================================

class ResnetBlockCBAM(nn.Module):
    """Bloc r√©siduel avec attention CBAM"""
    def __init__(self, dim, use_cbam=True, use_dropout=False):
        super().__init__()
        conv_block = [
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, bias=False),
            nn.InstanceNorm2d(dim),
            nn.ReLU(inplace=True)
        ]

        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        conv_block += [
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, bias=False),
            nn.InstanceNorm2d(dim)
        ]

        self.conv_block = nn.Sequential(*conv_block)
        self.use_cbam = use_cbam
        if use_cbam:
            self.cbam = CBAM(dim)

    def forward(self, x):
        residual = x
        out = self.conv_block(x)
        if self.use_cbam:
            out = self.cbam(out)
        return residual + out


# ============================================================================
# G√©n√©rateur ResNet avec CBAM
# ============================================================================

class GeneratorResNetCBAM(nn.Module):
    """
    G√©n√©rateur : Encoder -> ResBlocks avec CBAM -> Decoder
    Real (Classe A) -> Stellarium (Classe B) et vice-versa
    """
    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9, use_cbam=True):
        super().__init__()

        # Encoder initial
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, bias=False),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                         stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(ngf * mult * 2),
                nn.ReLU(inplace=True)
            ]

        # ResNet blocks avec CBAM
        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlockCBAM(ngf * mult, use_cbam=use_cbam)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [
                nn.ConvTranspose2d(ngf * mult, ngf * mult // 2,
                                  kernel_size=3, stride=2,
                                  padding=1, output_padding=1, bias=False),
                nn.InstanceNorm2d(ngf * mult // 2),
                nn.ReLU(inplace=True)
            ]

        # Couche finale
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# ============================================================================
# Discriminateur PatchGAN
# ============================================================================

class PatchGANDiscriminator(nn.Module):
    """
    Discriminateur PatchGAN (70x70 receptive field)
    Classe des patches comme r√©els/faux pour un feedback local
    """
    def __init__(self, input_nc=3, ndf=64, n_layers=3):
        super().__init__()

        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            model += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                         kernel_size=4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        model += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                     kernel_size=4, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        model += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# ============================================================================
# Dataset pour Classe A (Real) et Classe B (Stellarium)
# ============================================================================

class CelestialDataset(Dataset):
    """
    Dataset pour navigation c√©leste
    Classe A : Images r√©elles du ciel nocturne
    Classe B : Images synth√©tiques Stellarium
    """
    def __init__(self, root_classA, root_classB, transform=None, max_images=None):
        self.root_classA = Path(root_classA)
        self.root_classB = Path(root_classB)
        self.transform = transform

        # Extensions support√©es
        extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']

        # Lister les fichiers Classe A (Real)
        self.classA_images = []
        for ext in extensions:
            self.classA_images.extend(list(self.root_classA.glob(ext)))
        self.classA_images = sorted(self.classA_images)

        # Lister les fichiers Classe B (Stellarium)
        self.classB_images = []
        for ext in extensions:
            self.classB_images.extend(list(self.root_classB.glob(ext)))
        self.classB_images = sorted(self.classB_images)

        # Limiter le nombre d'images si sp√©cifi√©
        if max_images:
            self.classA_images = self.classA_images[:max_images]
            self.classB_images = self.classB_images[:max_images]

        print(f"üìÇ Dataset charg√©:")
        print(f"   Classe A (Real): {len(self.classA_images)} images")
        print(f"   Classe B (Stellarium): {len(self.classB_images)} images")

        if len(self.classA_images) == 0 or len(self.classB_images) == 0:
            raise ValueError("‚ö†Ô∏è  Aucune image trouv√©e! V√©rifiez les chemins des dossiers.")

    def __len__(self):
        return max(len(self.classA_images), len(self.classB_images))

    def __getitem__(self, idx):
        # Lecture cyclique si les tailles sont diff√©rentes
        classA_img = Image.open(
            self.classA_images[idx % len(self.classA_images)]
        ).convert('RGB')

        classB_img = Image.open(
            self.classB_images[idx % len(self.classB_images)]
        ).convert('RGB')

        if self.transform:
            classA_img = self.transform(classA_img)
            classB_img = self.transform(classB_img)

        return classA_img, classB_img


# ============================================================================
# Utilitaires
# ============================================================================

def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialisation des poids"""
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or
                                     classname.find('Linear') != -1):
            if init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
            if hasattr(m, 'weight') and m.weight is not None:
                nn.init.normal_(m.weight.data, 1.0, init_gain)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)

    net.apply(init_func)


class ImagePool:
    """Buffer d'images pour stabiliser l'entra√Ænement du discriminateur"""
    def __init__(self, pool_size=50):
        self.pool_size = pool_size
        self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images

        return_images = []
        for image in images:
            image = image.unsqueeze(0)
            if len(self.images) < self.pool_size:
                self.images.append(image)
                return_images.append(image)
            else:
                p = np.random.uniform(0, 1)
                if p > 0.5:
                    random_id = np.random.randint(0, self.pool_size)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)

        return torch.cat(return_images, 0)


# ============================================================================
# Trainer CycleGAN
# ============================================================================

class CycleGANTrainer:
    """Entra√Ænement CycleGAN pour navigation c√©leste"""
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"üñ•Ô∏è  Device: {self.device}")

        # Cr√©er r√©pertoires
        os.makedirs(config['output_dir'], exist_ok=True)
        os.makedirs(f"{config['output_dir']}/samples", exist_ok=True)
        os.makedirs(f"{config['output_dir']}/checkpoints", exist_ok=True)

        # Historique des pertes
        self.history = {
            'G_loss': [], 'D_loss': [], 'cycle_loss': [],
            'identity_loss': [], 'GAN_loss': []
        }

        # Initialiser mod√®les
        self._build_models()
        self._setup_optimizers()

        # Crit√®res de perte
        self.criterion_GAN = nn.MSELoss()  # LSGAN
        self.criterion_cycle = nn.L1Loss()
        self.criterion_identity = nn.L1Loss()

        # Pools d'images
        self.fake_real_pool = ImagePool(config['pool_size'])
        self.fake_sim_pool = ImagePool(config['pool_size'])

    def _build_models(self):
        """Construire les g√©n√©rateurs et discriminateurs"""
        config = self.config

        print("üî® Construction des mod√®les...")

        # G√©n√©rateurs
        self.G_A2B = GeneratorResNetCBAM(
            input_nc=3, output_nc=3, ngf=config['ngf'],
            n_blocks=config['n_blocks'], use_cbam=config['use_cbam']
        ).to(self.device)

        self.G_B2A = GeneratorResNetCBAM(
            input_nc=3, output_nc=3, ngf=config['ngf'],
            n_blocks=config['n_blocks'], use_cbam=config['use_cbam']
        ).to(self.device)

        # Discriminateurs
        self.D_A = PatchGANDiscriminator(
            input_nc=3, ndf=config['ndf'], n_layers=3
        ).to(self.device)

        self.D_B = PatchGANDiscriminator(
            input_nc=3, ndf=config['ndf'], n_layers=3
        ).to(self.device)

        # Initialiser poids
        init_weights(self.G_A2B, config['init_type'])
        init_weights(self.G_B2A, config['init_type'])
        init_weights(self.D_A, config['init_type'])
        init_weights(self.D_B, config['init_type'])

        # Compter param√®tres
        g_params = sum(p.numel() for p in self.G_A2B.parameters())
        d_params = sum(p.numel() for p in self.D_A.parameters())
        print(f"   G_A2B: {g_params:,} param√®tres")
        print(f"   D_A: {d_params:,} param√®tres")

    def _setup_optimizers(self):
        """Configurer optimiseurs et schedulers"""
        config = self.config

        self.optimizer_G = Adam(
            chain(self.G_A2B.parameters(), self.G_B2A.parameters()),
            lr=config['lr'], betas=(config['beta1'], 0.999)
        )

        self.optimizer_D = Adam(
            chain(self.D_A.parameters(), self.D_B.parameters()),
            lr=config['lr'], betas=(config['beta1'], 0.999)
        )

        # Learning rate scheduler (decay lin√©aire)
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch - config['n_epochs_decay']) / (
                config['n_epochs'] - config['n_epochs_decay'] + 1
            )
            return lr_l

        self.scheduler_G = lr_scheduler.LambdaLR(
            self.optimizer_G, lr_lambda=lambda_rule
        )
        self.scheduler_D = lr_scheduler.LambdaLR(
            self.optimizer_D, lr_lambda=lambda_rule
        )

    def train_epoch(self, dataloader, epoch):
        """Entra√Æner une √©poque"""
        self.G_A2B.train()
        self.G_B2A.train()
        self.D_A.train()
        self.D_B.train()

        losses = {'G': [], 'D': [], 'cycle': [], 'identity': [], 'GAN': []}

        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{self.config['n_epochs']}")

        for i, (real_A, real_B) in enumerate(pbar):
            real_A = real_A.to(self.device)
            real_B = real_B.to(self.device)
            batch_size = real_A.size(0)

            # Labels adversariaux
            valid = torch.ones((batch_size, 1, 30, 30), device=self.device)
            fake = torch.zeros((batch_size, 1, 30, 30), device=self.device)

            # ========================================
            # Entra√Æner G√©n√©rateurs
            # ========================================
            self.optimizer_G.zero_grad()

            # Identity loss (pr√©server identit√© si m√™me domaine)
            loss_id = 0
            if self.config['lambda_identity'] > 0:
                id_A = self.G_B2A(real_A)  # Devrait retourner real_A
                id_B = self.G_A2B(real_B)  # Devrait retourner real_B
                loss_id_A = self.criterion_identity(id_A, real_A)
                loss_id_B = self.criterion_identity(id_B, real_B)
                loss_id = (loss_id_A + loss_id_B) * 0.5

            # GAN loss
            fake_B = self.G_A2B(real_A)  # Real -> Stellarium
            pred_fake_B = self.D_B(fake_B)
            loss_GAN_A2B = self.criterion_GAN(pred_fake_B, valid)

            fake_A = self.G_B2A(real_B)  # Stellarium -> Real
            pred_fake_A = self.D_A(fake_A)
            loss_GAN_B2A = self.criterion_GAN(pred_fake_A, valid)

            loss_GAN = (loss_GAN_A2B + loss_GAN_B2A) * 0.5

            # Cycle consistency loss (A -> B -> A et B -> A -> B)
            recovered_A = self.G_B2A(fake_B)
            loss_cycle_A = self.criterion_cycle(recovered_A, real_A)

            recovered_B = self.G_A2B(fake_A)
            loss_cycle_B = self.criterion_cycle(recovered_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) * 0.5

            # Total generator loss
            loss_G = (loss_GAN +
                     self.config['lambda_cycle'] * loss_cycle +
                     self.config['lambda_identity'] * loss_id)

            loss_G.backward()
            self.optimizer_G.step()

            # ========================================
            # Entra√Æner Discriminateurs
            # ========================================
            self.optimizer_D.zero_grad()

            # D_A (discrimine classe A)
            pred_real_A = self.D_A(real_A)
            loss_D_A_real = self.criterion_GAN(pred_real_A, valid)

            fake_A_ = self.fake_real_pool.query(fake_A.detach())
            pred_fake_A = self.D_A(fake_A_)
            loss_D_A_fake = self.criterion_GAN(pred_fake_A, fake)

            loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5

            # D_B (discrimine classe B)
            pred_real_B = self.D_B(real_B)
            loss_D_B_real = self.criterion_GAN(pred_real_B, valid)

            fake_B_ = self.fake_sim_pool.query(fake_B.detach())
            pred_fake_B = self.D_B(fake_B_)
            loss_D_B_fake = self.criterion_GAN(pred_fake_B, fake)

            loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5

            loss_D = (loss_D_A + loss_D_B) * 0.5

            loss_D.backward()
            self.optimizer_D.step()

            # Sauvegarder pertes
            losses['G'].append(loss_G.item())
            losses['D'].append(loss_D.item())
            losses['cycle'].append(loss_cycle.item())
            losses['identity'].append(loss_id.item() if loss_id != 0 else 0)
            losses['GAN'].append(loss_GAN.item())

            # Mise √† jour barre de progression
            pbar.set_postfix({
                'G': f"{loss_G.item():.3f}",
                'D': f"{loss_D.item():.3f}",
                'Cycle': f"{loss_cycle.item():.3f}"
            })

            # Sauvegarder √©chantillons
            if i % self.config['sample_interval'] == 0:
                self.save_samples(epoch, i, real_A, real_B, fake_B, fake_A)

        # Moyennes
        avg_losses = {k: np.mean(v) for k, v in losses.items()}

        return avg_losses

    def save_samples(self, epoch, batch, real_A, real_B, fake_B, fake_A):
        """Sauvegarder √©chantillons de traduction"""
        samples_dir = f"{self.config['output_dir']}/samples"

        n_samples = min(4, real_A.size(0))

        # Cr√©er grille: Real A | Fake B | Real B | Fake A
        comparison = torch.cat([
            real_A[:n_samples],
            fake_B[:n_samples],
            real_B[:n_samples],
            fake_A[:n_samples]
        ], dim=0)

        save_image(comparison,
                  f"{samples_dir}/epoch_{epoch:03d}_batch_{batch:04d}.png",
                  nrow=n_samples, normalize=True, value_range=(-1, 1))

    def save_checkpoint(self, epoch, is_best=False):
        """Sauvegarder checkpoint"""
        checkpoint_dir = f"{self.config['output_dir']}/checkpoints"

        checkpoint = {
            'epoch': epoch,
            'G_A2B': self.G_A2B.state_dict(),
            'G_B2A': self.G_B2A.state_dict(),
            'D_A': self.D_A.state_dict(),
            'D_B': self.D_B.state_dict(),
            'optimizer_G': self.optimizer_G.state_dict(),
            'optimizer_D': self.optimizer_D.state_dict(),
            'history': self.history,
            'config': self.config
        }

        # Sauvegarder dernier checkpoint
        torch.save(checkpoint, f"{checkpoint_dir}/latest.pth")

        # Sauvegarder checkpoint p√©riodique
        if epoch % self.config['checkpoint_interval'] == 0:
            torch.save(checkpoint, f"{checkpoint_dir}/epoch_{epoch:03d}.pth")

        # Sauvegarder meilleur mod√®le
        if is_best:
            torch.save(checkpoint, f"{checkpoint_dir}/best.pth")

        print(f"   üíæ Checkpoint sauvegard√© (epoch {epoch})")

    def plot_losses(self):
        """Tracer courbes de perte"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle('Historique d\'entra√Ænement CycleGAN')

        axes[0, 0].plot(self.history['G_loss'])
        axes[0, 0].set_title('Generator Loss')
        axes[0, 0].set_xlabel('Epoch')

        axes[0, 1].plot(self.history['D_loss'])
        axes[0, 1].set_title('Discriminator Loss')
        axes[0, 1].set_xlabel('Epoch')

        axes[1, 0].plot(self.history['cycle_loss'])
        axes[1, 0].set_title('Cycle Consistency Loss')
        axes[1, 0].set_xlabel('Epoch')

        axes[1, 1].plot(self.history['GAN_loss'])
        axes[1, 1].set_title('GAN Loss')
        axes[1, 1].set_xlabel('Epoch')

        plt.tight_layout()
        plt.savefig(f"{self.config['output_dir']}/training_curves.png", dpi=150)
        plt.close()

    def train(self, dataloader):
        """Boucle d'entra√Ænement principale"""
        print("\n" + "="*70)
        print(f"üöÄ D√©marrage entra√Ænement CycleGAN Navigation C√©leste")
        print(f"   Epochs: {self.config['n_epochs']}")
        print(f"   Batch size: {self.config['batch_size']}")
        print(f"   CBAM: {self.config['use_cbam']}")
        print(f"   Device: {self.device}")
        print("="*70 + "\n")

        best_cycle_loss = float('inf')

        for epoch in range(1, self.config['n_epochs'] + 1):
            avg_losses = self.train_epoch(dataloader, epoch)

            # Sauvegarder dans historique
            self.history['G_loss'].append(avg_losses['G'])
            self.history['D_loss'].append(avg_losses['D'])
            self.history['cycle_loss'].append(avg_losses['cycle'])
            self.history['identity_loss'].append(avg_losses['identity'])
            self.history['GAN_loss'].append(avg_losses['GAN'])

            # Afficher r√©sum√©
            print(f"\nüìä Epoch {epoch} - Pertes moyennes:")
            print(f"   G: {avg_losses['G']:.4f} | D: {avg_losses['D']:.4f}")
            print(f"   Cycle: {avg_losses['cycle']:.4f} | GAN: {avg_losses['GAN']:.4f}")

            # Mettre √† jour learning rates
            self.scheduler_G.step()
            self.scheduler_D.step()

            # Sauvegarder checkpoint
            is_best = avg_losses['cycle'] < best_cycle_loss
            if is_best:
                best_cycle_loss = avg_losses['cycle']

            self.save_checkpoint(epoch, is_best)

            # Tracer courbes
            if epoch % 5 == 0:
                self.plot_losses()

        # Sauvegarder historique
        with open(f"{self.config['output_dir']}/history.json", 'w') as f:
            json.dump(self.history, f, indent=2)

        print("\n‚úÖ Entra√Ænement termin√©!")
        print(f"üìÅ R√©sultats sauvegard√©s dans: {self.config['output_dir']}")


# ============================================================================
# Configuration
# ============================================================================

def get_config(classA_dir='/content/drive/MyDrive/entainement/classeA', classB_dir='/content/drive/MyDrive/entainement/classeB'):
    """Configuration d'entra√Ænement"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    return {
        # Chemins des donn√©es
        'classA_dir': classA_dir,  # Images r√©elles
        'classB_dir': classB_dir,  # Images Stellarium
        'output_dir': f'./outputs/cyclegan_{timestamp}',

        # Architecture
        'use_cbam': True,
        'ngf': 64,  # Filtres g√©n√©rateur
        'ndf': 64,  # Filtres discriminateur
        'n_blocks': 9,  # Nombre de ResBlocks
        'init_type': 'normal',

        # Hyperparam√®tres
        'n_epochs': 100,
        'n_epochs_decay': 50,  # Decay LR apr√®s cette √©poque
        'batch_size': 1,
        'lr': 0.0002,
        'beta1': 0.5,

        # Poids des pertes
        'lambda_cycle': 10.0,
        'lambda_identity': 0.5,

        # Entra√Ænement
        'pool_size': 50,
        'num_workers': 4,
        'sample_interval': 100,
        'checkpoint_interval': 10,

        # Images
        'img_size': 256,
        'max_images': None,  # None = toutes les images
    }


# ============================================================================
# Script principal
# ============================================================================

def main():
    """Script d'entra√Ænement principal"""
    print("="*70)
    print("   CycleGAN - Navigation C√©leste")
    print("   Real (Classe A) ‚Üî Stellarium (Classe B)")
    print("="*70 + "\n")

    # Configuration
    config = get_config(
        classA_dir='./classeA',  # üìÅ Dossier images r√©elles
        classB_dir='./classeB'   # üìÅ Dossier images Stellarium
    )

    # V√©rifier dossiers
    if not os.path.exists(config['classA_dir']):
        print(f"‚ùå Erreur: Dossier '{config['classA_dir']}' introuvable!")
        print("   Cr√©ez le dossier et placez-y vos images r√©elles.")
        return

    if not os.path.exists(config['classB_dir']):
        print(f"‚ùå Erreur: Dossier '{config['classB_dir']}' introuvable!")
        print("   Cr√©ez le dossier et placez-y vos images Stellarium.")
        return

    # Transforms avec augmentation
    transform = transforms.Compose([
        transforms.Resize((config['img_size'], config['img_size'])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Cr√©er dataset
    print("üìÇ Chargement du dataset...")
    try:
        dataset = CelestialDataset(
            root_classA=config['classA_dir'],
            root_classB=config['classB_dir'],
            transform=transform,
            max_images=config['max_images']
        )
    except ValueError as e:
        print(f"‚ùå {e}")
        return

    # DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )

    print(f"‚úÖ DataLoader cr√©√©: {len(dataloader)} batches\n")

    # Cr√©er trainer et entra√Æner
    trainer = CycleGANTrainer(config)
    trainer.train(dataloader)


if __name__ == "__main__":
    main()

Writing cyclegan.py


In [None]:
%%writefile script.py
"""
CycleGAN pour Navigation C√©leste : Real (Classe A) ‚Üî Stellarium (Classe B)
Version optimis√©e et pr√™te √† l'ex√©cution
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
import os
from itertools import chain
from PIL import Image
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from datetime import datetime

# ============================================================================
# CBAM: Convolutional Block Attention Module
# ============================================================================

class ChannelAttention(nn.Module):
    """Attention sur les canaux (QUELS features sont importants)"""
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return x * self.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
    """Attention spatiale (O√ô sont les features - localisation des √©toiles)"""
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        return x * self.sigmoid(self.conv(out))


class CBAM(nn.Module):
    """Module CBAM complet"""
    def __init__(self, channels, reduction=16, kernel_size=7):
        super().__init__()
        self.channel_att = ChannelAttention(channels, reduction)
        self.spatial_att = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_att(x)
        x = self.spatial_att(x)
        return x


# ============================================================================
# ResNet Block avec CBAM
# ============================================================================

class ResnetBlockCBAM(nn.Module):
    """Bloc r√©siduel avec attention CBAM"""
    def __init__(self, dim, use_cbam=True, use_dropout=False):
        super().__init__()
        conv_block = [
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, bias=False),
            nn.InstanceNorm2d(dim),
            nn.ReLU(inplace=True)
        ]

        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        conv_block += [
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, bias=False),
            nn.InstanceNorm2d(dim)
        ]

        self.conv_block = nn.Sequential(*conv_block)
        self.use_cbam = use_cbam
        if use_cbam:
            self.cbam = CBAM(dim)

    def forward(self, x):
        residual = x
        out = self.conv_block(x)
        if self.use_cbam:
            out = self.cbam(out)
        return residual + out


# ============================================================================
# G√©n√©rateur ResNet avec CBAM
# ============================================================================

class GeneratorResNetCBAM(nn.Module):
    """
    G√©n√©rateur : Encoder -> ResBlocks avec CBAM -> Decoder
    Real (Classe A) -> Stellarium (Classe B) et vice-versa
    """
    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9, use_cbam=True):
        super().__init__()

        # Encoder initial
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, bias=False),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                         stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(ngf * mult * 2),
                nn.ReLU(inplace=True)
            ]

        # ResNet blocks avec CBAM
        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlockCBAM(ngf * mult, use_cbam=use_cbam)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [
                nn.ConvTranspose2d(ngf * mult, ngf * mult // 2,
                                  kernel_size=3, stride=2,
                                  padding=1, output_padding=1, bias=False),
                nn.InstanceNorm2d(ngf * mult // 2),
                nn.ReLU(inplace=True)
            ]

        # Couche finale
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# ============================================================================
# Discriminateur PatchGAN
# ============================================================================

class PatchGANDiscriminator(nn.Module):
    """
    Discriminateur PatchGAN (70x70 receptive field)
    Classe des patches comme r√©els/faux pour un feedback local
    """
    def __init__(self, input_nc=3, ndf=64, n_layers=3):
        super().__init__()

        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            model += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                         kernel_size=4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        model += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                     kernel_size=4, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        model += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# ============================================================================
# Dataset pour Classe A (Real) et Classe B (Stellarium)
# ============================================================================

class CelestialDataset(Dataset):
    """
    Dataset pour navigation c√©leste
    Classe A : Images r√©elles du ciel nocturne
    Classe B : Images synth√©tiques Stellarium
    """
    def __init__(self, root_classA, root_classB, transform=None, max_images=None):
        self.root_classA = Path(root_classA)
        self.root_classB = Path(root_classB)
        self.transform = transform

        # Extensions support√©es
        extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']

        # Lister les fichiers Classe A (Real)
        self.classA_images = []
        for ext in extensions:
            self.classA_images.extend(list(self.root_classA.glob(ext)))
        self.classA_images = sorted(self.classA_images)

        # Lister les fichiers Classe B (Stellarium)
        self.classB_images = []
        for ext in extensions:
            self.classB_images.extend(list(self.root_classB.glob(ext)))
        self.classB_images = sorted(self.classB_images)

        # Limiter le nombre d'images si sp√©cifi√©
        if max_images:
            self.classA_images = self.classA_images[:max_images]
            self.classB_images = self.classB_images[:max_images]

        print(f"üìÇ Dataset charg√©:")
        print(f"   Classe A (Real): {len(self.classA_images)} images")
        print(f"   Classe B (Stellarium): {len(self.classB_images)} images")

        if len(self.classA_images) == 0 or len(self.classB_images) == 0:
            raise ValueError("‚ö†Ô∏è  Aucune image trouv√©e! V√©rifiez les chemins des dossiers.")

    def __len__(self):
        return max(len(self.classA_images), len(self.classB_images))

    def __getitem__(self, idx):
        # Lecture cyclique si les tailles sont diff√©rentes
        classA_img = Image.open(
            self.classA_images[idx % len(self.classA_images)]
        ).convert('RGB')

        classB_img = Image.open(
            self.classB_images[idx % len(self.classB_images)]
        ).convert('RGB')

        if self.transform:
            classA_img = self.transform(classA_img)
            classB_img = self.transform(classB_img)

        return classA_img, classB_img


# ============================================================================
# Utilitaires
# ============================================================================

def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialisation des poids"""
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or
                                     classname.find('Linear') != -1):
            if init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
            if hasattr(m, 'weight') and m.weight is not None:
                nn.init.normal_(m.weight.data, 1.0, init_gain)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)

    net.apply(init_func)


class ImagePool:
    """Buffer d'images pour stabiliser l'entra√Ænement du discriminateur"""
    def __init__(self, pool_size=50):
        self.pool_size = pool_size
        self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images

        return_images = []
        for image in images:
            image = image.unsqueeze(0)
            if len(self.images) < self.pool_size:
                self.images.append(image)
                return_images.append(image)
            else:
                p = np.random.uniform(0, 1)
                if p > 0.5:
                    random_id = np.random.randint(0, self.pool_size)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)

        return torch.cat(return_images, 0)


# ============================================================================
# Trainer CycleGAN
# ============================================================================

class CycleGANTrainer:
    """Entra√Ænement CycleGAN pour navigation c√©leste"""
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"üñ•Ô∏è  Device: {self.device}")

        # Cr√©er r√©pertoires
        os.makedirs(config['output_dir'], exist_ok=True)
        os.makedirs(f"{config['output_dir']}/samples", exist_ok=True)
        os.makedirs(f"{config['output_dir']}/checkpoints", exist_ok=True)

        # Historique des pertes
        self.history = {
            'G_loss': [], 'D_loss': [], 'cycle_loss': [],
            'identity_loss': [], 'GAN_loss': []
        }

        # Initialiser mod√®les
        self._build_models()
        self._setup_optimizers()

        # Crit√®res de perte
        self.criterion_GAN = nn.MSELoss()  # LSGAN
        self.criterion_cycle = nn.L1Loss()
        self.criterion_identity = nn.L1Loss()

        # Pools d'images
        self.fake_real_pool = ImagePool(config['pool_size'])
        self.fake_sim_pool = ImagePool(config['pool_size'])

    def _build_models(self):
        """Construire les g√©n√©rateurs et discriminateurs"""
        config = self.config

        print("üî® Construction des mod√®les...")

        # G√©n√©rateurs
        self.G_A2B = GeneratorResNetCBAM(
            input_nc=3, output_nc=3, ngf=config['ngf'],
            n_blocks=config['n_blocks'], use_cbam=config['use_cbam']
        ).to(self.device)

        self.G_B2A = GeneratorResNetCBAM(
            input_nc=3, output_nc=3, ngf=config['ngf'],
            n_blocks=config['n_blocks'], use_cbam=config['use_cbam']
        ).to(self.device)

        # Discriminateurs
        self.D_A = PatchGANDiscriminator(
            input_nc=3, ndf=config['ndf'], n_layers=3
        ).to(self.device)

        self.D_B = PatchGANDiscriminator(
            input_nc=3, ndf=config['ndf'], n_layers=3
        ).to(self.device)

        # Initialiser poids
        init_weights(self.G_A2B, config['init_type'])
        init_weights(self.G_B2A, config['init_type'])
        init_weights(self.D_A, config['init_type'])
        init_weights(self.D_B, config['init_type'])

        # Compter param√®tres
        g_params = sum(p.numel() for p in self.G_A2B.parameters())
        d_params = sum(p.numel() for p in self.D_A.parameters())
        print(f"   G_A2B: {g_params:,} param√®tres")
        print(f"   D_A: {d_params:,} param√®tres")

    def _setup_optimizers(self):
        """Configurer optimiseurs et schedulers"""
        config = self.config

        self.optimizer_G = Adam(
            chain(self.G_A2B.parameters(), self.G_B2A.parameters()),
            lr=config['lr'], betas=(config['beta1'], 0.999)
        )

        self.optimizer_D = Adam(
            chain(self.D_A.parameters(), self.D_B.parameters()),
            lr=config['lr'], betas=(config['beta1'], 0.999)
        )

        # Learning rate scheduler (decay lin√©aire)
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch - config['n_epochs_decay']) / (
                config['n_epochs'] - config['n_epochs_decay'] + 1
            )
            return lr_l

        self.scheduler_G = lr_scheduler.LambdaLR(
            self.optimizer_G, lr_lambda=lambda_rule
        )
        self.scheduler_D = lr_scheduler.LambdaLR(
            self.optimizer_D, lr_lambda=lambda_rule
        )

    def train_epoch(self, dataloader, epoch):
        """Entra√Æner une √©poque"""
        self.G_A2B.train()
        self.G_B2A.train()
        self.D_A.train()
        self.D_B.train()

        losses = {'G': [], 'D': [], 'cycle': [], 'identity': [], 'GAN': []}

        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{self.config['n_epochs']}")

        for i, (real_A, real_B) in enumerate(pbar):
            real_A = real_A.to(self.device)
            real_B = real_B.to(self.device)
            batch_size = real_A.size(0)

            # Labels adversariaux
            valid = torch.ones((batch_size, 1, 30, 30), device=self.device)
            fake = torch.zeros((batch_size, 1, 30, 30), device=self.device)

            # ========================================
            # Entra√Æner G√©n√©rateurs
            # ========================================
            self.optimizer_G.zero_grad()

            # Identity loss (pr√©server identit√© si m√™me domaine)
            loss_id = 0
            if self.config['lambda_identity'] > 0:
                id_A = self.G_B2A(real_A)  # Devrait retourner real_A
                id_B = self.G_A2B(real_B)  # Devrait retourner real_B
                loss_id_A = self.criterion_identity(id_A, real_A)
                loss_id_B = self.criterion_identity(id_B, real_B)
                loss_id = (loss_id_A + loss_id_B) * 0.5

            # GAN loss
            fake_B = self.G_A2B(real_A)  # Real -> Stellarium
            pred_fake_B = self.D_B(fake_B)
            loss_GAN_A2B = self.criterion_GAN(pred_fake_B, valid)

            fake_A = self.G_B2A(real_B)  # Stellarium -> Real
            pred_fake_A = self.D_A(fake_A)
            loss_GAN_B2A = self.criterion_GAN(pred_fake_A, valid)

            loss_GAN = (loss_GAN_A2B + loss_GAN_B2A) * 0.5

            # Cycle consistency loss (A -> B -> A et B -> A -> B)
            recovered_A = self.G_B2A(fake_B)
            loss_cycle_A = self.criterion_cycle(recovered_A, real_A)

            recovered_B = self.G_A2B(fake_A)
            loss_cycle_B = self.criterion_cycle(recovered_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) * 0.5

            # Total generator loss
            loss_G = (loss_GAN +
                     self.config['lambda_cycle'] * loss_cycle +
                     self.config['lambda_identity'] * loss_id)

            loss_G.backward()
            self.optimizer_G.step()

            # ========================================
            # Entra√Æner Discriminateurs
            # ========================================
            self.optimizer_D.zero_grad()

            # D_A (discrimine classe A)
            pred_real_A = self.D_A(real_A)
            loss_D_A_real = self.criterion_GAN(pred_real_A, valid)

            fake_A_ = self.fake_real_pool.query(fake_A.detach())
            pred_fake_A = self.D_A(fake_A_)
            loss_D_A_fake = self.criterion_GAN(pred_fake_A, fake)

            loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5

            # D_B (discrimine classe B)
            pred_real_B = self.D_B(real_B)
            loss_D_B_real = self.criterion_GAN(pred_real_B, valid)

            fake_B_ = self.fake_sim_pool.query(fake_B.detach())
            pred_fake_B = self.D_B(fake_B_)
            loss_D_B_fake = self.criterion_GAN(pred_fake_B, fake)

            loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5

            loss_D = (loss_D_A + loss_D_B) * 0.5

            loss_D.backward()
            self.optimizer_D.step()

            # Sauvegarder pertes
            losses['G'].append(loss_G.item())
            losses['D'].append(loss_D.item())
            losses['cycle'].append(loss_cycle.item())
            losses['identity'].append(loss_id.item() if loss_id != 0 else 0)
            losses['GAN'].append(loss_GAN.item())

            # Mise √† jour barre de progression
            pbar.set_postfix({
                'G': f"{loss_G.item():.3f}",
                'D': f"{loss_D.item():.3f}",
                'Cycle': f"{loss_cycle.item():.3f}"
            })

            # Sauvegarder √©chantillons
            if i % self.config['sample_interval'] == 0:
                self.save_samples(epoch, i, real_A, real_B, fake_B, fake_A)

        # Moyennes
        avg_losses = {k: np.mean(v) for k, v in losses.items()}

        return avg_losses

    def save_samples(self, epoch, batch, real_A, real_B, fake_B, fake_A):
        """Sauvegarder √©chantillons de traduction"""
        samples_dir = f"{self.config['output_dir']}/samples"

        n_samples = min(4, real_A.size(0))

        # Cr√©er grille: Real A | Fake B | Real B | Fake A
        comparison = torch.cat([
            real_A[:n_samples],
            fake_B[:n_samples],
            real_B[:n_samples],
            fake_A[:n_samples]
        ], dim=0)

        save_image(comparison,
                  f"{samples_dir}/epoch_{epoch:03d}_batch_{batch:04d}.png",
                  nrow=n_samples, normalize=True, value_range=(-1, 1))

    def save_checkpoint(self, epoch, is_best=False):
        """Sauvegarder checkpoint"""
        checkpoint_dir = f"{self.config['output_dir']}/checkpoints"

        checkpoint = {
            'epoch': epoch,
            'G_A2B': self.G_A2B.state_dict(),
            'G_B2A': self.G_B2A.state_dict(),
            'D_A': self.D_A.state_dict(),
            'D_B': self.D_B.state_dict(),
            'optimizer_G': self.optimizer_G.state_dict(),
            'optimizer_D': self.optimizer_D.state_dict(),
            'history': self.history,
            'config': self.config
        }

        # Sauvegarder dernier checkpoint
        torch.save(checkpoint, f"{checkpoint_dir}/latest.pth")

        # Sauvegarder checkpoint p√©riodique
        if epoch % self.config['checkpoint_interval'] == 0:
            torch.save(checkpoint, f"{checkpoint_dir}/epoch_{epoch:03d}.pth")

        # Sauvegarder meilleur mod√®le
        if is_best:
            torch.save(checkpoint, f"{checkpoint_dir}/best.pth")

        print(f"   üíæ Checkpoint sauvegard√© (epoch {epoch})")

    def plot_losses(self):
        """Tracer courbes de perte"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        fig.suptitle('Historique d\'entra√Ænement CycleGAN')

        axes[0, 0].plot(self.history['G_loss'])
        axes[0, 0].set_title('Generator Loss')
        axes[0, 0].set_xlabel('Epoch')

        axes[0, 1].plot(self.history['D_loss'])
        axes[0, 1].set_title('Discriminator Loss')
        axes[0, 1].set_xlabel('Epoch')

        axes[1, 0].plot(self.history['cycle_loss'])
        axes[1, 0].set_title('Cycle Consistency Loss')
        axes[1, 0].set_xlabel('Epoch')

        axes[1, 1].plot(self.history['GAN_loss'])
        axes[1, 1].set_title('GAN Loss')
        axes[1, 1].set_xlabel('Epoch')

        plt.tight_layout()
        plt.savefig(f"{self.config['output_dir']}/training_curves.png", dpi=150)
        plt.close()

    def train(self, dataloader):
        """Boucle d'entra√Ænement principale"""
        print("\n" + "="*70)
        print(f"üöÄ D√©marrage entra√Ænement CycleGAN Navigation C√©leste")
        print(f"   Epochs: {self.config['n_epochs']}")
        print(f"   Batch size: {self.config['batch_size']}")
        print(f"   CBAM: {self.config['use_cbam']}")
        print(f"   Device: {self.device}")
        print("="*70 + "\n")

        best_cycle_loss = float('inf')

        for epoch in range(1, self.config['n_epochs'] + 1):
            avg_losses = self.train_epoch(dataloader, epoch)

            # Sauvegarder dans historique
            self.history['G_loss'].append(avg_losses['G'])
            self.history['D_loss'].append(avg_losses['D'])
            self.history['cycle_loss'].append(avg_losses['cycle'])
            self.history['identity_loss'].append(avg_losses['identity'])
            self.history['GAN_loss'].append(avg_losses['GAN'])

            # Afficher r√©sum√©
            print(f"\nüìä Epoch {epoch} - Pertes moyennes:")
            print(f"   G: {avg_losses['G']:.4f} | D: {avg_losses['D']:.4f}")
            print(f"   Cycle: {avg_losses['cycle']:.4f} | GAN: {avg_losses['GAN']:.4f}")

            # Mettre √† jour learning rates
            self.scheduler_G.step()
            self.scheduler_D.step()

            # Sauvegarder checkpoint
            is_best = avg_losses['cycle'] < best_cycle_loss
            if is_best:
                best_cycle_loss = avg_losses['cycle']

            self.save_checkpoint(epoch, is_best)

            # Tracer courbes
            if epoch % 5 == 0:
                self.plot_losses()

        # Sauvegarder historique
        with open(f"{self.config['output_dir']}/history.json", 'w') as f:
            json.dump(self.history, f, indent=2)

        print("\n‚úÖ Entra√Ænement termin√©!")
        print(f"üìÅ R√©sultats sauvegard√©s dans: {self.config['output_dir']}")


# ============================================================================
# Configuration
# ============================================================================

def get_config(classA_dir='/content/drive/MyDrive/entainement/classeA', classB_dir='/content/drive/MyDrive/entainement/classeB'):
    """Configuration d'entra√Ænement"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    return {
        # Chemins des donn√©es
        'classA_dir': classA_dir,  # Images r√©elles
        'classB_dir': classB_dir,  # Images Stellarium
        'output_dir': f'./outputs/cyclegan_{timestamp}',

        # Architecture
        'use_cbam': True,
        'ngf': 64,  # Filtres g√©n√©rateur
        'ndf': 64,  # Filtres discriminateur
        'n_blocks': 9,  # Nombre de ResBlocks
        'init_type': 'normal',

        # Hyperparam√®tres
        'n_epochs': 100,
        'n_epochs_decay': 50,  # Decay LR apr√®s cette √©poque
        'batch_size': 1,
        'lr': 0.0002,
        'beta1': 0.5,

        # Poids des pertes
        'lambda_cycle': 10.0,
        'lambda_identity': 0.5,

        # Entra√Ænement
        'pool_size': 50,
        'num_workers': 4,
        'sample_interval': 100,
        'checkpoint_interval': 10,

        # Images
        'img_size': 256,
        'max_images': None,  # None = toutes les images
    }


# ============================================================================
# Inf√©rence
# ============================================================================

class CycleGANInference:
    """Inf√©rence : traduire nouvelles images"""
    def __init__(self, checkpoint_path, device='cuda'):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        print(f"üñ•Ô∏è  Chargement mod√®le sur {self.device}...")

        # Charger checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        # Cr√©er g√©n√©rateurs
        self.G_A2B = GeneratorResNetCBAM(
            input_nc=3, output_nc=3, ngf=64, n_blocks=9, use_cbam=True
        ).to(self.device)

        self.G_B2A = GeneratorResNetCBAM(
            input_nc=3, output_nc=3, ngf=64, n_blocks=9, use_cbam=True
        ).to(self.device)

        # Charger poids
        self.G_A2B.load_state_dict(checkpoint['G_A2B'])
        self.G_B2A.load_state_dict(checkpoint['G_B2A'])

        self.G_A2B.eval()
        self.G_B2A.eval()

        # Transforms
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        print("‚úÖ Mod√®le charg√© avec succ√®s!")

    def translate_A2B(self, image_path, output_path):
        """Traduire Real (A) -> Stellarium (B)"""
        img = Image.open(image_path).convert('RGB')
        img_tensor = self.transform(img).unsqueeze(0).to(self.device)

        with torch.no_grad():
            fake_B = self.G_A2B(img_tensor)

        save_image(fake_B, output_path, normalize=True, value_range=(-1, 1))
        print(f"‚úÖ Traduit: {output_path}")

    def translate_B2A(self, image_path, output_path):
        """Traduire Stellarium (B) -> Real (A)"""
        img = Image.open(image_path).convert('RGB')
        img_tensor = self.transform(img).unsqueeze(0).to(self.device)

        with torch.no_grad():
            fake_A = self.G_B2A(img_tensor)

        save_image(fake_A, output_path, normalize=True, value_range=(-1, 1))
        print(f"‚úÖ Traduit: {output_path}")

    def batch_translate(self, input_dir, output_dir, direction='A2B'):
        """Traduire un dossier d'images"""
        os.makedirs(output_dir, exist_ok=True)

        extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
        image_files = []
        for ext in extensions:
            image_files.extend(list(Path(input_dir).glob(ext)))

        print(f"üîÑ Traduction de {len(image_files)} images ({direction})...")

        for img_path in tqdm(image_files):
            output_path = Path(output_dir) / f"translated_{img_path.name}"

            if direction == 'A2B':
                self.translate_A2B(str(img_path), str(output_path))
            else:
                self.translate_B2A(str(img_path), str(output_path))

    def create_comparison(self, image_path, output_path, direction='A2B'):
        """Cr√©er image de comparaison c√¥te √† c√¥te"""
        img = Image.open(image_path).convert('RGB')
        img_tensor = self.transform(img).unsqueeze(0).to(self.device)

        with torch.no_grad():
            if direction == 'A2B':
                fake = self.G_A2B(img_tensor)
            else:
                fake = self.G_B2A(img_tensor)

        # Cr√©er comparaison
        comparison = torch.cat([img_tensor, fake], dim=3)
        save_image(comparison, output_path, normalize=True, value_range=(-1, 1))
        print(f"‚úÖ Comparaison cr√©√©e: {output_path}")


# ============================================================================
# Script principal
# ============================================================================

def main():
    """Script d'entra√Ænement principal"""
    print("="*70)
    print("   CycleGAN - Navigation C√©leste")
    print("   Real (Classe A) ‚Üî Stellarium (Classe B)")
    print("="*70 + "\n")

    # Configuration
    config = get_config(
        classA_dir='./classeA',  # üìÅ Dossier images r√©elles
        classB_dir='./classeB'   # üìÅ Dossier images Stellarium
    )

    # V√©rifier dossiers
    if not os.path.exists(config['classA_dir']):
        print(f"‚ùå Erreur: Dossier '{config['classA_dir']}' introuvable!")
        print("   Cr√©ez le dossier et placez-y vos images r√©elles.")
        return

    if not os.path.exists(config['classB_dir']):
        print(f"‚ùå Erreur: Dossier '{config['classB_dir']}' introuvable!")
        print("   Cr√©ez le dossier et placez-y vos images Stellarium.")
        return

    # Transforms avec augmentation
    transform = transforms.Compose([
        transforms.Resize((config['img_size'], config['img_size'])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Cr√©er dataset
    print("üìÇ Chargement du dataset...")
    try:
        dataset = CelestialDataset(
            root_classA=config['classA_dir'],
            root_classB=config['classB_dir'],
            transform=transform,
            max_images=config['max_images']
        )
    except ValueError as e:
        print(f"‚ùå {e}")
        return

    # DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )

    print(f"‚úÖ DataLoader cr√©√©: {len(dataloader)} batches\n")

    # Cr√©er trainer et entra√Æner
    trainer = CycleGANTrainer(config)
    trainer.train(dataloader)


# ============================================================================
# Fonctions utilitaires
# ============================================================================

def test_inference():
    """Tester l'inf√©rence avec un mod√®le entra√Æn√©"""
    print("üß™ Test d'inf√©rence...")

    # Chemin vers checkpoint (√† adapter)
    checkpoint = './outputs/cyclegan_XXXXXX/checkpoints/best.pth'

    if not os.path.exists(checkpoint):
        print(f"‚ùå Checkpoint introuvable: {checkpoint}")
        print("   Entra√Ænez d'abord le mod√®le avec main()")
        return

    # Cr√©er inf√©renceur
    inferencer = CycleGANInference(checkpoint)

    # Test sur une image
    test_image = './classeA/test_image.jpg'
    if os.path.exists(test_image):
        inferencer.create_comparison(
            test_image,
            './test_comparison.png',
            direction='A2B'
        )

    # Traduire un dossier complet
    inferencer.batch_translate(
        input_dir='./classeA',
        output_dir='./translated_to_stellarium',
        direction='A2B'
    )


def quick_test():
    """Test rapide avec donn√©es synth√©tiques"""
    print("üß™ Test rapide avec donn√©es synth√©tiques...\n")

    config = get_config()
    config['n_epochs'] = 2
    config['batch_size'] = 2
    config['checkpoint_interval'] = 1

    # Dataset synth√©tique
    class DummyDataset(Dataset):
        def __len__(self):
            return 20

        def __getitem__(self, idx):
            return torch.randn(3, 256, 256), torch.randn(3, 256, 256)

    dataloader = DataLoader(DummyDataset(), batch_size=2, shuffle=True)

    trainer = CycleGANTrainer(config)
    trainer.train(dataloader)

    print("\n‚úÖ Test rapide r√©ussi!")


def create_sample_structure():
    """Cr√©er structure de dossiers exemple"""
    print("üìÅ Cr√©ation de la structure de dossiers...\n")

    folders = [
        './classeA',  # Images r√©elles
        './classeB',  # Images Stellarium
        './outputs',
    ]

    for folder in folders:
        os.makedirs(folder, exist_ok=True)
        print(f"   ‚úì {folder}")

    print("\nüìã Instructions:")
    print("   1. Placez vos images R√âELLES du ciel dans: ./classeA/")
    print("   2. Placez vos images STELLARIUM dans: ./classeB/")
    print("   3. Lancez l'entra√Ænement avec: python script.py --mode train")
    print("\n   Formats support√©s: .jpg, .jpeg, .png")


# ============================================================================
# Point d'entr√©e
# ============================================================================

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description='CycleGAN pour Navigation C√©leste (Real ‚Üî Stellarium)'
    )
    parser.add_argument('--mode', type=str, default='train',
                       choices=['train', 'inference', 'test', 'setup'],
                       help='Mode d\'ex√©cution')
    parser.add_argument('--classA', type=str, default='./classeA',
                       help='Dossier images r√©elles (Classe A)')
    parser.add_argument('--classB', type=str, default='./classeB',
                       help='Dossier images Stellarium (Classe B)')
    parser.add_argument('--checkpoint', type=str, default=None,
                       help='Chemin checkpoint pour inf√©rence')
    parser.add_argument('--input', type=str, default=None,
                       help='Image/dossier d\'entr√©e pour inf√©rence')
    parser.add_argument('--output', type=str, default='./translated',
                       help='Dossier de sortie pour inf√©rence')
    parser.add_argument('--direction', type=str, default='A2B',
                       choices=['A2B', 'B2A'],
                       help='Direction traduction: A2B (Real->Stellarium) ou B2A')
    parser.add_argument('--epochs', type=int, default=100,
                       help='Nombre d\'√©poques')
    parser.add_argument('--batch_size', type=int, default=1,
                       help='Taille du batch')

    args = parser.parse_args()

    if args.mode == 'setup':
        # Cr√©er structure de dossiers
        create_sample_structure()

    elif args.mode == 'train':
        # Entra√Ænement
        config = get_config(args.classA, args.classB)
        config['n_epochs'] = args.epochs
        config['batch_size'] = args.batch_size

        # Transforms
        transform = transforms.Compose([
            transforms.Resize((config['img_size'], config['img_size'])),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        try:
            dataset = CelestialDataset(
                root_classA=config['classA_dir'],
                root_classB=config['classB_dir'],
                transform=transform,
                max_images=config['max_images']
            )

            dataloader = DataLoader(
                dataset,
                batch_size=config['batch_size'],
                shuffle=True,
                num_workers=config['num_workers'],
                pin_memory=True,
                drop_last=True
            )

            trainer = CycleGANTrainer(config)
            trainer.train(dataloader)

        except Exception as e:
            print(f"‚ùå Erreur: {e}")

    elif args.mode == 'inference':
        # Inf√©rence
        if not args.checkpoint:
            print("‚ùå Erreur: --checkpoint requis pour l'inf√©rence")
        elif not args.input:
            print("‚ùå Erreur: --input requis pour l'inf√©rence")
        else:
            inferencer = CycleGANInference(args.checkpoint)

            if os.path.isdir(args.input):
                # Traduire dossier
                inferencer.batch_translate(
                    args.input,
                    args.output,
                    direction=args.direction
                )
            else:
                # Traduire image unique
                os.makedirs(args.output, exist_ok=True)
                output_path = os.path.join(args.output, 'translated.png')

                if args.direction == 'A2B':
                    inferencer.translate_A2B(args.input, output_path)
                else:
                    inferencer.translate_B2A(args.input, output_path)

    elif args.mode == 'test':
        # Test rapide
        quick_test()

Writing script.py


In [6]:
import shutil
import os

source = "/content/drive/MyDrive/entainement/classeB"
destination = "/content/classeB"

# Cr√©er le dossier destination s'il n'existe pas
os.makedirs(destination, exist_ok=True)

# Copier tous les fichiers
for filename in os.listdir(source):
    source_path = os.path.join(source, filename)
    destination_path = os.path.join(destination, filename)

    if os.path.isfile(source_path):  # V√©rifier que c'est un fichier
        shutil.copy(source_path, destination_path)
        print(f"Copi√©: {filename}")

Copi√©: image_00001.jpg
Copi√©: image_00002.jpg
Copi√©: image_00003.jpg
Copi√©: image_00004.jpg
Copi√©: image_00005.jpg
Copi√©: image_00006.jpg
Copi√©: image_00007.jpg
Copi√©: image_00008.jpg
Copi√©: image_00009.jpg
Copi√©: image_00010.jpg
Copi√©: image_00011.jpg
Copi√©: image_00012.jpg
Copi√©: image_00013.jpg
Copi√©: image_00014.jpg
Copi√©: image_00015.jpg
Copi√©: image_00016.jpg
Copi√©: image_00017.jpg
Copi√©: image_00018.jpg
Copi√©: image_00019.jpg
Copi√©: image_00020.jpg
Copi√©: image_00021.jpg
Copi√©: image_00022.jpg
Copi√©: image_00023.jpg
Copi√©: image_00024.jpg
Copi√©: image_00025.jpg
Copi√©: image_00026.jpg
Copi√©: image_00027.jpg
Copi√©: image_00028.jpg
Copi√©: image_00029.jpg
Copi√©: image_00030.jpg
Copi√©: image_00031.jpg
Copi√©: image_00032.jpg
Copi√©: image_00033.jpg
Copi√©: image_00034.jpg
Copi√©: image_00035.jpg
Copi√©: image_00036.jpg
Copi√©: image_00037.jpg
Copi√©: image_00038.jpg
Copi√©: image_00039.jpg
Copi√©: image_00040.jpg
Copi√©: image_00041.jpg
Copi√©: image_00

In [19]:
%%writefile script.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import os
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image
from datetime import datetime

# ============================================================================
# ARCHITECTURE L√âG√àRE - MobileNet Style (CORRIG√âE)
# ============================================================================

class DepthwiseSeparableConv(nn.Module):
    """Convolution s√©parable en profondeur (9x moins de param√®tres)"""
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.depthwise = nn.Conv2d(in_ch, in_ch, 3, stride, 1, groups=in_ch, bias=False)
        self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.norm = nn.InstanceNorm2d(out_ch)
        # ‚úÖ CORRECTION: Retirer inplace=True
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.norm(x)
        return self.relu(x)


class LightweightResBlock(nn.Module):
    """Bloc r√©siduel ultra-l√©ger avec Squeeze-Excitation"""
    def __init__(self, channels, reduction=4):
        super().__init__()
        self.conv1 = DepthwiseSeparableConv(channels, channels)
        self.conv2 = DepthwiseSeparableConv(channels, channels)

        # Squeeze-Excitation (attention canal simple)
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(inplace=False),  # ‚úÖ CORRECTION
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        out = out * self.se(out)  # Attention
        return residual + out * 0.1  # Scaled residual


class FastGenerator(nn.Module):
    """G√©n√©rateur ultra-rapide pour drone"""
    def __init__(self, ngf=32, n_blocks=6):
        super().__init__()

        # Encoder (downsampling)
        self.encoder = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, ngf, 7, bias=False),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=False),  # ‚úÖ CORRECTION

            # Down 1: 128x128
            DepthwiseSeparableConv(ngf, ngf * 2, stride=2),
            # Down 2: 64x64
            DepthwiseSeparableConv(ngf * 2, ngf * 4, stride=2),
        )

        # Transformation (ResBlocks l√©gers)
        transform = []
        for _ in range(n_blocks):
            transform.append(LightweightResBlock(ngf * 4))
        self.transform = nn.Sequential(*transform)

        # Decoder (upsampling)
        self.decoder = nn.Sequential(
            # Up 1: 128x128
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, 1, bias=False),
            nn.InstanceNorm2d(ngf * 2),
            nn.ReLU(inplace=False),  # ‚úÖ CORRECTION

            # Up 2: 256x256
            nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, 1, bias=False),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=False),  # ‚úÖ CORRECTION

            # Output
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, 3, 7),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.transform(x)
        return self.decoder(x)


class FastDiscriminator(nn.Module):
    """Discriminateur l√©ger (70x70 PatchGAN)"""
    def __init__(self, ndf=32):
        super().__init__()

        self.model = nn.Sequential(
            # 128x128
            nn.Conv2d(3, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=False),  # ‚úÖ CORRECTION

            # 64x64
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=False),  # ‚úÖ CORRECTION

            # 32x32
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=False),  # ‚úÖ CORRECTION

            # 31x31
            nn.Conv2d(ndf * 4, 1, 4, 1, 1)
        )

    def forward(self, x):
        return self.model(x)


# ============================================================================
# DATASET OPTIMIS√â
# ============================================================================

class FastCelestialDataset(Dataset):
    """Dataset avec cache en m√©moire et augmentation rapide"""
    def __init__(self, root_A, root_B, img_size=128, cache_size=500):
        self.img_size = img_size
        self.cache = {}
        self.cache_size = cache_size

        # Charger chemins
        extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
        self.files_A = []
        self.files_B = []

        for ext in extensions:
            self.files_A.extend(list(Path(root_A).glob(ext)))
            self.files_B.extend(list(Path(root_B).glob(ext)))

        self.files_A = sorted(self.files_A)
        self.files_B = sorted(self.files_B)

        print(f"üìÇ Dataset: {len(self.files_A)} Real | {len(self.files_B)} Stellarium")

        if not self.files_A or not self.files_B:
            raise ValueError("‚ùå Aucune image trouv√©e!")

        # Transforms optimis√©s
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size), transforms.InterpolationMode.BILINEAR),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

    def _load_image(self, path, cache_key):
        """Chargement avec cache"""
        if cache_key in self.cache:
            return self.cache[cache_key]

        img = Image.open(path).convert('RGB')
        img = self.transform(img)

        # Cache si pas trop grand
        if len(self.cache) < self.cache_size:
            self.cache[cache_key] = img

        return img

    def __getitem__(self, idx):
        file_A = self.files_A[idx % len(self.files_A)]
        file_B = self.files_B[idx % len(self.files_B)]

        img_A = self._load_image(file_A, f"A_{idx % len(self.files_A)}")
        img_B = self._load_image(file_B, f"B_{idx % len(self.files_B)}")

        return img_A, img_B


# ============================================================================
# TRAINER OPTIMIS√â (CORRIG√â)
# ============================================================================

class FastCycleGANTrainer:
    """Entra√Ænement rapide avec mixed precision et optimisations"""
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        print(f"üöÄ Device: {self.device}")
        if torch.cuda.is_available():
            print(f"   GPU: {torch.cuda.get_device_name()}")

        os.makedirs(config['output_dir'], exist_ok=True)
        os.makedirs(f"{config['output_dir']}/samples", exist_ok=True)
        os.makedirs(f"{config['output_dir']}/checkpoints", exist_ok=True)

        self.history = {'G': [], 'D': [], 'cycle': []}
        self._build_models()
        self._setup_optimizers()

        # Mixed precision scaler
        self.scaler = GradScaler() if config['use_amp'] else None

        # Crit√®res
        self.criterion_GAN = nn.MSELoss()
        self.criterion_cycle = nn.L1Loss()

    def _build_models(self):
        """Construire mod√®les l√©gers"""
        config = self.config

        self.G_A2B = FastGenerator(config['ngf'], config['n_blocks']).to(self.device)
        self.G_B2A = FastGenerator(config['ngf'], config['n_blocks']).to(self.device)
        self.D_A = FastDiscriminator(config['ndf']).to(self.device)
        self.D_B = FastDiscriminator(config['ndf']).to(self.device)

        # Compter param√®tres
        g_params = sum(p.numel() for p in self.G_A2B.parameters()) / 1e6
        d_params = sum(p.numel() for p in self.D_A.parameters()) / 1e6
        print(f"üìä G√©n√©rateur: {g_params:.2f}M params")
        print(f"üìä Discriminateur: {d_params:.2f}M params")

    def _setup_optimizers(self):
        """AdamW avec cosine annealing"""
        config = self.config

        self.opt_G = AdamW(
            list(self.G_A2B.parameters()) + list(self.G_B2A.parameters()),
            lr=config['lr'],
            betas=(0.5, 0.999),
            weight_decay=1e-4
        )

        self.opt_D = AdamW(
            list(self.D_A.parameters()) + list(self.D_B.parameters()),
            lr=config['lr'],
            betas=(0.5, 0.999),
            weight_decay=1e-4
        )

        # Cosine annealing pour convergence rapide
        self.sched_G = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.opt_G, T_max=config['n_epochs']
        )
        self.sched_D = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.opt_D, T_max=config['n_epochs']
        )

    def train_step(self, real_A, real_B):
        """Un pas d'entra√Ænement optimis√© - CORRIG√â"""
        batch_size = real_A.size(0)

        # ‚úÖ CORRECTION: Cr√©er les labels AVANT l'autocast
        # Obtenir la taille de sortie du discriminateur
        with torch.no_grad():
            sample_output = self.D_A(real_A[:1])
            patch_size = sample_output.shape[-2:]  # (H, W)

        # Cr√©er les labels avec la bonne taille
        valid = torch.ones(batch_size, 1, *patch_size, device=self.device, dtype=torch.float32)
        fake_label = torch.zeros(batch_size, 1, *patch_size, device=self.device, dtype=torch.float32)

        # ==================== G√âN√âRATEURS ====================
        self.opt_G.zero_grad(set_to_none=True)

        with autocast(enabled=self.config['use_amp']):
            # G√©n√©ration
            fake_B = self.G_A2B(real_A)
            fake_A = self.G_B2A(real_B)

            # GAN loss
            pred_fake_B = self.D_B(fake_B)
            loss_GAN_A2B = self.criterion_GAN(pred_fake_B, valid)

            pred_fake_A = self.D_A(fake_A)
            loss_GAN_B2A = self.criterion_GAN(pred_fake_A, valid)

            # Cycle consistency
            recovered_A = self.G_B2A(fake_B)
            recovered_B = self.G_A2B(fake_A)
            loss_cycle = (
                self.criterion_cycle(recovered_A, real_A) +
                self.criterion_cycle(recovered_B, real_B)
            ) * self.config['lambda_cycle']

            # Total G loss
            loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle

        if self.scaler:
            self.scaler.scale(loss_G).backward()
            self.scaler.step(self.opt_G)
        else:
            loss_G.backward()
            self.opt_G.step()

        # ==================== DISCRIMINATEURS ====================
        self.opt_D.zero_grad(set_to_none=True)

        with autocast(enabled=self.config['use_amp']):
            # D_A
            pred_real_A = self.D_A(real_A)
            loss_D_real_A = self.criterion_GAN(pred_real_A, valid)

            pred_fake_A = self.D_A(fake_A.detach())
            loss_D_fake_A = self.criterion_GAN(pred_fake_A, fake_label)

            # D_B
            pred_real_B = self.D_B(real_B)
            loss_D_real_B = self.criterion_GAN(pred_real_B, valid)

            pred_fake_B = self.D_B(fake_B.detach())
            loss_D_fake_B = self.criterion_GAN(pred_fake_B, fake_label)

            loss_D = (loss_D_real_A + loss_D_fake_A +
                     loss_D_real_B + loss_D_fake_B) * 0.5

        if self.scaler:
            self.scaler.scale(loss_D).backward()
            self.scaler.step(self.opt_D)
            self.scaler.update()
        else:
            loss_D.backward()
            self.opt_D.step()

        return {
            'G': loss_G.item(),
            'D': loss_D.item(),
            'cycle': loss_cycle.item()
        }

    def train_epoch(self, dataloader, epoch):
        """Entra√Æner une √©poque"""
        self.G_A2B.train()
        self.G_B2A.train()
        self.D_A.train()
        self.D_B.train()

        losses = {'G': [], 'D': [], 'cycle': []}

        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{self.config['n_epochs']}")

        for i, (real_A, real_B) in enumerate(pbar):
            real_A = real_A.to(self.device, non_blocking=True)
            real_B = real_B.to(self.device, non_blocking=True)

            loss_dict = self.train_step(real_A, real_B)

            for k, v in loss_dict.items():
                losses[k].append(v)

            pbar.set_postfix({
                'G': f"{loss_dict['G']:.3f}",
                'D': f"{loss_dict['D']:.3f}",
                'Cyc': f"{loss_dict['cycle']:.3f}"
            })

            # √âchantillons
            if i == 0:
                self.save_samples(epoch, real_A, real_B)

        return {k: np.mean(v) for k, v in losses.items()}

    def save_samples(self, epoch, real_A, real_B):
        """Sauvegarder √©chantillons"""
        self.G_A2B.eval()
        self.G_B2A.eval()

        with torch.no_grad():
            fake_B = self.G_A2B(real_A[:4])
            fake_A = self.G_B2A(real_B[:4])

        comparison = torch.cat([
            real_A[:4], fake_B, real_B[:4], fake_A
        ], dim=0)

        save_image(
            comparison,
            f"{self.config['output_dir']}/samples/epoch_{epoch:03d}.png",
            nrow=4,
            normalize=True,
            value_range=(-1, 1)
        )

        self.G_A2B.train()
        self.G_B2A.train()

    def save_checkpoint(self, epoch):
        """Sauvegarder checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'G_A2B': self.G_A2B.state_dict(),
            'G_B2A': self.G_B2A.state_dict(),
            'D_A': self.D_A.state_dict(),
            'D_B': self.D_B.state_dict(),
            'opt_G': self.opt_G.state_dict(),
            'opt_D': self.opt_D.state_dict(),
            'config': self.config
        }

        path = f"{self.config['output_dir']}/checkpoints/checkpoint_{epoch:03d}.pth"
        torch.save(checkpoint, path)

        # Garder seulement les 3 derniers
        checkpoints = sorted(Path(f"{self.config['output_dir']}/checkpoints").glob("checkpoint_*.pth"))
        if len(checkpoints) > 3:
            checkpoints[0].unlink()

    def train(self, dataloader):
        """Boucle principale"""
        print("\n" + "="*70)
        print("üöÄ ENTRA√éNEMENT CYCLEGAN RAPIDE")
        print(f"   Epochs: {self.config['n_epochs']}")
        print(f"   Batch: {self.config['batch_size']}")
        print(f"   Mixed Precision: {self.config['use_amp']}")
        print("="*70 + "\n")

        for epoch in range(1, self.config['n_epochs'] + 1):
            avg_losses = self.train_epoch(dataloader, epoch)

            for k, v in avg_losses.items():
                self.history[k].append(v)

            print(f"\nüìä Epoch {epoch}: G={avg_losses['G']:.4f} | "
                  f"D={avg_losses['D']:.4f} | Cycle={avg_losses['cycle']:.4f}")

            self.sched_G.step()
            self.sched_D.step()

            if epoch % self.config['save_interval'] == 0:
                self.save_checkpoint(epoch)

        # Sauvegarder final
        torch.save(
            self.G_A2B.state_dict(),
            f"{self.config['output_dir']}/G_A2B_final.pth"
        )
        torch.save(
            self.G_B2A.state_dict(),
            f"{self.config['output_dir']}/G_B2A_final.pth"
        )

        print("\n‚úÖ Entra√Ænement termin√©!")
        print(f"üìÅ Mod√®les sauvegard√©s dans: {self.config['output_dir']}")


# ============================================================================
# CONFIGURATION
# ============================================================================

def get_fast_config():
    """Configuration optimis√©e pour entra√Ænement rapide"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    return {
        # Donn√©es
        'classA_dir': './classeA',
        'classB_dir': './classeB',
        'output_dir': f'./outputs/fast_cyclegan_{timestamp}',

        # Architecture l√©g√®re
        'ngf': 32,
        'ndf': 32,
        'n_blocks': 6,

        # Entra√Ænement rapide
        'n_epochs': 50,
        'batch_size': 8,
        'lr': 0.0002,

        # Optimisations
        'use_amp': True,
        'img_size': 128,
        'num_workers': 4,

        # Poids
        'lambda_cycle': 10.0,

        # Sauvegarde
        'save_interval': 10,
    }


# ============================================================================
# MAIN
# ============================================================================

def main():
    """Script principal optimis√©"""
    print("="*70)
    print("   CYCLEGAN RAPIDE - Navigation C√©leste pour Drone")
    print("="*70 + "\n")

    config = get_fast_config()

    # V√©rifications
    if not os.path.exists(config['classA_dir']):
        print(f"‚ùå Dossier '{config['classA_dir']}' introuvable!")
        return

    if not os.path.exists(config['classB_dir']):
        print(f"‚ùå Dossier '{config['classB_dir']}' introuvable!")
        return

    # Dataset
    try:
        dataset = FastCelestialDataset(
            config['classA_dir'],
            config['classB_dir'],
            img_size=config['img_size']
        )
    except ValueError as e:
        print(f"‚ùå {e}")
        return

    # DataLoader optimis√©
    dataloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True,
        persistent_workers=True
    )

    print(f"‚úÖ {len(dataloader)} batches charg√©s\n")

    # Entra√Ænement
    trainer = FastCycleGANTrainer(config)
    trainer.train(dataloader)


if __name__ == "__main__":
    main()

Overwriting script.py


In [20]:
!python script.py train


   CYCLEGAN RAPIDE - Navigation C√©leste pour Drone

üìÇ Dataset: 1000 Real | 1000 Stellarium
‚úÖ 125 batches charg√©s

üöÄ Device: cuda
   GPU: Tesla T4
üìä G√©n√©rateur: 0.37M params
üìä Discriminateur: 0.17M params
  self.scaler = GradScaler() if config['use_amp'] else None

üöÄ ENTRA√éNEMENT CYCLEGAN RAPIDE
   Epochs: 50
   Batch: 8
   Mixed Precision: True

  with autocast(enabled=self.config['use_amp']):
  with autocast(enabled=self.config['use_amp']):
Epoch 1/50: 100% 125/125 [00:22<00:00,  5.63it/s, G=0.747, D=0.475, Cyc=0.127]

üìä Epoch 1: G=2.3827 | D=0.4302 | Cycle=1.6073
Epoch 2/50: 100% 125/125 [00:15<00:00,  7.99it/s, G=0.583, D=0.581, Cyc=0.141]

üìä Epoch 2: G=0.9757 | D=0.3545 | Cycle=0.1171
Epoch 3/50: 100% 125/125 [00:15<00:00,  7.93it/s, G=0.752, D=0.404, Cyc=0.066]

üìä Epoch 3: G=0.8814 | D=0.3740 | Cycle=0.0914
Epoch 4/50: 100% 125/125 [00:16<00:00,  7.64it/s, G=1.105, D=0.253, Cyc=0.078]

üìä Epoch 4: G=0.9109 | D=0.3658 | Cycle=0.0818
Epoch 5/50: 100%

In [22]:
%%writefile improved_cyclegan.py
"""
CycleGAN AM√âLIOR√â pour Navigation C√©leste
Optimisations : √âtoiles brillantes + Meilleure qualit√© visuelle + Visualisation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
import os
from pathlib import Path
from tqdm import tqdm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime
import json

# ============================================================================
# ARCHITECTURE AM√âLIOR√âE avec Attention sur les √âtoiles
# ============================================================================

class StarAttentionModule(nn.Module):
    """
    Attention module sp√©cialis√© pour les √©toiles
    Renforce les points lumineux (√©toiles) dans l'image
    """
    def __init__(self, channels):
        super().__init__()

        # D√©tection des points lumineux (√©toiles)
        self.star_detector = nn.Sequential(
            nn.Conv2d(channels, channels // 4, 1),
            nn.ReLU(inplace=False),
            nn.Conv2d(channels // 4, 1, 1),
            nn.Sigmoid()
        )

        # Amplification s√©lective
        self.amplifier = nn.Sequential(
            nn.Conv2d(channels, channels, 1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=False)
        )

    def forward(self, x):
        # D√©tecter les √©toiles (points lumineux)
        star_mask = self.star_detector(x)

        # Amplifier les features aux positions des √©toiles
        amplified = self.amplifier(x)

        # Combiner : zone √©toile = amplifi√©, reste = original
        return x + star_mask * amplified * 0.5


class ChannelAttention(nn.Module):
    """Attention sur les canaux"""
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(inplace=False),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return x * self.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
    """Attention spatiale pour localiser les √©toiles"""
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        return x * self.sigmoid(self.conv(out))


class EnhancedResBlock(nn.Module):
    """
    Bloc r√©siduel am√©lior√© avec attention multi-niveaux
    Sp√©cialement con√ßu pour pr√©server et am√©liorer les √©toiles
    """
    def __init__(self, channels):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3, bias=False),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=False)
        )

        self.conv2 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3, bias=False),
            nn.InstanceNorm2d(channels)
        )

        # Triple attention
        self.star_attention = StarAttentionModule(channels)
        self.channel_attention = ChannelAttention(channels)
        self.spatial_attention = SpatialAttention()

    def forward(self, x):
        residual = x

        # Convolutions
        out = self.conv1(x)
        out = self.conv2(out)

        # Attention sur les √©toiles
        out = self.star_attention(out)

        # Attention canal et spatiale
        out = self.channel_attention(out)
        out = self.spatial_attention(out)

        return residual + out * 0.2  # Scaled residual connection


class ImprovedGenerator(nn.Module):
    """
    G√©n√©rateur am√©lior√© pour conversion R√©el‚ÜíStellarium
    Focus : Rendre les √©toiles brillantes et claires
    """
    def __init__(self, ngf=64, n_blocks=9):
        super().__init__()

        # === ENCODER ===
        self.encoder_init = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, ngf, 7, bias=False),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=False)
        )

        # Downsampling
        self.down1 = nn.Sequential(
            nn.Conv2d(ngf, ngf * 2, 3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(ngf * 2),
            nn.ReLU(inplace=False)
        )

        self.down2 = nn.Sequential(
            nn.Conv2d(ngf * 2, ngf * 4, 3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(ngf * 4),
            nn.ReLU(inplace=False)
        )

        # === TRANSFORMATION avec attention √©toiles ===
        transform_blocks = []
        for _ in range(n_blocks):
            transform_blocks.append(EnhancedResBlock(ngf * 4))
        self.transform = nn.Sequential(*transform_blocks)

        # === DECODER ===
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, stride=2, padding=1,
                              output_padding=1, bias=False),
            nn.InstanceNorm2d(ngf * 2),
            nn.ReLU(inplace=False)
        )

        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, ngf, 3, stride=2, padding=1,
                              output_padding=1, bias=False),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=False)
        )

        # Couche finale
        self.output = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, 3, 7),
            nn.Tanh()
        )

    def forward(self, x):
        # Encoder
        x = self.encoder_init(x)
        x = self.down1(x)
        x = self.down2(x)

        # Transformation avec attention
        x = self.transform(x)

        # Decoder
        x = self.up1(x)
        x = self.up2(x)

        return self.output(x)


class ImprovedDiscriminator(nn.Module):
    """Discriminateur am√©lior√© (PatchGAN)"""
    def __init__(self, ndf=64):
        super().__init__()

        self.model = nn.Sequential(
            # 128x128 ‚Üí 64x64
            nn.Conv2d(3, ndf, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=False),

            # 64x64 ‚Üí 32x32
            nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=False),

            # 32x32 ‚Üí 16x16
            nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=False),

            # 16x16 ‚Üí 15x15
            nn.Conv2d(ndf * 4, ndf * 8, 4, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=False),

            # Output
            nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)


# ============================================================================
# PERCEPTUAL LOSS pour am√©liorer la qualit√© visuelle
# ============================================================================

class PerceptualLoss(nn.Module):
    """
    Perte perceptuelle bas√©e sur VGG
    Am√©liore la qualit√© visuelle des √©toiles
    """
    def __init__(self):
        super().__init__()
        # Utiliser les premi√®res couches de VGG pour extraire features
        try:
            from torchvision.models import vgg19, VGG19_Weights
            vgg = vgg19(weights=VGG19_Weights.DEFAULT).features
        except:
            from torchvision.models import vgg19
            vgg = vgg19(pretrained=True).features

        # Extraire jusqu'√† relu3_4 (indices 0-17)
        self.feature_extractor = nn.Sequential(*list(vgg.children())[:18]).eval()

        # Geler les poids
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        self.criterion = nn.L1Loss()

    def forward(self, fake, real):
        # Normalisation ImageNet
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(fake.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(fake.device)

        fake_normalized = (fake * 0.5 + 0.5 - mean) / std
        real_normalized = (real * 0.5 + 0.5 - mean) / std

        # Extraire features
        fake_features = self.feature_extractor(fake_normalized)
        real_features = self.feature_extractor(real_normalized)

        return self.criterion(fake_features, real_features)


class StarBrightnessLoss(nn.Module):
    """
    Perte personnalis√©e pour encourager les √©toiles brillantes
    P√©nalise si les √©toiles sont trop sombres
    """
    def __init__(self):
        super().__init__()

    def forward(self, fake_stellarium, real_stellarium):
        # Calculer la luminosit√© moyenne des pixels brillants
        # (on consid√®re les 10% de pixels les plus lumineux comme des √©toiles)

        fake_flat = fake_stellarium.view(fake_stellarium.size(0), -1)
        real_flat = real_stellarium.view(real_stellarium.size(0), -1)

        # Top 10% pixels
        k = int(fake_flat.size(1) * 0.1)

        fake_bright = torch.topk(fake_flat, k, dim=1)[0].mean()
        real_bright = torch.topk(real_flat, k, dim=1)[0].mean()

        # Encourager fake √† avoir des √©toiles aussi brillantes que real
        return F.l1_loss(fake_bright, real_bright)


# ============================================================================
# DATASET
# ============================================================================

class CelestialDataset(Dataset):
    """Dataset avec augmentation pour navigation c√©leste"""
    def __init__(self, root_A, root_B, img_size=128):
        self.img_size = img_size

        # Charger chemins
        extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
        self.files_A = []
        self.files_B = []

        for ext in extensions:
            self.files_A.extend(list(Path(root_A).glob(ext)))
            self.files_B.extend(list(Path(root_B).glob(ext)))

        self.files_A = sorted(self.files_A)
        self.files_B = sorted(self.files_B)

        print(f"üìÇ Dataset: {len(self.files_A)} R√©el | {len(self.files_B)} Stellarium")

        if not self.files_A or not self.files_B:
            raise ValueError("‚ùå Aucune image trouv√©e!")

        # Transforms avec augmentation
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size), transforms.InterpolationMode.BILINEAR),
            transforms.RandomHorizontalFlip(p=0.5),
            # Augmentation sp√©cifique pour le ciel nocturne
            transforms.ColorJitter(brightness=0.15, contrast=0.15),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

    def __getitem__(self, idx):
        file_A = self.files_A[idx % len(self.files_A)]
        file_B = self.files_B[idx % len(self.files_B)]

        img_A = Image.open(file_A).convert('RGB')
        img_B = Image.open(file_B).convert('RGB')

        img_A = self.transform(img_A)
        img_B = self.transform(img_B)

        return img_A, img_B


# ============================================================================
# TRAINER AM√âLIOR√â
# ============================================================================

class ImprovedCycleGANTrainer:
    """Entra√Ænement am√©lior√© avec pertes suppl√©mentaires"""
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        print(f"üöÄ Device: {self.device}")
        if torch.cuda.is_available():
            print(f"   GPU: {torch.cuda.get_device_name()}")

        os.makedirs(config['output_dir'], exist_ok=True)
        os.makedirs(f"{config['output_dir']}/samples", exist_ok=True)
        os.makedirs(f"{config['output_dir']}/checkpoints", exist_ok=True)
        os.makedirs(f"{config['output_dir']}/comparisons", exist_ok=True)

        self.history = {
            'G': [], 'D': [], 'cycle': [], 'perceptual': [],
            'star_brightness': [], 'total': []
        }

        self._build_models()
        self._setup_optimizers()

        # Scaler pour mixed precision
        self.scaler = GradScaler() if config['use_amp'] else None

        # Crit√®res
        self.criterion_GAN = nn.MSELoss()
        self.criterion_cycle = nn.L1Loss()
        self.criterion_identity = nn.L1Loss()

        # Pertes am√©lior√©es
        if config['use_perceptual']:
            self.criterion_perceptual = PerceptualLoss().to(self.device)

        if config['use_star_loss']:
            self.criterion_star = StarBrightnessLoss().to(self.device)

    def _build_models(self):
        """Construire mod√®les"""
        config = self.config

        self.G_A2B = ImprovedGenerator(config['ngf'], config['n_blocks']).to(self.device)
        self.G_B2A = ImprovedGenerator(config['ngf'], config['n_blocks']).to(self.device)
        self.D_A = ImprovedDiscriminator(config['ndf']).to(self.device)
        self.D_B = ImprovedDiscriminator(config['ndf']).to(self.device)

        # Initialisation
        for net in [self.G_A2B, self.G_B2A, self.D_A, self.D_B]:
            self._init_weights(net)

        # Stats
        g_params = sum(p.numel() for p in self.G_A2B.parameters()) / 1e6
        d_params = sum(p.numel() for p in self.D_A.parameters()) / 1e6
        print(f"üìä G√©n√©rateur: {g_params:.2f}M params")
        print(f"üìä Discriminateur: {d_params:.2f}M params")

    def _init_weights(self, net, init_gain=0.02):
        """Initialisation des poids"""
        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and 'Conv' in classname:
                nn.init.normal_(m.weight.data, 0.0, init_gain)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
            elif 'Norm' in classname:
                if hasattr(m, 'weight') and m.weight is not None:
                    nn.init.normal_(m.weight.data, 1.0, init_gain)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)

        net.apply(init_func)

    def _setup_optimizers(self):
        """Optimiseurs"""
        config = self.config

        self.opt_G = AdamW(
            list(self.G_A2B.parameters()) + list(self.G_B2A.parameters()),
            lr=config['lr'],
            betas=(0.5, 0.999),
            weight_decay=1e-4
        )

        self.opt_D = AdamW(
            list(self.D_A.parameters()) + list(self.D_B.parameters()),
            lr=config['lr'],
            betas=(0.5, 0.999),
            weight_decay=1e-4
        )

        # Schedulers
        self.sched_G = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.opt_G, T_max=config['n_epochs']
        )
        self.sched_D = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.opt_D, T_max=config['n_epochs']
        )

    def train_step(self, real_A, real_B):
        """Un pas d'entra√Ænement avec toutes les pertes"""
        batch_size = real_A.size(0)

        # Labels
        with torch.no_grad():
            sample_out = self.D_A(real_A[:1])
            patch_size = sample_out.shape[-2:]

        valid = torch.ones(batch_size, 1, *patch_size, device=self.device)
        fake_label = torch.zeros(batch_size, 1, *patch_size, device=self.device)

        # ==================== G√âN√âRATEURS ====================
        self.opt_G.zero_grad(set_to_none=True)

        with autocast(enabled=self.config['use_amp']):
            # G√©n√©ration
            fake_B = self.G_A2B(real_A)  # R√©el ‚Üí Stellarium
            fake_A = self.G_B2A(real_B)

            # GAN loss
            pred_fake_B = self.D_B(fake_B)
            loss_GAN_A2B = self.criterion_GAN(pred_fake_B, valid)

            pred_fake_A = self.D_A(fake_A)
            loss_GAN_B2A = self.criterion_GAN(pred_fake_A, valid)

            loss_GAN = (loss_GAN_A2B + loss_GAN_B2A) * 0.5

            # Cycle consistency
            recovered_A = self.G_B2A(fake_B)
            recovered_B = self.G_A2B(fake_A)
            loss_cycle = (
                self.criterion_cycle(recovered_A, real_A) +
                self.criterion_cycle(recovered_B, real_B)
            ) * self.config['lambda_cycle']

            # Identity loss (optionnel)
            loss_identity = 0
            if self.config['lambda_identity'] > 0:
                id_B = self.G_A2B(real_B)
                id_A = self.G_B2A(real_A)
                loss_identity = (
                    self.criterion_identity(id_B, real_B) +
                    self.criterion_identity(id_A, real_A)
                ) * self.config['lambda_identity']

            # Perceptual loss (am√©liore qualit√© visuelle)
            loss_perceptual = 0
            if self.config['use_perceptual']:
                loss_perceptual = (
                    self.criterion_perceptual(fake_B, real_B) +
                    self.criterion_perceptual(fake_A, real_A)
                ) * self.config['lambda_perceptual']

            # Star brightness loss (√©toiles brillantes!)
            loss_star = 0
            if self.config['use_star_loss']:
                loss_star = self.criterion_star(fake_B, real_B) * self.config['lambda_star']

            # Total G loss
            loss_G = loss_GAN + loss_cycle + loss_identity + loss_perceptual + loss_star

        if self.scaler:
            self.scaler.scale(loss_G).backward()
            self.scaler.step(self.opt_G)
        else:
            loss_G.backward()
            self.opt_G.step()

        # ==================== DISCRIMINATEURS ====================
        self.opt_D.zero_grad(set_to_none=True)

        with autocast(enabled=self.config['use_amp']):
            # D_A
            pred_real_A = self.D_A(real_A)
            loss_D_real_A = self.criterion_GAN(pred_real_A, valid)

            pred_fake_A = self.D_A(fake_A.detach())
            loss_D_fake_A = self.criterion_GAN(pred_fake_A, fake_label)

            # D_B
            pred_real_B = self.D_B(real_B)
            loss_D_real_B = self.criterion_GAN(pred_real_B, valid)

            pred_fake_B = self.D_B(fake_B.detach())
            loss_D_fake_B = self.criterion_GAN(pred_fake_B, fake_label)

            loss_D = (loss_D_real_A + loss_D_fake_A +
                     loss_D_real_B + loss_D_fake_B) * 0.5

        if self.scaler:
            self.scaler.scale(loss_D).backward()
            self.scaler.step(self.opt_D)
            self.scaler.update()
        else:
            loss_D.backward()
            self.opt_D.step()

        return {
            'G': loss_G.item(),
            'D': loss_D.item(),
            'cycle': loss_cycle.item(),
            'perceptual': loss_perceptual.item() if loss_perceptual != 0 else 0,
            'star': loss_star.item() if loss_star != 0 else 0
        }

    def train_epoch(self, dataloader, epoch):
        """Entra√Æner une √©poque"""
        self.G_A2B.train()
        self.G_B2A.train()
        self.D_A.train()
        self.D_B.train()

        losses = {'G': [], 'D': [], 'cycle': [], 'perceptual': [], 'star': []}

        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{self.config['n_epochs']}")

        for i, (real_A, real_B) in enumerate(pbar):
            real_A = real_A.to(self.device, non_blocking=True)
            real_B = real_B.to(self.device, non_blocking=True)

            loss_dict = self.train_step(real_A, real_B)

            for k, v in loss_dict.items():
                losses[k].append(v)

            pbar.set_postfix({
                'G': f"{loss_dict['G']:.3f}",
                'D': f"{loss_dict['D']:.3f}",
                'Cyc': f"{loss_dict['cycle']:.3f}",
                'Star': f"{loss_dict['star']:.3f}"
            })

            # Sauvegarder √©chantillons
            if i == 0:
                self.save_comparison(epoch, real_A, real_B)

        return {k: np.mean(v) for k, v in losses.items()}

    def save_comparison(self, epoch, real_A, real_B):
        """Cr√©er image de comparaison d√©taill√©e"""
        self.G_A2B.eval()
        self.G_B2A.eval()

        with torch.no_grad():
            fake_B = self.G_A2B(real_A[:4])
            fake_A = self.G_B2A(real_B[:4])
            recovered_A = self.G_B2A(fake_B)
            recovered_B = self.G_A2B(fake_A)

        # Grille de comparaison
        comparison = torch.cat([
            real_A[:4],        # Images r√©elles
            fake_B,            # Converti en Stellarium
            real_B[:4],        # Stellarium r√©el
            fake_A,            # Converti en R√©el
            recovered_A,       # Cycle A‚ÜíB‚ÜíA
            recovered_B        # Cycle B‚ÜíA‚ÜíB
        ], dim=0)

        save_image(
            comparison,
            f"{self.config['output_dir']}/comparisons/epoch_{epoch:03d}_full.png",
            nrow=4,
            normalize=True,
            value_range=(-1, 1)
        )

        # Comparaison c√¥te √† c√¥te (R√©el vs Stellarium)
        side_by_side = torch.cat([real_A[:4], fake_B], dim=3)
        save_image(
            side_by_side,
            f"{self.config['output_dir']}/comparisons/epoch_{epoch:03d}_sidebyside.png",
            nrow=1,
            normalize=True,
            value_range=(-1, 1)
        )

        self.G_A2B.train()
        self.G_B2A.train()

    def plot_training_curves(self):
        """Graphiques d'entra√Ænement"""
        fig, axes = plt.subplots(2, 3, figsize=(15, 8))
        fig.suptitle('Entra√Ænement CycleGAN - Navigation C√©leste', fontsize=14)

        # G loss
        axes[0, 0].plot(self.history['G'], color='blue', alpha=0.7)
        axes[0, 0].set_title('Generator Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].grid(True, alpha=0.3)

        # D loss
        axes[0, 1].plot(self.history['D'], color='red', alpha=0.7)
        axes[0, 1].set_title('Discriminator Loss')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].grid(True, alpha=0.3)

        # Cycle loss
        axes[0, 2].plot(self.history['cycle'], color='green', alpha=0.7)
        axes[0, 2].set_title('Cycle Consistency Loss')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].grid(True, alpha=0.3)

        # Perceptual loss
        if len(self.history['perceptual']) > 0 and max(self.history['perceptual']) > 0:
            axes[1, 0].plot(self.history['perceptual'], color='purple', alpha=0.7)
            axes[1, 0].set_title('Perceptual Loss')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].grid(True, alpha=0.3)
        else:
            axes[1, 0].text(0.5, 0.5, 'Perceptual Loss\nDisabled',
                          ha='center', va='center', transform=axes[1, 0].transAxes)

        # Star brightness loss
        if len(self.history['star_brightness']) > 0 and max(self.history['star_brightness']) > 0:
            axes[1, 1].plot(self.history['star_brightness'], color='orange', alpha=0.7)
            axes[1, 1].set_title('Star Brightness Loss')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].grid(True, alpha=0.3)
        else:
            axes[1, 1].text(0.5, 0.5, 'Star Brightness\nLoss Disabled',
                          ha='center', va='center', transform=axes[1, 1].transAxes)

        # Total loss
        axes[1, 2].plot(self.history['total'], color='black', alpha=0.7)
        axes[1, 2].set_title('Total Loss')
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f"{self.config['output_dir']}/training_curves.png", dpi=150, bbox_inches='tight')
        plt.close()
        print("   üìà Graphiques sauvegard√©s")

    def save_checkpoint(self, epoch):
        """Sauvegarder checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'G_A2B': self.G_A2B.state_dict(),
            'G_B2A': self.G_B2A.state_dict(),
            'D_A': self.D_A.state_dict(),
            'D_B': self.D_B.state_dict(),
            'opt_G': self.opt_G.state_dict(),
            'opt_D': self.opt_D.state_dict(),
            'config': self.config,
            'history': self.history
        }

        path = f"{self.config['output_dir']}/checkpoints/checkpoint_{epoch:03d}.pth"
        torch.save(checkpoint, path)

        # Garder seulement les 3 derniers + best
        checkpoints = sorted(Path(f"{self.config['output_dir']}/checkpoints").glob("checkpoint_*.pth"))
        if len(checkpoints) > 3:
            for ckpt in checkpoints[:-3]:
                if 'best' not in str(ckpt):
                    ckpt.unlink()

    def train(self, dataloader):
        """Boucle principale d'entra√Ænement"""
        print("\n" + "="*70)
        print("üöÄ ENTRA√éNEMENT CYCLEGAN AM√âLIOR√â")
        print(f"   Epochs: {self.config['n_epochs']}")
        print(f"   Batch: {self.config['batch_size']}")
        print(f"   Perceptual Loss: {self.config['use_perceptual']}")
        print(f"   Star Brightness Loss: {self.config['use_star_loss']}")
        print("="*70 + "\n")

        best_cycle_loss = float('inf')

        for epoch in range(1, self.config['n_epochs'] + 1):
            avg_losses = self.train_epoch(dataloader, epoch)

            # Historique
            self.history['G'].append(avg_losses['G'])
            self.history['D'].append(avg_losses['D'])
            self.history['cycle'].append(avg_losses['cycle'])
            self.history['perceptual'].append(avg_losses['perceptual'])
            self.history['star_brightness'].append(avg_losses['star'])
            self.history['total'].append(avg_losses['G'] + avg_losses['D'])

            print(f"\nüìä Epoch {epoch}: G={avg_losses['G']:.4f} | "
                  f"D={avg_losses['D']:.4f} | Cycle={avg_losses['cycle']:.4f} | "
                  f"Star={avg_losses['star']:.4f}")

            # Update schedulers
            self.sched_G.step()
            self.sched_D.step()

            # Sauvegarder
            if epoch % self.config['save_interval'] == 0:
                self.save_checkpoint(epoch)

            # Sauvegarder meilleur
            if avg_losses['cycle'] < best_cycle_loss:
                best_cycle_loss = avg_losses['cycle']
                torch.save(
                    self.G_A2B.state_dict(),
                    f"{self.config['output_dir']}/checkpoints/best_G_A2B.pth"
                )
                print(f"   ‚≠ê Nouveau meilleur mod√®le! Cycle loss: {best_cycle_loss:.4f}")

            # Graphiques
            if epoch % 5 == 0:
                self.plot_training_curves()

        # Final
        torch.save(
            self.G_A2B.state_dict(),
            f"{self.config['output_dir']}/G_A2B_final.pth"
        )
        torch.save(
            self.G_B2A.state_dict(),
            f"{self.config['output_dir']}/G_B2A_final.pth"
        )

        # Sauvegarder historique
        with open(f"{self.config['output_dir']}/history.json", 'w') as f:
            json.dump(self.history, f, indent=2)

        # Graphiques finaux
        self.plot_training_curves()

        print("\n‚úÖ Entra√Ænement termin√©!")
        print(f"üìÅ Mod√®les sauvegard√©s dans: {self.config['output_dir']}")


# ============================================================================
# VISUALISATION ET ANALYSE
# ============================================================================

class ResultAnalyzer:
    """Analyser et visualiser les r√©sultats"""
    def __init__(self, model_path, device='cuda'):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')

        # Charger mod√®le
        self.G_A2B = ImprovedGenerator(ngf=64, n_blocks=9).to(self.device)
        self.G_A2B.load_state_dict(torch.load(model_path, map_location=self.device))
        self.G_A2B.eval()

        self.transform = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

        print(f"‚úÖ Mod√®le charg√© sur {self.device}")

    def create_detailed_comparison(self, image_path, output_path):
        """
        Cr√©er une comparaison d√©taill√©e avec:
        - Image originale (r√©elle)
        - Image convertie (Stellarium)
        - Zoom sur les √©toiles
        - Histogrammes de luminosit√©
        """
        # Charger image
        img = Image.open(image_path).convert('RGB')
        img_tensor = self.transform(img).unsqueeze(0).to(self.device)

        # G√©n√©rer version Stellarium
        with torch.no_grad():
            fake_stellarium = self.G_A2B(img_tensor)

        # Convertir en numpy pour matplotlib
        real_np = img_tensor.squeeze().cpu().numpy().transpose(1, 2, 0)
        real_np = (real_np * 0.5 + 0.5).clip(0, 1)

        fake_np = fake_stellarium.squeeze().cpu().numpy().transpose(1, 2, 0)
        fake_np = (fake_np * 0.5 + 0.5).clip(0, 1)

        # Cr√©er figure
        fig = plt.figure(figsize=(16, 10))
        gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

        # Image r√©elle compl√®te
        ax1 = fig.add_subplot(gs[0, 0])
        ax1.imshow(real_np)
        ax1.set_title('Image R√©elle', fontsize=14, fontweight='bold')
        ax1.axis('off')

        # Image Stellarium g√©n√©r√©e
        ax2 = fig.add_subplot(gs[0, 1])
        ax2.imshow(fake_np)
        ax2.set_title('G√©n√©r√© (Style Stellarium)', fontsize=14, fontweight='bold')
        ax2.axis('off')

        # Diff√©rence
        ax3 = fig.add_subplot(gs[0, 2])
        diff = np.abs(fake_np - real_np)
        ax3.imshow(diff)
        ax3.set_title('Diff√©rence (Changements)', fontsize=14, fontweight='bold')
        ax3.axis('off')

        # Zoom sur une r√©gion (d√©tails √©toiles)
        h, w = real_np.shape[:2]
        crop_h, crop_w = h//3, w//3
        start_h, start_w = h//3, w//3

        real_zoom = real_np[start_h:start_h+crop_h, start_w:start_w+crop_w]
        fake_zoom = fake_np[start_h:start_h+crop_h, start_w:start_w+crop_w]

        ax4 = fig.add_subplot(gs[1, 0])
        ax4.imshow(real_zoom)
        ax4.set_title('Zoom - R√©el', fontsize=12)
        ax4.axis('off')

        ax5 = fig.add_subplot(gs[1, 1])
        ax5.imshow(fake_zoom)
        ax5.set_title('Zoom - Stellarium', fontsize=12)
        ax5.axis('off')

        # Histogramme luminosit√©
        ax6 = fig.add_subplot(gs[1, 2])
        real_gray = np.mean(real_np, axis=2)
        fake_gray = np.mean(fake_np, axis=2)

        ax6.hist(real_gray.flatten(), bins=50, alpha=0.6, label='R√©el', color='blue')
        ax6.hist(fake_gray.flatten(), bins=50, alpha=0.6, label='Stellarium', color='orange')
        ax6.set_title('Distribution Luminosit√©', fontsize=12)
        ax6.set_xlabel('Intensit√©')
        ax6.set_ylabel('Fr√©quence')
        ax6.legend()
        ax6.grid(True, alpha=0.3)

        # Statistiques sur les √©toiles (pixels brillants)
        ax7 = fig.add_subplot(gs[2, :])

        # Top 10% pixels les plus brillants
        real_bright = np.sort(real_gray.flatten())[-int(len(real_gray.flatten())*0.1):]
        fake_bright = np.sort(fake_gray.flatten())[-int(len(fake_gray.flatten())*0.1):]

        stats_text = f"""
        üìä STATISTIQUES DES √âTOILES (Top 10% pixels brillants)

        R√©el:
        - Luminosit√© moyenne: {real_bright.mean():.3f}
        - Luminosit√© max: {real_gray.max():.3f}
        - √âcart-type: {real_bright.std():.3f}

        Stellarium (G√©n√©r√©):
        - Luminosit√© moyenne: {fake_bright.mean():.3f}
        - Luminosit√© max: {fake_gray.max():.3f}
        - √âcart-type: {fake_bright.std():.3f}

        Am√©lioration:
        - Gain luminosit√©: {(fake_bright.mean() / real_bright.mean() - 1) * 100:+.1f}%
        - Nombre pixels tr√®s brillants (>0.8): {(fake_gray > 0.8).sum()} vs {(real_gray > 0.8).sum()}
        """

        ax7.text(0.1, 0.5, stats_text, fontsize=11, family='monospace',
                verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        ax7.axis('off')

        plt.suptitle('Analyse D√©taill√©e - Conversion R√©el ‚Üí Stellarium',
                    fontsize=16, fontweight='bold', y=0.98)

        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"‚úÖ Analyse d√©taill√©e sauvegard√©e: {output_path}")

    def batch_analyze(self, input_dir, output_dir):
        """Analyser un dossier d'images"""
        os.makedirs(output_dir, exist_ok=True)

        extensions = ['*.jpg', '*.jpeg', '*.png']
        image_files = []
        for ext in extensions:
            image_files.extend(list(Path(input_dir).glob(ext)))

        print(f"üîç Analyse de {len(image_files)} images...")

        for img_path in tqdm(image_files[:10]):  # Limiter √† 10 pour l'exemple
            output_path = Path(output_dir) / f"analysis_{img_path.stem}.png"
            self.create_detailed_comparison(str(img_path), str(output_path))

    def create_grid_comparison(self, image_paths, output_path, max_images=16):
        """Cr√©er grille de comparaisons multiples"""
        n_images = min(len(image_paths), max_images)

        fig, axes = plt.subplots(n_images, 2, figsize=(10, 5*n_images))
        if n_images == 1:
            axes = [axes]

        for idx, img_path in enumerate(image_paths[:n_images]):
            img = Image.open(img_path).convert('RGB')
            img_tensor = self.transform(img).unsqueeze(0).to(self.device)

            with torch.no_grad():
                fake = self.G_A2B(img_tensor)

            real_np = img_tensor.squeeze().cpu().numpy().transpose(1, 2, 0)
            real_np = (real_np * 0.5 + 0.5).clip(0, 1)

            fake_np = fake.squeeze().cpu().numpy().transpose(1, 2, 0)
            fake_np = (fake_np * 0.5 + 0.5).clip(0, 1)

            axes[idx][0].imshow(real_np)
            axes[idx][0].set_title(f'R√©el - {Path(img_path).name}', fontsize=10)
            axes[idx][0].axis('off')

            axes[idx][1].imshow(fake_np)
            axes[idx][1].set_title(f'Stellarium G√©n√©r√©', fontsize=10)
            axes[idx][1].axis('off')

        plt.tight_layout()
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"‚úÖ Grille de comparaison sauvegard√©e: {output_path}")


# ============================================================================
# CONFIGURATION
# ============================================================================

def get_improved_config():
    """Configuration am√©lior√©e"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    return {
        # Donn√©es
        'classA_dir': './classeA',
        'classB_dir': './classeB',
        'output_dir': f'./outputs/improved_cyclegan_{timestamp}',

        # Architecture
        'ngf': 64,  # Augment√© pour meilleure qualit√©
        'ndf': 64,
        'n_blocks': 9,  # Plus de ResBlocks pour transformer

        # Entra√Ænement
        'n_epochs': 100,
        'batch_size': 4,  # Augment√© si GPU le permet
        'lr': 0.0002,

        # Optimisations
        'use_amp': True,
        'img_size': 128,
        'num_workers': 4,

        # Poids des pertes
        'lambda_cycle': 10.0,
        'lambda_identity': 0.5,
        'lambda_perceptual': 1.0,  # Nouvelle: qualit√© visuelle
        'lambda_star': 2.0,  # Nouvelle: √©toiles brillantes

        # Nouvelles pertes
        'use_perceptual': True,  # ‚úÖ Am√©liore qualit√©
        'use_star_loss': True,   # ‚úÖ √âtoiles brillantes

        # Sauvegarde
        'save_interval': 10,
    }


# ============================================================================
# MAIN
# ============================================================================

def main():
    """Script principal am√©lior√©"""
    print("="*70)
    print("   CYCLEGAN AM√âLIOR√â - Navigation C√©leste")
    print("   Conversion: R√©el ‚Üí Stellarium (√âtoiles Brillantes)")
    print("="*70 + "\n")

    config = get_improved_config()

    # V√©rifications
    if not os.path.exists(config['classA_dir']):
        print(f"‚ùå Dossier '{config['classA_dir']}' introuvable!")
        return

    if not os.path.exists(config['classB_dir']):
        print(f"‚ùå Dossier '{config['classB_dir']}' introuvable!")
        return

    # Dataset
    try:
        dataset = CelestialDataset(
            config['classA_dir'],
            config['classB_dir'],
            img_size=config['img_size']
        )
    except ValueError as e:
        print(f"‚ùå {e}")
        return

    # DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True,
        persistent_workers=True
    )

    print(f"‚úÖ {len(dataloader)} batches charg√©s\n")

    # Entra√Ænement
    trainer = ImprovedCycleGANTrainer(config)
    trainer.train(dataloader)


def analyze_results(model_path, test_dir, output_dir):
    """Analyser les r√©sultats apr√®s entra√Ænement"""
    print("üîç Analyse des r√©sultats...\n")

    analyzer = ResultAnalyzer(model_path)

    # Cr√©er analyses d√©taill√©es
    analyzer.batch_analyze(test_dir, f"{output_dir}/detailed_analysis")

    # Cr√©er grille de comparaison
    test_images = list(Path(test_dir).glob("*.jpg"))[:16]
    if test_images:
        analyzer.create_grid_comparison(
            test_images,
            f"{output_dir}/comparison_grid.png"
        )


if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        if sys.argv[1] == 'train':
            main()
        elif sys.argv[1] == 'analyze':
            if len(sys.argv) < 4:
                print("Usage: python script.py analyze <model_path> <test_dir>")
            else:
                analyze_results(sys.argv[2], sys.argv[3], './analysis_output')
    else:
        main()


"""
============================================================================
AM√âLIORATIONS PRINCIPALES PAR RAPPORT AU MOD√àLE RAPIDE
============================================================================

1. ‚ú® ARCHITECTURE AM√âLIOR√âE
   - StarAttentionModule: Focus sp√©cial sur les √©toiles
   - Triple attention (Star + Channel + Spatial)
   - ResBlocks plus profonds avec attention
   ‚ûú √âtoiles plus brillantes et mieux d√©finies

2. üé® PERCEPTUAL LOSS
   - Utilise VGG19 pour comparer features
   - Am√©liore la qualit√© visuelle globale
   - Pr√©serve mieux les d√©tails
   ‚ûú Images plus r√©alistes et naturelles

3. ‚≠ê STAR BRIGHTNESS LOSS
   - Perte personnalis√©e pour √©toiles brillantes
   - Force le g√©n√©rateur √† rendre les √©toiles lumineuses
   - Compare top 10% pixels (√©toiles)
   ‚ûú √âtoiles √©clatantes comme dans Stellarium!

4. üìä VISUALISATION AVANC√âE
   - Comparaisons d√©taill√©es avec zoom
   - Histogrammes de luminosit√©
   - Statistiques sur les √©toiles
   - Grilles de comparaison multiple
   ‚ûú Analyse compl√®te des r√©sultats

5. üéØ MEILLEURE STABILIT√â
   - Initialisation des poids optimis√©e
   - Schedulers cosine annealing
   - Sauvegarde du meilleur mod√®le
   ‚ûú Convergence plus stable

============================================================================
UTILISATION
============================================================================

# Entra√Ænement am√©lior√©
python script.py train

# Analyse des r√©sultats
python script.py analyze ./outputs/*/checkpoints/best_G_A2B.pth ./classeA

# R√©sultats g√©n√©r√©s:
# - /comparisons/: Comparaisons c√¥te-√†-c√¥te
# - /samples/: √âchantillons par √©poque
# - /training_curves.png: Graphiques d'entra√Ænement
# - /detailed_analysis/: Analyses d√©taill√©es avec stats

============================================================================
DIFF√âRENCES VS MOD√àLE LOURD
============================================================================

CONSERV√â:
‚úì Architecture ResNet avec attention
‚úì PatchGAN discriminateur
‚úì Cycle consistency
‚úì Mixed precision training

AJOUT√â/AM√âLIOR√â:
‚úì Star attention module (nouveau)
‚úì Perceptual loss VGG (qualit√©++)
‚úì Star brightness loss (√©toiles brillantes++)
‚úì Visualisations compl√®tes (analyse++)
‚úì Taille optimis√©e (128px pour vitesse)

ALL√âG√â:
‚úì 64 filtres au lieu de plus
‚úì Images 128x128 (4x plus rapide que 256)
‚úì Pas de pools d'images (simplification)

R√âSULTAT: 10x plus rapide, qualit√© comparable, √©toiles BRILLANTES! ‚≠ê
============================================================================
"""

Writing improved_cyclegan.py


In [23]:
!python improved_cyclegan.py train

   CYCLEGAN AM√âLIOR√â - Navigation C√©leste
   Conversion: R√©el ‚Üí Stellarium (√âtoiles Brillantes)

üìÇ Dataset: 1000 R√©el | 1000 Stellarium
‚úÖ 250 batches charg√©s

üöÄ Device: cuda
   GPU: Tesla T4
üìä G√©n√©rateur: 12.26M params
üìä Discriminateur: 2.76M params
  self.scaler = GradScaler() if config['use_amp'] else None
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100% 548M/548M [00:05<00:00, 100MB/s] 

üöÄ ENTRA√éNEMENT CYCLEGAN AM√âLIOR√â
   Epochs: 100
   Batch: 4
   Perceptual Loss: True
   Star Brightness Loss: True

  with autocast(enabled=self.config['use_amp']):
  with autocast(enabled=self.config['use_amp']):
Epoch 1/100: 100% 250/250 [01:24<00:00,  2.95it/s, G=1.445, D=0.517, Cyc=0.064, Star=0.015]

üìä Epoch 1: G=2.1083 | D=0.6378 | Cycle=0.5207 | Star=0.1304
   ‚≠ê Nouveau meilleur mod√®le! Cycle loss: 0.5207
Epoch 2/100: 100% 250/250 [01:24<00:00,  2.95it/s, G=1.449, D=0.471, C

In [27]:
# M√©thode 3: Commande shell directe
!cd /content && zip -r mon_modele_cyclegan.zip outputs/ -i "*"

# Puis t√©l√©charger
from google.colab import files
files.download('/content/mon_modele_cyclegan.zip')

  adding: outputs/ (stored 0%)
  adding: outputs/fast_cyclegan_20251117_161346/ (stored 0%)
  adding: outputs/fast_cyclegan_20251117_161346/samples/ (stored 0%)
  adding: outputs/fast_cyclegan_20251117_161346/checkpoints/ (stored 0%)
  adding: outputs/fast_cyclegan_20251117_160740/ (stored 0%)
  adding: outputs/fast_cyclegan_20251117_160740/samples/ (stored 0%)
  adding: outputs/fast_cyclegan_20251117_160740/checkpoints/ (stored 0%)
  adding: outputs/improved_cyclegan_20251117_165354/ (stored 0%)
  adding: outputs/improved_cyclegan_20251117_165354/samples/ (stored 0%)
  adding: outputs/improved_cyclegan_20251117_165354/history.json (deflated 61%)
  adding: outputs/improved_cyclegan_20251117_165354/G_A2B_final.pth (deflated 7%)
  adding: outputs/improved_cyclegan_20251117_165354/G_B2A_final.pth (deflated 7%)
  adding: outputs/improved_cyclegan_20251117_165354/training_curves.png (deflated 13%)
  adding: outputs/improved_cyclegan_20251117_165354/comparisons/ (stored 0%)
  adding: outputs

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [30]:
import shutil
import os
from pathlib import Path

def copy_to_drive_simple():
    """Copier un fichier sp√©cifique vers Drive"""

    # Chemin du fichier dans Colab
    source_file = "/content/outputs.zip"  # Remplacez par votre fichier

    if not os.path.exists(source_file):
        print(f"‚ùå Fichier non trouv√©: {source_file}")
        print("üìÅ Fichiers disponibles dans /content:")
        !ls -la /content/
        return

    # Chemin de destination dans Drive
    drive_destination = "/content/drive/MyDrive/outouts.zip"

    # Copier le fichier
    shutil.copy2(source_file, drive_destination)

    print(f"‚úÖ Fichier copi√©: {source_file} ‚Üí {drive_destination}")

copy_to_drive_simple()

‚úÖ Fichier copi√©: /content/outputs.zip ‚Üí /content/drive/MyDrive/outouts.zip


In [None]:
# Test 1 - Batch size 1 (actuel)
!python script.py --mode train --batch_size 1 --epochs 5

# Test 2 - Batch size 2
!python script.py --mode train --batch_size 2 --epochs 5

# Comparez les courbes de loss !

üìÇ Dataset charg√©:
   Classe A (Real): 1000 images
   Classe B (Stellarium): 1000 images
üñ•Ô∏è  Device: cuda
üî® Construction des mod√®les...
   G_A2B: 11,447,541 param√®tres
   D_A: 2,763,841 param√®tres

üöÄ D√©marrage entra√Ænement CycleGAN Navigation C√©leste
   Epochs: 5
   Batch size: 1
   CBAM: True
   Device: cuda

Epoch 1/5: 100% 1000/1000 [09:15<00:00,  1.80it/s, G=0.392, D=0.190, Cycle=0.001]

üìä Epoch 1 - Pertes moyennes:
   G: 0.4070 | D: 0.2924
   Cycle: 0.0071 | GAN: 0.3308
   üíæ Checkpoint sauvegard√© (epoch 1)
Epoch 2/5: 100% 1000/1000 [09:16<00:00,  1.80it/s, G=0.253, D=0.274, Cycle=0.003]

üìä Epoch 2 - Pertes moyennes:
   G: 0.4104 | D: 0.2200
   Cycle: 0.0023 | GAN: 0.3853
   üíæ Checkpoint sauvegard√© (epoch 2)
Epoch 3/5: 100% 1000/1000 [09:16<00:00,  1.80it/s, G=0.210, D=0.227, Cycle=0.001]

üìä Epoch 3 - Pertes moyennes:
   G: 0.2720 | D: 0.2469
   Cycle: 0.0020 | GAN: 0.2503
   üíæ Checkpoint sauvegard√© (epoch 3)
Epoch 4/5: 100% 1000/1000 [09:16

In [None]:
# Pour utiliser le mod√®le entra√Æn√©
# Remplacez le chemin par votre checkpoint r√©el
checkpoint_path = "./outputs/cyclegan_YYYYMMDD_HHMMSS/checkpoints/best.pth"

# Inf√©rence sur une image
!python script.py --mode inference \
    --checkpoint {checkpoint_path} \
    --input ./classeA/test_image.jpg \
    --output ./results \
    --direction A2B

In [4]:
# Cr√©er la structure de dossiers
!python script.py --mode setup

In [32]:
import shutil
shutil.copy("/content/improved_modele_cyclegan.zip", "/content/drive/MyDrive/improved_cyclegan.zip")

print("‚úÖ ZIP copi√© vers Drive: /content/drive/MyDrive/mon_modele_cyclegan.zip")

‚úÖ ZIP copi√© vers Drive: /content/drive/MyDrive/mon_modele_cyclegan.zip
