In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import math
import h5py
import numpy as np
import scipy.ndimage
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

###############################
# 1) DATA LOADING & RESIZING
###############################
def load_swept_sine_case(file_path, amplitude_key, target_x=256, target_y=256):
    """RedimensionnÃ© en 256x256 pour Ãªtre divisible par 16 (patch size)"""
    with h5py.File(file_path, 'r') as f:
        dataset = f['data_structure']['swept_sines'][amplitude_key]
        wz_grid = np.array(dataset['wz_grid'])  # (6405, T)
        T = wz_grid.shape[1]
        spatial_x, spatial_y = 61, 105
        wz_3d = wz_grid.reshape(spatial_x, spatial_y, T)
        u = np.array(dataset['y'])
        if u.ndim > 1:
            u = u.squeeze()
        frames_resized = np.zeros((target_x, target_y, T), dtype=np.float32)
        for t in range(T):
            frames_resized[:, :, t] = scipy.ndimage.zoom(
                wz_3d[:, :, t],
                (target_x / spatial_x, target_y / spatial_y),
                order=1
            )
        frames_resized = np.transpose(frames_resized, (2, 0, 1))  # (T, target_x, target_y)
        frames_resized = frames_resized[:, np.newaxis, :, :]      # (T, 1, target_x, target_y)
        return {
            'amplitude': amplitude_key,
            'u': u.astype(np.float32),
            'frames': frames_resized
        }

def load_all_amplitudes(file_path, amplitude_map, amplitude_list):
    data_list = []
    for amp in amplitude_list:
        amp_key = amplitude_map[amp]
        data_case = load_swept_sine_case(file_path, amp_key, target_x=48, target_y=48)
        data_list.append(data_case)
    return data_list

def split_data(file_path):
    amplitude_map = {
        0.5:  'A0p05',
        0.75: 'A0p075',
        1.0:  'A0p10',
        1.25: 'A0p125',
        1.5:  'A0p15',
        1.75: 'A0p175',
        2.0:  'A0p20',
        2.25: 'A0p225',
        2.5:  'A0p25',
        2.75: 'A0p275',
        3.0:  'A0p30'
    }
    train_amps = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
    val_amps   = [0.75, 1.75, 2.75]
    test_amps  = [1.25, 2.25]
    train_list = load_all_amplitudes(file_path, amplitude_map, train_amps)
    val_list   = load_all_amplitudes(file_path, amplitude_map, val_amps)
    test_list  = load_all_amplitudes(file_path, amplitude_map, test_amps)
    return train_list, val_list, test_list

###############################
# 1.1) NORMALISATION DES DONNÃ‰ES
###############################
def compute_normalization_stats(data_list):
    all_frames = np.concatenate([data['frames'] for data in data_list], axis=0)
    frame_mean = all_frames.mean()
    frame_std = all_frames.std()
    all_u = np.concatenate([data['u'] for data in data_list], axis=0)
    u_mean = all_u.mean()
    u_std = all_u.std()
    return frame_mean, frame_std, u_mean, u_std

def normalize_data_list(data_list, frame_mean, frame_std, u_mean, u_std):
    for data in data_list:
        data['frames'] = (data['frames'] - frame_mean) / frame_std
        data['u'] = (data['u'] - u_mean) / u_std



###############################
# 2) CREATING SEQUENCES FOR JiT
###############################
def create_jit_sequences(data_list, past_window=2, delta_t=0.1):
    """CrÃ©er des sÃ©quences pour l'entraÃ®nement avec JiT"""
    X_frames_list, X_u_past_list, X_u_curr_list, Y_list = [], [], [], []
    for data_case in data_list:
        frames = data_case['frames']  # (T, 1, H, W)
        u = data_case['u']            # (T,)
        T = frames.shape[0]
        for i in range(past_window, T):
            past_f = frames[i-past_window:i]               # (past_window, 1, H, W)
            past_u = u[i-past_window:i].reshape(-1, 1)     # (past_window, 1)
            current_u = np.array([u[i]], dtype=np.float32) # (1,)
            
            target_f = frames[i]                           # (1, H, W)
            

            
            X_frames_list.append(past_f)
            X_u_past_list.append(past_u)
            X_u_curr_list.append(current_u)
            Y_list.append(target_f)
            
    X_frames = np.array(X_frames_list, dtype=np.float32)
    X_u_past = np.array(X_u_past_list, dtype=np.float32)
    X_u_curr = np.array(X_u_curr_list, dtype=np.float32)
    Y = np.array(Y_list, dtype=np.float32)
    return X_frames, X_u_past, X_u_curr, Y

class FluidDynamicsDataset(Dataset):
    def __init__(self, X_frames, X_u_past, X_u_curr, Y):
        self.X_frames = X_frames
        self.X_u_past = X_u_past
        self.X_u_curr = X_u_curr
        self.Y = Y
    
    def __len__(self):
        return len(self.X_frames)
    
    def __getitem__(self, idx):
        return self.X_frames[idx], self.X_u_past[idx], self.X_u_curr[idx], self.Y[idx]
    

###############################
# 3) JiT ARCHITECTURE COMPONENTS (from model_jit.py)
###############################
def get_2d_sincos_pos_embed(embed_dim, grid_size):
    """Generate 2D sin-cos positional embedding"""
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)
    grid = np.stack(grid, axis=0)
    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    return pos_embed

def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
    emb = np.concatenate([emb_h, emb_w], axis=1)
    return emb

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega
    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        norm = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / norm * self.weight

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class BottleneckPatchEmbed(nn.Module):
    def __init__(self, img_size=48, patch_size=16, in_chans=1, pca_dim=128, embed_dim=384, bias=True):
        super().__init__()
        self.img_size = (img_size, img_size)
        self.patch_size = (patch_size, patch_size)
        self.num_patches = (img_size // patch_size) ** 2
        
        self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
        self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
    
    def forward(self, x):
        x = self.proj2(self.proj1(x)).flatten(2).transpose(1, 2)
        return x

class TimestepEmbedder(nn.Module):
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size
    
    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
        )
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    
    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb

class ControlEmbedder(nn.Module):
    """Embed control signals (u_past and u_curr) into hidden dimension"""
    def __init__(self, past_window, hidden_size):
        super().__init__()
        # Embed past_window + 1 (past + current) control values
        self.mlp = nn.Sequential(
            nn.Linear(past_window + 1, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
        )
    
    def forward(self, u_past, u_curr):
        # u_past: (B, past_window, 1), u_curr: (B, 1)
        u_combined = torch.cat([u_past.squeeze(-1), u_curr], dim=1)  # (B, past_window + 1)
        return self.mlp(u_combined)

def scaled_dot_product_attention(query, key, value, dropout_p=0.0):
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1))
    attn_bias = torch.zeros(query.size(0), 1, L, S, dtype=query.dtype, device=query.device)
    
    with torch.cuda.amp.autocast(enabled=False):
        attn_weight = query.float() @ key.float().transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        
        self.q_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
        self.k_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        q = self.q_norm(q)
        k = self.k_norm(k)
        
        x = scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
        
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class SwiGLUFFN(nn.Module):
    def __init__(self, dim, hidden_dim, drop=0.0, bias=True):
        super().__init__()
        hidden_dim = int(hidden_dim * 2 / 3)
        self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
        self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
        self.ffn_dropout = nn.Dropout(drop)
    
    def forward(self, x):
        x12 = self.w12(x)
        x1, x2 = x12.chunk(2, dim=-1)
        hidden = F.silu(x1) * x2
        return self.w3(self.ffn_dropout(hidden))

class FinalLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = RMSNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )
    
    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x

class JiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.norm1 = RMSNorm(hidden_size, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True,
                              attn_drop=attn_drop, proj_drop=proj_drop)
        self.norm2 = RMSNorm(hidden_size, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.mlp = SwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )
    
    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x



###############################
# 4) JiT-BASED CONDITIONAL DIFFUSION MODEL
###############################
class JiTFluidDiffusion(nn.Module):
    def __init__( 
        self,
        img_size=256,
        patch_size=16,
        in_channels=1,
        past_window=10,
        hidden_size=384,
        depth=4,
        num_heads=6,
        mlp_ratio=4.0,
        bottleneck_dim=64,
        diffusion_steps=50,
        P_mean=-0.8,
        P_std=0.8,
        t_eps=0.05,
        noise_scale=1.0
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.hidden_size = hidden_size
        self.past_window = past_window
        self.timesteps = diffusion_steps
        self.P_mean = P_mean
        self.P_std = P_std
        self.t_eps = t_eps
        self.noise_scale = noise_scale
        
        # Embedders
        self.t_embedder = TimestepEmbedder(hidden_size)
        self.control_embedder = ControlEmbedder(past_window, hidden_size)
        
        # Patch embedding pour la target frame (bruitÃ©e)
        self.x_embedder = BottleneckPatchEmbed(
            img_size, patch_size, in_channels, bottleneck_dim, hidden_size, bias=True
        )
        
        # Patch embedding pour les past frames de conditionnement
        # On va concatÃ©ner past_window frames en entrÃ©e
        self.cond_embedder = BottleneckPatchEmbed(
            img_size, patch_size, past_window, bottleneck_dim, hidden_size, bias=True
        )
        
        # Positional embedding
        num_patches = self.x_embedder.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            JiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_drop=0.0, proj_drop=0.0)
            for _ in range(depth)
        ])
        
        # Final prediction layer
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        
        self.initialize_weights()
    
    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)
        
        # Initialize pos_embed
        pos_embed = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1], 
            int(self.x_embedder.num_patches ** 0.5)
        )
        # (B, num_patches, hidden_size)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        
        # Initialize patch embeddings
        for embedder in [self.x_embedder, self.cond_embedder]:
            w1 = embedder.proj1.weight.data
            nn.init.xavier_uniform_(w1.view([w1.shape[0], -1]))
            w2 = embedder.proj2.weight.data
            nn.init.xavier_uniform_(w2.view([w2.shape[0], -1]))
            nn.init.constant_(embedder.proj2.bias, 0)
        
        # Zero-out adaLN modulation layers
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
        
        # Zero-out output layers
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)
    
    def unpatchify(self, x):
        """x: (N, num_patches, patch_size**2 * C) -> imgs: (N, C, H, W)"""
        p = self.patch_size
        c = self.out_channels
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
        return imgs
    
    def sample_t(self, n, device):
        """Sample timesteps from logit-normal distribution"""
        z = torch.randn(n, device=device) * self.P_std + self.P_mean
        return torch.sigmoid(z)
    
    def forward(self, cond_frames, cond_u_past, cond_u_curr, target_frame):
        """
        Training forward pass with x-prediction and v-loss
        cond_frames: (B, past_window, 1, H, W)
        cond_u_past: (B, past_window, 1)
        cond_u_curr: (B, 1)
        target_frame: (B, 1, H, W)
        """
        B = target_frame.size(0)
        device = target_frame.device
        
        # Sample timestep
        t = self.sample_t(B, device).view(-1, *([1] * (target_frame.ndim - 1)))
        t_flat = t.flatten()
        
        # Add noise to target frame (forward diffusion)
        e = torch.randn_like(target_frame) * self.noise_scale
        z_t = t * target_frame + (1 - t) * e
        
        # Ground truth velocity
        v = (target_frame - z_t) / (1 - t).clamp_min(self.t_eps)
        
        # Embed conditioning
        # (B, past_window, 1, H, W) -> (B, past_window, H, W)
        cond_frames_reshaped = cond_frames.squeeze(2)
        # (B, past_window, H, W) -> (B, num_patches, hidden_size)
        cond_tokens = self.cond_embedder(cond_frames_reshaped)
        
        # Embed noisy target
        x_tokens = self.x_embedder(z_t)  # (B, num_patches, hidden_size)
        
        # Combine conditioning and noisy target tokens
        # Simple approach: add them
        tokens = x_tokens + cond_tokens + self.pos_embed
        
        # Time and control embedding
        t_emb = self.t_embedder(t_flat)
        control_emb = self.control_embedder(cond_u_past, cond_u_curr)
        c = t_emb + control_emb
        
        # Transformer blocks
        for block in self.blocks:
            tokens = block(tokens, c)
        
        # Final layer predicts x directly
        x_pred_tokens = self.final_layer(tokens, c)
        x_pred = self.unpatchify(x_pred_tokens)
        
        # Convert x_pred to v_pred for loss
        v_pred = (x_pred - z_t) / (1 - t).clamp_min(self.t_eps)
        
        # v-loss
        loss = F.smooth_l1_loss(v_pred, v)
        
        return loss
    
    
    def forward(self, cond_frames, cond_u_past, cond_u_curr, target_frame, delta_t):
        """
        Training forward pass with solver-in-the-loop (Heun)
        cond_frames: (B, past_window, 1, H, W)
        cond_u_past: (B, past_window, 1)
        cond_u_curr: (B, 1)
        target_frame: (B, 1, H, W)
        """
        B = target_frame.size(0)
        device = target_frame.device

        # --------------------------------------------------
        # Current state = last conditioning frame
        # --------------------------------------------------
        x_curr = cond_frames[:, -1]   # (B, 1, H, W)

        # --------------------------------------------------
        # Conditioning tokens (UNCHANGED)
        # --------------------------------------------------
        cond_frames_reshaped = cond_frames.squeeze(2)
        cond_tokens = self.cond_embedder(cond_frames_reshaped)

        # --------------------------------------------------
        # Time + control embedding (keep, but t = 1)
        # --------------------------------------------------
        t = torch.ones(B, device=device)
        t_emb = self.t_embedder(t)
        control_emb = self.control_embedder(cond_u_past, cond_u_curr)
        c = t_emb + control_emb

        # ==================================================
        # k1
        # ==================================================
        x_tokens = self.x_embedder(x_curr)
        tokens = x_tokens + cond_tokens + self.pos_embed

        for block in self.blocks:
            tokens = block(tokens, c)

        x_dot_1_tokens = self.final_layer(tokens, c)
        x_dot_1 = self.unpatchify(x_dot_1_tokens)

        # --------------------------------------------------
        # Heun intermediate state
        # --------------------------------------------------
        x_tilde = x_curr + delta_t * x_dot_1

        cond_frames_tilde = torch.cat(
            [cond_frames[:, 1:], x_tilde.unsqueeze(1)], dim=1
        )
        cond_frames_tilde_reshaped = cond_frames_tilde.squeeze(2)
        cond_tokens_tilde = self.cond_embedder(cond_frames_tilde_reshaped)

        # ==================================================
        # k2
        # ==================================================
        x_tokens_2 = self.x_embedder(x_tilde)
        tokens_2 = x_tokens_2 + cond_tokens_tilde + self.pos_embed

        for block in self.blocks:
            tokens_2 = block(tokens_2, c)

        x_dot_2_tokens = self.final_layer(tokens_2, c)
        x_dot_2 = self.unpatchify(x_dot_2_tokens)

        # --------------------------------------------------
        # Heun integration
        # --------------------------------------------------
        x_hat = x_curr + 0.5 * delta_t * (x_dot_1 + x_dot_2)

        # --------------------------------------------------
        # Loss on frame
        # --------------------------------------------------
        loss = F.mse_loss(x_hat, target_frame)

        return loss
    
    def forward(self, x, labels):
        labels_dropped = self.drop_labels(labels) if self.training else labels

        t = self.sample_t(x.size(0), device=x.device).view(-1, *([1] * (x.ndim - 1)))
        e = torch.randn_like(x) * self.noise_scale

        z = t * x + (1 - t) * e
        v = (x - z) / (1 - t).clamp_min(self.t_eps)

        x_pred = self.net(z, t.flatten(), labels_dropped)
        v_pred = (x_pred - z) / (1 - t).clamp_min(self.t_eps)

        # l2 loss
        loss = (v - v_pred) ** 2
        loss = loss.mean(dim=(1, 2, 3)).mean()

        return loss

    
    @torch.no_grad() 
    def sample(self, cond_frames, cond_u_past, cond_u_curr, num_steps=50, method='heun'):
        """
        Generate a frame using reverse diffusion
        """
        B = cond_frames.size(0)
        device = cond_frames.device
        H, W = self.img_size, self.img_size
        
        # Start from noise
        z = self.noise_scale * torch.randn(B, 1, H, W, device=device)
        
        # Timesteps
        timesteps = torch.linspace(0.0, 1.0, num_steps + 1, device=device)
        
        # Embed conditioning (constant throughout sampling)
        cond_frames_reshaped = cond_frames.squeeze(2)
        cond_tokens = self.cond_embedder(cond_frames_reshaped)
        cond_tokens = cond_tokens + self.pos_embed
        
        # Control embedding (constant)
        control_emb = self.control_embedder(cond_u_past, cond_u_curr)
        
        for i in range(num_steps):
            t_curr = timesteps[i]
            t_next = timesteps[i + 1]
            
            if method == 'euler':
                z = self._euler_step(z, t_curr, t_next, cond_tokens, control_emb)
            elif method == 'heun':
                z = self._heun_step(z, t_curr, t_next, cond_tokens, control_emb)
            else:
                raise ValueError(f"Unknown method: {method}")
        return z
    
    def _forward_sample(self, z, t, cond_tokens, control_emb):
        """Single forward pass during sampling"""
        B = z.size(0)
        t_scalar = t.expand(B)
        
        # Embed noisy sample
        x_tokens = self.x_embedder(z)
        tokens = x_tokens + cond_tokens
        
        # Time embedding
        t_emb = self.t_embedder(t_scalar)
        c = t_emb + control_emb
        
        # Transformer
        for block in self.blocks:
            tokens = block(tokens, c)
        
        # Predict x
        x_pred_tokens = self.final_layer(tokens, c)
        x_pred = self.unpatchify(x_pred_tokens)
        
        # Convert to velocity
        t_broadcast = t.view(-1, *([1] * (z.ndim - 1)))
        v_pred = (x_pred - z) / (1.0 - t_broadcast).clamp_min(self.t_eps)
        
        return v_pred
    
    def _euler_step(self, z, t_curr, t_next, cond_tokens, control_emb):
        v_pred = self._forward_sample(z, t_curr, cond_tokens, control_emb)
        z_next = z + (t_next - t_curr) * v_pred
        return z_next
    
    def _heun_step(self, z, t_curr, t_next, cond_tokens, control_emb):
        # First Euler step
        v_pred_curr = self._forward_sample(z, t_curr, cond_tokens, control_emb)
        z_euler = z + (t_next - t_curr) * v_pred_curr
        
        # Second evaluation
        v_pred_next = self._forward_sample(z_euler, t_next, cond_tokens, control_emb)
        
        # Heun average
        v_pred = 0.5 * (v_pred_curr + v_pred_next)
        z_next = z + (t_next - t_curr) * v_pred
        return z_next
    
    @torch.no_grad()
    def predict_x_dot(self, cond_frames, cond_u_past, cond_u_curr):
        B = cond_frames.size(0)
        device = cond_frames.device

        x_curr = cond_frames[:, -1]

        cond_frames_reshaped = cond_frames.squeeze(2)
        cond_tokens = self.cond_embedder(cond_frames_reshaped)

        t = torch.ones(B, device=device)
        t_emb = self.t_embedder(t)
        control_emb = self.control_embedder(cond_u_past, cond_u_curr)
        c = t_emb + control_emb

        x_tokens = self.x_embedder(x_curr)
        tokens = x_tokens + cond_tokens + self.pos_embed

        for block in self.blocks:
            tokens = block(tokens, c)

        x_dot_tokens = self.final_layer(tokens, c)
        x_dot = self.unpatchify(x_dot_tokens)

        return x_dot

    
    
###############################
# 5) TRAINING LOOP
###############################
def train_jit_diffusion(model, train_loader, val_loader, num_epochs, device, learning_rate=1e-4, training_delta_t=0.1):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.0)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        for cond_frames, cond_u_past, cond_u_curr, target_frame in tqdm(
            train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"
        ):
            cond_frames = cond_frames.to(device)
            cond_u_past = cond_u_past.to(device)
            cond_u_curr = cond_u_curr.to(device)
            target_frame = target_frame.to(device)
            
            optimizer.zero_grad()
            loss = model(cond_frames, cond_u_past, cond_u_curr, target_frame, training_delta_t)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for cond_frames, cond_u_past, cond_u_curr, target_frame in tqdm(
                val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"
            ):
                cond_frames = cond_frames.to(device)
                cond_u_past = cond_u_past.to(device)
                cond_u_curr = cond_u_curr.to(device)
                target_frame = target_frame.to(device)
                
                loss = model(cond_frames, cond_u_past, cond_u_curr, target_frame)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        scheduler.step()
        
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | LR: {scheduler.get_last_lr()[0]:.7f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_jit_fluid_model.pth')
            print(f"âœ“ New best model saved! Val Loss: {best_val_loss:.6f}")
    
    print("\nâœ… Training complete!")







###############################
# 6) AUTOREGRESSIVE EVALUATION (Heun / RK2)
###############################
@torch.no_grad()
def evaluate_autoregressive(
    model,
    data_case,
    past_window,
    device,
    num_frames,
    frame_mean,
    frame_std,
    delta_t=0.1
):
    model.eval()

    frames = data_case['frames']  # (T, 1, H, W)
    u = data_case['u']
    T = frames.shape[0]

    # initial history
    history_frames = frames[:past_window].copy()

    pred_frames = []
    mse_list = []

    for i in range(num_frames):
        t_idx = past_window + i
        if t_idx >= T:
            break

        # -----------------------------
        # Conditioning
        # -----------------------------
        cond_frames = history_frames[-past_window:]
        cond_u_past = u[t_idx - past_window:t_idx].reshape(-1, 1)
        cond_u_curr = np.array([u[t_idx]], dtype=np.float32)

        cond_frames_t = torch.tensor(
            cond_frames, dtype=torch.float32, device=device
        ).unsqueeze(0)
        u_past_t = torch.tensor(
            cond_u_past, dtype=torch.float32, device=device
        ).unsqueeze(0)
        u_curr_t = torch.tensor(
            cond_u_curr, dtype=torch.float32, device=device
        ).unsqueeze(0)

        # -----------------------------
        # Current state
        # -----------------------------
        x_curr = cond_frames_t[:, -1]  # (B=1, 1, H, W)

        # =============================
        # Heun (RK2)
        # =============================

        # ---- k1 ----
        x_dot_1 = model.predict_x_dot(
            cond_frames_t, u_past_t, u_curr_t
        )

        x_tilde = x_curr + delta_t * x_dot_1

        # update conditioning for k2
        cond_frames_tilde = torch.cat(
            [cond_frames_t[:, 1:], x_tilde.unsqueeze(1)], dim=1
        )

        # ---- k2 ----
        x_dot_2 = model.predict_x_dot(
            cond_frames_tilde, u_past_t, u_curr_t
        )

        # ---- Heun update ----
        x_next = x_curr + 0.5 * delta_t * (x_dot_1 + x_dot_2)

        # -----------------------------
        # To numpy
        # -----------------------------
        x_next_np = x_next.squeeze(0).cpu().numpy()

        pred_frames.append(x_next_np)

        # -----------------------------
        # Error
        # -----------------------------
        true_frame = frames[t_idx]
        mse = np.mean((x_next_np - true_frame) ** 2)
        mse_list.append(mse)

        # -----------------------------
        # Update history
        # -----------------------------
        history_frames = np.concatenate(
            [history_frames, x_next_np[np.newaxis]], axis=0
        )
        history_frames = history_frames[-past_window:]

    # ---------------------------------
    # Stack & denormalize
    # ---------------------------------
    pred_frames = np.array(pred_frames)
    true_frames = frames[past_window:past_window + len(pred_frames)]

    pred_frames_denorm = pred_frames * frame_std + frame_mean
    true_frames_denorm = true_frames * frame_std + frame_mean

    avg_mse = float(np.mean(mse_list))
    print(f"Average MSE for amplitude {data_case['amplitude']}: {avg_mse:.6f}")

    return pred_frames_denorm, true_frames_denorm, avg_mse



###############################
# 7) VISUALIZATION
###############################
def create_comparison_video(gt_frames, pred_frames, amplitude, save_path="comparison_video.mp4"):
    T = gt_frames.shape[0]
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    fig.suptitle(f"Amplitude: {amplitude}")
    ax1, ax2 = axes
    ax1.set_title("Ground Truth")
    ax2.set_title("Predicted (JiT)")
    
    vmin = min(gt_frames.min(), pred_frames.min())
    vmax = max(gt_frames.max(), pred_frames.max())
    
    def update(frame):
        ax1.clear()
        ax2.clear()
        sns.heatmap(gt_frames[frame], cmap="magma", vmin=vmin, vmax=vmax,
                    center=0, ax=ax1, cbar=False, square=True)
        sns.heatmap(pred_frames[frame], cmap="magma", vmin=vmin, vmax=vmax,
                    center=0, ax=ax2, cbar=False, square=True)
        ax1.set_title(f"Ground Truth (Frame {frame+1}/{T})")
        ax2.set_title(f"Predicted JiT (Frame {frame+1}/{T})")
    
    ani = animation.FuncAnimation(fig, update, frames=T, interval=200)
    ani.save(save_path, writer="ffmpeg", fps=5, dpi=200)
    plt.close(fig)
    print(f"Video saved: {save_path}")
    return save_path


###############################
# 8) MAIN
###############################
if __name__ == "__main__":

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"ðŸš€ Using device: {device}")

    # =============================
    # PATH
    # =============================
    file_path = "/kaggle/input/oscillating-cylinder-benchmark-dataset-v2/oscillating_cylinder_benchmark_dataset_v2.mat"

    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Dataset not found at {file_path}")

    print(f"âœ“ File found: {file_path}")

    # =============================
    # LOAD DATA
    # =============================
    print("\n=== Loading data ===")
    train_list, val_list, test_list = split_data(file_path)
    print(f"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}")

    # =============================
    # NORMALIZATION
    # =============================
    print("\n=== Normalization ===")
    frame_mean, frame_std, u_mean, u_std = compute_normalization_stats(train_list)
    print(f"Frame: Î¼={frame_mean:.6f}, Ïƒ={frame_std:.6f}")
    print(f"Control u: Î¼={u_mean:.6f}, Ïƒ={u_std:.6f}")

    normalize_data_list(train_list, frame_mean, frame_std, u_mean, u_std)
    normalize_data_list(val_list, frame_mean, frame_std, u_mean, u_std)
    normalize_data_list(test_list, frame_mean, frame_std, u_mean, u_std)

    # =============================
    # SEQUENCES
    # =============================
    print("\n=== Creating sequences ===")
    past_window = 10
    X_train_frames, X_train_u_past, X_train_u_curr, Y_train = create_jit_sequences(train_list, past_window)
    X_val_frames, X_val_u_past, X_val_u_curr, Y_val = create_jit_sequences(val_list, past_window)

    print(f"Train sequences: {X_train_frames.shape[0]}")
    print(f"Val sequences: {X_val_frames.shape[0]}")

    train_dataset = FluidDynamicsDataset(X_train_frames, X_train_u_past, X_train_u_curr, Y_train)
    val_dataset = FluidDynamicsDataset(X_val_frames, X_val_u_past, X_val_u_curr, Y_val)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

    # =============================
    # MODEL
    # =============================
    print("\n=== Creating JiT model ===")
    model = JiTFluidDiffusion(
        img_size=256,
        patch_size=16,
        in_channels=1,
        past_window=past_window,
        hidden_size=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4.0,
        bottleneck_dim=64
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

    # =============================
    # TRAIN
    # =============================
    print("\n=== Training ===")
    NUM_EPOCHS = 20
    DELTA_T_TRAIN = 1.0

    train_jit_diffusion(
        model,
        train_loader,
        val_loader,
        num_epochs=NUM_EPOCHS,
        device=device,
        learning_rate=2e-4
    )

    # =============================
    # EVALUATION
    # =============================
    print("\n=== Evaluation on test set ===")
    model.load_state_dict(torch.load("best_jit_fluid_model.pth", map_location=device))

    DELTA_T_EVAL = 1.0
    total_mse = 0.0

    for i, test_case in enumerate(test_list):
        print(f"\nðŸ“Š Test case {i+1}/{len(test_list)} (amplitude: {test_case['amplitude']})")

        pred_frames, true_frames, mse = evaluate_autoregressive(
            model=model,
            data_case=test_case,
            past_window=past_window,
            device=device,
            num_frames=20,
            frame_mean=frame_mean,
            frame_std=frame_std,
            delta_t=DELTA_T_EVAL
        )

        total_mse += mse

        gt_seq = true_frames.squeeze(1)
        pred_seq = pred_frames.squeeze(1)

        video_path = f"jit_comparison_amp_{test_case['amplitude']}.mp4"
        create_comparison_video(gt_seq, pred_seq, test_case['amplitude'], video_path)

    avg_test_mse = total_mse / len(test_list)
    print("\n" + "=" * 60)
    print(f"âœ… Average test MSE (normalized): {avg_test_mse:.6f}")
    print("=" * 60)