In [1]:
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, SubsetRandomSampler
import copy

import os
import math
import numpy as np
import scipy.ndimage
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from tqdm import tqdm
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

In [3]:
class DataParams(object):
    def __init__(self, batch=4, augmentations=[], sequenceLength=[1,1], randSeqOffset=False,
                dataSize=[128,64], dimension=2, simFields=[], simParams=[], normalizeMode=""):
        self.batch          = batch             # batch size
        self.augmentations  = augmentations     # used data augmentations
        self.sequenceLength = sequenceLength    # number of simulation frames in one sequence
        self.randSeqOffset  = randSeqOffset     # randomize sequence starting frame
        self.dataSize       = dataSize          # target data size for scale/crop/cropRandom transformation
        self.dimension      = dimension         # number of data dimension
        self.simFields      = simFields         # which simulation fields are added (vel is always used) from ["dens", "pres"]
        self.simParams      = simParams         # which simulation parameters are added from ["rey", "mach"]
        self.normalizeMode  = normalizeMode     # which mean and std values from different data sets are used in normalization transformation

    @classmethod
    def fromDict(cls, d:dict):
        p = cls()
        p.batch          = d.get("batch",            -1)
        p.augmentations  = d.get("augmentations",    [])
        p.sequenceLength = d.get("sequenceLength",   [])
        p.randSeqOffset  = d.get("randSeqOffset",    False)
        p.dataSize       = d.get("dataSize",         -1)
        p.dimension      = d.get("dimension",        -1)
        p.simFields      = d.get("simFields",        [])
        p.simParams      = d.get("simParams",        [])
        p.normalizeMode  = d.get("normalizeMode",    "")
        return p

    def asDict(self) -> dict:
        return {
            "batch"          : self.batch,
            "augmentations"  : self.augmentations,
            "sequenceLength" : self.sequenceLength,
            "randSeqOffset"  : self.randSeqOffset,
            "dataSize"       : self.dataSize,
            "dimension"      : self.dimension,
            "simFields"      : self.simFields,
            "simParams"      : self.simParams,
            "normalizeMode"  : self.normalizeMode,
        }



class TrainingParams(object):
    def __init__(self, epochs=20, lr=0.0001, expLrGamma=1.0, weightDecay=0.0, fadeInPredLoss=[-1,0], fadeInSeqLen=[-1,0], fadeInSeqLenLin=False):
        self.epochs            = epochs            # number of training epochs
        self.lr                = lr                # learning rate
        self.expLrGamma        = expLrGamma        # factor for exponential learning rate decay
        self.weightDecay       = weightDecay       # weight decay factor to regularize the net by penalizing large weights
        self.fadeInPredLoss    = fadeInPredLoss    # start and end epoch of fading in the prediction loss terms
        self.fadeInSeqLen      = fadeInSeqLen      # start and end epoch of fading in the sequence length
        self.fadeInSeqLenLin   = fadeInSeqLenLin   # exponential or linear scaling of fading in the sequence length
        
    @classmethod
    def fromDict(cls, d:dict):
        p = cls()
        p.epochs            = d.get("epochs",           -1)
        p.lr                = d.get("lr",               -1)
        p.expLrGamma        = d.get("expLrGamma",        1)
        p.weightDecay       = d.get("weightDecay",      -1)
        p.fadeInPredLoss    = d.get("fadeInPredLoss",   [])
        p.fadeInSeqLen      = d.get("fadeInSeqLen",     [])
        p.fadeInSeqLenLin   = d.get("fadeInSeqLenLin",  False)
        return p

    def asDict(self) -> dict:
        return {
            "epochs"            : self.epochs,
            "lr"                : self.lr,
            "expLrGamma"        : self.expLrGamma,
            "weightDecay"       : self.weightDecay,
            "fadeInPredLoss"    : self.fadeInPredLoss,
            "fadeInSeqLen"      : self.fadeInSeqLen,
            "fadeInSeqLenLin"   : self.fadeInSeqLenLin,
        }



class LossParams(object):
    def __init__(self, recMSE=1.0, recLSIM=0, predMSE=1.0, predLSIM=0, extraMSEvelZ=0, regMeanStd=0, regDiv=0, regVae=0, regLatStep=0):
        self.recMSE       = recMSE       # mse loss reconstruction weight
        self.recLSIM      = recLSIM      # lsim loss reconstruction weight
        self.predMSE      = predMSE      # mse loss prediction weight
        self.predLSIM     = predLSIM     # lsim loss prediction weight
        self.regMeanStd   = regMeanStd   # mean and standard deviation regularization weight
        self.regDiv       = regDiv       # divergence regularization weight
        self.regVae       = regVae       # regularization weight for VAE KL divergence
        self.regLatStep   = regLatStep   # latent space step regularization weight

    @classmethod
    def fromDict(cls, d:dict):
        p = cls()
        p.recMSE       = d.get("recMSE", -1)
        p.recLSIM      = d.get("recLSIM", -1)
        p.predMSE      = d.get("predMSE", -1)
        p.predLSIM     = d.get("predLSIM", -1)
        p.regMeanStd   = d.get("regMeanStd", -1)
        p.regDiv       = d.get("regDiv", -1)
        p.regVae       = d.get("regVae", -1)
        p.regLatStep   = d.get("regLatStep", -1)
        return p

    def asDict(self) -> dict:
        return {
            "recMSE"       : self.recMSE,
            "recLSIM"      : self.recLSIM,
            "predMSE"      : self.predMSE,
            "predLSIM"     : self.predLSIM,
            "regMeanStd"   : self.regMeanStd,
            "regDiv"       : self.regDiv,
            "regVae"       : self.regVae,
            "regLatStep"   : self.regLatStep,
        }



class ModelParamsEncoder(object):
    def __init__(self, arch="skip", pretrained=False, frozen=False, encWidth=16, latentSize=16):
        self.arch = arch              # architecture variant
        self.pretrained = pretrained  # load pretrained weight initialization
        self.frozen = frozen          # freeze weights after initialization
        self.encWidth = encWidth      # width of encoder network
        self.latentSize = latentSize  # size of latent space vector

    @classmethod
    def fromDict(cls, d:dict):
        p = cls()
        p.arch       = d.get("arch", "")
        p.pretrained = d.get("pretrained", False)
        p.frozen     = d.get("frozen", False)
        p.encWidth   = d.get("encWidth", -1)
        p.latentSize = d.get("latentSize", -1)
        return p

    def asDict(self) -> dict:
        return {
            "arch"       : self.arch,
            "pretrained" : self.pretrained,
            "frozen"     : self.frozen,
            "encWidth"   : self.encWidth,
            "latentSize" : self.latentSize,
        }



class ModelParamsDecoder(object):
    def __init__(self, arch="skip", pretrained=False, frozen=False, decWidth=48, vae=False, trainingNoise=0.0,
                 diffSteps=500, diffSchedule="linear", diffCondIntegration="noisy", fnoModes=(16,16), refinerStd=0.0):
        self.arch = arch                 # architecture variant
        self.pretrained = pretrained     # load pretrained weight initialization
        self.frozen = frozen             # freeze weights after initialization
        self.decWidth = decWidth         # width of decoder network
        self.vae = vae                   # use a variational AE setup
        self.trainingNoise = trainingNoise # amount of training noise added to inputs
        self.diffSteps = diffSteps       # diffusion model diffusion time steps
        self.diffSchedule = diffSchedule # diffusion model variance schedule
        self.diffCondIntegration = diffCondIntegration # integrationg of conditioning during diffusion training
        self.fnoModes = fnoModes         # number of fourier modes for FNO setup
        self.refinerStd = refinerStd     # noise standard dev. in pde refiner setup

    @classmethod
    def fromDict(cls, d:dict):
        p = cls()
        p.arch         = d.get("arch", "")
        p.pretrained   = d.get("pretrained", False)
        p.frozen       = d.get("frozen", False)
        p.decWidth     = d.get("decWidth", -1)
        p.vae          = d.get("vae", False)
        p.trainingNoise= d.get("trainingNoise", 0.0)
        p.diffSteps    = d.get("diffSteps", 500)
        p.diffSchedule = d.get("diffSchedule", "linear")
        p.diffCondIntegration  = d.get("diffCondIntegration", "noisy")
        p.fnoModes     = d.get("fnoModes", ())
        p.refinerStd   = d.get("refinerStd", 0.0)
        return p

    def asDict(self) -> dict:
        return {
            "arch"         : self.arch,
            "pretrained"   : self.pretrained,
            "frozen"       : self.frozen,
            "decWidth"     : self.decWidth,
            "vae"          : self.vae,
            "trainingNoise": self.trainingNoise,
            "diffSteps"    : self.diffSteps,
            "diffSchedule" : self.diffSchedule,
            "diffCondIntegration" : self.diffCondIntegration,
            "fnoModes"     : self.fnoModes,
            "refinerStd"   : self.refinerStd,
        }



class ModelParamsLatent(object):
    def __init__(self, arch="fc", pretrained=False, frozen=False, width=512, layers=6, heads=4, dropout=0.0,
               transTrainUnroll=False, transTargetFull=False, maxInputLen=-1):
        self.arch = arch                         # architecture variant
        self.pretrained = pretrained             # load pretrained weight initialization
        self.frozen = frozen                     # freeze weights after initialization
        self.width = width                       # latent network width
        self.layers = layers                     # number of latent network layers
        self.heads = heads                       # number of attention heads in transformer
        self.dropout = dropout                   # dropout rate in latent network
        self.transTrainUnroll = transTrainUnroll # unrolled training for transformer latent models, FALSE for one step predictions TRUE for full rollouts
        self.transTargetFull = transTargetFull   # full target data for transformer and transformer decoder latent models, FALSE for only the previous step as a target TRUE for every previous step as a target
        self.maxInputLen = maxInputLen           # how many steps of the input sequence are processed at once for models that predict full sequences (-1 for no limit)


    @classmethod
    def fromDict(cls, d:dict):
        p = cls()
        p.arch             = d.get("arch", "")
        p.pretrained       = d.get("pretrained", False)
        p.frozen           = d.get("frozen", False)
        p.width            = d.get("width", "")
        p.layers           = d.get("layers", "")
        p.heads            = d.get("heads", "")
        p.dropout          = d.get("dropout", "")
        p.transTrainUnroll = d.get("transTrainUnroll", False)
        p.transTargetFull  = d.get("transTargetFull", False)
        p.maxInputLen      = d.get("maxInputLen", -1)
        return p

    def asDict(self) -> dict:
        return {
            "arch"             : self.arch,
            "pretrained"       : self.pretrained,
            "frozen"           : self.frozen,
            "width"            : self.width,
            "layers"           : self.layers,
            "heads"            : self.heads,
            "dropout"          : self.dropout,
            "transTrainUnroll" : self.transTrainUnroll,
            "transTargetFull"  : self.transTargetFull,
            "maxInputLen"      : self.maxInputLen,
        }



import torch
import torch.nn.functional as F
import numpy as np


class Transforms(object):
    p_d: DataParams

    def __init__(self, p_d:DataParams):

        assert all(aug in ["normalize", "flip", "crop", "resize"]
                        for aug in p_d.augmentations), "Invalid augmentation provided!"
        assert not ("crop" in p_d.augmentations and "resize" in p_d.augmentations
                        ), "Crop and resize augmentation not allowed at the same time!"
        assert (p_d.normalizeMode != ""), "Invalid normalization mode!"

        self.p_d = p_d
        self.normalize = "normalize" in p_d.augmentations
        self.flip = "flip" in p_d.augmentations
        self.crop = "crop" in p_d.augmentations
        self.resize = "resize" in p_d.augmentations
        self.outputSize = p_d.dataSize
        self.dim = p_d.dimension
        self.simFields = p_d.simFields
        self.simParams = p_d.simParams

        # mean and std statistics from whole dataset for normalization
        if self.dim == 2:
            l = self.p_d.normalizeMode.lower()
            if ("inc" in l and "mixed" in l) or ("karman" in l and "mixed" in l):
                # ORDER (fields): velocity (x,y), --, pressure, ORDER (params): rey, --, --
                self.normMean = np.array([0.444969, 0.000299, 0, 0.000586, 550.000000, 0, 0], dtype=np.float32)
                self.normStd =  np.array([0.206128, 0.206128, 1, 0.003942, 262.678467, 1, 1], dtype=np.float32)

            if ("tra" in l and "mixed" in l) or ("mach" in l and "mixed" in l):
                # ORDER (fields): velocity (x,y), density, pressure, ORDER (params): rey, mach, --
                self.normMean = np.array([0.560642, -0.000129, 0.903352, 0.637941, 10000.000000, 0.700000, 0], dtype=np.float32)
                self.normStd =  np.array([0.216987, 0.216987, 0.145391, 0.119944, 1, 0.118322, 1], dtype=np.float32)

            if "iso" in l and "single" in l:
                # ORDER (fields): velocity (x,y,z), pressure, ORDER (params): --, --, --
                self.normMean = np.array([-0.054618, -0.385225, -0.255757, 0.033446, 0, 0, 0], dtype=np.float32)
                self.normStd =  np.array([0.539194, 0.710318, 0.510352, 0.258235, 1, 1, 1], dtype=np.float32)

        # seeding once for single thread data loading
        self.randGen = np.random.RandomState(torch.random.initial_seed() % 4294967295)


    def __call__(self, sample:dict):
        # seeding in every call for multi thread data loading
        if torch.utils.data.get_worker_info():
            self.randGen = np.random.RandomState(torch.utils.data.get_worker_info().seed % 4294967295)

        data = sample["data"]
        simParameters = sample["simParameters"]
        allParameters = sample["allParameters"]
        obsMask = sample.get("obsMask", None)
        path = sample["path"]

        # normalization to std. normal distr. with zero mean and unit std via statistics from whole dataset
        # ORDER (fields): velocity (x,y), velocity z / density, pressure, ORDER (params): rey, mach, zslice
        if self.normalize:
            filterList = [0, 1] if self.dim == 2 else [0, 1, 2]
            if "dens" in self.simFields or "velZ" in self.simFields:
                filterList += [2] if self.dim == 2 else [3]
            if "pres" in self.simFields:
                filterList += [3] if self.dim == 2 else [4]
            if "rey" in self.simParams:
                filterList += [4] if self.dim == 2 else [5]
            if "mach" in self.simParams:
                filterList += [5] if self.dim == 2 else [6]
            if "zslice" in self.simParams:
                filterList += [6] if self.dim == 2 else [7]
            filterArr = np.array(filterList)
            filterArrParam = filterArr[-len(self.simParams):]

            if self.simParams:
                meanParam = self.normMean[filterArrParam].reshape((1,-1))
                stdParam = self.normStd[filterArrParam].reshape((1,-1))
                simParameters = (simParameters - meanParam) / stdParam

            meanData = self.normMean[filterArr].reshape((1,-1,1,1)) if self.dim == 2 else self.normMean[filterArr].reshape((1,-1,1,1,1))
            stdData = self.normStd[filterArr].reshape((1,-1,1,1)) if self.dim == 2 else self.normStd[filterArr].reshape((1,-1,1,1,1))
            if self.dim == 2:
                data = (data - meanData) / stdData
            elif self.dim == 3:
                data = (data - meanData) / stdData

        # random flip
        if self.flip:
            if self.dim == 2:
                rand = self.randGen.rand(2) > 0.5
                flipped = False
                if rand[0]:
                    data = np.flip(data, axis=2)
                    flipped = True
                if rand[1]:
                    data = np.flip(data, axis=3)
                    flipped = True
                if flipped:
                    data = data.copy() #prevent negative strides that has issues with torch tensor creation
            if self.dim == 3:
                raise NotImplementedError("Flip augmentation not supported for 3D yet!")

        # random crop
        if self.crop:
            if self.dim == 2:
                s = self.outputSize
                if data.shape[2] > s[0] or data.shape[3] > s[1]:
                    c1 = self.randGen.randint(0, data.shape[2] - s[0]+1)
                    c2 = self.randGen.randint(0, data.shape[3] - s[1]+1)
                    data = data[..., c1:c1+s[0], c2:c2+s[1]]
            if self.dim == 3:
                raise NotImplementedError("Crop augmentation not supported for 3D yet!")

        # toTensor
        result = torch.from_numpy(data)
        if obsMask is not None:
            obsMask = torch.from_numpy(obsMask)

        # resize
        if self.resize:
            result = F.interpolate(result, self.outputSize, mode="bilinear", align_corners=True)
            if obsMask is not None:
                obsMask = F.interpolate(obsMask, self.outputSize, mode="nearest", align_corners=True)

        outDict = {"data": result, "simParameters": simParameters, "allParameters": allParameters, "path": path}
        if obsMask is not None:
            outDict["obsMask"] = obsMask
        return outDict
    
    
import torch
from torch.utils.data import Dataset

import numpy as np
import os, json
import logging
from typing import List,Tuple

# from turbpred.data_transformations import Transforms


class TurbulenceDataset(Dataset):
    """Data set for turbulence and wavelet noise data

    Args:
        name: name of the dataset
        dataDirs: list of paths to data directories
        filterTop: filter for top level folder names (e.g. different types of data)
        excludeFilterTop: mode for filterTop (exclude or include)
        filterSim: filter simulations by min and max (min inclusive, max exclusive)
        excludefilterSim: mode for filterSim (exclude or include)
        filterFrame: mandatory filter for simulation frames by min and max (min inclusive, max exclusive)
        sequenceLength: number of frames to group into a sequence and number of frames to omit in between
        randSeqOffset: randomizes the starting frame of each sequence
        simFields: list of simulation fields to include (vel is always included) ["dens", "pres"]
        simParams: list of simulation parameters to include ["rey", "mach"]
        printLevel: print mode for contents of the dataset ["none", "top", "sim", "full"]
        logLevel: log mode for contents of the dataset ["none", "top", "sim", "full"]
    """
    transform: Transforms
    name:str
    dataDirs:List[str]
    filterTop:List[str]
    excludeFilterTop:bool
    filterSim:List[Tuple[int, int]]
    excludefilterSim:bool
    filterFrame:List[Tuple[int, int]]
    sequenceLength:List[Tuple[int, int]]
    randSeqOffset:bool
    simFields:List[str]
    simParams:List[str]
    printLevel:str="none"
    logLevel:str="sim"

    def __init__(self, name:str, dataDirs:List[str], filterTop:List[str], excludeFilterTop:bool=False, filterSim:List[Tuple[int, int]]=[],
                excludefilterSim:bool=False, filterFrame:List[Tuple[int, int]]=[], sequenceLength:List[Tuple[int, int]]=[],
                randSeqOffset:bool=False, simFields:List[str]=[], simParams:List[str]=[], printLevel:str="none", logLevel:str="sim"):

        assert (len(filterSim) in [0,1,len(filterTop)]), "Sim filter is not set up correctly. Use len=0 for all; len=1 for the same everywhere, len=len(filterTop) to adjust for each top filter"
        assert (len(filterFrame) in [1,len(filterTop)]), "Frame filter is not set up correctly. Use len=1 for the same everywhere, len=len(filterTop) to adjust for each top filter"
        assert (len(sequenceLength) == len(filterFrame)), "Sequence length is not set up correctly, it should match the frame filter."
        if excludeFilterTop:
            assert (len(filterSim) <= 1), "Excluded top filter and adjust sim filtering is not supported!"
            assert (len(filterFrame) <= 1), "Excluded top filter and adjust frame filtering is not supported!"
        assert (printLevel in ["none", "top", "sim", "full"]), "Invalid print level!"

        self.transform = None
        self.name = name
        self.dataDirs = dataDirs
        self.filterTop = filterTop
        self.excludeFilterTop = excludeFilterTop
        self.filterSim = filterSim
        self.excludefilterSim = excludefilterSim
        self.filterFrame = filterFrame
        self.sequenceLength = sequenceLength
        self.randSeqOffset = randSeqOffset
        self.simFields = ["velocity"]
        if "velZ" in simFields:
            self.simFields += ["velocityZ"]
        if "dens" in simFields:
            self.simFields += ["density"]
        if "pres" in simFields:
            self.simFields += ["pressure"]

        self.simParams = simParams
        self.printLevel = printLevel
        self.logLevel = logLevel

        self.summaryPrint = []
        self.summaryLog = []
        self.summaryPrint += ["Dataset " + name + " at " + str(dataDirs)]
        self.summaryLog   += ["Dataset " + name + " at " + str(dataDirs)]
        self.summaryPrint += [self.getFilterInfoString()]
        self.summaryLog   += [self.getFilterInfoString()]

        # BUILD FULL FILE LIST
        self.dataPaths = []
        self.dataPathModes = []

        for dataDir in dataDirs:
            topDirs = os.listdir(dataDir)
            topDirs.sort()

            # top level folders
            for topDir in topDirs:
                if filterTop:
                    # continue when excluding or including according to filter
                    if excludeFilterTop == any( item in topDir for item in filterTop ):
                        continue

                match = -1
                # compute matching top filter for according sim or frame filtering
                if len(filterSim) > 1 or len(filterFrame) > 1:
                    for i in range(len(filterTop)):
                        if filterTop[i] in topDir:
                            match = i
                            break
                    assert (match >= 0), "Match computation error"

                simDir = os.path.join(dataDir, topDir)
                sims = os.listdir(simDir)
                sims.sort()

                if printLevel == "top":
                    self.summaryPrint += ["Top folder loaded: " + simDir.replace(dataDir + "/", "")]
                if logLevel == "top":
                    self.summaryLog   += ["Top folder loaded: " + simDir.replace(dataDir + "/", "")]

                # sim_000001 folders
                for sim in sims:
                    currentDir = os.path.join(simDir, sim)
                    if not os.path.isdir(currentDir):
                        continue

                    if len(filterSim) > 0:
                        simNum = int(sim.split("_")[1])
                        if len(filterSim) == 1:
                            if type(filterSim[0]) is tuple:
                                inside = simNum >= filterSim[0][0] and simNum < filterSim[0][1]
                            elif type(filterSim[0]) is list:
                                inside = simNum in filterSim[0]
                        else:
                            if type(filterSim[match]) is tuple:
                                inside = simNum >= filterSim[match][0] and simNum < filterSim[match][1]
                            elif type(filterSim[match]) is list:
                                inside = simNum in filterSim[match]
                        # continue when excluding or including according to filter
                        if inside == excludefilterSim:
                            continue

                    if printLevel == "sim":
                        self.summaryPrint += ["Sim loaded: " + currentDir.replace(dataDir + "/", "")]
                    if logLevel == "sim":
                        self.summaryLog   += ["Sim loaded: " + currentDir.replace(dataDir + "/", "")]

                    # individual simulation frames
                    minFrame = filterFrame[0][0] if len(filterFrame) == 1 else filterFrame[match][0]
                    maxFrame = filterFrame[0][1] if len(filterFrame) == 1 else filterFrame[match][1]
                    seqLength = sequenceLength[0][0] if len(sequenceLength) == 1 else sequenceLength[match][0]
                    seqSkip   = sequenceLength[0][1] if len(sequenceLength) == 1 else sequenceLength[match][1]
                    for seqStart in range(minFrame, maxFrame, seqLength*seqSkip):
                        validSeq = True
                        for frame in range(seqStart, seqStart+seqLength*seqSkip, seqSkip):
                            # discard incomplete sequences at simulation end
                            if seqStart+seqLength*seqSkip > maxFrame:
                                validSeq = False
                                break

                            for field in self.simFields:
                                currentField = os.path.join(currentDir, "%s_%06d.npz" % (field, frame))
                                if not os.path.isfile(currentField):
                                    raise FileNotFoundError("Could not load %s file: %s" % (field, currentField))

                        # imcomplete sequence means there are no more frames left
                        if not validSeq:
                            break

                        if printLevel == "full":
                            self.summaryPrint += ["Frames %s loaded: %s/%s_%06d-%06d(%03d).npz" % ("-".join(self.simFields),
                                        currentDir.replace(dataDir + "/", ""), "-".join(self.simFields), seqStart, seqStart + seqLength*(seqSkip-1), seqSkip)]
                        if logLevel == "full":
                            self.summaryLog   += ["Frames %s loaded: %s/%s_%06d-%06d(%03d).npz" % ("-".join(self.simFields),
                                        currentDir.replace(dataDir + "/", ""), "-".join(self.simFields), seqStart, seqStart + seqLength*(seqSkip-1), seqSkip)]

                        self.dataPaths.append((currentDir, seqStart, seqStart + seqLength*seqSkip, seqSkip))

        self.summaryPrint += ["Dataset Length: %d\n" % len(self.dataPaths)]
        self.summaryLog   += ["Dataset Length: %d\n" % len(self.dataPaths)]


    def __len__(self) -> int:
        return len(self.dataPaths)


    def __getitem__(self, idx:int) -> dict:
        # sequence indexing
        basePath, seqStart, seqEnd, seqSkip = self.dataPaths[idx]
        seqLen = int((seqEnd - seqStart) / seqSkip)
        if self.randSeqOffset:
            halfSeq = int((seqEnd-seqStart) / 2)
            offset = torch.randint(-halfSeq, halfSeq+1, (1,)).item()
            if seqStart + offset >= self.filterFrame[0][0] and seqEnd + offset < self.filterFrame[0][1]:
                seqStart = seqStart + offset
                seqEnd = seqEnd + offset

        # loading simulation parameters
        with open(os.path.join(basePath, "src", "description.json")) as f:
            loadedJSON = json.load(f)

            loadNames = ["Reynolds Number", "Mach Number", "Drag Coefficient", "Lift Coefficient", "Z Slice"]
            loadedParams = {}
            for loadName in loadNames:
                loadedParam = np.zeros(seqLen, dtype=np.float32)
                if loadName in loadedJSON:
                    temp = loadedJSON[loadName]
                    if isinstance(temp, int) or isinstance(temp, float):
                        temp = np.array(temp, dtype=np.float32)
                        loadedParam[0:] = np.repeat(temp, seqLen)
                    elif isinstance(temp, list):
                        loadedParam[0:] = temp[seqStart:seqEnd:seqSkip]
                    else:
                        raise ValueError("Invalid simulation parameter data type")
                loadedParams[loadName] = loadedParam

            if "rey" in self.simParams and "mach" in self.simParams:
                simParameters = np.stack([loadedParams["Reynolds Number"], loadedParams["Mach Number"]], axis=1)
            elif "rey" in self.simParams:
                simParameters = np.reshape(loadedParams["Reynolds Number"], (-1,1))
            elif "mach" in self.simParams:
                simParameters = np.reshape(loadedParams["Mach Number"], (-1,1))
            elif "zslice" in self.simParams:
                simParameters = np.reshape(loadedParams["Z Slice"], (-1,1))
            elif not self.simParams:
                simParameters ={}
            else:
                raise ValueError("Invalid specification of simulation parameters")

        # loading obstacle mask
        if os.path.isfile(os.path.join(basePath, "obstacle_mask.npz")):
            obsMask = np.load(os.path.join(basePath, "obstacle_mask.npz"))['arr_0']
        else:
            obsMask = None

        # loading fields and combining them with simulation parameters
        loaded = {}
        for field in self.simFields:
            loaded[field] = []

        for frame in range(seqStart, seqEnd, seqSkip):
            for field in self.simFields:
                loadedArr = np.load(os.path.join(basePath, "%s_%06d.npz" % (field,frame)))['arr_0']
                loaded[field] += [loadedArr.astype(np.float32)]

        loadedFields = []
        for field in self.simFields:
            loadedFields += [np.stack(loaded[field], axis=0)]

        if type(simParameters) is not dict:
            vel = loadedFields[0]
            if vel.ndim == 4:
                simParExpanded = simParameters[:,:,np.newaxis,np.newaxis]
                simParExpanded = np.repeat(np.repeat(simParExpanded, vel.shape[2], axis=2), vel.shape[3], axis=3)
            elif vel.ndim == 5:
                simParExpanded = simParameters[:,:,np.newaxis,np.newaxis,np.newaxis]
                simParExpanded = np.repeat(np.repeat(np.repeat(simParExpanded, vel.shape[2], axis=2), vel.shape[3], axis=3), vel.shape[4], axis=4)
            else:
                raise ValueError("Invalid input shape when loading samples!")
            loadedFields += [simParExpanded]

        data = np.concatenate(loadedFields, axis=1) # ORDER (fields): velocity (x,y), velocity z / density, pressure, ORDER (params): rey, mach, zslice

        # output
        dataPath = "%s/%s_%06d-%06d(%03d).npz" % (basePath, "-".join(self.simFields), seqStart, seqEnd - seqSkip, seqSkip)
        sample = {"data" : data, "simParameters" : simParameters, "allParameters" : loadedParams, "path" : dataPath}
        if obsMask is not None:
            sample["obsMask"] = obsMask

        if self.transform:
            sample = self.transform(sample)
        else:
            print("WARNING: no data transformations are employed!")

        return sample


    def printDatasetInfo(self):
        if self.transform:
            s  = "%s - Data Augmentations: %s\n" % (self.name, len(self.transform.p_d.augmentations))
            s += "\tactivate augmentations: [%s]\n" % ", ".join(self.transform.p_d.augmentations)
            if self.transform.crop:
                s += "\tcrop settings: outputSize (%d, %d)\n" % (self.transform.outputSize[0], self.transform.outputSize[1])
            if self.transform.resize:
                s += "\tresize settings: outputSize (%d, %d)\n" % (self.transform.outputSize[0], self.transform.outputSize[1])

            self.summaryPrint += [s]
            self.summaryLog   += [s]

        print('\n'.join(self.summaryPrint))
        logging.info('\n'.join(self.summaryLog))

    def getFilterInfoString(self) -> str:
        s  = "%s - Data Filter Setup: \n" % (self.name)
        s += "\tdataDirs: %s\n" % (str(self.dataDirs))
        s += "\tfilterTop: %s  exlude: %s\n" % (str(self.filterTop), self.excludeFilterTop)
        s += "\tfilterSim: %s  exlude: %s\n" % (str(self.filterSim), self.excludefilterSim)
        s += "\tfilterFrame: %s\n" % (str(self.filterFrame))
        s += "\tsequenceLength: %s\n" % (str(self.sequenceLength))
        return s

In [8]:
data_path = "/home/nkhaous/myLLF/JiT_SysId/data/128_inc"
p_d = DataParams(batch=64, augmentations=["normalize"], sequenceLength=[10,2], randSeqOffset=True,
            dataSize=[128,64], dimension=2, simFields=["pres"], simParams=["rey"], normalizeMode="incMixed")

trainSet = TurbulenceDataset("Training", [data_path], filterTop=["128_inc"], filterSim=[(10,81)], filterFrame=[(800,1300)],
                sequenceLength=[p_d.sequenceLength], randSeqOffset=p_d.randSeqOffset, simFields=p_d.simFields, simParams=p_d.simParams, printLevel="sim")

transTrain = Transforms(p_d)
trainSet.transform = transTrain
# trainSet.printDatasetInfo()

trainSampler = RandomSampler(trainSet)
#trainSampler = SubsetRandomSampler(range(2))
trainLoader = DataLoader(trainSet, sampler=trainSampler,
                batch_size=p_d.batch, drop_last=True, num_workers=4)

In [9]:
next(iter(trainLoader)).keys()

dict_keys(['data', 'simParameters', 'allParameters', 'path', 'obsMask'])

In [10]:
next(iter(trainLoader))['data'].shape

torch.Size([64, 10, 4, 128, 64])

In [12]:
testSets = {
    "lowRey":
        TurbulenceDataset("Test Low Reynolds 100-200", [data_path], filterTop=["128_inc"], filterSim=[[82,84,86,88,90]],
                filterFrame=[(1000,1150)], sequenceLength=[[60,2]], simFields=p_d.simFields, simParams=p_d.simParams, printLevel="sim"),
    "highRey" :
        TurbulenceDataset("Test High Reynolds 900-1000", [data_path], filterTop=["128_inc"], filterSim=[[0,2,4,6,8]],
                filterFrame=[(1000,1150)], sequenceLength=[[60,2]], simFields=p_d.simFields, simParams=p_d.simParams, printLevel="sim"),
    "varReyIn" :
        TurbulenceDataset("Test Varying Reynolds Number (200-900)", [data_path], filterTop=["128_reyVar"], filterSim=[[0]],
                filterFrame=[(300,800)], sequenceLength=[[250,2]], simFields=p_d.simFields, simParams=p_d.simParams, printLevel="sim"),
}


test_loaders = []
for shortName, testSet in testSets.items():
    p_d_test = copy.deepcopy(p_d)
    p_d_test.augmentations = ["normalize"]
    p_d_test.sequenceLength = testSet.sequenceLength
    p_d_test.randSeqOffset = False
    if p_d.sequenceLength[0] != p_d_test.sequenceLength[0]:
        p_d_test.batch = 1

    transTest = Transforms(p_d_test)
    testSet.transform = transTest
    # testSet.printDatasetInfo()
    testSampler = SequentialSampler(testSet)
    #testSampler = SubsetRandomSampler(range(2))
    testLoader = DataLoader(testSet, sampler=testSampler,
                    batch_size=p_d_test.batch, drop_last=False, num_workers=4)
    test_loaders.append(testLoader)

## Model

In [13]:
import numpy as np

def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w):
    """
    Generate 2D sin-cos positional embedding pour une grille rectangulaire.
    grid_size_h: nombre de patchs en hauteur (ex: 128/16 = 8)
    grid_size_w: nombre de patchs en largeur (ex: 64/16 = 4)
    """
    # Création des axes avec leurs tailles respectives
    grid_h = np.arange(grid_size_h, dtype=np.float32)
    grid_w = np.arange(grid_size_w, dtype=np.float32)
    
    # Génération de la grille (meshgrid)
    grid = np.meshgrid(grid_w, grid_h) # Attention à l'ordre (w, h)
    grid = np.stack(grid, axis=0)
    
    # On reformate pour l'envoyer à la fonction suivante
    grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
    
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    return pos_embed

def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0
    
    # grid[0] est l'axe horizontal (w), grid[1] est l'axe vertical (h)
    # On utilise la moitié des dimensions pour chaque axe
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
    
    # On concatène pour avoir le vecteur complet (H*W, D)
    emb = np.concatenate([emb_h, emb_w], axis=1)
    return emb

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega
    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        norm = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / norm * self.weight

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class BottleneckPatchEmbed(nn.Module):
    def __init__(self, img_size=(128, 64), patch_size=16, in_chans=1, pca_dim=128, embed_dim=384, bias=True):
        super().__init__()
        # On s'assure que img_size est un tuple (Hauteur, Largeur)
        self.img_size = img_size 
        self.patch_size = (patch_size, patch_size)
        
        # Calcul de la grille de patchs pour un format rectangulaire
        # Pour 128x64 et patch=16 -> grid_size = (8, 4)
        self.grid_size = (img_size[0] // patch_size, img_size[1] // patch_size)
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        
        # La projection reste identique dans son principe :
        # Elle transforme chaque patch de pixels en un vecteur de dimension pca_dim
        self.proj1 = nn.Conv2d(in_chans, pca_dim, kernel_size=patch_size, stride=patch_size, bias=False)
        
        # Elle projette ensuite ce vecteur vers la dimension du Transformer (embed_dim)
        self.proj2 = nn.Conv2d(pca_dim, embed_dim, kernel_size=1, stride=1, bias=bias)
    
    def forward(self, x):
        # x shape attendu : [Batch, Channels, H, W]
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"La taille de l'image en entrée ({H}x{W}) ne correspond pas à la taille configurée ({self.img_size[0]}x{self.img_size[1]})"
        
        # 1. Convolution : extrait les patchs et les projette
        # [B, in_chans, 128, 64] -> [B, embed_dim, 8, 4]
        x = self.proj1(x)
        x = self.proj2(x)
        
        # 2. Flatten : on aplatit la grille spatiale en une séquence
        # [B, embed_dim, 8, 4] -> [B, embed_dim, 32]
        x = x.flatten(2)
        
        # 3. Transpose : on met la dimension d'embedding à la fin (standard Transformer)
        # [B, embed_dim, 32] -> [B, 32, embed_dim]
        x = x.transpose(1, 2)
        return x
    
    
class TimestepEmbedder(nn.Module):
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size
    
    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
        )
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    
    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb
    

class ControlEmbedder(nn.Module):
    """Embed control signals (u_past and u_curr) into hidden dimension"""
    def __init__(self, past_window, hidden_size):
        super().__init__()
        # Embed past_window + 1 (past + current) control values
        self.mlp = nn.Sequential(
            nn.Linear(past_window + 1, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
        )
    
    def forward(self, u_past, u_curr):
        # u_past: (B, past_window, 1), u_curr: (B, 1)
        u_combined = torch.cat([u_past.squeeze(-1), u_curr], dim=1)  # (B, past_window + 1)
        return self.mlp(u_combined)
    
    
def scaled_dot_product_attention(query, key, value, dropout_p=0.0):
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1))
    attn_bias = torch.zeros(query.size(0), 1, L, S, dtype=query.dtype, device=query.device)
    
    with torch.cuda.amp.autocast(enabled=False):
        attn_weight = query.float() @ key.float().transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value
    
    
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, qk_norm=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        
        self.q_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
        self.k_norm = RMSNorm(head_dim) if qk_norm else nn.Identity()
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        q = self.q_norm(q)
        k = self.k_norm(k)
        
        x = scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
        
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class SwiGLUFFN(nn.Module):
    def __init__(self, dim, hidden_dim, drop=0.0, bias=True):
        super().__init__()
        hidden_dim = int(hidden_dim * 2 / 3)
        self.w12 = nn.Linear(dim, 2 * hidden_dim, bias=bias)
        self.w3 = nn.Linear(hidden_dim, dim, bias=bias)
        self.ffn_dropout = nn.Dropout(drop)
    
    def forward(self, x):
        x12 = self.w12(x)
        x1, x2 = x12.chunk(2, dim=-1)
        hidden = F.silu(x1) * x2
        return self.w3(self.ffn_dropout(hidden))
    

class FinalLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = RMSNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )
    
    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x
    

class JiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.norm1 = RMSNorm(hidden_size, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True,
                              attn_drop=attn_drop, proj_drop=proj_drop)
        self.norm2 = RMSNorm(hidden_size, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.mlp = SwiGLUFFN(hidden_size, mlp_hidden_dim, drop=proj_drop)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )
    
    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x
    
    
    
    
class JitFluid(nn.Module):
    def __init__(
        self,
        img_size=(128, 64),
        patch_size=16,
        in_channels=4,        # CHANGÉ : 4 canaux en entrée (VelX, VelY, Pres, Rey)
        out_channels=3,       # CHANGÉ : 3 canaux en sortie (On ne prédit que la physique)
        past_window=9,
        hidden_size=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4.0,
        bottleneck_dim=64,
        diffusion_steps=50,
        P_mean=-0.8,
        P_std=0.8,
        t_eps=0.05,
        noise_scale=1.0
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels   # 4
        self.out_channels = out_channels # 3
        self.hidden_size = hidden_size
        self.past_window = past_window
        self.timesteps = diffusion_steps
        self.P_mean = P_mean
        self.P_std = P_std
        self.t_eps = t_eps
        self.noise_scale = noise_scale
        
        # 1. Embedders
        self.t_embedder = TimestepEmbedder(hidden_size)
        # On garde le ControlEmbedder pour la valeur scalaire du Reynolds
        self.control_embedder = ControlEmbedder(past_window, hidden_size)
        
        # 2. Patch embedding pour la target frame bruitée (3 canaux physiques)
        # Note : On ne bruite que la physique, pas le Reynolds
        self.x_embedder = BottleneckPatchEmbed(
            img_size, patch_size, out_channels, bottleneck_dim, hidden_size, bias=True
        )
        
        # 3. Patch embedding pour le passé (4 canaux par frame)
        # 9 frames * 4 canaux = 36 canaux en entrée
        self.cond_embedder = BottleneckPatchEmbed(
            img_size, patch_size, past_window * in_channels, bottleneck_dim, hidden_size, bias=True
        )
        
        # Positional embedding (Grille 8x4)
        num_patches = self.x_embedder.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            JiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
            for _ in range(depth)
        ])
        
        # Final layer (Sortie vers 3 canaux physiques)
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        
        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)
        
        # Calcul de la grille rectangulaire pour l'embedding de position
        grid_h = self.img_size[0] // self.patch_size # 128 / 16 = 8
        grid_w = self.img_size[1] // self.patch_size # 64 / 16 = 4
        
        # Appel de la fonction rectangulaire
        pos_embed = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1], 
            grid_h, 
            grid_w
        )
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        
        # Initialisation des patch embeddings
        for embedder in [self.x_embedder, self.cond_embedder]:
            w1 = embedder.proj1.weight.data
            nn.init.xavier_uniform_(w1.view([w1.shape[0], -1]))
            w2 = embedder.proj2.weight.data
            nn.init.xavier_uniform_(w2.view([w2.shape[0], -1]))
            nn.init.constant_(embedder.proj2.bias, 0)
        
        # Zero-out adaLN et output (pour démarrer comme une fonction identité/neutre)
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
        
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)
    
    def unpatchify(self, x):
        """
        x: (N, num_patches, patch_size**2 * C) -> imgs: (N, C, H, W)
        N = Batch size
        num_patches = 32 (grille 8x4)
        C = out_channels (3 : VelX, VelY, Pres)
        """
        p = self.patch_size
        c = self.out_channels
        
        # CHANGÉ : On ne fait plus de racine carrée. 
        # On calcule h et w à partir de la taille de l'image et du patch.
        h = self.img_size[0] // p  # 128 // 16 = 8 patchs en hauteur
        w = self.img_size[1] // p  # 64 // 16 = 4 patchs en largeur
        
        # Vérification de sécurité pour s'assurer que le nombre de tokens 
        # correspond bien à notre grille 8x4
        assert h * w == x.shape[1], f"Le nombre de patchs ({x.shape[1]}) ne correspond pas à la grille {h}x{w}"

        # 1. Reshape : On sépare les pixels de chaque patch et on recrée la grille spatiale
        # [N, 32, 768] -> [N, 8, 4, 16, 16, 3]
        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        
        # 2. Permutation (Einsum) : C'est l'étape magique.
        # On déplace les dimensions pour que les canaux (c) soient en premier (N, C, ...)
        # et que les pixels des patchs (p, q) soient placés à côté de leurs indices de grille (h, w).
        # 'nhwpqc' -> 'nchpwq'
        # n: batch, h: grille_H, w: grille_W, p: patch_H, q: patch_W, c: channels
        x = torch.einsum('nhwpqc->nchpwq', x)
        
        # 3. Final Reshape : On fusionne la grille et les pixels des patchs pour recréer l'image complète.
        # [N, 3, 8, 16, 4, 16] -> [N, 3, 128, 64]
        imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
        
        return imgs
    
    def sample_t(self, n, device):
        """Sample timesteps from logit-normal distribution"""
        z = torch.randn(n, device=device) * self.P_std + self.P_mean
        return torch.sigmoid(z)
    
    def forward(self, cond_frames, cond_u_past, cond_u_curr, target_frame):
        B = target_frame.size(0)
        device = target_frame.device
        
        # 1. Échantillonnage du temps et ajout de bruit (Diffusion)
        t = self.sample_t(B, device).view(-1, *([1] * (target_frame.ndim - 1)))
        t_flat = t.flatten()
        
        e = torch.randn_like(target_frame) * self.noise_scale
        z_t = t * target_frame + (1 - t) * e
        
        # Vélocité cible pour la v-loss (plus stable pour les fluides)
        v = (target_frame - z_t) / (1 - t).clamp_min(self.t_eps)
        
        # 2. Embedding du conditionnement passé (4 canaux)
        # CHANGÉ : On passe de [B, 9, 4, 128, 64] à [B, 36, 128, 64]
        cond_frames_reshaped = cond_frames.flatten(1, 2) 
        cond_tokens = self.cond_embedder(cond_frames_reshaped) 
        
        # 3. Embedding de la cible bruitée (3 canaux)
        x_tokens = self.x_embedder(z_t) 
        
        # Fusion tokens + positions
        tokens = x_tokens + cond_tokens + self.pos_embed
        
        # 4. Temps et contrôle (Reynolds)
        t_emb = self.t_embedder(t_flat)
        control_emb = self.control_embedder(cond_u_past, cond_u_curr)
        c = t_emb + control_emb
        
        # 5. Transformer Blocks
        for block in self.blocks:
            tokens = block(tokens, c)
        
        # 6. Sortie et Reconstruction
        x_pred_tokens = self.final_layer(tokens, c)
        x_pred = self.unpatchify(x_pred_tokens)
        
        # Calcul de la v-pred pour la loss
        v_pred = (x_pred - z_t) / (1 - t).clamp_min(self.t_eps)
        loss = F.smooth_l1_loss(v_pred, v)
        
        return loss
    
    @torch.no_grad()
    def sample(self, cond_frames, cond_u_past, cond_u_curr, num_steps=50, method='heun'):
        B = cond_frames.size(0)
        device = cond_frames.device
        H, W = self.img_size # (128, 64)
        
        # CHANGÉ : On commence avec un bruit à 3 canaux (Physique)
        z = self.noise_scale * torch.randn(B, 3, H, W, device=device)
        
        timesteps = torch.linspace(0.0, 1.0, num_steps + 1, device=device)
        
        # Pré-calcul de l'embedding du passé (fixe pendant le sampling d'une frame)
        cond_frames_reshaped = cond_frames.flatten(1, 2)
        cond_tokens = self.cond_embedder(cond_frames_reshaped)
        cond_tokens = cond_tokens + self.pos_embed
        
        control_emb = self.control_embedder(cond_u_past, cond_u_curr)
        
        for i in range(num_steps):
            t_curr = timesteps[i]
            t_next = timesteps[i + 1]
            
            if method == 'euler':
                z = self._euler_step(z, t_curr, t_next, cond_tokens, control_emb)
            elif method == 'heun':
                z = self._heun_step(z, t_curr, t_next, cond_tokens, control_emb)
        
        return z
    
    def _forward_sample(self, z, t, cond_tokens, control_emb):
        B = z.size(0)
        t_scalar = t.expand(B)
        
        # On embed l'échantillon bruité actuel
        x_tokens = self.x_embedder(z)
        tokens = x_tokens + cond_tokens # cond_tokens inclut déjà pos_embed
        
        t_emb = self.t_embedder(t_scalar)
        c = t_emb + control_emb
        
        for block in self.blocks:
            tokens = block(tokens, c)
        
        # Prédiction de x (la frame propre)
        x_pred_tokens = self.final_layer(tokens, c)
        x_pred = self.unpatchify(x_pred_tokens)
        
        # Conversion en vélocité pour le pas d'intégration
        t_broadcast = t.view(-1, *([1] * (z.ndim - 1)))
        v_pred = (x_pred - z) / (1.0 - t_broadcast).clamp_min(self.t_eps)
        
        return v_pred
    
    def _euler_step(self, z, t_curr, t_next, cond_tokens, control_emb):
        v_pred = self._forward_sample(z, t_curr, cond_tokens, control_emb)
        z_next = z + (t_next - t_curr) * v_pred
        return z_next
    
    def _heun_step(self, z, t_curr, t_next, cond_tokens, control_emb):
        # First Euler step
        v_pred_curr = self._forward_sample(z, t_curr, cond_tokens, control_emb)
        z_euler = z + (t_next - t_curr) * v_pred_curr
        
        # Second evaluation
        v_pred_next = self._forward_sample(z_euler, t_next, cond_tokens, control_emb)
        
        # Heun average
        v_pred = 0.5 * (v_pred_curr + v_pred_next)
        z_next = z + (t_next - t_curr) * v_pred
        return z_next

In [14]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

def train_jit_diffusion(model, train_loader, val_loader, num_epochs, device, lr=1e-4):
    # 1. Configuration de l'optimiseur (AdamW est standard pour les Transformers)
    optimizer = optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.01)
    
    # 2. Scheduler : Décroissance cosinusoïdale (très efficace pour la diffusion)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        # Barre de progression
        pbar = tqdm(train_loader, desc=f"Époque {epoch+1}/{num_epochs}")
        
        for batch in pbar:
            # Extraction des données du dictionnaire TurbulenceDataset
            # data shape: [64, 10, 4, 128, 64]
            # params shape: [64, 10, 1]
            data = batch['data'].to(device)
            params = batch['simParameters'].to(device)

            # --- DÉCOUPAGE PHYSIQUE ---
            # Conditionnement : les 9 premières frames avec les 4 canaux (V, P, Re)
            cond_frames = data[:, :9, :, :, :] 
            
            # Cible : la 10ème frame, on ne prédit que les 3 canaux physiques (VelX, VelY, Pres)
            target_frame = data[:, 9, :3, :, :]
            
            # --- DÉCOUPAGE DES PARAMÈTRES (Reynolds) ---
            cond_u_past = params[:, :9, :]
            cond_u_curr = params[:, 9, :]

            # Forward pass
            optimizer.zero_grad()
            loss = model(cond_frames, cond_u_past, cond_u_curr, target_frame)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping (essentiel pour la stabilité des Transformers)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            pbar.set_postfix({"Loss": f"{loss.item():.6f}"})

        avg_train_loss = train_loss / len(train_loader)
        scheduler.step()

        # 3. Validation (si un val_loader est fourni)
        val_loss = 0.0
        if val_loader:
            model.eval()
            with torch.no_grad():
                for batch in val_loader:
                    data = batch['data'].to(device)
                    params = batch['simParameters'].to(device)
                    
                    cond_frames = data[:, :9, :, :, :]
                    target_frame = data[:, 9, :3, :, :]
                    cond_u_past = params[:, :9, :]
                    cond_u_curr = params[:, 9, :]
                    
                    loss = model(cond_frames, cond_u_past, cond_u_curr, target_frame)
                    val_loss += loss.item()
            
            avg_val_loss = val_loss / len(val_loader)
            print(f"\nÉpoque {epoch+1} terminée | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")
            
            # Sauvegarde du meilleur modèle
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), "best_jit_fluid.pth")
                print("✓ Modèle sauvegardé !")
        else:
            print(f"\nÉpoque {epoch+1} terminée | Train Loss: {avg_train_loss:.6f}")

    print("\n✅ Entraînement terminé !")


In [17]:
# --- PARAMÈTRES DU MODÈLE ---
# img_size: (Hauteur, Largeur) de tes simulations
# patch_size: 16 donne une grille de 8x4 patchs
# in_channels: 4 (VelX, VelY, Pres, Reynolds)
# out_channels: 3 (On ne prédit que VelX, VelY, Pres)
# past_window: 9 (On utilise les 9 premières frames pour prédire la 10ème)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Sélectionne le premier GPU
device = torch.device("cuda:0")



model = JitFluid(
    img_size=(128, 64),
    patch_size=16,
    in_channels=4,
    out_channels=3,
    past_window=9,
    hidden_size=384,
    depth=12,
    num_heads=6,
    diffusion_steps=50
).to(device)

# --- VÉRIFICATION ---
print(f"Modèle JitFluid initialisé sur {device}")
print(f"Nombre de paramètres : {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Modèle JitFluid initialisé sur cuda:0
Nombre de paramètres : 33,615,744


In [20]:
train_jit_diffusion(model, trainLoader, None, num_epochs=20, device=device)

  with torch.cuda.amp.autocast(enabled=False):
Époque 1/20: 100%|██████████| 27/27 [00:07<00:00,  3.65it/s, Loss=0.046643]



Époque 1 terminée | Train Loss: 0.082989


Époque 2/20: 100%|██████████| 27/27 [00:07<00:00,  3.67it/s, Loss=0.021637]



Époque 2 terminée | Train Loss: 0.034581


Époque 3/20: 100%|██████████| 27/27 [00:07<00:00,  3.81it/s, Loss=0.009531]



Époque 3 terminée | Train Loss: 0.014956


Époque 4/20: 100%|██████████| 27/27 [00:07<00:00,  3.76it/s, Loss=0.007716]



Époque 4 terminée | Train Loss: 0.009037


Époque 5/20: 100%|██████████| 27/27 [00:07<00:00,  3.74it/s, Loss=0.006709]



Époque 5 terminée | Train Loss: 0.006890


Époque 6/20: 100%|██████████| 27/27 [00:07<00:00,  3.69it/s, Loss=0.004607]



Époque 6 terminée | Train Loss: 0.005103


Époque 7/20: 100%|██████████| 27/27 [00:07<00:00,  3.69it/s, Loss=0.003983]



Époque 7 terminée | Train Loss: 0.004309


Époque 8/20: 100%|██████████| 27/27 [00:07<00:00,  3.62it/s, Loss=0.002761]



Époque 8 terminée | Train Loss: 0.003571


Époque 9/20: 100%|██████████| 27/27 [00:07<00:00,  3.75it/s, Loss=0.002864]



Époque 9 terminée | Train Loss: 0.003020


Époque 10/20: 100%|██████████| 27/27 [00:07<00:00,  3.78it/s, Loss=0.002437]



Époque 10 terminée | Train Loss: 0.002567


Époque 11/20: 100%|██████████| 27/27 [00:07<00:00,  3.71it/s, Loss=0.002117]



Époque 11 terminée | Train Loss: 0.002221


Époque 12/20: 100%|██████████| 27/27 [00:07<00:00,  3.81it/s, Loss=0.001981]



Époque 12 terminée | Train Loss: 0.002158


Époque 13/20: 100%|██████████| 27/27 [00:07<00:00,  3.74it/s, Loss=0.001552]



Époque 13 terminée | Train Loss: 0.001861


Époque 14/20: 100%|██████████| 27/27 [00:07<00:00,  3.65it/s, Loss=0.001558]



Époque 14 terminée | Train Loss: 0.001714


Époque 15/20: 100%|██████████| 27/27 [00:07<00:00,  3.59it/s, Loss=0.001351]



Époque 15 terminée | Train Loss: 0.001516


Époque 16/20: 100%|██████████| 27/27 [00:07<00:00,  3.68it/s, Loss=0.001706]



Époque 16 terminée | Train Loss: 0.001475


Époque 17/20: 100%|██████████| 27/27 [00:07<00:00,  3.67it/s, Loss=0.001304]



Époque 17 terminée | Train Loss: 0.001355


Époque 18/20: 100%|██████████| 27/27 [00:07<00:00,  3.61it/s, Loss=0.001241]



Époque 18 terminée | Train Loss: 0.001329


Époque 19/20: 100%|██████████| 27/27 [00:07<00:00,  3.72it/s, Loss=0.001228]



Époque 19 terminée | Train Loss: 0.001325


Époque 20/20: 100%|██████████| 27/27 [00:07<00:00,  3.66it/s, Loss=0.001742]



Époque 20 terminée | Train Loss: 0.001245

✅ Entraînement terminé !


In [None]:
model.train()
model.to(device)

# print model info and trainable weigths
paramsTrainable = sum([np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())])
params = sum([np.prod(p.size()) for p in model.parameters()])
#print(model)
print("Trainable Weights (All Weights): %d (%d)" % (paramsTrainable, params))

# training loop
print("\nStarting training...")
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
    losses = []
    for s, sample in enumerate(trainLoader, 0):
        optimizer.zero_grad()

        d = sample["data"].to(device)

        inputSteps = 2
        cond = []
        for i in range(inputSteps):
            cond += [d[:,i:i+1]] # collect input steps
        conditioning = torch.concat(cond, dim=2) # combine along channel dimension
        data = d[:, inputSteps:inputSteps+1]

        noise, predictedNoise = model(conditioning=conditioning, data=data)

        loss = F.smooth_l1_loss(noise, predictedNoise)
        print("    [Epoch %2d, Batch %4d]: %1.7f" % (epoch, s, loss.detach().cpu().item()))
        loss.backward()

        losses += [loss.detach().cpu().item()]

        optimizer.step()
    print("[Epoch %2d, FULL]: %1.7f" % (epoch, sum(losses)/len(losses)))

print("Training complete!")

### Eval

In [49]:
import torch
import numpy as np
from tqdm import tqdm

def run_rollout(model, testLoader, device, num_steps_to_predict=None):
    model.eval()
    all_gt = []
    all_pred = []
    
    past_window = model.past_window # Généralement 9

    # Barre de progression principale (Séquences)
    pbar_sequences = tqdm(testLoader, desc="Total Séquences", unit="seq")

    with torch.no_grad():
        for s, batch in enumerate(pbar_sequences):
            # data: [B, T, 4, 128, 64] | params: [B, T, 1]
            data = batch['data'].to(device)
            params = batch['simParameters'].to(device)
            
            B, T, C, H, W = data.shape
            steps = T if num_steps_to_predict is None else num_steps_to_predict

            # 1. Initialisation
            prediction_phys = torch.zeros((B, steps, 3, H, W), device=device)
            prediction_phys[:, :past_window] = data[:, :past_window, :3, :, :]

            # 2. Boucle Autorégressive avec barre de progression interne
            # desc="Rollout" pour suivre l'avancement dans le temps (ex: 250 steps)
            pbar_steps = tqdm(range(past_window, steps), desc=f"Seq {s+1} steps", leave=False, unit="step")
            
            for i in pbar_steps:
                # a. Préparer le conditionnement (9 frames passées)
                history_phys = prediction_phys[:, i-past_window:i, :, :, :]
                # Reynolds réel du dataset pour l'historique
                history_reynolds = data[:, i-past_window:i, 3:4, :, :] 
                
                # Fusion [B, 9, 4, 128, 64]
                cond_frames = torch.cat([history_phys, history_reynolds], dim=2)
                
                # b. Paramètres scalaires (Reynolds)
                u_past = params[:, i-past_window:i, :]
                u_curr = params[:, i, :]

                # c. Inférence par Diffusion (sampling)
                result = model.sample(cond_frames, u_past, u_curr, num_steps=20, method='heun')
                
                # d. Stockage
                prediction_phys[:, i] = result

            all_gt.append(data[:, :, :3, :, :].cpu().numpy())
            all_pred.append(prediction_phys.cpu().numpy())

    # Recomposition des tenseurs
    gt_phys = np.concatenate(all_gt, axis=0)
    pred_phys = np.concatenate(all_pred, axis=0)

    # 3. Dé-normalisation (Uniquement sur les 3 canaux physiques)
    # Conversion des moyennes et stds en tenseurs numpy avec la bonne forme pour le calcul
    normMean = testLoader.dataset.transform.normMean[:3].reshape(1, 1, 3, 1, 1)
    normStd = testLoader.dataset.transform.normStd[:3].reshape(1, 1, 3, 1, 1)
    
    # Application de la formule : (x * std) + mean
    gt_final = (gt_phys * normStd) + normMean
    pred_final = (pred_phys * normStd) + normMean

    print("\n✅ Rollout complet terminé !")
    return gt_final, pred_final

def calculate_mse(model, loader, device):
    """
    Exécute le rollout sur un loader et calcule les MSE physiques.
    """
    # 1. Obtenir les prédictions et la vérité terrain dé-normalisées
    # On utilise 20 steps de diffusion pour un bon compromis précision/vitesse
    gt_phys, pred_phys = run_rollout(model, loader, device)

    # 2. Calcul des MSE par canal (VelX: 0, VelY: 1, Pression: 2)
    mse_vx = np.mean((gt_phys[:, :, 0] - pred_phys[:, :, 0])**2)
    mse_vy = np.mean((gt_phys[:, :, 1] - pred_phys[:, :, 1])**2)
    mse_p  = np.mean((gt_phys[:, :, 2] - pred_phys[:, :, 2])**2)
    
    # MSE Globale
    mse_global = (mse_vx + mse_vy + mse_p) / 3

    print(f"\n--- Résultats ---")
    print(f"MSE Globale : {mse_global}")
    print(f"MSE Vel X   : {mse_vx}")
    print(f"MSE Vel Y   : {mse_vy}")
    print(f"MSE Pression: {mse_p}")

    return {"global": mse_global, "vx": mse_vx, "vy": mse_vy, "p": mse_p}

In [None]:
res_low = calculate_mse(model, test_loaders[0], device)

  with torch.cuda.amp.autocast(enabled=False):
Total Séquences: 100%|██████████| 5/5 [02:08<00:00, 25.74s/seq]


✅ Rollout complet terminé !

--- Résultats ---
MSE Globale : 0.07315181195735931
MSE Vel X   : 0.006337182596325874
MSE Vel Y   : 0.004584290087223053
MSE Pression: 0.20853395760059357





In [50]:
res_high = calculate_mse(model, test_loaders[1], device)
print(res_high)

  with torch.cuda.amp.autocast(enabled=False):
Total Séquences: 100%|██████████| 5/5 [02:11<00:00, 26.28s/seq]


✅ Rollout complet terminé !

--- Résultats ---
MSE Globale : 0.004852836020290852
MSE Vel X   : 0.0004129851295147091
MSE Vel Y   : 0.00035153795033693314
MSE Pression: 0.013793985359370708
{'global': np.float32(0.004852836), 'vx': np.float32(0.00041298513), 'vy': np.float32(0.00035153795), 'p': np.float32(0.013793985)}





In [51]:
res_vac = calculate_mse(model, test_loaders[2], device)
print(res_vac)

  with torch.cuda.amp.autocast(enabled=False):
Total Séquences: 100%|██████████| 1/1 [02:04<00:00, 124.17s/seq]


✅ Rollout complet terminé !

--- Résultats ---
MSE Globale : 0.11326124519109726
MSE Vel X   : 0.010605995543301105
MSE Vel Y   : 0.0082953330129385
MSE Pression: 0.3208824098110199
{'global': np.float32(0.113261245), 'vx': np.float32(0.010605996), 'vy': np.float32(0.008295333), 'p': np.float32(0.3208824)}





In [None]:
import matplotlib.pyplot as plt
import numpy as np

# --- CONFIGURATION ---
sequence_idx = 0         # Quelle simulation du batch afficher
field = 2                # 0: VelX, 1: VelY, 2: Pression
# On choisit des étapes significatives sur les 241 disponibles
timeSteps = [0, 79, 159, 239] 
timeSteps = [1, 20, 40, 60] 

# 1. Extraction des données (B, T, C, H, W)
# GT : On prend la séquence, les pas de temps et le canal choisi
gt_vis = gt[sequence_idx, timeSteps, field]      # Shape: (4, 128, 64)

# PRED : Idem pour la prédiction
pred_vis = pred[sequence_idx, timeSteps, field]  # Shape: (4, 128, 64)

# On les combine : Ligne 0 = GT, Ligne 1 = Prediction
gt_pred_combined = np.stack([gt_vis, pred_vis], axis=0) # Shape: (2, 4, 128, 64)

# 2. Création de la figure
nrows, ncols = gt_pred_combined.shape[0], gt_pred_combined.shape[1]
fig, axs = plt.subplots(
    nrows=nrows, 
    ncols=ncols, 
    figsize=(ncols * 3, nrows * 3.5), # Ajusté pour le format 128x64
    dpi=120, 
    squeeze=False
)

# 3. Boucle d'affichage
for i in range(nrows):
    for j in range(ncols):
        ax = axs[i, j]
        
        # Données à afficher
        data = gt_pred_combined[i, j]
        
        # .T (Transpose) pour mettre le 128 en largeur
        # origin='lower' pour avoir le bas de la soufflerie en bas de l'image
        im = ax.imshow(data.T, interpolation="catrom", cmap="viridis", origin='lower')
        
        # Titres de colonnes (Temps)
        if i == 0:
            ax.set_title(f"Step {timeSteps[j]}", fontsize=10, pad=10)
            
        # Labels de lignes
        if j == 0:
            label = "Ground Truth" if i == 0 else "Jit Pred"
            ax.set_ylabel(label, fontsize=11, fontweight='bold')
        
        ax.set_xticks([])
        ax.set_yticks([])

plt.tight_layout()
plt.show()