In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:

%pip install CRPS

In [None]:
 
from pathlib import Path
import sys
!sudo apt-get install unzip
base = Path('/content/drive/MyDrive/ML/2023/conv_strat_dataset')
sys.path.append(str(base))

zip_path = base/"lwe_dataset_010322.zip"


!cp "{zip_path}" .

!unzip -q lwe_dataset_010322.zip -d "/content"

!rm lwe_dataset_010322.zip




In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import random
import json
from tqdm import tqdm
import os
from os import listdir
from PIL import Image

import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F
from torch.optim import Adam

# import wandb
# import CRPS.CRPS as pscore


from prec_dataset import RadarPrecipitationSequence
from radar_transforms import radar_transform
from plotting_funcs import show_sequence
from radar_transforms import radar_transform, reverse_transform
from fdp import fdp_sample, get_named_beta_schedule
from unet import UNet
from loss import loss_fn
from sampler import sample_plot_image
from helper_module import get_random_test_seq, get_CRPS_sequence
from calc_metrics import get_CRPS

config = dict(
    img_out_size=64,
    rgb_grayscale=1,
    sequence_length=5,
    max_prec_val=3.4199221045419974,
    prediction_time_step_ahead=1,
    frames_to_predict=1,
    num_cond_frames=4,
    epochs=150,
    batch_size=24,
    lr=0.001,
    T=1000,
    schedule="linear",
    root_dir=r"/content/lwe_dataset_010322",
    validate_on_convective = False,
    plot_folder = "/content/drive/MyDrive/ML/results/plots/0503_concat_t_1000"
)

plot_folder = config["plot_folder"]
with open(f'{plot_folder}/config.json', 'w') as fp:
    json.dump(config, fp)

betas = get_named_beta_schedule(
    schedule_name=config["schedule"], num_diffusion_timesteps=config["T"]
)
T = config["T"]
if config["schedule"] == "linear":
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
    posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)


device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet(
    rgb_grayscale=config["rgb_grayscale"], num_cond_frames=config["num_cond_frames"], device=device
)
model.to(device)
print("Num params in Unet: ", sum(p.numel() for p in model.parameters()))


### Define datasets and dataloader: 

In [None]:
dataset_radar_sequence = RadarPrecipitationSequence(
    root_dir=config["root_dir"],
    transform=radar_transform(max_prec_val=config["max_prec_val"]),
    num_cond_frames=config["num_cond_frames"],
    frames_to_predict=config["frames_to_predict"],
    img_out_size=config["img_out_size"],
    prediction_time_step_ahead=config["prediction_time_step_ahead"],
    train_test_val = "train"
)
dataset_radar_sequence_val = RadarPrecipitationSequence(
    root_dir=config["root_dir"],
    transform=radar_transform(max_prec_val=config["max_prec_val"]),
    num_cond_frames=config["num_cond_frames"],
    frames_to_predict=config["frames_to_predict"],
    img_out_size=config["img_out_size"],
    prediction_time_step_ahead=config["prediction_time_step_ahead"],
    train_test_val = "val"
)

dataset_radar_sequence_val_CRPS = RadarPrecipitationSequence(
    root_dir=config["root_dir"],
    transform=radar_transform(max_prec_val=config["max_prec_val"]),
    num_cond_frames=config["num_cond_frames"],
    frames_to_predict=config["frames_to_predict"],
    img_out_size=config["img_out_size"],
    prediction_time_step_ahead=config["prediction_time_step_ahead"],
    train_test_val = "val",
    center_crop = True
)

dataset_radar_sequence_test = RadarPrecipitationSequence(
    root_dir=config["root_dir"],
    transform=radar_transform(max_prec_val=config["max_prec_val"]),
    num_cond_frames=config["num_cond_frames"],
    frames_to_predict=config["frames_to_predict"],
    img_out_size=config["img_out_size"],
    prediction_time_step_ahead=config["prediction_time_step_ahead"],
    train_test_val = "test"

)

# dataset_radar_sequence_test = RadarPrecipitationSequence(root_dir="dataset_1000_five_seq", transform= radar_transform(IMG_SIZE=img_out_size), output_img_size=img_out_size, train=False)
dataloader = DataLoader(
    dataset_radar_sequence,
    batch_size=config["batch_size"],
    shuffle=True,
    drop_last=True,
)

validation_dataloader = DataLoader(
    dataset_radar_sequence_val,
    batch_size=config["batch_size"],
    shuffle=True,
    drop_last=True,
)
validation_dataloader_CRPS = DataLoader(
    dataset_radar_sequence_val_CRPS,
    batch_size=config["batch_size"],
    shuffle=False,
    drop_last=True,
)



### Random samples from dataset: 

In [None]:
# for i in range(3):
#     idx = np.random.randint(low=0, high=250)
#     test_sample = dataset_radar_sequence_test.__getitem__(idx)

#     show_sequence(test_sample, config["sequence_length"], pred_ahead= config["prediction_time_step_ahead"])

### Simulate forward diffusion process: 

In [None]:
# simulate_fdp(forward_diffusion_sample=forward_diffusion_sample, dataloader=dataloader,sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod, sqrt_alphas_cumprod=sqrt_alphas_cumprod, T=config["T"])

### Train Model:

In [None]:
from datetime import date


optimizer = Adam(model.parameters(), lr=config["lr"])

train_loss_list = []
validation_loss_list = []
avg_4_crps_mean = []
avg_16_crps_mean = []
max_4_crps_mean = []
max_16_crps_mean = []

avg_4_crps_std = []
avg_16_crps_std = []
max_4_crps_std = []
max_16_crps_std = []


crps_idx_list = [57,
 460,
 63,
 203,
 357,
 164,
 327,
 470,
 260,
 161,
 140,
 404,
 379,
 451,
 289,
 79,
 141,
 76,
 42,
 47]



best_val_score = 1

for epoch in range(config["epochs"]):
    avg_epoch_train_loss = 0 
    
    for step, batch in enumerate(tqdm(dataloader)):
        
        batch_lwe = batch[0]

        optimizer.zero_grad()

        t = torch.randint(0, config["T"], (config["batch_size"],), device=device).long()

        conditional_imgs = batch_lwe[:, 0:-1, :, :]

        imgs_to_model_training = batch_lwe[:, -1, :, :]
        imgs_to_model_training = imgs_to_model_training[:, None, :, :]
        # print(f"imgs_to_model_training.shape = {imgs_to_model_training.shape}")
        # print(f"conditional_imgs.shape = {conditional_imgs.shape}")

        conditional_imgs = conditional_imgs.to(device)
        loss = loss_fn(
            model=model,
            x=imgs_to_model_training,
            t=t,
            device=device,
            sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod,
            sqrt_alphas_cumprod=sqrt_alphas_cumprod,
            condition=conditional_imgs,
        )
        loss.backward()
        optimizer.step()
        avg_epoch_train_loss += loss.item()
        
    avg_epoch_train_loss = avg_epoch_train_loss/(step+1)
    
    #get validation loss for same epoch
    avg_epoch_val_loss = 0 
    model.eval()
    with torch.no_grad():
        for vstep, vbatch in enumerate(tqdm(validation_dataloader)):
            batch_lwe = vbatch[0]

            t = torch.randint(0, config["T"], (config["batch_size"],), device=device).long()

            conditional_imgs = batch_lwe[:, 0:-1, :, :]

            imgs_to_model_training = batch_lwe[:, -1, :, :]
            imgs_to_model_training = imgs_to_model_training[:, None, :, :]
            # print(f"imgs_to_model_training.shape = {imgs_to_model_training.shape}")
            # print(f"conditional_imgs.shape = {conditional_imgs.shape}")

            conditional_imgs = conditional_imgs.to(device)
            loss =  loss_fn(
                model=model,
                x=imgs_to_model_training,
                t=t,
                device=device,
                sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod,
                sqrt_alphas_cumprod=sqrt_alphas_cumprod,
                condition=conditional_imgs,
            )
            avg_epoch_val_loss += loss.item()
        
        avg_epoch_val_loss = avg_epoch_val_loss/(vstep+1)
        if avg_epoch_val_loss < best_val_score:
            best_val_score = avg_epoch_val_loss
            today = date.today()
            torch.save(model.state_dict(), f"{plot_folder}/{today}_epoch_{epoch}")

        print(f"Epoch {epoch} | Avg Train Loss: {avg_epoch_train_loss}, Avg Validation Loss: {avg_epoch_val_loss} ")
        train_loss_list.append(avg_epoch_train_loss)
        validation_loss_list.append(avg_epoch_val_loss)

        
    # if convective_crps:
        # get_convective_test_seq()

        #get CRPS 
 
            #set model back to training mode: 
    model.train()



In [None]:
val_loss_arr = np.asarray(validation_loss_list)
train_loss_arr = np.asarray(train_loss_list)

np.savetxt(f"{plot_folder}/val_loss_arr.csv", val_loss_arr, delimiter=",")
np.savetxt(f"{plot_folder}/train_loss_arr.csv", train_loss_arr, delimiter=",")
