In [1]:
from torch.utils.data import Dataset
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import random
from matern_kernel import matern_kernel_noise_batch
import torch.nn.functional as F
import torchda
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
random.seed(42)
g = torch.Generator()
g.manual_seed(42)

<torch._C.Generator at 0x7e98a8203fb0>

In [3]:
def add_matern_kernel_noise_batch(sample, nu=2.5, lengthscale=1.0, sigma=1.0):
    B, C, H, W = sample.shape
    x = torch.linspace(0, 1, W, device=sample.device)
    y = torch.linspace(0, 1, H, device=sample.device)
    xx, yy = torch.meshgrid(x, y, indexing="xy")
    coords = torch.stack([xx.flatten(), yy.flatten()], dim=1)
    noise = matern_kernel_noise_batch(
        sample=sample,
        coords=coords,
        nu=nu,
        lengthscale=lengthscale,
        sigma=sigma,
    )
    return sample + noise

In [4]:
def resize_encoder(sample):
    sample = F.interpolate(
        sample, size=(144, 72), mode="bicubic", align_corners=False
    )
    return sample


def random_mask(sample, mask_prob_low=0.7, mask_prob_high=0.7):
    if mask_prob_low == mask_prob_high:
        mask_prob = mask_prob_low
    else:
        mask_prob = random.uniform(mask_prob_low, mask_prob_high)
    random_tensor = torch.rand(sample.shape, device=sample.device)
    mask = (random_tensor > mask_prob).float()
    masked_image = sample * mask
    return masked_image


class WeatherBenchDatasetWindow(Dataset):
    def __init__(
        self,
        data,
        context_length,
        target_length,
        stride=1,
        mask_prob_low=0.7,
        mask_prob_high=0.7,
    ):
        self.data = data
        self.context_length = context_length
        self.target_length = target_length
        self.stride = stride
        self.mask_prob_low = mask_prob_low
        self.mask_prob_high = mask_prob_high

    def __len__(self):
        return (
            self.data.shape[0] - (self.context_length + self.target_length)
        ) // self.stride + 1

    def __getitem__(self, idx):
        x = resize_encoder(self.data[idx : idx + self.context_length])
        y = self.data[
            idx
            + self.context_length : idx
            + self.context_length
            + self.target_length
        ]
        y_masked = random_mask(
            resize_encoder(y),
            mask_prob_low=self.mask_prob_low,
            mask_prob_high=self.mask_prob_high,
        )
        return x, y_masked, y

In [5]:
def prepare_inputs(input_data, encoder_model):
    B, T, C, H, W = input_data.shape
    input_data = input_data.reshape(B * T, C, H, W)
    input_encoded_data, _ = encoder_model(input_data)
    input_encoded_data = input_encoded_data.reshape(B, T, -1)
    return input_encoded_data

In [6]:
def get_background_obs(input_data, seq2seq_model):
    model_pred = seq2seq_model(input_data)
    if model_pred.ndim == 2:
        model_pred = model_pred.unsqueeze(0)
    B, T, L = model_pred.shape
    model_pred = model_pred.reshape(B * T, L)
    return model_pred

In [7]:
seq2seq_model = torch.load(
    "downstream_model_no_decoder.pth", weights_only=False, map_location=DEVICE
)
model = torch.load(
    "det_autoencoder.pth", weights_only=False, map_location=DEVICE
)
model.eval()
seq2seq_model.eval()

encoder_model = model.encoder
decoder_model = model.decoder

data = torch.load("/vol/bitbucket/nb324/ERA5_64x32_daily_850.pt")
n_samples = data.shape[0]
n_train = int(n_samples * 0.6)
n_valid = int(n_samples * 0.2)
data = data[n_train + n_valid :]
BATCH_SIZE = 128
n_samples = data.shape[0]

n_train = int(n_samples * 0.6)
n_valid = int(n_samples * 0.2)

train_data = data[:n_train]
valid_data = data[n_train : n_train + n_valid]
test_data = data[n_train + n_valid :]

mean = train_data.mean(dim=(0, 2, 3), keepdim=True)
std = train_data.std(dim=(0, 2, 3), keepdim=True)

test_data = (test_data - mean) / std

In [8]:
test_dataset = WeatherBenchDatasetWindow(
    data=test_data,
    context_length=30,
    target_length=1,
    stride=1,
    mask_prob_low=0.2,
    mask_prob_high=0.9,
)
testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
mse_loss = torch.nn.MSELoss()

In [10]:
# decode_loss = []
# assimlated_loss = []

# for batch in testloader:
#     data = batch[2].to(DEVICE)
#     B, T, C, H, W = data.shape
#     data = data.reshape(B*T, C, H, W)
#     clean_data = batch[0].to(DEVICE)
#     masked_data = batch[1].to(DEVICE)
#     B, T, C, H, W = masked_data.shape
#     masked_data = masked_data.reshape(B*T, C, H, W)
#     masked_data, R = add_matern_kernel_noise_batch(masked_data)
#     print(R.shape)
#     masked_latent, _ = encoder_model(masked_data)
#     clean_data_prep = prepare_inputs(clean_data, encoder_model)
#     clean_latent = get_background_obs(clean_data_prep, seq2seq_model)
#     xb = clean_latent
#     y = masked_latent
#     B = torch.eye(y.shape[1], device=DEVICE)
#     R = 0.5 * torch.eye(y.shape[1], device=DEVICE)
#     results_3dvar = torchda.CaseBuilder().set_background_covariance_matrix(
#         B
#     ).set_observation_covariance_matrix(R).set_observations(
#         y
#     ).set_background_state(
#         xb
#     ).set_algorithm(
#         torchda.Algorithms.Var3D
#     ).set_device(
#         DEVICE
#     ).set_observation_model(
#         # H should be identity function as background and observations live in the same space
#         lambda x: x
#     ).set_optimizer_cls(
#         torch.optim.Adam
#     ).set_optimizer_args(
#         {"lr": 0.01}
#     ).set_max_iterations(
#         500
#     ).execute()
#     results = results_3dvar['assimilated_state']
#     decoded_results = model.decoder(results)
#     loss = mse_loss(decoded_results, data)

#     assimlated_loss.append(loss.item())
#     loss = mse_loss(model.decoder(masked_latent), data)
#     decode_loss.append(loss.item())


# print(f"Assimlated MSE Loss: {np.mean(assimlated_loss):.5f}")
# print(f"Decoded Masked MSE Loss: {np.mean(decode_loss):.5f}")

In [11]:
batch = next(iter(testloader))
data = batch[2].to(DEVICE)
B, T, C, H, W = data.shape
data = data.reshape(B * T, C, H, W)
clean_data = batch[0].to(DEVICE)
masked_data = batch[1].to(DEVICE)
B, T, C, H, W = masked_data.shape
masked_data = masked_data.reshape(B * T, C, H, W)
masked_data = add_matern_kernel_noise_batch(masked_data)

for batch in testloader:
    data = batch[2].to(DEVICE)
    B, T, C, H, W = data.shape
    data = data.reshape(B * T, C, H, W)
    clean_data = batch[0].to(DEVICE)
    masked_data = batch[1].to(DEVICE)
    B, T, C, H, W = masked_data.shape
    masked_data = masked_data.reshape(B * T, C, H, W)
    masked_data = add_matern_kernel_noise_batch(masked_data)
    masked_latent = model.encoder(masked_data)[0]
    loss = mse_loss(model.decoder(masked_latent), data)
    print(loss)


decode_loss = []
assimlated_loss = []

for idx, (d, c_d, m_d) in enumerate(zip(data, clean_data, masked_data)):
    if idx == 0:
        c_d = c_d.unsqueeze(0)
        clean_data_prep = prepare_inputs(c_d, encoder_model)
    else:
        results.unsqueeze_(0)
        clean_data_prep = torch.cat((clean_data_prep, results), dim=1)

    d = d.unsqueeze(0)
    m_d = m_d.unsqueeze(0)

    masked_latent, _ = encoder_model(m_d)

    clean_latent = get_background_obs(clean_data_prep, seq2seq_model)
    xb = clean_latent
    y = masked_latent
    B = torch.eye(y.shape[1], device=DEVICE)
    R = 0.1 * torch.eye(y.shape[1], device=DEVICE)
    results_3dvar = (
        torchda.CaseBuilder()
        .set_background_covariance_matrix(B)
        .set_observation_covariance_matrix(R)
        .set_observations(y)
        .set_background_state(xb)
        .set_algorithm(torchda.Algorithms.Var3D)
        .set_device(DEVICE)
        .set_observation_model(
            # H should be identity function as background and observations live in the same space
            lambda x: x
        )
        .set_optimizer_cls(torch.optim.Adam)
        .set_optimizer_args({"lr": 0.01})
        .set_max_iterations(500)
        .execute()
    )

    results = results_3dvar["assimilated_state"]
    decoded_results = model.decoder(results)
    loss = mse_loss(decoded_results, d)

    assimlated_loss.append(loss.item())
    loss = mse_loss(model.decoder(masked_latent), d)
    decode_loss.append(loss.item())

print(f"Assimlated MSE Loss: {np.mean(assimlated_loss):.5f}")
print(f"Decoded Masked MSE Loss: {np.mean(decode_loss):.5f}")

tensor(0.1259, grad_fn=<MseLossBackward0>)
tensor(0.1252, grad_fn=<MseLossBackward0>)
tensor(0.1231, grad_fn=<MseLossBackward0>)
tensor(0.1217, grad_fn=<MseLossBackward0>)
tensor(0.1244, grad_fn=<MseLossBackward0>)
tensor(0.1271, grad_fn=<MseLossBackward0>)
tensor(0.1247, grad_fn=<MseLossBackward0>)
tensor(0.1188, grad_fn=<MseLossBackward0>)
Timestamp: 2025-05-12 15:24:02.424431, Iterations: 0, J: 534.2424926757812, Norm of J gradient: 146.1837615966797
Timestamp: 2025-05-12 15:24:02.426490, Iterations: 1, J: 499.01019287109375, Norm of J gradient: 140.78167724609375
Timestamp: 2025-05-12 15:24:02.426894, Iterations: 2, J: 465.9312744140625, Norm of J gradient: 135.5138397216797
Timestamp: 2025-05-12 15:24:02.427246, Iterations: 3, J: 434.95660400390625, Norm of J gradient: 130.3883819580078
Timestamp: 2025-05-12 15:24:02.427598, Iterations: 4, J: 405.9986877441406, Norm of J gradient: 125.40725708007812
Timestamp: 2025-05-12 15:24:02.427957, Iterations: 5, J: 378.9573974609375, Norm o