In [1]:
from data.module import DataLightningModule
from models.module import ModelLightningModule
from utils.objects.utils import EmptyObj
from utils.files.utils import pkl_load
import os
from results.utils import save_objects
import SDE.sampler as sampler
import torch
from tqdm import tqdm

In [2]:
IDs = ['y2025_m04_d02_09h_25m_13s', ...]

num_samples = 1

for _ in range(num_samples):
    for run_number in [0,1,2]:
        for ID in tqdm(IDs):
            storage_dir = "./storage"
            models_dir = os.path.join(storage_dir, "models")
            model_dir = os.path.join(models_dir, ID)
            checkpoint_path = os.path.join(model_dir, "checkpoint.ckpt")
            Cfg_path = os.path.join(model_dir, "Cfg.pkl")

            sampleCfg = pkl_load(Cfg_path)
            sampleCfg.cluster = False
            sampleCfg.work_dir = os.getcwd()

            dataModule = DataLightningModule(sampleCfg)
            dataModule.prepare_data(force_call=True, export=False)  # We still need the trainSet for ease of checkpoint loading
            dataModule.setup(stage="test")  # Prepare the trainSet

            print(sampleCfg.sdeCfg.sde_name)

            model_module = ModelLightningModule.load_from_checkpoint(checkpoint_path=checkpoint_path, Cfg=sampleCfg, data_set=dataModule.trainSet)

            model_module.eval()
            device = model_module.device

            sampleCfg.samplingCfg = EmptyObj()

            sampleCfg.samplingCfg.method = "pc"
            sampleCfg.samplingCfg.noise_removal = True
            sampleCfg.samplingCfg.predictor = "euler_maruyama"  # none euler_maruyama reverse_diffusion ancestral_sampling
            sampleCfg.samplingCfg.corrector = "none" # none langevin ald
            sampleCfg.samplingCfg.snr = .16
            sampleCfg.samplingCfg.n_steps_each = 1
            sampleCfg.samplingCfg.probability_flow = False

            dataset_name = sampleCfg.globalCfg.dataset_name
            print(dataset_name)
            print("Regularization:", sampleCfg.trainingCfg.energy_loss)


            _, sample = dataModule.trainSet[0]
            if not dataset_name == "MHD_64":
                _, num_fields, lx, ly = sample.shape
            else:
                _, num_fields, lx, ly, lz = sample.shape


            sampling_fn = sampler.get_sampling_fn(sampleCfg.samplingCfg, model_module.SDE, tuple(sample.shape), eps=1e-3)
            num_frames = 1

            GT, _ = dataModule.testSet[run_number]
            print(len(dataModule.testSet))
            pred = torch.zeros_like(GT)
            pred[:num_frames, ...] = GT[:num_frames, ...]  # Set the same first frames

            GT = GT.to(device)
            pred = pred.to(device)

            resh_fn = model_module.model.cond_format
            mask = dataModule.testSet.mask.to(model_module.device) if dataset_name == "JHTDB" else None
            c=0

            for i in range(num_frames, 61 if not dataset_name == "MHD_64" else 21):

                cond_frames = pred[i-num_frames:i, ...]
                cond = resh_fn(cond_frames.unsqueeze(0))
                sample, n_tot_iter = sampling_fn(model_module, cond=cond.to(model_module.device))
                pred[i, ...] = sample

                if dataset_name == "JHTDB":
                    pred[i, -1, ...] = GT[i, -1, ...]  # Mach is known, JHTDB only

                c+=1

            sample = dict(GT=GT, pred=pred, fields_names=dataModule.testSet.fields_names, mask=mask)
            save_objects(storage_path=storage_dir, sampleCfg=sampleCfg, data_dict=sample, test_case_number=run_number+1)

        print('--- Done ---')

  0%|          | 0/2 [00:00<?, ?it/s]

subvpsde
turbulent_radiative_layer_2D
Regularization: False
9


  0%|          | 0/2 [00:13<?, ?it/s]


KeyboardInterrupt: 