In [None]:
import os
import math
import h5py
import numpy as np
import scipy.ndimage
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

###############################
# 1) DATA LOADING & RESIZING
###############################
def load_swept_sine_case(file_path, amplitude_key, target_x=41, target_y=45):
    with h5py.File(file_path, 'r') as f:
        dataset = f['data_structure']['swept_sines'][amplitude_key]
        wz_grid = np.array(dataset['wz_grid'])  # (6405, T)
        T = wz_grid.shape[1]
        spatial_x, spatial_y = 61, 105
        wz_3d = wz_grid.reshape(spatial_x, spatial_y, T)
        u = np.array(dataset['y'])
        if u.ndim > 1:
            u = u.squeeze()
        frames_resized = np.zeros((target_x, target_y, T), dtype=np.float32)
        for t in range(T):
            frames_resized[:, :, t] = scipy.ndimage.zoom(
                wz_3d[:, :, t],
                (target_x / spatial_x, target_y / spatial_y),
                order=1
            )
        frames_resized = np.transpose(frames_resized, (2, 0, 1))  # (T, target_x, target_y)
        frames_resized = frames_resized[:, np.newaxis, :, :]      # (T, 1, target_x, target_y)
        return {
            'amplitude': amplitude_key,
            'u': u.astype(np.float32),
            'frames': frames_resized
        }

def load_all_amplitudes(file_path, amplitude_map, amplitude_list):
    data_list = []
    for amp in amplitude_list:
        amp_key = amplitude_map[amp]
        data_case = load_swept_sine_case(file_path, amp_key, target_x=41, target_y=45)
        data_list.append(data_case)
    return data_list

def split_data(file_path):
    amplitude_map = {
        0.5:  'A0p05',
        0.75: 'A0p075',
        1.0:  'A0p10',
        1.25: 'A0p125',
        1.5:  'A0p15',
        1.75: 'A0p175',
        2.0:  'A0p20',
        2.25: 'A0p225',
        2.5:  'A0p25',
        2.75: 'A0p275',
        3.0:  'A0p30'
    }
    train_amps = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
    val_amps   = [0.75, 1.75, 2.75]
    test_amps  = [1.25, 2.25]
    train_list = load_all_amplitudes(file_path, amplitude_map, train_amps)
    val_list   = load_all_amplitudes(file_path, amplitude_map, val_amps)
    test_list  = load_all_amplitudes(file_path, amplitude_map, test_amps)
    return train_list, val_list, test_list

###############################
# 1.1) NORMALISATION DES DONNÉES
###############################
def compute_normalization_stats(data_list):
    all_frames = np.concatenate([data['frames'] for data in data_list], axis=0)  # (total_T, 1, H, W)
    frame_mean = all_frames.mean()
    frame_std = all_frames.std()
    all_u = np.concatenate([data['u'] for data in data_list], axis=0)
    u_mean = all_u.mean()
    u_std = all_u.std()
    return frame_mean, frame_std, u_mean, u_std

def normalize_data_list(data_list, frame_mean, frame_std, u_mean, u_std):
    for data in data_list:
        data['frames'] = (data['frames'] - frame_mean) / frame_std
        data['u'] = (data['u'] - u_mean) / u_std

###############################
# 2) CREATING SEQUENCES FOR ACDM
###############################
def create_acdm_sequences(data_list, past_window=2):
    X_frames_list, X_u_past_list, X_u_curr_list, Y_list = [], [], [], []
    for data_case in data_list:
        frames = data_case['frames']  # (T, 1, H, W) - normalisées
        u = data_case['u']            # (T,) - normalisé
        T = frames.shape[0]
        for i in range(past_window, T):
            past_f = frames[i-past_window:i]          # (past_window, 1, H, W)
            past_u = u[i-past_window:i].reshape(-1, 1)   # (past_window, 1)
            current_u = np.array([u[i]], dtype=np.float32)  # (1,)
            target_f = frames[i]                        # (1, H, W)
            X_frames_list.append(past_f)
            X_u_past_list.append(past_u)
            X_u_curr_list.append(current_u)
            Y_list.append(target_f)
    X_frames = np.array(X_frames_list, dtype=np.float32)
    X_u_past = np.array(X_u_past_list, dtype=np.float32)
    X_u_curr = np.array(X_u_curr_list, dtype=np.float32)
    Y = np.array(Y_list, dtype=np.float32)
    return X_frames, X_u_past, X_u_curr, Y

class ACylinderDataset(Dataset):
    def init(self, X_frames, X_u_past, X_u_curr, Y):
        self.X_frames = X_frames
        self.X_u_past = X_u_past
        self.X_u_curr = X_u_curr
        self.Y = Y
    def len(self):
        return len(self.X_frames)
    def getitem(self, idx):
        return self.X_frames[idx], self.X_u_past[idx], self.X_u_curr[idx], self.Y[idx]

###############################
# 3) TIME EMBEDDING
###############################
def get_timestep_embedding(timesteps, embedding_dim, max_period=10000):
    """
    Crée une embedding sinusoïdale pour les timesteps.
    timesteps: tensor de forme (B,)
    Retourne: tensor de forme (B, embedding_dim)
    """
    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
    emb = timesteps.float()[:, None] * exponent[None, :].exp()
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # pad si nécessaire
        emb = F.pad(emb, (0, 1))
    return emb

class TimeEmbedding(nn.Module):
    def init(self, embedding_dim):
        super(TimeEmbedding, self).init()
        self.embedding_dim = embedding_dim
        self.linear1 = nn.Linear(embedding_dim, embedding_dim)
        self.linear2 = nn.Linear(embedding_dim, embedding_dim)
    def forward(self, t):
        # t est un tensor de shape (B,)
        emb = get_timestep_embedding(t, self.embedding_dim)
        emb = self.linear1(emb)
        emb = F.relu(emb)
        emb = self.linear2(emb)
        return emb

###############################
# 4) ATTENTION MODULES & RESIDUAL BLOCK
###############################
class SelfAttention(nn.Module):
    def init(self, in_channels):
        super(SelfAttention, self).init()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x):
        batch_size, C, height, width = x.size()
        proj_query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        proj_key = self.key(x).view(batch_size, -1, height * width)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value(x).view(batch_size, -1, height * width)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, height, width)
        out = self.gamma * out + x
        return out

class ResidualBlock(nn.Module):
    def init(self, in_channels, out_channels):
        super(ResidualBlock, self).init()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels)
            )
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(residual)
        out = self.relu(out)
        return out

###############################
# 5) IMPROVED DEEP U-NET WITH TIME EMBEDDING, ATTENTION AND RESIDUAL CONNECTIONS
###############################
class ImprovedDeepUNet(nn.Module):
    def init(self, in_channels, out_channels, features=[32, 64, 128, 256], time_emb_dim=256):
        super(ImprovedDeepUNet, self).init()
        self.time_emb_dim = time_emb_dim
        # MLPs pour injecter le time embedding
        self.time_mlp1 = nn.Linear(time_emb_dim, features[0])
        self.time_mlp2 = nn.Linear(time_emb_dim, features[1])
        self.time_mlp3 = nn.Linear(time_emb_dim, features[2])
        self.time_mlp4 = nn.Linear(time_emb_dim, features[3])
        # Encoder
        self.encoder1 = ResidualBlock(in_channels, features[0])
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = ResidualBlock(features[0], features[1])
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = ResidualBlock(features[1], features[2])
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Bottleneck avec attention
        self.bottleneck = nn.Sequential(
            ResidualBlock(features[2], features[3]),
            SelfAttention(features[3])
        )
        # Décodeur
        self.up3 = nn.ConvTranspose2d(features[3], features[2], kernel_size=2, stride=2)
        self.decoder3 = ResidualBlock(features[3], features[2])
        self.attention3 = SelfAttention(features[2])
        self.up2 = nn.ConvTranspose2d(features[2], features[1], kernel_size=2, stride=2)
        self.decoder2 = ResidualBlock(features[2], features[1])
        self.attention2 = SelfAttention(features[1])
        self.up1 = nn.ConvTranspose2d(features[1], features[0], kernel_size=2, stride=2)
        self.decoder1 = ResidualBlock(features[1], features[0])
        # Convolution finale
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x, t_emb):
        # Injecte le time embedding dans chaque étape via une projection linéaire et un broadcast spatial
        B = x.size(0)
        time1 = self.time_mlp1(t_emb).view(B, -1, 1, 1)
        time2 = self.time_mlp2(t_emb).view(B, -1, 1, 1)
        time3 = self.time_mlp3(t_emb).view(B, -1, 1, 1)
        time4 = self.time_mlp4(t_emb).view(B, -1, 1, 1)
        
        enc1 = self.encoder1(x) + time1
        enc2 = self.encoder2(self.pool1(enc1)) + time2
        enc3 = self.encoder3(self.pool2(enc2)) + time3
        bottleneck = self.bottleneck(self.pool3(enc3)) + time4
        dec3 = self.up3(bottleneck)
        if dec3.size() != enc3.size():
            dec3 = F.interpolate(dec3, size=enc3.size()[2:], mode='bilinear', align_corners=False)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.decoder3(dec3)
        dec3 = self.attention3(dec3)
        dec2 = self.up2(dec3)
        if dec2.size() != enc2.size():
            dec2 = F.interpolate(dec2, size=enc2.size()[2:], mode='bilinear', align_corners=False)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.decoder2(dec2)
        dec2 = self.attention2(dec2)
        dec1 = self.up1(dec2)
        if dec1.size() != enc1.size():
            dec1 = F.interpolate(dec1, size=enc1.size()[2:], mode='bilinear', align_corners=False)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.decoder1(dec1)
        out = self.final_conv(dec1)
        if out.shape[-2:] != x.shape[-2:]:
            out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)
        return out

###############################
# 6) CONDITIONAL DIFFUSION MODEL (ACDM) WITH TIME EMBEDDING
###############################
class DiffusionModelACDM(nn.Module):
    def init(self, diffusion_steps, cond_channels, data_channels, time_emb_dim=256, cond_noise_std=0.05):
        """
        cond_noise_std: niveau d'écart-type du bruit ajouté au conditionnement pendant l'entraînement
        """
        super(DiffusionModelACDM, self).init()
        self.timesteps = diffusion_steps
        self.cond_noise_std = cond_noise_std
        # Le UNet prend la concaténation des données conditionnelles et des données cibles
        self.unet = ImprovedDeepUNet(in_channels=cond_channels + data_channels, 
                                     out_channels=data_channels,
                                     time_emb_dim=time_emb_dim)
        self.time_embedding = TimeEmbedding(embedding_dim=time_emb_dim)
        betas = self.linear_beta_schedule(diffusion_steps)  # (T,)
        betas = betas.unsqueeze(1).unsqueeze(2).unsqueeze(3)  # (T,1,1,1)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.register_buffer("betas", betas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))

    def linear_beta_schedule(self, timesteps):
        beta_start = 0.0001
        beta_end = 0.02
        return torch.linspace(beta_start, beta_end, timesteps)

    def forward(self, cond_frames, cond_u_past, cond_u_curr, target_frame):
        # cond_frames: (B, past_window, 1, H, W)
        # cond_u_past: (B, past_window, 1)
        # cond_u_curr: (B, 1)
        # target_frame: (B, 1, H, W)
        B, past_window, _, H, W = cond_frames.shape
        # Reformater le conditionnement
        cond_frames_map = cond_frames.view(B, past_window, H, W)
        # AJOUT DE BRUIT AU CONDITIONNEMENT (pendant l'entraînement uniquement)
        if self.training:
            cond_frames_map = cond_frames_map + torch.randn_like(cond_frames_map) * self.cond_noise_std
        u_past_map = cond_u_past.view(B, past_window, 1, 1).expand(B, past_window, H, W)
        u_curr_map = cond_u_curr.view(B, 1, 1, 1).expand(B, 1, H, W)
        cond = torch.cat([cond_frames_map, u_past_map, u_curr_map], dim=1)
        
        x0 = target_frame
        t = torch.randint(0, self.timesteps, (B,), device=x0.device).long()
        t_emb = self.time_embedding(t)
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(B, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(B, 1, 1, 1)
        noise = torch.randn_like(x0)
        x_t = sqrt_alpha * x0 + sqrt_one_minus_alpha * noise
        input_unet = torch.cat([cond, x_t], dim=1)
        predicted_noise = self.unet(input_unet, t_emb)
        return noise, predicted_noise

    def sample(self, cond_frames, cond_u_past, cond_u_curr, num_steps=None):
        # Lors de l'inférence, aucun bruit n'est ajouté au conditionnement
        if num_steps is None:
            num_steps = self.timesteps
        B, past_window, _, H, W = cond_frames.shape
        cond_frames_map = cond_frames.view(B, past_window, H, W)
        u_past_map = cond_u_past.view(B, past_window, 1, 1).expand(B, past_window, H, W)
        u_curr_map = cond_u_curr.view(B, 1, 1, 1).expand(B, 1, H, W)
        cond = torch.cat([cond_frames_map, u_past_map, u_curr_map], dim=1)
        x_t = torch.randn(B, 1, H, W, device=cond_frames.device)
        for i in reversed(range(num_steps)):
            t = torch.full((B,), i, device=cond_frames.device, dtype=torch.long)
            t_emb = self.time_embedding(t)
            sqrt_alpha = self.sqrt_alphas_cumprod[t].view(B, 1, 1, 1)
            sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(B, 1, 1, 1)
            input_unet = torch.cat([cond, x_t], dim=1)
            predicted_noise = self.unet(input_unet, t_emb)
            x0_pred = (x_t - sqrt_one_minus_alpha * predicted_noise) / sqrt_alpha
            if i > 0:
                beta_t = self.betas[t].view(B, 1, 1, 1)
                noise = torch.randn_like(x_t)
                x_t = x0_pred * (1 - beta_t) + noise * beta_t
            else:
                x_t = x0_pred
        return x_t

###############################
# 7) TRAINING LOOP WITH COSINE ANNEALING SCHEDULER AND TQDM
###############################
def train_diffusion_model(model, train_loader, val_loader, num_epochs, device):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for cond_frames, cond_u_past, cond_u_curr, target_frame in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} training"):
            cond_frames = cond_frames.to(device)
            cond_u_past = cond_u_past.to(device)
            cond_u_curr = cond_u_curr.to(device)
            target_frame = target_frame.to(device)
            optimizer.zero_grad()
            noise, pred_noise = model(cond_frames, cond_u_past, cond_u_curr, target_frame)
            loss = F.smooth_l1_loss(pred_noise, noise)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for cond_frames, cond_u_past, cond_u_curr, target_frame in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} validation"):
                cond_frames = cond_frames.to(device)
                cond_u_past = cond_u_past.to(device)
                cond_u_curr = cond_u_curr.to(device)
                target_frame = target_frame.to(device)
                noise, pred_noise = model(cond_frames, cond_u_past, cond_u_curr, target_frame)
                loss = F.smooth_l1_loss(pred_noise, noise)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.7f}, Val Loss: {val_loss:.7f}, LR: {scheduler.get_last_lr()[0]:.7f}")
        if val_loss < best_val_loss: 
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_diffusion_model.pth')
            print(f"New best model saved with validation loss: {best_val_loss:.7f}")
    print("Training complete!") 

###############################
# 8) EVALUATION ET INFÉRENCE AUTO-RÉGRESSIVE SANS RE-BRUITAGE DU CONDITIONNEMENT
###############################
def evaluate_single_test_autoregressive(model, data_case, past_window, device, num_frames, frame_mean, frame_std):
    """
    Pour chaque frame prédite, on utilise le résultat tel quel pour mettre à jour l'historique,
    car le conditionnement n'est pas bruité lors de l'inférence.
    """
    model.eval()
    frames = data_case['frames']  # frames normalisées
    u = data_case['u']            # signal u normalisé
    T = frames.shape[0]
    history_frames = frames[:past_window].copy()  # (past_window, 1, H, W)
    pred_frames = []
    mse_list = []
    with torch.no_grad():
        for i in range(num_frames):
            t_idx = past_window + i
            if t_idx >= T:
                break
            # Construction des conditions à partir de l'historique
            cond_frames = history_frames[-past_window:]  # (past_window, 1, H, W)
            cond_u_past = u[t_idx - past_window:t_idx].reshape(-1, 1)
            cond_u_curr = np.array([u[t_idx]], dtype=np.float32)
            cond_frames_tensor = torch.tensor(cond_frames, dtype=torch.float32, device=device).unsqueeze(0)
            past_u_tensor = torch.tensor(cond_u_past, dtype=torch.float32, device=device).unsqueeze(0)
            current_u_tensor = torch.tensor(cond_u_curr, dtype=torch.float32, device=device).unsqueeze(0)
            
            # Génération de la frame prédite
            pred_frame = model.sample(cond_frames_tensor, past_u_tensor, current_u_tensor)
            pred_frame_np = pred_frame.squeeze(0).cpu().numpy()  # (1, H, W)
            
            # Calcul de l'erreur (MSE) par rapport à la frame vraie
            true_frame = frames[t_idx]  # (1, H, W)
            mse = F.mse_loss(pred_frame.squeeze(0), torch.tensor(true_frame, dtype=torch.float32, device=device)).item()
            mse_list.append(mse)
            pred_frames.append(pred_frame.squeeze(0).cpu().numpy())
            
            # Mise à jour de l'historique sans ajout de bruit
            history_frames = np.concatenate([history_frames, pred_frame_np[np.newaxis, ...]], axis=0)
            
    pred_frames = np.array(pred_frames)  # (num_frames, 1, H, W)
    true_frames = frames[past_window:past_window+num_frames]  # (num_frames, 1, H, W)
    true_frames_denorm = true_frames * frame_std + frame_mean
    pred_frames_denorm = pred_frames * frame_std + frame_mean
    avg_mse = sum(mse_list) / len(mse_list)
    print(f"Average MSE for amplitude {data_case['amplitude']} (autoregressive): {avg_mse:.6f}")
    return pred_frames_denorm, true_frames_denorm, avg_mse

###############################
# 9) CREATE COMPARISON VIDEO
###############################
def create_comparison_video(gt_frames, pred_frames, amplitude, save_path="comparison_video.mp4"):
    T = gt_frames.shape[0]
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    fig.suptitle(f"Amplitude: {amplitude}")
    ax1, ax2 = axes
    ax1.set_title("Ground Truth")
    ax2.set_title("Predicted")
    vmin = min(gt_frames.min(), pred_frames.min())
    vmax = max(gt_frames.max(), pred_frames.max())
    sns.heatmap(gt_frames[0], cmap="magma", vmin=vmin, vmax=vmax, center=0, ax=ax1, cbar=False, square=True)
    sns.heatmap(pred_frames[0], cmap="magma", vmin=vmin, vmax=vmax, center=0, ax=ax2, cbar=False, square=True)
    def update(frame):
        ax1.clear()
        ax2.clear()
        sns.heatmap(gt_frames[frame], cmap="magma", vmin=vmin, vmax=vmax, center=0, ax=ax1, cbar=False, square=True)
        sns.heatmap(pred_frames[frame], cmap="magma", vmin=vmin, vmax=vmax, center=0, ax=ax2, cbar=False, square=True)
        ax1.set_title(f"Ground Truth (Frame {frame+1}/{T})")
        ax2.set_title(f"Predicted (Frame {frame+1}/{T})")
    ani = animation.FuncAnimation(fig, update, frames=T, interval=200)
    ani.save(save_path, writer="ffmpeg", fps=5, dpi=200)
    print(f"Vidéo sauvegardée sous : {save_path}")
    return save_path

###############################
# 10) MAIN SCRIPT
###############################
if name == "main":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    file_path = "oscillating_cylinder_benchmark_dataset_v2.mat"
    train_list, val_list, test_list = split_data(file_path)
    print(f"Training: {len(train_list)} amplitudes, Validation: {len(val_list)} amplitudes, Test: {len(test_list)} amplitudes")
    
    frame_mean, frame_std, u_mean, u_std = compute_normalization_stats(train_list)
    print(f"Frame mean: {frame_mean:.6f}, Frame std: {frame_std:.6f}")
    print(f"u mean: {u_mean:.6f}, u std: {u_std:.6f}")
    
    normalize_data_list(train_list, frame_mean, frame_std, u_mean, u_std)
    normalize_data_list(val_list, frame_mean, frame_std, u_mean, u_std)
    normalize_data_list(test_list, frame_mean, frame_std, u_mean, u_std)
    
    past_window = 10
    X_train_frames, X_train_u_past, X_train_u_curr, Y_train = create_acdm_sequences(train_list, past_window)
    X_val_frames, X_val_u_past, X_val_u_curr, Y_val = create_acdm_sequences(val_list, past_window)
    
    train_dataset = ACylinderDataset(X_train_frames, X_train_u_past, X_train_u_curr, Y_train)
    val_dataset = ACylinderDataset(X_val_frames, X_val_u_past, X_val_u_curr, Y_val)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    
    # Le conditionnement est composé de past frames (multiplié par 2 pour y inclure u_past) et du u courant
    cond_channels = 2 * past_window + 1  # past frames + past u + current u
    data_channels = 1  # frame cible
    diffusion_steps = 40
    model = DiffusionModelACDM(diffusion_steps, cond_channels, data_channels, cond_noise_std=0.2).to(device)
    
    num_epochs = 10
    print("Début de l'entraînement avec bruitage du conditionnement pendant le training et scheduler cosine annealing...")
    train_diffusion_model(model, train_loader, val_loader, num_epochs, device)
    
    model.load_state_dict(torch.load('best_diffusion_model.pth'))
    
    total_mse = 0
    for i, test_case in enumerate(test_list):
        print(f"\nEvaluation test case {i+1}/{len(test_list)} (amplitude: {test_case['amplitude']}) avec inférence auto-régressive...")
        pred_frames, true_frames, mse = evaluate_single_test_autoregressive(
            model, test_case, past_window, device, num_frames=10, frame_mean=frame_mean, frame_std=frame_std
        )
        total_mse += mse
        gt_sequence = true_frames.squeeze(1)
        pred_sequence = pred_frames.squeeze(1)
        video_path = f"comparison_{test_case['amplitude']}_autoregressive.mp4"
        create_comparison_video(gt_sequence, pred_sequence, test_case['amplitude'], save_path=video_path)
    
    avg_test_mse = total_mse / len(test_list)
    print(f"\nAverage MSE across all test cases (autoregressive): {avg_test_mse:.6f}")
    print("All done!")