In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#from src.mlcast.models.ldcast.ldcast import LDCast, LDCastLightningModule

In [3]:
#LDCastLightningModule(nn.Module(), nn.Module())

In [4]:
from torch import nn
import torch

In [5]:
from src.mlcast.models.ldcast.autoenc.autoenc import AutoencoderKL
from src.mlcast.models.ldcast.autoenc.encoder import SimpleConvEncoder, SimpleConvDecoder
from src.mlcast.models.ldcast.context.context import AFNONowcastNetCascade
from src.mlcast.models.ldcast.diffusion.diffusion import DiffusionModel

take care of ema scope


In [6]:
future_timesteps = 20
autoenc_time_ratio = 4 # number of timesteps encoded in the autoencoder

In [7]:
# setup the different parts of LDCast

# setup forecaster
conditioner = AFNONowcastNetCascade(
    32,
    train_autoenc=False,
    output_patches=future_timesteps//autoenc_time_ratio,
    cascade_depth=3,
    embed_dim=128,
    analysis_depth=4
).to('cuda')

enc = SimpleConvEncoder()
dec = SimpleConvDecoder()
autoencoder = AutoencoderKL(enc, dec).to('cuda')

# setup denoiser
from src.mlcast.models.ldcast.diffusion.unet import UNetModel
denoiser = UNetModel(in_channels=autoencoder.hidden_width,
    model_channels=256, out_channels=autoencoder.hidden_width,
    num_res_blocks=2, attention_resolutions=(1,2), 
    dims=3, channel_mult=(1, 2, 4), num_heads=8,
    num_timesteps=future_timesteps//autoenc_time_ratio,
    # context channels (= analysis_net.cascade_dims)
    context_ch=[128, 256, 512]).to('cuda')

diffuser = DiffusionModel(denoiser).to('cuda')

In [8]:
class LDCastLightningModule(nn.ModuleDict):
    def __init__(self, autoencoder, conditioner, diffuser):
        super().__init__({'autoencoder': autoencoder, 'conditioner': conditioner, 'diffuser': diffuser})

    def forward(self, x, timesteps):
        
        # encoded is tuple of 3 tensors, but only the first one is used !!
        encoded = self.autoencoder.encode(x) 

        # condition is a dict of tensors
        condition = conditioner(encoded[0], timesteps)

        latent_diffused = diffuser(condition) # tensor

        prediction = self.autoencoder.decode(latent_diffused) # tensor
        
        return prediction

In [9]:
ldcast = LDCastLightningModule(autoencoder, conditioner, denoiser)

In [10]:
# create fake data
timesteps = torch.tensor([-3, -2, -1, 0], device = 'cuda', dtype = torch.float32)
timesteps = timesteps.unsqueeze(0).expand(1,-1) # need to expand timesteps because of the AFNONowcastNetBase.add_pos_enc method, not sure why
x = torch.randn(1, 1, 4, 256, 256, device = 'cuda')

In [11]:
prediction = ldcast(x, timesteps)

PLMS Sampler: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:10<00:00,  4.84it/s]


In [12]:
prediction.shape

torch.Size([1, 1, 20, 256, 256])