In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:

%pip install CRPS
%pip install torchgeo

In [None]:
 
from pathlib import Path
import sys
!sudo apt-get install unzip
base = Path('/content/drive/MyDrive/ML/2023/conv_strat_dataset')
sys.path.append(str(base))

zip_path = base/"lwe_dataset_010322.zip"


!cp "{zip_path}" .

!unzip -q lwe_dataset_010322.zip -d "/content"

!rm lwe_dataset_010322.zip




In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import random
import json
from tqdm import tqdm
import os
from os import listdir
from PIL import Image
from datetime import date
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F
from torch.optim import Adam

# import wandb
# import CRPS.CRPS as pscore


from prec_dataset import RadarPrecipitationSequence
from radar_transforms import radar_transform
from plotting_funcs import show_sequence
from radar_transforms import radar_transform, reverse_transform, conditional_embedding_transform
from fdp import fdp_sample, get_named_beta_schedule
from unet_refr_emb import UNet_embedding
from loss import loss_fn
from sampler import sample_plot_image
from helper_module import get_random_test_seq, get_CRPS_sequence
from calc_metrics import get_CRPS

config = dict(
    img_out_size=64,
    rgb_grayscale=1,
    sequence_length=4,
    max_prec_val=3.4199221045419974,
    prediction_time_step_ahead=1,
    frames_to_predict=1,
    num_cond_frames=3,
    epochs=51,
    batch_size=24,
    lr=0.001,
    T=300,
    schedule="linear",
    root_dir=r"/content/lwe_dataset",
    validate_on_convective = False,
    plot_folder = "/content/drive/MyDrive/ML/results/plots/2702_embedding_concat"
)

plot_folder = config["plot_folder"]
with open(f'{plot_folder}/config.json', 'w') as fp:
    json.dump(config, fp)

betas = get_named_beta_schedule(
    schedule_name=config["schedule"], num_diffusion_timesteps=config["T"]
)
T = config["T"]
if config["schedule"] == "linear":
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
    posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)


device = "cuda" if torch.cuda.is_available() else "cpu"



In [None]:
import torch
from torchgeo.models import resnet18, ResNet18_Weights

weights = ResNet18_Weights.SENTINEL2_RGB_MOCO

cond_emb_model = resnet18(weights=weights)
cond_emb_model.to(device)
cond_emb_model.eval()


In [None]:
model = UNet_embedding(rgb_grayscale=1, num_cond_frames=3, device=device)
model = model.to(device)
model.eval()


### Define datasets and dataloader: 

In [None]:
dataset_radar_sequence = RadarPrecipitationSequence(
    root_dir=config["root_dir"],
    transform=radar_transform(max_prec_val=config["max_prec_val"]),
    emb_transform=conditional_embedding_transform(max_prec_val=config["max_prec_val"]),
    num_cond_frames=config["num_cond_frames"],
    frames_to_predict=config["frames_to_predict"],
    img_out_size=config["img_out_size"],
    prediction_time_step_ahead=config["prediction_time_step_ahead"],
    train_test_val = "train"
)
dataset_radar_sequence_val = RadarPrecipitationSequence(
    root_dir=config["root_dir"],
    transform=radar_transform(max_prec_val=config["max_prec_val"]),
    emb_transform=conditional_embedding_transform(max_prec_val=config["max_prec_val"]),
    num_cond_frames=config["num_cond_frames"],
    frames_to_predict=config["frames_to_predict"],
    img_out_size=config["img_out_size"],
    prediction_time_step_ahead=config["prediction_time_step_ahead"],
    train_test_val = "val"
)

dataset_radar_sequence_val_CRPS = RadarPrecipitationSequence(
    root_dir=config["root_dir"],
    transform=radar_transform(max_prec_val=config["max_prec_val"]),
    emb_transform=conditional_embedding_transform(max_prec_val=config["max_prec_val"]),
    num_cond_frames=config["num_cond_frames"],
    frames_to_predict=config["frames_to_predict"],
    img_out_size=config["img_out_size"],
    prediction_time_step_ahead=config["prediction_time_step_ahead"],
    train_test_val = "val",
    center_crop = True
)

# dataset_radar_sequence_test = RadarPrecipitationSequence(
#     root_dir=config["root_dir"],
#     transform=radar_transform(max_prec_val=config["max_prec_val"]),
#     emb_transform=conditional_embedding_transform(max_prec_val=config["max_prec_val"]),
#     num_cond_frames=config["num_cond_frames"],
#     frames_to_predict=config["frames_to_predict"],
#     img_out_size=config["img_out_size"],
#     prediction_time_step_ahead=config["prediction_time_step_ahead"],
#     train_test_val = "test"

# )

# dataset_radar_sequence_test = RadarPrecipitationSequence(root_dir="dataset_1000_five_seq", transform= radar_transform(IMG_SIZE=img_out_size), output_img_size=img_out_size, train=False)
dataloader = DataLoader(
    dataset_radar_sequence,
    batch_size=config["batch_size"],
    shuffle=True,
    drop_last=True,
)

validation_dataloader = DataLoader(
    dataset_radar_sequence_val,
    batch_size=config["batch_size"],
    shuffle=True,
    drop_last=True,
)
validation_dataloader_CRPS = DataLoader(
    dataset_radar_sequence_val_CRPS,
    batch_size=config["batch_size"],
    shuffle=False,
    drop_last=True,
)



### Random samples from dataset: 

In [None]:
for i in range(3):
    idx = np.random.randint(low=0, high=1000)
    train_sample = dataset_radar_sequence.__getitem__(idx)

    show_sequence(train_sample, 3, pred_ahead= config["prediction_time_step_ahead"])

### Train Model:

In [None]:
optimizer = Adam(model.parameters(), lr=config["lr"])

train_loss_list = []
validation_loss_list = []
avg_4_crps_mean = []
avg_16_crps_mean = []
max_4_crps_mean = []
max_16_crps_mean = []

avg_4_crps_std = []
avg_16_crps_std = []
max_4_crps_std = []
max_16_crps_std = []


crps_idx_list = [57,
 460,
 63,
 203,
 357,
 164,
 327,
 470,
 260,
 161,
 140,
 404,
 379,
 451,
 289,
 79,
 141,
 76,
 42,
 47]



best_val_score = 1
for epoch in range(config["epochs"]):
    avg_epoch_train_loss = 0 
    
    for step, batch in enumerate(tqdm(dataloader)):
        
        batch_lwe = batch[0]
        batch_lwe = batch_lwe.float()

        optimizer.zero_grad()

        t = torch.randint(0, config["T"], (config["batch_size"],), device=device).long()

        conditional_imgs = batch[1]

        imgs_to_model_training = batch_lwe
        # imgs_to_model_training = imgs_to_model_training[:, None, :, :]
        # print(f"imgs_to_model_training.shape = {imgs_to_model_training.shape}")
        # print(f"conditional_imgs.shape = {conditional_imgs.shape}")
        conditional_imgs = conditional_imgs.float()
        conditional_imgs = conditional_imgs.to(device)
        loss = loss_fn(
            model=model,
            x=imgs_to_model_training,
            t=t,
            device=device,
            cond_emb_model = cond_emb_model,
            sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod,
            sqrt_alphas_cumprod=sqrt_alphas_cumprod,
            condition=conditional_imgs,
        )
        loss.backward()
        optimizer.step()
        avg_epoch_train_loss += loss.item()
        
    avg_epoch_train_loss = avg_epoch_train_loss/(step+1)
    
    #get validation loss for same epoch
    avg_epoch_val_loss = 0 
    model.eval()
    with torch.no_grad():
        for vstep, vbatch in enumerate(tqdm(validation_dataloader)):
            batch_lwe = vbatch[0]
            batch_lwe = batch_lwe.float()

            t = torch.randint(0, config["T"], (config["batch_size"],), device=device).long()

            conditional_imgs = vbatch[1]
            imgs_to_model_training = batch_lwe
            # print(f"imgs_to_model_training.shape = {imgs_to_model_training.shape}")
            # print(f"conditional_imgs.shape = {conditional_imgs.shape}")
            conditional_imgs = conditional_imgs.float()

            conditional_imgs = conditional_imgs.to(device)
            loss = loss_fn(
                model=model,
                x=imgs_to_model_training,
                t=t,
                device=device,
                cond_emb_model = cond_emb_model,
                sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod,
                sqrt_alphas_cumprod=sqrt_alphas_cumprod,
                condition=conditional_imgs,
            )
            avg_epoch_val_loss += loss.item()
        
        avg_epoch_val_loss = avg_epoch_val_loss/(vstep+1)
        if avg_epoch_val_loss < best_val_score:
            best_val_score = avg_epoch_val_loss
            today = date.today()
            torch.save(model.state_dict(), f"{plot_folder}/{today}_epoch_{epoch}")
        
        print(f"Epoch {epoch} | Avg Train Loss: {avg_epoch_train_loss}, Avg Validation Loss: {avg_epoch_val_loss} ")
        train_loss_list.append(avg_epoch_train_loss)
        validation_loss_list.append(avg_epoch_val_loss)

        
    # if convective_crps:
        # get_convective_test_seq()

        # if epoch % 3 == 0:# and epoch > 0:
        #     #get CRPS 
        #     seq_list_crps = get_CRPS_sequence(dataset=dataset_radar_sequence_val_CRPS, idx_list= crps_idx_list)

        #     avg4_mean, avg16_mean, max4_mean,max16_mean, avg4_std, avg16_std, max4_std, max16_std = get_CRPS(
        #             test_list = seq_list_crps,
        #             rgb_grayscale = config["rgb_grayscale"],
        #             img_out_size = config["img_out_size"],
        #             sequence_length = config["sequence_length"],
        #             device = device,
        #             sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod,
        #             sqrt_recip_alphas = sqrt_recip_alphas,
        #             posterior_variance = posterior_variance,
        #             model = model,
        #             T = config["T"],
        #             numb_cond = config["num_cond_frames"],
        #             betas = betas,
        #             max_prec_val = config["max_prec_val"],
        #             cond_emb_model = cond_emb_model,
        #     )
        #     avg_4_crps_mean.append(avg4_mean)
        #     avg_16_crps_mean.append(avg16_mean)
        #     max_4_crps_mean.append(max4_mean)
        #     max_16_crps_mean.append(max16_mean)
            
        #     avg_4_crps_std.append(avg4_std)
        #     avg_16_crps_std.append(avg16_std)
        #     max_4_crps_std.append(max4_std)
        #     max_16_crps_std.append(max16_std)
        
        #     fig, axs = plt.subplots(1, 1, figsize=(30, 5))
            
        #     epoch_list = np.arange(0, len(validation_loss_list))
        #     ax1 = plt.subplot(111)
        #     ax1.plot(epoch_list, validation_loss_list, label='Validation loss')
        #     ax1.plot(epoch_list, train_loss_list, label='Train loss')
        #     ax1.set_title(f"MSE")
        #     ax1.legend()
            
        #     ax2 = plt.subplot(132)
        #     crps_epochs = np.arange(0, len(avg_4_crps_mean))
        #     labels = crps_epochs*3
        #     labels = labels.astype('str')

        #     ax2.errorbar(crps_epochs, avg_4_crps_mean,avg_4_crps_std, marker='^', label='4-km aggregations')
        #     ax2.errorbar(crps_epochs, avg_16_crps_mean,avg_16_crps_std, marker='*', label='16-km aggregations')
        #     ax2.set_yscale('log')
        #     ax2.set_title(f"Pooled CRPS using the average rain rate")
        #     ax2.legend()

        #     ax3 = plt.subplot(133)
        #     ax3.errorbar(crps_epochs, max_4_crps_mean,max_4_crps_std, marker='^', label='4-km aggregations')
        #     ax3.errorbar(crps_epochs, max_16_crps_mean,max_16_crps_std, marker='^', label='16-km aggregations')
        #     ax3.set_title(f"Pooled CRPS using the maximum rain rate")
        #     ax3.set_yscale('log')
        #     ax3.legend()

        #     folder_path = config["plot_folder"]
        #     fig.savefig(f"{folder_path}/MSE_epoch_{epoch}.png")

        #     plt.show()


        # if epoch % 5 == 0:

        #     if config["validate_on_convective"]:
        #         pass
    
        #     seq_list_plot = get_CRPS_sequence(dataset=dataset_radar_sequence_val_CRPS, idx_list= crps_idx_list)

        #     sample_plot_image(
        #         test_list=seq_list_plot,
        #         rgb_grayscale=config["rgb_grayscale"],
        #         img_out_size=config["img_out_size"],
        #         sequence_length=config["sequence_length"],
        #         device=device,
        #         sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod,
        #         sqrt_recip_alphas=sqrt_recip_alphas,
        #         posterior_variance=posterior_variance,
        #         model=model,
        #         T=config["T"],
        #         pred_ahead= config["prediction_time_step_ahead"],
        #         numb_cond = config["num_cond_frames"],
        #         betas = betas,
        #         max_prec_val = config["max_prec_val"],
        #         epoch = epoch,
        #         out_folder = config["plot_folder"]
        #         )
    #set model back to training mode: 

    model.train()



In [None]:
PATH = config["plot_folder"]
torch.save(model.state_dict(), f = f"{PATH}/model.pt")

In [None]:
validation_loss_list

In [None]:
train_loss_list