In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt
import random
import json
from tqdm import tqdm
import os
from os import listdir
from PIL import Image

import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F
from torch.optim import Adam
import datetime
import torch.nn as nn

# import wandb
# import CRPS.CRPS as pscore


from prec_dataset import RadarPrecipitationSequence
from radar_transforms import radar_transform
from plotting_funcs import show_sequence
from radar_transforms import radar_transform, reverse_transform
from fdp import fdp_sample, get_named_beta_schedule
from unet_refr_emb import UNet_embedding
# from unet_embedded import UNet_embedding
from loss import loss_fn
from helper_module import get_20min_forecast_sequence
from metrics import csi

config = dict(
    img_out_size=64,
    rgb_grayscale=1,
    sequence_length=8,
    max_prec_val=3.4199221045419974,
    prediction_time_step_ahead=1,
    frames_to_predict=1,
    num_cond_frames=4,
    tot_pred_ahead=4,
    schedule="linear",
    T=300,
    validate_on_convective = False,
)

betas = get_named_beta_schedule(
    schedule_name=config["schedule"], num_diffusion_timesteps=config["T"]
)



T = config["T"]
if config["schedule"] == "linear":
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
    posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)


device = "cuda" if torch.cuda.is_available() else "cpu"



In [2]:
import torch
from torchgeo.models import resnet18, ResNet18_Weights

weights = ResNet18_Weights.SENTINEL2_RGB_MOCO

cond_emb_model = resnet18(weights=weights)
cond_emb_model_copy = resnet18(weights=weights)
cond_emb_model.to(device)
cond_emb_model.eval()




ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

In [3]:
model_UNet = UNet_embedding(rgb_grayscale=1, num_cond_frames=3, device=device)
model_UNet = model_UNet.to(device)
model_UNet.eval()


UNet_embedding(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): ReLU()
  )
  (down1): Encoder(
    (time_mlp): Linear(in_features=32, out_features=128, bias=True)
    (cond_emb): Linear(in_features=512, out_features=128, bias=True)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (GELU): GELU(approximate='none')
    (GN): GroupNorm(1, 128, eps=1e-05, affine=True)
  )
  (down2): Encoder(
    (time_mlp): Linear(in_features=32, out_features=256, bias=True)
    (cond_emb): Linear(in_features=512, out_features=256, bias=True)
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (GELU): GELU(approximate='none')
    (GN): GroupNorm(1, 256, eps=1e-05, affine=True)
  )
  (down3): Enco

In [4]:
batch_size = 6
input_imgs = torch.randn(batch_size,1,64,64)
cond_imgs = torch.randn(batch_size,3,224,224)

In [5]:
t = torch.randint(0, 300, (batch_size,), device=device).long()
from loss import loss_fn



    
loss = loss_fn(
model=model_UNet,
x=input_imgs,
t=t,
device=device,
cond_emb_model = cond_emb_model,
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod,
sqrt_alphas_cumprod=sqrt_alphas_cumprod,
condition=cond_imgs)
loss


torch.Size([6, 128, 32, 32])
torch.Size([6, 256, 16, 16])
torch.Size([6, 512, 8, 8])
torch.Size([6, 1024, 4, 4])


tensor(1.0372, grad_fn=<MseLossBackward0>)

In [6]:
cond_emb_model_copy.eval()
emb_out = []
def hook(module, input, output):
    """Copy embeddings from the penultimate layer.
    """

    emb_out.append(output.detach())


_embedding_layer = cond_emb_model_copy._modules.get("global_pool")
# extract output
_ = _embedding_layer.register_forward_hook(hook)

In [7]:
out_ = cond_emb_model_copy(cond_imgs)

In [8]:
emb_out

[tensor([[0.0000, 0.0000, 0.0430,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0531,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0261,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0413,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0341,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0370,  ..., 0.0000, 0.0000, 0.0000]])]