# Run Forecasts of the diffusion models

author: Randy Chase <br>
email: randy 'dot' chase 'at' colostate.edu 

This notebook is to run a bunch of forecasts. Please make sure you have the datasets ready for this. If you are using the results of the paper, they are located on dryad!

### 1) Load some packages and the dataset

In [1]:
#imports 
import zarr
import torch
from torch.utils.data import Dataset, DataLoader
from diffusers import UNet2DModel
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt 
import gc
import os 
from tqdm import tqdm
from pathlib import Path
import os
import math
from typing import List, Optional, Tuple, Union
from diffusers.utils.torch_utils import randn_tensor
import tqdm 
import time as timer 
from diffusers import AutoencoderKL


class ZarrDataset(Dataset):
    """Class to load the zarr data for pytorch. Might need to be optimized if the data are being lazily loaded..."""
    def __init__(self, zarr_store):
        self.store = zarr_store
        self.data = zarr.open(self.store, mode='r')
        self.length = self.data['input_images'].shape[0]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Load data lazily
        input_image = self.data['input_images'][idx]
        output_image = self.data['output_images'][idx]
        return torch.tensor(output_image, dtype=torch.float16),torch.tensor(input_image, dtype=torch.float16)
    
def to_K(images):
    """function to convert from the mean 0, std 1 space back to brightness temp"""
    mean_data=279.0699458792467
    std_data =19.32967519050003
    return (images * std_data) + mean_data

# Initialize the main dataset, choose the: 
#training dataset
# zarr_store = '/mnt/data1/rchas1/TRANSITION/datasets/edm_GOES_ch13_training_dataset.zarr'
#validation dataset
# zarr_store = '/mnt/data1/rchas1/TRANSITION/datasets/edm_GOES_ch13_validation_dataset.zarr'
#test dataset
zarr_store = './datasets/edm_GOES_ch13_test_dataset.zarr'
dataset = ZarrDataset(zarr_store)

#this is the max size that works with generating 10 ensemble members on the GH200, might need to change depneding on the GPU. 
batch_size = 10

train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

#load the latent dataset (data were compressed with the radames VAE to same time)
#validation
# zarr_store = '/mnt/data1/rchas1/diffusion_10_4_2inputs_2024validation_latent_radames_v3.zarr/'
#test
zarr_store = './datasets/edm_GOES_ch13_test_dataset_latent.zarr'

dataset = ZarrDataset(zarr_store)
latent_train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)


  from .autonotebook import tqdm as notebook_tqdm


### 2) Get the classes 
These are the usual classes originally from NVIDIA, but adapted to take our condition and to go with the paper. 

In [2]:
class EDMPrecond(torch.nn.Module):
    """ Original Func:: https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/networks.py#L519
    
    This is a wrapper for your diffusers model. It's purpose is to apply the preconditioning that is talked about in Karras et al. (2022)'s EDM paper. 
    
    I've made some changes for the sake of conditional-EDM (the original paper is unconditional).
    
    """
    def __init__(self,
        generation_channels,                # number of channels you want to generate
        model,                              # pytorch model from diffusers 
        use_fp16        = False,             # Execute the underlying model at FP16 precision?
        sigma_min       = 0,                # Minimum supported noise level.
        sigma_max       = float('inf'),     # Maximum supported noise level.
        sigma_data      = 0.5,              # Expected standard deviation of the training data. this was the default from above
    ):
        super().__init__()
        self.generation_channels = generation_channels
        self.model = model
        self.use_fp16 = use_fp16
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.sigma_data = sigma_data
        
    def forward(self, x, sigma, force_fp32=False, **model_kwargs):
        
        """ 
        
        This method is to 'call' the neural net. But this is the preconditioning from the Karras EDM paper. 
        
        note for conditional, it expects x to have the condition in the channel dim (dim=1). and the images you want to generate should already have noise.
        
        x: input stacked image with the generation images stacked with the condition images [batch,generation_channels + condition_channels,nx,ny]
        sigma: the noise level of the images in batch [??]
        force_fp32: this is forcing calculations to be a certain percision. 
        
        """
        
        #for the calculations, use float 32
        x = x.to(torch.float32)
        #reshape sigma from _ to _ 
        sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
        
        #forcing dtype matching
        dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
        
        #get weights from EDM 
        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
        c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
        c_noise = sigma.log() / 4

        # split out the images you want to generate and the condition, because the scaling will depend on this. 
        x_noisy = torch.clone(x[:,0:self.generation_channels])
        
        #the condition
        x_condition = torch.clone(x[:,self.generation_channels:])

        
        #concatinate back with the scaling applied to only the the generation dimension (x_noisy)
        model_input_images = torch.cat([x_noisy*c_in, x_condition], dim=1)
        
        #denoise the image (e.g., run it through your diffusers model) 
        F_x = self.model((model_input_images).to(dtype), c_noise.flatten(), return_dict=False)[0]
        
        #force dtype
        assert F_x.dtype == dtype
        
        #apply additional scalings: make sure you apply skip just to the generation dim (x[:,0:generation_channel]) and NOT applied to (x*c_in)
        D_x = c_skip * x_noisy + c_out * F_x.to(torch.float32)
        
        return D_x

    def round_sigma(self, sigma):
        return torch.as_tensor(sigma)
    
class EDMPrecond_TF(torch.nn.Module):
    """ Original Func:: https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/networks.py#L519
    
    This is a wrapper for your diffusers model. It's purpose is to apply the preconditioning that is talked about in Karras et al. (2022)'s EDM paper. 
    
    I've made some changes for the sake of conditional-EDM (the original paper is unconditional).
    
    """
    def __init__(self,
        generation_channels,                # number of channels you want to generate
        model,                              # pytorch model from diffusers 
        use_fp16        = False,             # Execute the underlying model at FP16 precision?
        sigma_min       = 0,                # Minimum supported noise level.
        sigma_max       = float('inf'),     # Maximum supported noise level.
        sigma_data      = 0.5,              # Expected standard deviation of the training data. this was the default from above
    ):
        super().__init__()
        self.generation_channels = generation_channels
        self.model = model
        self.use_fp16 = use_fp16
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.sigma_data = sigma_data
        
    def forward(self, x, sigma, force_fp32=False, **model_kwargs):
        
        """ 
        
        This method is to 'call' the neural net. But this is the preconditioning from the Karras EDM paper. 
        
        note for conditional, it expects x to have the condition in the channel dim (dim=1). and the images you want to generate should already have noise.
        
        x: input stacked image with the generation images stacked with the condition images [batch,generation_channels + condition_channels,nx,ny]
        sigma: the noise level of the images in batch [??]
        force_fp32: this is forcing calculations to be a certain percision. 
        
        """
        
        #for the calculations, use float 32
        x = x.to(torch.float32)
        #reshape sigma from _ to _ 
        sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
        
        #forcing dtype matching
        dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
        
        #get weights from EDM 
        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
        c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
        c_noise = sigma.log() / 4

        # split out the images you want to generate and the condition, because the scaling will depend on this. 
        x_noisy = torch.clone(x[:,0:self.generation_channels])
        
        #the condition
        x_condition = torch.clone(x[:,self.generation_channels:])

        
        #concatinate back with the scaling applied to only the the generation dimension (x_noisy)
        model_input_images = torch.cat([x_noisy*c_in, x_condition], dim=1)
        
        #denoise the image (e.g., run it through your diffusers model) 
        F_x = self.model((model_input_images).to(dtype), c_noise.flatten(),
                         class_labels=torch.zeros(model_input_images.shape[0]).to(torch.int).to(model_input_images.device),
                         return_dict=False)[0]
        
        #force dtype
        assert F_x.dtype == dtype
        
        #apply additional scalings: make sure you apply skip just to the generation dim (x[:,0:generation_channel]) and NOT applied to (x*c_in)
        D_x = c_skip * x_noisy + c_out * F_x.to(torch.float32)
        
        return D_x

    def round_sigma(self, sigma):
        return torch.as_tensor(sigma)

class EDMLoss:
    
    """Original Func:: https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py
    
    This is the loss function class from Karras et al. (2022)'s EDM paper. Only thing changed here is that the __call__ takes the clean_images and the condition_images seperately. It expects your model to be wrapped with that EDMPrecond class. 
    
    """
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
        """ These describe the distribution of sigmas we should sample during training """
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net, clean_images, condition_images, labels=None, augment_pipe=None):
        
        """ 
        
        net: is a pytorch model wrapped with EDMPrecond
        clean_images: the images you want to generate, [batch,generation_channels,nx,ny]
        condition_images:images you want to condition with [batch,condition_channels,nx,ny]
        
        """
        
        #get random seeds, one for each image in the batch 
        rnd_normal = torch.randn([clean_images.shape[0], 1, 1, 1], device=clean_images.device)
        
        #get random noise levels (sigmas)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        
        #get the loss weight for those sigmas 
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        
        #make the noise scalars images so we can add them to our images
        n = torch.randn_like(clean_images) * sigma
    
        #add noise to the clean images 
        noisy_images = torch.clone(clean_images + n)
        
        #cat the images for the wrapped model call 
        model_input_images = torch.cat([noisy_images, condition_images], dim=1)
        
        #call the EDMPrecond model 
        denoised_images = net(model_input_images, sigma)
        
        #calc the weighted loss at each pixel, the mean across all GPUs and pixels is in the main train_loop 
        loss = weight * ((denoised_images - clean_images) ** 2)
        
        return loss
    
class StackedRandomGenerator:  # pragma: no cover
    """
    Wrapper for torch.Generator that allows specifying a different random seed
    for each sample in a minibatch.
    """

    def __init__(self, device, seeds):
        super().__init__()
        self.generators = [
            torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds
        ]

    def randn(self, size, **kwargs):
        if size[0] != len(self.generators):
            raise ValueError(
                f"Expected first dimension of size {len(self.generators)}, got {size[0]}"
            )
        return torch.stack(
            [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]
        )

    def randn_like(self, input):
        return self.randn(
            input.shape, dtype=input.dtype, layout=input.layout, device=input.device
        )

    def randint(self, *args, size, **kwargs):
        if size[0] != len(self.generators):
            raise ValueError(
                f"Expected first dimension of size {len(self.generators)}, got {size[0]}"
            )
        return torch.stack(
            [
                torch.randint(*args, size=size[1:], generator=gen, **kwargs)
                for gen in self.generators
            ]
        )
    
def edm_sampler(
    net, latents, condition_images,class_labels=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
    """ adapted from: https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/generate.py 
    
    only thing i had to change was provide a condition as input to this func, then take that input and concat with generated image for the model call. 
    
    """
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
    
    # Main sampling loop.
    x_next = latents.to(torch.float64) * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)

        #need to concat the condition here 
        model_input_images = torch.cat([x_hat, condition_images], dim=1)
        # Euler step.
        with torch.no_grad():
            denoised = net(model_input_images, t_hat).to(torch.float64)

        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            model_input_images = torch.cat([x_next, condition_images], dim=1)
            with torch.no_grad():
                denoised = net(model_input_images, t_next).to(torch.float64)
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

    return x_next


### 3) Define wrapper functions

These functions help do the forecasts by wrapping them and doing the autoregression for you. Alot is similar between them, but there are some differences with shapes and what not. 

In [3]:
def run_forecast_diff_ens(model,images_condition,ens_size=10,same_seed=True,time_steps_forward = np.arange(0,18),S_churn=np.sqrt(2) - 1,num_steps=18,sigma_max=80, rho=7):
    """
    
    This is to run an ensemble forecast for the plain EDM diffusion models.
    
    INPUTS
    model: the EDMPrecondition wrapped pytorch model 
    images_condition: a tensor of shape [batch,condition,nx,ny] containing the condition images for the start of the forecast 
    ens_size: an integer for the number of ensemble members that are requested. default is 10
    same_seed: set the seeds the same for all forecast steps. This is off by default because it breaks the autoregression.
    time_steps_forward: array of the forecast length 
    
    # Karras et al. params passed through to the edm_sampler 
    time_steps_forward: array of the number of steps you want to take forward. by default its 3 hours for our paper. 
    S_churn: how big of a step backwards you want to take in the SDE image generation. If its 0, this will just turn off. By default I set it to the max value. 
    num_steps: number of diffusion steps you want to take. 
    sigma_max: the starting sigma value
    rho: how the steps are spaced 
     
    OUTPUTS
    full_image_list: full list of forecast, first 2 images were the original inputs. The rest are all the ML forecasts. 
    """

    #define arrays to store things as we generate the sequence 
    full_image_list = torch.zeros([images_condition.shape[0],ens_size,time_steps_forward.shape[0]+2,images_condition.shape[2],images_condition.shape[3]])
    
    #fill the first two in the full sequence with the condition, make sure you repeat it to match the requested ens size
    full_image_list[:,:,0:2] = torch.clone(images_condition.unsqueeze(1)).repeat((1,ens_size,1,1,1)).to('cpu')
    
    #loop over requested time steps 
    for time in tqdm.tqdm(time_steps_forward):
        
        #if you by chance want to use the same noisy seed for every time step, use same_seed. This broke my autoregression before....
        if same_seed:
            rnd = StackedRandomGenerator('cuda',np.arange(0,full_image_list.shape[0]*ens_size,1).astype(int).tolist())
        else:
            rnd = StackedRandomGenerator('cuda',np.arange(0,full_image_list.shape[0]*ens_size,1).astype(int).tolist() + time)
        
        #build out the random noise images 
        latents = rnd.randn([full_image_list.shape[0]*ens_size, 1, images_condition.shape[2], images_condition.shape[3]],device='cuda')
        
        #create the current condition tensor 
        current_image_list = torch.clone(full_image_list[:,:,time:time+2]).to(torch.float32).to('cuda')

        #reshape to run through NN, needs to be [batch*ens_size,condition,nx,ny]
        current_image_list = current_image_list.reshape([full_image_list.shape[0]*full_image_list.shape[1],
                                                        2,full_image_list.shape[3],full_image_list.shape[4]])

        #run diffusion model 10 min forecast 
        images_ens = edm_sampler(model,latents,current_image_list,num_steps=num_steps,S_churn=S_churn,sigma_max=sigma_max,rho=rho)
        
        #undo reshape back to [batch,ens_size,1,nx,ny]
        images_ens = images_ens.reshape([full_image_list.shape[0],full_image_list.shape[1],1,full_image_list.shape[3],full_image_list.shape[4]])
        
        #save that 10 min forecast out so you can keep the whole sequence 
        full_image_list[:,:,time+2:time+3] = torch.clone(images_ens.cpu())
             
    return full_image_list

def run_forecast_corrdiff_ens(model_diff,model_unet,images_condition,same_seed=True,ens_size=10,time_steps_forward = np.arange(0,18),S_churn=np.sqrt(2) - 1,num_steps=18,sigma_max=80, rho=7):
    
    """
    
    This is to run an ensemble forecast for the CorrDiff EDM diffusion models.
    
    INPUTS
    model_diff: the EDMPrecondition wrapped pytorch model 
    model_unet: the Unet model that was used to get the 'first guess' for the corrdiff approach. 
    images_condition: a tensor of shape [batch,condition,nx,ny] containing the condition images for the start of the forecast 
    ens_size: an integer for the number of ensemble members that are requested. default is 10
    same_seed: set the seeds the same for all forecast steps. This is off by default because it breaks the autoregression.
    time_steps_forward: array of the forecast length 
    
    # Karras et al. params passed through to the edm_sampler 
    time_steps_forward: array of the number of steps you want to take forward. by default its 3 hours for our paper. 
    S_churn: how big of a step backwards you want to take in the SDE image generation. If its 0, this will just turn off. By default I set it to the max value. 
    num_steps: number of diffusion steps you want to take. 
    sigma_max: the starting sigma value
    rho: how the steps are spaced 
     
    OUTPUTS
    full_image_list: full list of forecast, first 2 images were the original inputs. The rest are all the ML forecasts. 
    """
    #define arrays to store things as we generate the sequence 
    full_image_list = torch.zeros([images_condition.shape[0],ens_size,time_steps_forward.shape[0]+2,images_condition.shape[2],images_condition.shape[3]])
    #fill the first two in the full sequence with the condition, make sure you repeat it to match the requested ens size
    full_image_list[:,:,0:2] = torch.clone(images_condition.unsqueeze(1)).repeat((1,ens_size,1,1,1)).to('cpu')
    
    #loop over requested time steps 
    for time in tqdm.tqdm(time_steps_forward):
        
        #if you by chance want to use the same noisy seed for every time step, use same_seed. This broke my autoregression before....
        if same_seed:
            rnd = StackedRandomGenerator('cuda',np.arange(0,full_image_list.shape[0]*ens_size,1).astype(int).tolist())
        else:
            rnd = StackedRandomGenerator('cuda',np.arange(0,full_image_list.shape[0]*ens_size,1).astype(int).tolist() + time)
        
        #build out the random noise images 
        latents = rnd.randn([full_image_list.shape[0]*ens_size, 1, images_condition.shape[2], images_condition.shape[3]],device='cuda')
        
        #create the current condition tensor 
        current_image_list = torch.clone(full_image_list[:,:,time:time+2]).to(torch.float32).to('cuda')
        
        #reshape to run through NN, needs to be [batch*ens_size,condition,nx,ny]
        current_image_list = current_image_list.reshape([full_image_list.shape[0]*full_image_list.shape[1],
                                                        2,full_image_list.shape[3],full_image_list.shape[4]])
        
        #STEP 1: run unet to get 'blurry' forecast 
        with torch.no_grad(): #turn off gradient tracking to ensure memory savings, the zeros here are a hack of how i used the same unet. 
            images_unet = model_unet(current_image_list,torch.zeros(current_image_list.shape[0]).to(current_image_list.device), return_dict=False)[0]
        
        #concat the unet output to the other current images end shape will be [batch*ens_size,3,nx,ny]
        current_image_list_diff = torch.concat([current_image_list,images_unet],axis=1)
        
        #STEP 2: run diffusion to get residual for adjusting the unet 'blurry' forecast  
        images_diff = edm_sampler(model_diff,latents,current_image_list_diff,num_steps=num_steps,S_churn=S_churn,sigma_max=sigma_max,rho=rho)
        
        #undo the 0-1 scaling I did to help the ML learn 
        data_mean = torch.tensor(-0.0009)
        data_std = torch.tensor(0.0807)
        images_diff = images_diff*data_std + data_mean
        
        #undo reshape back to [batch,ens_size,1,nx,ny]
        images_diff =  images_diff.reshape([full_image_list.shape[0],full_image_list.shape[1],1,full_image_list.shape[3],full_image_list.shape[4]])
        images_unet =  images_unet.reshape([full_image_list.shape[0],full_image_list.shape[1],1,full_image_list.shape[3],full_image_list.shape[4]])
        
        #add resid to unet output 
        images_ens = images_unet.cpu() + images_diff.cpu()
        
        #store new 'starting' state
        full_image_list[:,:,time+2:time+3] = torch.clone(images_ens.cpu())

    return full_image_list

def run_forecast_ldm_ens(model,images_condition,images_condition_latent,vae,same_seed=True,ens_size=10,time_steps_forward = np.arange(0,18,dtype=int),S_churn=np.sqrt(2) - 1,num_steps=18,sigma_max=80, rho=7):
    
    """
    This is to run an ensemble forecast for the latent EDM diffusion models.
    
    INPUTS
    model: the EDMPrecondition wrapped pytorch model 
    images_condition: a tensor of shape [batch,condition,nx,ny] containing the condition images for the start of the forecast
    images_condition_latent: the same tensor as before, but pre-compressed through the VAE to save time. 
    vae: the huggingface pretrained pytorch VAE.
    ens_size: an integer for the number of ensemble members that are requested. default is 10
    same_seed: set the seeds the same for all forecast steps. This is off by default because it breaks the autoregression.
    time_steps_forward: array of the forecast length 
    
    # Karras et al. params passed through to the edm_sampler 
    time_steps_forward: array of the number of steps you want to take forward. by default its 3 hours for our paper. 
    S_churn: how big of a step backwards you want to take in the SDE image generation. If its 0, this will just turn off. By default I set it to the max value. 
    num_steps: number of diffusion steps you want to take. 
    sigma_max: the starting sigma value
    rho: how the steps are spaced 
     
    OUTPUTS
    full_image_list: full list of forecast, first 2 images were the original inputs. The rest are all the ML forecasts. 
    """
    
    #define arrays to store things as we generate the sequence 
    full_image_list = torch.zeros([images_condition.shape[0],ens_size,time_steps_forward.shape[0]+2,images_condition.shape[2],images_condition.shape[3]])
    full_image_list[:,:,0:2] = torch.clone(images_condition.unsqueeze(1)).repeat((1,ens_size,1,1,1)).to('cpu')
    
    #define arrays to store things as we generate the sequence but also for the latent space to keep track there, 
    #the 8 and the 4 come from the added 4 channels in the latent space for each image in the full space 
    latent_image_list = torch.zeros([images_condition.shape[0],ens_size, 8 + time_steps_forward.shape[0]*4,images_condition_latent.shape[2],images_condition_latent.shape[3]])
    latent_image_list[:,:,0:8] = torch.clone(images_condition_latent.unsqueeze(1)).repeat((1,ens_size,1,1,1)).to('cpu')
    
    #these are scalars from the training set to undo the scaling of the latent space that will then be fed into the vae.
    condition_std = torch.tensor(6.0271)
    condition_mean = torch.tensor(-4.2534)

    for time in tqdm.tqdm(time_steps_forward):
        
        #if you by chance want to use the same noisy seed for every time step, use same_seed. This broke my autoregression before....
        if same_seed:
            rnd = StackedRandomGenerator('cuda',np.arange(0,full_image_list.shape[0]*ens_size,1).astype(int).tolist())
        else:
            rnd = StackedRandomGenerator('cuda',np.arange(0,full_image_list.shape[0]*ens_size,1).astype(int).tolist() + time)
         
        #build out the random noise images 
        latents = rnd.randn([images_condition.shape[0]*ens_size, 4, latent_image_list.shape[3], latent_image_list.shape[4]],device='cuda')
        
        #create the current condition tensor (again the 4 and 8 come from the 4 extra channels in the latent space)
        current_image_list = torch.clone(latent_image_list[:,:,(time*4):(time*4)+8]).to('cuda')
        
        #reshape to run through NN [batch*ens_size,8,nx,ny]
        current_image_list = current_image_list.reshape([latent_image_list.shape[0]*latent_image_list.shape[1],
                                                        8,latent_image_list.shape[3],latent_image_list.shape[4]])

        #run diffusion model 10 min forecast 
        images_batch = edm_sampler(model,latents,current_image_list,num_steps=num_steps,S_churn=S_churn,sigma_max=sigma_max,rho=rho)
        
        #store latents for the next time step, reshape it to be [batch,ens_size,nx,ny]
        latent_image_list[:,:,(time*4+8):(time*4+12)]  = torch.clone(images_batch.reshape([latent_image_list.shape[0],latent_image_list.shape[1],4,latent_image_list.shape[3],latent_image_list.shape[4]]).cpu())

        #undo mean0,std1 scalaing before decoding (done to make training the denoiser easier)
        images_batch = (images_batch*condition_std) + condition_mean
        
        #decode with the pretrained vae 
        with torch.no_grad():
            reconstructed_images_batch = vae.decode(images_batch.to(torch.float)) #input should be [batch*channel,4,64,64]
            #take the mean across RGB, this is needed because we are using a vae for RGB images 
            reconstructed_images_batch = reconstructed_images_batch.sample.mean(axis=1).unsqueeze(1) #reconstructed should be [batch*ens_size,3,nx,ny]
    
        #undo 0-1 scaling that was used for the VAE 
        data_min = -6.16015625
        data_max = 3.15234375
        reconstructed_images_batch = reconstructed_images_batch*(data_max-data_min) + data_min
        
        #reshape to [batch,ens_size,nx,ny]
        reconstructed_images_batch = reconstructed_images_batch.reshape([full_image_list.shape[0],full_image_list.shape[1],1,full_image_list.shape[3],full_image_list.shape[4]])
        
        #store all the forecast times 
        full_image_list[:,:,time+2:time+3] = torch.clone(reconstructed_images_batch.cpu())
        
    return full_image_list

def run_forecast_unet(model,images_condition,time_steps_forward = np.arange(0,18)):
    """
    This is to run a vanilla unet forecast 
    
    INPUTS
    model: the pytorch unet, from diffusers here 
    images_condition: a tensor of shape [batch,condition,nx,ny] containing the condition images for the start of the forecast
    time_steps_forward: array of the forecast length 
     
    OUTPUTS
    full_image_list: full list of forecast, first 2 images were the original inputs. The rest are all the ML forecasts. 
    """

    full_image_list = torch.zeros([images_condition.shape[0],time_steps_forward.shape[0]+2,images_condition.shape[2],images_condition.shape[3]])
    full_image_list[:,0:2] = torch.clone(images_condition).to('cpu')

    for time in tqdm.tqdm(time_steps_forward):
        current_image_list = torch.clone(full_image_list[:,time:time+2]).to(torch.float32).to('cuda')

        with torch.no_grad():
            images_ens = model(current_image_list,torch.zeros(current_image_list.shape[0]).to(current_image_list.device), return_dict=False)[0]

        full_image_list[:,time+2:time+3] = torch.clone(images_ens.cpu())
        
    return full_image_list

### 4) Load in the models!

In [4]:
###################################
###### Load plain EDM model #######
###################################

#build arch 
model = UNet2DModel(
    sample_size=256,  # the target image resolution
    in_channels=3,  # the number of input channels, noisy + 2 condition channels 
    out_channels=1,  # the number of output channels, just 1 10 min forecast 
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)
#get weights 
checkpoint = torch.load("./models/edm_plain_diffusion/checkpoint.pth", map_location='cuda')
#load weights 
model.load_state_dict(checkpoint['model_state_dict'])
#put it on the gpu 
model.to('cuda')
#wrap it with the precondition
model_wrapped = EDMPrecond(1,model)

########################################
######### Load CorrDiff model ##########
########################################

#build arch 
model_corrdiff = UNet2DModel(
    sample_size=256,  # the target image resolution
    in_channels=4,  # the number of input channels, noisy + 2 conidtion + 1 unet forecast
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)
#get weights
checkpoint = torch.load("./models/edm_corrdiff/checkpoint.pth", map_location='cuda')
#load weights 
model_corrdiff.load_state_dict(checkpoint['model_state_dict'])
#put it on the gpu
model_corrdiff.to('cuda')
#wrap it with the precondition
model_corrdiff_wrapped = EDMPrecond(1,model_corrdiff)


########################################
############ Load LDM model ############
########################################

#build arch 
model_ldm = UNet2DModel(
    sample_size=64,  # the target image resolution
    in_channels=12,  # Expands alot because the latent space adds 4 channels per one input, (4*3)
    out_channels=4,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)
#get weights
checkpoint = torch.load("./models/edm_ldm_radames/checkpoint.pth", map_location='cuda')
#load weights 
model_ldm.load_state_dict(checkpoint['model_state_dict'])
#put it on the gpu
model_ldm.to('cuda')
#wrap it with the precondition
model_ldm_wrapped = EDMPrecond(4,model_ldm)

########################################
######## Load plain unet model #########
########################################

model_unet = UNet2DModel.from_pretrained("./models/unet_vanilla/").to('cuda')

########################################
######### Load pretrained vae ##########
########################################

#this is the path on hugging face, not locally. 
vae =  AutoencoderKL.from_pretrained("radames/stable-diffusion-x4-upscaler-img2img", subfolder="vae").to('cuda')

### 5) Run the forecast 

This loop will run the forecast for all methods, the plain EDM first, then the CorrDiff, then the unet and LDM. It also passes along a persistence forecast and a the truth. 

This takes the longest amount of time. With the generation params I have selected, its about 10 mins for each of the big ones. 

11 mins for the plain edm diffusion 
11 mins for the corrdiff 
 0 for the vanilla unet 
 1 for the LDM forecast 
------------------------
23 mins total per batch size of 10

1024 total images 

~103 batches 

(23*100)/60  ~= 40 hours to run... come back later if this is in a screen 

In [None]:
#3 hour forecast, this can be changed to your liking. 
time_steps_forward = np.arange(0,18)
time_len = time_steps_forward.shape[0] + 2

# some image shapes for all the arrays (for validation and test)
nx = 256
ens_size = 10
num_forecasts = 1000


#preallocate arrays so we can store the data because we only gen a batch at a time 
full_image_list_K_concat_diff_ens_mems = torch.zeros([num_forecasts,ens_size,time_len,nx,nx])
full_image_list_K_concat_corrdiff_ens_mems = torch.zeros([num_forecasts,ens_size,time_len,nx,nx])
full_image_list_K_concat_ldm_ens_mems = torch.zeros([num_forecasts,ens_size,time_len,nx,nx])
full_image_list_K_concat_unet = torch.zeros([num_forecasts,time_len,nx,nx])
#these are hardcoded because these are fixed 
full_image_list_K_concat_pers = torch.zeros([num_forecasts,1,nx,nx])
clean_image_list_K_concat = torch.zeros([num_forecasts,18,nx,nx])

#best set of Karras 2022 generation parameters for the latent model, assuming it works for all of them. 
num_steps = 36
S_churn=7.2
sigma_max = 140 
rho = 4

#Karras 2022 defaults
# num_steps = 18
# S_churn=0
# sigma_max = 80
# rho = 7

#main loop, for each batch, run the forecast 
for step, (batch, batch_latent) in enumerate(tqdm.tqdm(zip(train_dataloader, latent_train_dataloader))):
        
        #these if's are if you need to start stop from a specific batch 
        if (step >= 1) and (step < 2):
            #the true data 
            clean_images_eval = batch[0]
            #the condition, throw it on the GPU cause we need it 
            condition_images_eval = batch[1].to('cuda')
            
            #the true latent image 
            latent_clean_images_eval = batch_latent[0]
            #the latent condition , throw it on the GPU cause we need it 
            latent_condition_images_eval = batch_latent[1].to('cuda')
            
            #run forecasts 
            full_image_list_diff_mems_ens = run_forecast_diff_ens(model_wrapped,condition_images_eval,S_churn=S_churn,
                                                                num_steps=num_steps,rho=rho,sigma_max=sigma_max,same_seed=False,time_steps_forward=time_steps_forward)
            full_image_list_corrdiff_ens_mems = run_forecast_corrdiff_ens(model_corrdiff_wrapped,model_unet,condition_images_eval,S_churn=S_churn,
                                                                num_steps=num_steps,rho=rho,sigma_max=sigma_max,same_seed=False,time_steps_forward=time_steps_forward)
            full_image_list_unet = run_forecast_unet(model_unet,condition_images_eval,time_steps_forward=time_steps_forward)
            full_image_list_ldm_ens_mems = run_forecast_ldm_ens(model_ldm_wrapped,condition_images_eval,latent_condition_images_eval,vae,S_churn=S_churn,
                                                                    num_steps=num_steps,rho=rho,sigma_max=sigma_max,same_seed=False,time_steps_forward=time_steps_forward)

            #convert back to IR brightness temperatures 
            full_image_list_diff_ens_mems_K = to_K(full_image_list_diff_mems_ens)
            full_image_list_corrdiff_ens_mems_K = to_K(full_image_list_corrdiff_ens_mems)
            full_image_list_ldm_ens_mems_K = to_K(full_image_list_ldm_ens_mems)
            full_image_list_unet_K = to_K(full_image_list_unet)
            clean_images_eval_K = to_K(clean_images_eval)
            persistence_images_K = to_K(condition_images_eval[:,1:2].cpu())
            
            #store them 
            full_image_list_K_concat_diff_ens_mems[step*batch_size:(step+1)*batch_size] = full_image_list_diff_ens_mems_K
            full_image_list_K_concat_corrdiff_ens_mems[step*batch_size:(step+1)*batch_size] = full_image_list_corrdiff_ens_mems_K
            full_image_list_K_concat_unet[step*batch_size:(step+1)*batch_size] = full_image_list_unet_K
            clean_image_list_K_concat[step*batch_size:(step+1)*batch_size] = clean_images_eval_K
            full_image_list_K_concat_pers[step*batch_size:(step+1)*batch_size] = persistence_images_K
            full_image_list_K_concat_ldm_ens_mems[step*batch_size:(step+1)*batch_size] = full_image_list_ldm_ens_mems_K
        elif step < 1:
            continue 
        else:
            break


0it [00:00, ?it/s]
  0%|                                                                                                                                                     | 0/18 [00:00<?, ?it/s][A
  6%|███████▊                                                                                                                                     | 1/18 [01:19<22:39, 80.00s/it][A
 11%|███████████████▋                                                                                                                             | 2/18 [02:39<21:17, 79.86s/it][A
 17%|███████████████████████▌                                                                                                                     | 3/18 [03:59<19:57, 79.83s/it][A
 22%|███████████████████████████████▎                                                                                                             | 4/18 [05:19<18:37, 79.81s/it][A
 28%|███████████████████████████████████████▏                               

### 6) Save the forecasts to disk 

I use xarray, you could use whatever you want, I am only saving out the forecasts, not the inputs. 

In [None]:
import xarray as xr
ds_forecasts = xr.Dataset({'corr_diff_mems':(['n_sample','n_member','forecast_time','nx','ny'],full_image_list_K_concat_corrdiff_ens_mems[:,:,2:]),
           'diff_mems':(['n_sample','n_member','forecast_time','nx','ny'],full_image_list_K_concat_diff_ens_mems[:,:,2:]),
           'truth':(['n_sample','forecast_time','nx','ny'],clean_image_list_K_concat[:,:]),
           'unet':(['n_sample','forecast_time','nx','ny'],full_image_list_K_concat_unet[:,2:]),
           'ldm_mems':(['n_sample','n_member','forecast_time','nx','ny'],full_image_list_K_concat_ldm_ens_mems[:,:,2:]),
            'pers':(['n_sample','nx','ny'],full_image_list_K_concat_pers[:,0])})

# Specify compression for each variable
encoding = {
    'corr_diff_mems': {'zlib': True, 'complevel': 5},
    'diff_mems': {'zlib': True, 'complevel': 5},
    'ldm_mems': {'zlib': True, 'complevel': 5},
    'truth': {'zlib': True, 'complevel': 5},
    'unet': {'zlib': True, 'complevel': 5},
    'pers': {'zlib': True, 'complevel': 5},
}

ds_forecasts.to_netcdf('./datasets/test_set_forecasts.nc', encoding=encoding)