### Create temporally downscaled temperature data using some UKESM-CMIP6 data

In [66]:
import os
import math
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd 
import iris 
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from einops import rearrange
from typing import List
from timm.utils import ModelEmaV3

mps_device = torch.device("mps")

  from .autonotebook import tqdm as notebook_tqdm


In [118]:
# Load in the data
tas_monthly = iris.load_cube('../data/tas_Amon_UKESM1-0-LL_ssp370_r1i1p1f2_gn_201501-204912.nc')[:1, 30:62, 40:72]
tas_daily = iris.load_cube('../data/tas_day_UKESM1-0-LL_ssp370_r1i1p1f2_gn_20150101-20491230.nc')[:30, 30:62, 40:72]



In [119]:
def prepare_data(X, y):
    X_flattened = X.reshape(X.shape[0], (X.shape[1] * X.shape[2]))
    y = y.reshape(-1, 30, y.shape[1], y.shape[2])

    scaler = StandardScaler()
    X_transformed = scaler.fit_transform(X_flattened)

    X_transformed = X_transformed.reshape(X.shape[0], X.shape[1], X.shape[2])

    dataset = TensorDataset(torch.from_numpy(X_transformed), torch.from_numpy(y))
    loader = DataLoader(dataset, batch_size=1, shuffle=True)
    return loader

loader = prepare_data(tas_monthly.data, tas_daily.data)


tas_monthly shape --> (420, 144, 192)
tas_daily shape --> (12600, 144, 192)

Each month corresponds to 30 days.

I want the model to take two monthly fields (month before $X^{t-1}$, current month $X^{t}$), and generate 30 days of temperature fields. We want to use the month before, since I want the current temperature field to be temporally correlated with the previous month. 


At first I'll just do a single deterministic output, then I'll move onto generating an ensemble. 



#### Incorporate previous predictions into 'memory' using an LSTM

I'm going to penalise the model for generating daily fields that are not temporally consistent with the previous month's fields.


Diffusion model algorithm in a nut shell:

1. Take a randomly sampled data point from our training dataset
2. Select a random timestep on our noise schedule
3. Add the noise from that time step to our data, simulating the forward diffusion process through the diffusion kernel
4. Pass the defused impage into our model to predict the noise we added
5. Compute the mean squared error between the predicted noise and the actial noise, then optimise the parametes through that loss function
6. Repeat


Sampling algorithm:

1. Generate random noise from a standard normal distribution

Then for each timestep starting from our last timestep and moving backwards:

2. Update Z (latent noise related variable)by estimating the reverse process distribution with mean parameterized by Z from the previous step and variance parameterized by the noise our model estimates from that timestep
3. Add a small amount of the noise back for stability
4. And repeat until we arrive at time step 0, our recovered image

In [120]:
# We'll start with a naive Diffusion model 
# We need to create Sinusoidal Embeddings so the model knows where it is in the noise schedule

class SinusoidalEmbeddings(nn.Module):
    def __init__(self, time_steps: int, embed_dim: int):
        super().__init__()
        position = torch.arange(time_steps).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        embeddings = torch.zeros(time_steps, embed_dim, requires_grad=False)
        embeddings[:, 0::2] = torch.sin(position * div)
        embeddings[:, 1::2] = torch.cos(position * div)
        self.embeddings = embeddings

    def forward(self, x, t):
        embeds = self.embeddings[t].to(x.device)
        return embeds[:, :, None, None]

In [121]:
# From DDPM paper
class ResBlock(nn.Module):
    def __init__(self, C:int, num_groups: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.gnorm1 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.gnorm2 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(p=dropout_prob, inplace=True)

    def forward(self, x, embeddings):
        x = x + embeddings[:, :x.shape[1], :, :]
        r = self.conv1(self.relu(self.gnorm1(x)))
        r = self.dropout(r)
        r = self.conv2(self.relu(self.gnorm2(r)))
        return r + x

class Attention(nn.Module):
    def __init__(self, C:int, num_heads:int, dropout_prob: float):
        super().__init__()
        self.proj1 = nn.Linear(C, C*3)
        self.proj2 = nn.Linear(C, C)
        self.num_heads = num_heads
        self.dropout_prob = dropout_prob

    def forward(self, x):
        h, w = x.shape[2:]
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.proj1(x)
        x = rearrange(x, 'b L (C H K) -> K b H L C', K=3, H=self.num_heads)
        q,k,v = x[0], x[1], x[2]
        x = F.scaled_dot_product_attention(q,k,v, is_causal=False, dropout_p=self.dropout_prob)
        x = rearrange(x, 'b H (h w) C -> b h w (C H)', h=h, w=w)
        x = self.proj2(x)
        return rearrange(x, 'b h w C -> b C h w')  

class UNetLayer(nn.Module):
    def __init__(self,
            upscale: bool,
            attention: bool,
            num_groups: int,
            dropout_prob: float,
            num_heads: int,
            C: int):
        super().__init__()
        self.ResBlock1 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)
        self.ResBlock2 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)

        if upscale: 
            self.conv = nn.ConvTranspose2d(C, C//2, kernel_size=4, stride=2, padding=1)
        else:
            self.conv = nn.Conv2d(C, C*2, kernel_size=3, stride=2, padding=1)
        
        if attention:
            self.attention_layer = Attention(C, num_heads=num_heads, dropout_prob=dropout_prob)

    def forward(self, x, embeddings):
        x = self.ResBlock1(x, embeddings)
        if hasattr(self, 'attention_layer'):
            x = self.attention_layer(x)
        
        x = self.ResBlock2(x, embeddings)
        return self.conv(x), x
            

In [126]:
class UNet(nn.Module):
    def __init__(self,
            Channels: List = [64, 128, 256, 512, 512, 384],
            Attentions: List = [False, True, False, False, False, True],
            Upscales: List = [False, False, False, True, True, True],
            num_groups: int = 32,
            dropout_prob: float = 0.1,
            num_heads: int = 8,
            input_channels: int = 1,
            output_channels: int = 30,
            time_steps: int = 1000):
        super().__init__()

        self.num_layers = len(Channels)
        self.shallow_conv = nn.Conv2d(input_channels, Channels[0], kernel_size=3, padding=1)
        out_channels = (Channels[-1]//2) + Channels[0]
        self.late_conv = nn.Conv2d(out_channels, out_channels//2, kernel_size=3,  padding=1)
        self.output_conv = nn.Conv2d(out_channels//2, output_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.embeddings = SinusoidalEmbeddings(time_steps=time_steps, embed_dim=max(Channels))

        for i in range(self.num_layers):
            layer = UNetLayer(
                upscale=Upscales[i],
                attention=Attentions[i],
                num_groups=num_groups,
                dropout_prob=dropout_prob,
                C=Channels[i],
                num_heads=num_heads
            )
            setattr(self, f'Layer{i+1}', layer)

    def forward(self, x, t):
        x = self.shallow_conv(x)
        residuals = []
        for i in range(self.num_layers//2):
            layer = getattr(self, f'Layer{i+1}')
            embeddings = self.embeddings(x, t)
            x, r = layer(x, embeddings)
            residuals.append(r)

        for i in range(self.num_layers//2, self.num_layers):
            layer = getattr(self, f'Layer{i+1}')
            print(f"Residual size: {residuals[self.num_layers-i-1].shape}, Layer output size: {layer(x, embeddings)[0].shape}")

            x = torch.concat((layer(x, embeddings)[0], residuals[self.num_layers-i-1]), dim=1)
            return self.output_conv(self.relu(self.late_conv(x)))


# Now we make the noise scheduler
class DDPM_Scheduler(nn.Module):
    def __init__(self, num_time_steps: int=1000):
        super().__init__()
        self.beta = torch.linspace(1e-4, 0.02, num_time_steps, requires_grad=False)
        alpha = 1 - self.beta
        self.alpha = torch.cumprod(alpha, dim=0).requires_grad_(False)

    def forward(self, t):
        return self.beta[t], self.alpha[t]




In [127]:
def train(train_loader,
          batch_size: int=64,
          num_time_steps: int=1000,
          num_epochs: int=15,
          seed: int=-1,
          ema_decay: float=0.9999,  
          lr=2e-5,
          checkpoint_path: str=None):


    scheduler = DDPM_Scheduler(num_time_steps=num_time_steps)
    model = UNet().to(mps_device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    ema = ModelEmaV3(model, decay=ema_decay)
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['weights'])
        ema.load_state_dict(checkpoint['ema'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    criterion = nn.MSELoss(reduction='mean')

    for i in range(num_epochs):
        total_loss = 0
        for bidx, (x,_) in enumerate(tqdm(train_loader, desc=f"Epoch {i+1}/{num_epochs}")):
            x = x.to(mps_device)
            t = torch.randint(0,num_time_steps,(batch_size,))
            e = torch.randn_like(x, requires_grad=False)
            a = scheduler.alpha[t].view(batch_size,1,1,1).to(mps_device)
            x = (torch.sqrt(a)*x) + (torch.sqrt(1-a)*e)
            output = model(x, t)
            optimizer.zero_grad()
            loss = criterion(output, e)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            ema.update(model)
        print(f'Epoch {i+1} | Loss {total_loss / (60000/batch_size):.5f}')

    checkpoint = {
        'weights': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'ema': ema.state_dict()
    }
    torch.save(checkpoint, 'checkpoints/ddpm_checkpoint')



train(loader)

Epoch 1/15:   0%|          | 0/1 [00:03<?, ?it/s]

Residual size: torch.Size([64, 256, 8, 8]), Layer output size: torch.Size([64, 256, 8, 8])





RuntimeError: Given groups=1, weight of size [128, 256, 3, 3], expected input[64, 512, 8, 8] to have 256 channels, but got 512 channels instead