In [None]:
import sys
sys.path.append(".")

# also disable grad to save memory
import torch
torch.set_grad_enabled(False)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)
print(torch.__version__)
print(torch.version.cuda)
import yaml
import torch
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ

In [3]:
import io
import os, sys
import requests
import PIL
from PIL import Image
from PIL import ImageDraw, ImageFont
import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

In [4]:
import torch
from nuwa_pytorch import VQGanVAE
from nuwa_pytorch.optimizer import get_optimizer
vae = VQGanVAE(
    dim = 256,
    channels = 1,               # default is 3, but can be changed to any value for the training of the segmentation masks (sketches)
    image_size = 256,           # image size
    num_layers = 4,             # number of downsampling layers
    num_resnet_blocks = 2,      # number of resnet blocks
    vq_codebook_size = 1024,    # codebook size
    vq_decay = 0.8 ,             # codebook exponential decay
    use_hinge_loss = True,
    use_vgg_and_gan = True
).to(DEVICE)
checkpoint = torch.load('/bulk/junzheyin/aaa/vae_epoch80', map_location = 'cpu')
vae.load_state_dict(checkpoint, strict=False)

<All keys matched successfully>

In [5]:
# Define dataset
import torch
import sys
from nuwa_pytorch import VQGanVAE
import h5py
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
def eventGeneration(start_time, obs_time = 3 ,lead_time = 6, time_interval = 30):
    # Generate event based on starting time point, return a list: [[t-4,...,t-1,t], [t+1,...,t+72]]
    # Get the start year, month, day, hour, minute
    year = int(start_time[0:4])
    month = int(start_time[4:6])
    day = int(start_time[6:8])
    hour = int(start_time[8:10])
    minute = int(start_time[10:12])
    #print(datetime(year=year, month=month, day=day, hour=hour, minute=minute))
    times = [(datetime(year, month, day, hour, minute) + timedelta(minutes = time_interval * (x+1))) for x in range(lead_time)]
    lead = [dt.strftime('%Y%m%d%H%M') for dt in times]
    times = [(datetime(year, month, day, hour, minute) - timedelta(minutes = time_interval * x)) for x in range(obs_time)]
    obs = [dt.strftime('%Y%m%d%H%M') for dt in times]
    obs.reverse()
    return lead, obs

from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
from torchvision.transforms import ToTensor, Compose, CenterCrop
class radarDataset(Dataset):
    def __init__(self, root_dir, event_times, obs_number = 3, pred_number = 6, transform=None):
        # event_times is an array of starting time t(string)
        # transform is the preprocessing functions
        self.root_dir = root_dir
        self.transform = transform
        self.event_times = event_times
        self.obs_number = obs_number
        self.pred_number = pred_number
    def __len__(self):
        return len(self.event_times)
    def __getitem__(self, idx):
        start_time = str(self.event_times[idx])
        time_list_pre, time_list_obs = eventGeneration(start_time, self.obs_number, self.pred_number)
        output = []
        time_list = time_list_obs + time_list_pre
        #print(time_list)
        for time in time_list:
            year = time[0:4]
            month = time[4:6]
            #path = self.root_dir + year + '/' + month + '/' + 'RAD_NL25_RAC_MFBS_EM_5min_' + time + '_NL.h5'
            path = self.root_dir + year + '/' + month + '/' + 'RAD_NL25_RAP_5min_' + time + '.h5'
            image = np.array(h5py.File(path)['image1']['image_data'])
            #image = np.ma.masked_where(image == 65535, image)

            image = image[264:520,242:498]
            image[image == 65535] = 0
            image = image.astype('float32')
            image = image/100*12
            image = np.clip(image, 0, 128)
            image = image/40
            output.append(image)
        output = torch.permute(torch.tensor(np.array(output)), (1, 2, 0))
        output = self.transform(np.array(output))
        return output, start_time
#root_dir = '/users/hbi/data/RAD_NL25_RAC_MFBS_EM_5min/'
#dataset = radarDataset(root_dir, ["200808031600"], transform = Compose([ToTensor(),CenterCrop(256)]))


In [6]:
# develop dataset
from torch.cuda.amp import autocast
from torch.autograd import Variable
from torch import optim
import pandas as pd
root_dir = '/home/hbi/RAD_NL25_RAP_5min/' 

# M-L, top 5%
df_train = pd.read_csv('/users/hbi/taming-transformers/training_Delfland08-14.csv', header = None)
event_times = df_train[0].to_list()
dataset_train = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))           

df_test = pd.read_csv('/users/hbi/taming-transformers/testing_Delfland18-20.csv', header = None)
event_times = df_test[0].to_list()
dataset_test = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

df_vali = pd.read_csv('/users/hbi/taming-transformers/validation_Delfland15-17.csv', header = None)
event_times = df_vali[0].to_list()
dataset_vali = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

df_train_aa = pd.read_csv('/users/hbi/taming-transformers/training_Aa08-14.csv', header = None)
event_times = df_train_aa[0].to_list()
dataset_train_aa = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))    

df_train_dw = pd.read_csv('/users/hbi/taming-transformers/training_Dwar08-14.csv', header = None)
event_times = df_train_dw[0].to_list()
dataset_train_dw = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))   

df_train_re = pd.read_csv('/users/hbi/taming-transformers/training_Regge08-14.csv', header = None)
event_times = df_train_re[0].to_list()
dataset_train_re = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))    

# extreme, top 1%
df_train_ext = pd.read_csv('/users/hbi/taming-transformers/training_Delfland08-14_ext.csv', header = None)
event_times = df_train_ext[0].to_list()
mfbs = df_train_ext[1].to_list()
dic_mfbs1 = dict(zip(event_times, mfbs))
dataset_train_ext = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))           

df_test_ext = pd.read_csv('/users/hbi/taming-transformers/testing_Delfland18-20_ext.csv', header = None)
event_times = df_test_ext[0].to_list()
mfbs = df_test_ext[1].to_list()
dic_mfbs2 = dict(zip(event_times, mfbs))
dataset_test_ext = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

df_vali_ext = pd.read_csv('/users/hbi/taming-transformers/validation_Delfland15-17_ext.csv', header = None)
event_times = df_vali_ext[0].to_list()
mfbs = df_vali_ext[1].to_list()
dic_mfbs3 = dict(zip(event_times, mfbs))
dataset_vali_ext = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

dic_mfbs = {**dic_mfbs1, **dic_mfbs2, **dic_mfbs3}
new_valid = ['201606222325', '201508131920', '201606230050', '201606230100', '201606230200', '201508261735', '201508261800', '201708300255', '201708300300', '201707291755', '201707291800', '201606201210', '201708300155', '201606230300', '201707120520', '201707120600', '201707120455', '201708300055', '201707120700', '201605301745', '201509041820', '201605301800', '201709081725', '201606201300', '201708300420', '201709081650', '201509041755', '201709081545', '201709081455', '201709141200', '201709141155', '201709081800', '201508132000', '201707062000', '201707061955', '201711270745', '201709141300', '201709141055', '201605301655', '201709081900', '201708300500', '201805291515', '202006171920', '201805291455', '201809050545', '202006171855', '201809050600', '201906150240', '201906150300', '202008161555', '202008161600', '201805291600', '202006122025', '201906150155', '201906120820', '201906052100', '202006171755', '201906120755', '201808242240', '201804292310', '201906120940', '202007252055', '201804292255', '201808242300', '202007252100', '201906121000', '202006172000', '201809050700', '202002091830', '202009232020', '201808101930', '201808242050', '201906150400', '201808242100', '202002091900', '202006050545', '201910210430', '202007251955', '201808102000', '201804102045', '201804300000', '201804102100', '201906052200', '201808101820', '202006050600', '201810300150', '201906052300', '201810300200', '201804102200', '201808101755', '201910061250', '201810300300', '202009232100']
new_valid.sort()
new_valid = ['201903070300', '201903070315', '201904241715', '201904241730', '201904241745', '201904241800', '201904241815', '201904241830', '201904241845', '201904241900', '201904241915', '201905280330', '201905280345', '201905280400', '201905280415', '201905280430', '201905280445', '201905280500', '201905280515', '201906052035', '201906052050', '201906052100', '201906120545', '201906120600', '201906120615', '201906120630', '201906120645', '201906120700', '201906120715', '201906120730', '201906120745', '201906120800', '201906120815', '201906120830', '201906120845', '201906120900', '201906120915', '201906120930', '201906120945', '201906121000', '201906121015', '201906121030', '201906121045', '201906121100', '201906150035', '201906150050', '201906150100', '201906150115', '201906150130', '201906150145', '201906150200', '201906150215', '201906150230', '201906150245', '201906150300', '201906150315', '201906150330', '201906190400', '201906190415', '201906190430', '201906190445', '201906190500', '201906190515', '201906190530', '201906190545', '201906190600', '201908121840', '201910210250', '201910210300', '201910210315', '201910210330', '201910210345', '201910210400', '201910210415', '201910210430', '201910210445', '201911280500', '201911280515', '201911280530', '201911280545', '201911280600', '201911280615', '202002091800', '202002091815', '202002091830', '202002091845', '202002091900', '202002161520', '202002161535', '202002161550', '202003051615', '202003051630', '202003051645', '202003051700', '202006050400', '202006050415', '202006050430', '202006050445', '202006050500', '202006050515', '202006050530', '202006050545', '202006050600', '202006050615', '202006050630', '202006121915', '202006121930', '202006121945', '202006122000', '202006122015', '202006122030', '202006122045', '202006161035', '202006161050', '202006161100', '202006161115', '202006161130', '202006161145', '202006161200', '202006161215', '202006161230', '202006161245', '202006171730', '202006171745', '202006171800', '202006171815', '202006171830', '202006171845', '202006171900', '202006171915', '202006171930', '202006171945', '202006172000', '202006172015', '202007252020', '202007252035', '202007252050', '202007252100', '202007252115', '202008161345', '202008161400', '202008161415', '202008161430', '202008161445', '202008161500', '202008161515', '202008161530', '202008161545', '202008161600', '202008161615', '202009232000', '202009232015', '202102031100', '202102031115', '202102031130', '202105131415', '202105131430', '202105131445', '202105131500', '202105131515', '202105131530', '202105131545', '202105131600', '202105131615', '202105131630', '202105131645', '202105131700', '202105131715', '202106180100', '202106180115', '202106180130', '202106180145', '202106180200', '202106192015', '202106192030', '202106192045', '202106192100', '202106192115', '202106192130', '202106192145', '202106192200', '202106192215', '202106192230', '202106192245']
new_valid = ['201903070300', '201903070315', '201904241715', '201904241730', '201904241745', '201904241800', '201904241815', '201904241830', '201904241845', '201904241900', '201904241915', '201905280330', '201905280345', '201905280400', '201905280415', '201905280430', '201905280445', '201905280500', '201905280515', '201906052035', '201906052050', '201906052100', '201906120545', '201906120600', '201906120615', '201906120630', '201906120645', '201906120700', '201906120715', '201906120730', '201906120745', '201906120800', '201906120815', '201906120830', '201906120845', '201906120900', '201906120915', '201906120930', '201906120945', '201906121000', '201906121015', '201906121030', '201906121045', '201906121100', '201906150035', '201906150050', '201906150100', '201906150115', '201906150130', '201906150145', '201906150200', '201906150215', '201906150230', '201906150245', '201906150300', '201906150315', '201906150330', '201906190400', '201906190415', '201906190430', '201906190445', '201906190500', '201906190515', '201906190530', '201906190545', '201906190600', '201908121840', '201910210250', '201910210300', '201910210315', '201910210330', '201910210345', '201910210400', '201910210415', '201910210430', '201910210445', '201911280500', '201911280515', '201911280530', '201911280545', '201911280600', '201911280615', '202002091800', '202002091815', '202002091830', '202002091845', '202002091900', '202002161520', '202002161535', '202002161550', '202003051615', '202003051630', '202003051645', '202003051700', '202006050400', '202006050415', '202006050430', '202006050445', '202006050500', '202006050515', '202006050530', '202006050545', '202006050600', '202006050615', '202006050630', '202006121915', '202006121930', '202006121945', '202006122000', '202006122015', '202006122030', '202006122045', '202006161035', '202006161050', '202006161100', '202006161115', '202006161130', '202006161145', '202006161200', '202006161215', '202006161230', '202006161245', '202006171730', '202006171745', '202006171800', '202006171815', '202006171830', '202006171845', '202006171900', '202006171915', '202006171930', '202006171945', '202006172000', '202006172015', '202007252020', '202007252035', '202007252050', '202007252100', '202007252115', '202008161345', '202008161400', '202008161415', '202008161430', '202008161445', '202008161500', '202008161515', '202008161530', '202008161545', '202008161600', '202008161615', '202009232000', '202009232015', '202102031100', '202102031115', '202102031130', '202105131415', '202105131430', '202105131445', '202105131500', '202105131515', '202105131530', '202105131545', '202105131600', '202105131615', '202105131630', '202105131645', '202105131700', '202105131715', '202106180100', '202106180115', '202106180130', '202106180145', '202106180200', '202106192015', '202106192030', '202106192045', '202106192100', '202106192115', '202106192130', '202106192145', '202106192200', '202106192215', '202106192230', '202106192245', '201904241800', '201904241815', '201904241830', '201907111120', '201907111135', '201907111150', '201907111200', '201907111215', '201907111230', '201907111245', '201907111300', '201907111315', '201907111330', '201908091035', '201908091050', '201908091100', '201908091115', '201908091130', '201908091145', '201908091200', '201908121410', '201908121425', '201908121440', '201908121500', '201908121515', '201908121530', '201908121545', '201908121600', '201908121615', '201908121630', '201908121645', '201908282300', '201908282315', '201908282330', '201908282345', '201908290000', '201908290015', '201908290030', '201908290045', '201908290100', '201908290115', '201908290130', '201908290145', '201909261805', '201909261820', '201909261835', '201909261850', '201909261900', '201909261915', '201909261930', '201909261945', '201909262000', '201910011820', '201910011835', '201910011850', '201910011900', '201910011915', '201910011930', '201910011945', '201910012000', '201910012015', '201910012030', '202006141135', '202006141150', '202006141200', '202006141215', '202006141230', '202006141245', '202006141300', '202006141315', '202006141330', '202006141345', '202006141400', '202006141415', '202006141430', '202006141445', '202006141500', '202006141515', '202006141530', '202006141545', '202006141600', '202006141615', '202006141630', '202006141645', '202006141700', '202006141715', '202006141730', '201904241750', '201904241800', '201905191600', '201905191615', '201905191630', '201905191645', '201905191700', '201905191715', '201905191730', '201905191745', '201906052135', '201906052150', '201906052200', '201906052215', '201906150100', '201906150115', '201906150130', '201906150145', '201907261340', '201907261400', '201907261415', '201907261430', '201907261445', '201907261500', '201907261515', '201908282240', '201908282300', '201908282315', '201908282330', '201908282345', '201908290000', '201908290015', '201908290030', '201908290045', '201908290100', '201908290115', '201910010435', '201910010450', '201910010500', '201910010515', '201910010530', '201910010545', '201910010600', '201910200930', '201910200945', '202006122010', '202006122025', '202006122040', '202006122100', '202006122115', '202006122130', '202006122145', '202006122200', '202006122215', '202006122230', '202006122245', '202006171525', '202006171540', '202006171600', '202006171615', '202006171630', '202006171645', '202006171700', '202006171715', '202006171730', '202006171745', '202006171800', '202006171815', '202006171830', '202006171845', '202006261700', '202006261715', '202006261730', '202006261745', '202006261800', '202006261815', '202006261830', '202006261845', '202006261900', '202006261915', '202006261930', '202006261945', '202006262000', '202009261920', '202009261935', '202009261950', '202009262000']
#print(new_valid)
        
dataset_ext = radarDataset(root_dir, new_valid, transform = Compose([ToTensor()])) 

#print("Extreme:", len(dataset_train_ext), len(dataset_test_ext), len(dataset_vali_ext))
loaders = { 'test' :DataLoader(dataset_test_ext, batch_size=1, shuffle=True, num_workers=8), 
           'valid' :DataLoader(dataset_vali, batch_size=1, shuffle=False, num_workers=0),
            'ext' :DataLoader(dataset_ext, batch_size=1, shuffle=False, num_workers=8),
           'train_ext' :DataLoader(dataset_train_ext, batch_size=1, shuffle=True, num_workers=8),
           'train_aa5' :DataLoader(dataset_train_aa, batch_size=1, shuffle=False, num_workers=0),
           'train_del5' :DataLoader(dataset_train, batch_size=1, shuffle=True, num_workers=0),
           'train_dw5' :DataLoader(dataset_train_dw, batch_size=1, shuffle=False, num_workers=0),
           'train_re5' :DataLoader(dataset_train_re, batch_size=1, shuffle=False, num_workers=0)}

print(len(loaders['ext']))
#print(dataset_ext[0])

358


In [7]:
from torch import nn
from nuwa_pytorch.nuwa_pytorch import Sparse3DNA, Attention, SparseCross2DNA, cast_tuple, mult_reduce, calc_same_padding, unfoldNd, FeedForward, ShiftVideoTokens, SandwichNorm, StableLayerNorm
class SparseCross3DNA(nn.Module):
    def __init__(
        self,
        dim,
        video_shape,
        kernel_size = 3,
        dilation = 1,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False,
        query_num_frames_chunk = None,
        rel_pos_bias = False
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.dropout = nn.Dropout(dropout)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)
        self.dilation = cast_tuple(dilation, size = 3)
        self.kernel_size = cast_tuple(kernel_size, size = 3)
        self.kernel_numel = mult_reduce(self.kernel_size)
       # relative positional bias per head, if needed
        self.rel_pos_bias = AxialPositionalEmbedding(heads, shape = self.kernel_size) if rel_pos_bias else None
        # calculate padding
        self.padding_frame = calc_same_padding(self.kernel_size[0], self.dilation[0])
        self.padding_height = calc_same_padding(self.kernel_size[1], self.dilation[1])
        self.padding_width = calc_same_padding(self.kernel_size[2], self.dilation[2])
        self.video_padding = (self.padding_width, self.padding_width, self.padding_height, self.padding_height, self.padding_frame, self.padding_frame)
        # save video shape and calculate max number of tokens
        self.video_shape = video_shape
        max_frames, fmap_size, _ = video_shape
        max_num_tokens = torch.empty(video_shape).numel()
        self.max_num_tokens = max_num_tokens
        # how many query tokens to process at once to limit peak memory usage, by multiple of frame tokens (fmap_size ** 2)
        self.query_num_frames_chunk = default(query_num_frames_chunk, max_frames)
        # precalculate causal mask
        indices = torch.arange(max_num_tokens)
        shaped_indices = rearrange(indices, '(f h w) -> 1 1 f h w', f = max_frames, h = fmap_size, w = fmap_size)
        padded_indices = F.pad(shaped_indices, self.video_padding, value = max_num_tokens) # padding has value of max tokens so to be masked out
        unfolded_indices = unfoldNd(padded_indices, kernel_size = self.kernel_size, dilation = self.dilation)
        unfolded_indices = rearrange(unfolded_indices, '1 k n -> n k')
        # if causal, compare query and key indices and make sure past cannot see future
        # if not causal, just mask out the padding

        if causal:
            mask = rearrange(indices, 'n -> n 1') < unfolded_indices
        else:
            mask = unfolded_indices == max_num_tokens

        #mask = F.pad(mask, (1, 0), value = False) # bos tokens never get masked out
        self.register_buffer('mask', mask)
        #print(self.mask.shape)

    def forward(self, x, context, **kwargs):
        b, n, _, h, device = *x.shape, self.heads, x.device
        n_context = context.shape[1]
        # more variables

        dilation = self.dilation
        kernel_size = self.kernel_size
        video_padding = self.video_padding
        fmap_size = self.video_shape[1]
        tokens_per_frame = fmap_size ** 2

        padding = padding_to_multiple_of(n - 1, tokens_per_frame)
        num_frames = (n + padding) // tokens_per_frame
        num_frames_input = (n_context + padding) // tokens_per_frame
        
        # derive queries / keys / values

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        
        bos_only = n == 1
        if bos_only:
            return self.to_out(v)
        # split out heads

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
        
       
        # take care of bos

        #q = q[:, 1:]
        #bos_value = q[:, :1]
        
        # scale queries
        q = q * self.scale

        # reshape keys and values to video and add appropriate padding along all dimensions (frames, height, width)
        k, v = map(lambda t: rearrange(t, 'b (f h w) d -> b d f h w',  h = fmap_size, w = fmap_size), (k, v))
        k, v = map(lambda t: F.pad(t, video_padding), (k, v))
        #print(k.shape)
        # axial relative pos bias

        rel_pos_bias = None

        if exists(self.rel_pos_bias):
            rel_pos_bias = rearrange(self.rel_pos_bias(), 'j h -> h 1 j')
            rel_pos_bias = F.pad(rel_pos_bias, (1, 0), value = 0.)

        # put the attention processing code in a function
        # to allow for processing queries in chunks of frames

        out = []

        def attend(q, k, v, mask, kernel_size):
            chunk_length = q.shape[1]
            k, v = map(lambda t: unfoldNd(t, kernel_size = kernel_size, dilation = dilation), (k, v))
            k, v = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = self.kernel_numel), (k, v))
            k, v = map(lambda t: t[:, :chunk_length], (k, v))

            # calculate sim

            sim = einsum('b i d, b i j d -> b i j', q, k)
            #print(q.shape, k.shape, sim.shape)
            # add rel pos bias, if needed

            if exists(rel_pos_bias):
                sim = sim + rel_pos_bias

            # causal mask

            
            if exists(mask):
                mask_value = -torch.finfo(sim.dtype).max               
                mask = rearrange(mask, 'i j -> 1 i j')
                sim = sim.masked_fill(mask, mask_value)
            
            # attention

            attn = stable_softmax(sim, dim = -1)

            attn = rearrange(attn, '(b h) ... -> b h ...', h = h)
            attn = self.talking_heads(attn)
            attn = rearrange(attn, 'b h ... -> (b h) ...')

            attn = self.dropout(attn)

            # aggregate values

            return einsum('b i j, b i j d -> b i d', attn, v)

        # process queries in chunks

        frames_per_chunk = num_frames_input
        chunk_size = frames_per_chunk * tokens_per_frame
        q_chunks = q.split(chunk_size, dim = 1)
        mask = self.mask
        #print(q.shape, chunk_size)
        for ind, q_chunk in enumerate(q_chunks):
            #print(ind, q_chunks[ind].shape, mask.shape)
            q_chunk = q_chunks[ind]
            length = q_chunk.shape[1]
            q_chunk = F.pad(input=q_chunk, pad=(0, 0, 0, chunk_size-length), mode='constant', value=float("-Inf"))
            mask_chunk = mask

            # slice the keys and values to the appropriate frames, accounting for padding along frames dimension

            k_slice, v_slice = k,v
            # calculate output chunk
            out_chunk = attend(
                q = q_chunk,
                k = k_slice,
                v = v_slice,
                mask = mask_chunk,
                kernel_size = kernel_size,
            )

            out_chunk = out_chunk[:,:length,:]
            out.append(out_chunk)
            
        # combine all chunks
        out = torch.cat(out, dim = 1)
        # append bos value

        #out = torch.cat((bos_value, out), dim = 1)  # bos will always adopt its own value, since it pays attention only to itself
        # merge heads

        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)
    
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        causal = False,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        cross_attend = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_chunk_size = None,
        cross_2dna_attn = False,
        cross_2dna_image_size = None,
        cross_2dna_kernel_size = 3,
        cross_2dna_dilations = (1,),
        cross_3dna_attn = False,
        cross_3dna_image_size = None,
        cross_3dna_kernel_size = 3,
        cross_3dna_dilations = (1,),
        sparse_3dna_attn = False,
        sparse_3dna_kernel_size = 3,
        sparse_3dna_video_shape = None,
        sparse_3dna_query_num_frames_chunk = None,
        sparse_3dna_dilations = (1,),
        sparse_3dna_rel_pos_bias = False,
        shift_video_tokens = False,
        rotary_pos_emb = False
    ):
        super().__init__()

        self.layers = MList([])

        for ind in range(depth):
            if sparse_3dna_attn:
                dilation = sparse_3dna_dilations[ind % len(sparse_3dna_dilations)]

                self_attn = Sparse3DNA(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    kernel_size = sparse_3dna_kernel_size,
                    dilation = dilation,
                    video_shape = sparse_3dna_video_shape,
                    query_num_frames_chunk = sparse_3dna_query_num_frames_chunk,
                    rel_pos_bias = sparse_3dna_rel_pos_bias,
                )
            else:
                self_attn = Attention(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    dropout = attn_dropout
                )

            cross_attn = None
            
            if cross_attend:
                if cross_2dna_attn:
                    dilation = cross_2dna_dilations[ind % len(cross_2dna_dilations)]

                    cross_attn = SparseCross2DNA(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout,
                        image_size = cross_2dna_image_size,
                        kernel_size = cross_2dna_kernel_size,
                        dilation = dilation
                    )
                
                elif cross_3dna_attn:
                    
                    cross_attn = SparseCross3DNA(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout,
                        video_shape = cross_3dna_image_size,
                        kernel_size = cross_3dna_kernel_size,
                    )
                
                else:
                    cross_attn = Attention(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout
                    )

            ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, chunk_size = ff_chunk_size)

            if sparse_3dna_attn and shift_video_tokens:
                fmap_size = sparse_3dna_video_shape[-1]
                self_attn = ShiftVideoTokens(self_attn, image_size = fmap_size)
                ff        = ShiftVideoTokens(ff, image_size = fmap_size)

            self.layers.append(MList([
                SandwichNorm(dim = dim, fn = self_attn),
                SandwichNorm(dim = dim, fn = cross_attn) if cross_attend else None,
                SandwichNorm(dim = dim, fn = ff)
            ]))

        self.norm = StableLayerNorm(dim)

    def forward(
        self,
        x,
        mask = None,
        context = None,
        context_mask = None
    ):
        for attn, cross_attn, ff in self.layers:
            x = attn(x, mask = mask) + x

            if exists(cross_attn):
                x = cross_attn(x, context = context, mask = mask, context_mask = context_mask) + x

            x = ff(x) + x

        return self.norm(x)

In [8]:
# transformer for second stage
# transformer for second stage
from torch import nn
import nuwa_pytorch.nuwa_pytorch
from nuwa_pytorch.nuwa_pytorch import MList, Embedding, AxialPositionalEmbedding, ReversibleTransformer, default, eval_decorator, prob_mask_like, exists
from nuwa_pytorch.nuwa_pytorch import padding_to_multiple_of, einsum, top_k
from nuwa_pytorch.vqgan_vae import stable_softmax
from einops import rearrange, repeat
import torch.nn.functional as F
from main import instantiate_from_config
#from dice_loss import DiceLoss
class nuwa_v2v(nn.Module):
    def __init__(
        self,
        *,
        dim,
        vae = None,
        image_size = 256,
        codebook_size = 1024,
        compression_rate = 16,
        video_out_seq_len = 5,
        video_in_seq_len = 5,
        video_dec_depth = 6,
        video_dec_dim_head = 64,
        video_dec_heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_chunk_size = None,
        embed_gradient_frac = 0.2,
        shift_video_tokens = True,
        sparse_3dna_kernel_size = 3,
        sparse_3dna_query_num_frames_chunk = None,
        sparse_3dna_dilation = 1,
    ):
        super().__init__()
        
        # VAE
        #self.first_stage_model = vae.copy_for_eval()
        self.image_size = image_size
        num_image_tokens = codebook_size
        fmap_size = image_size // compression_rate
        self.video_fmap_size = fmap_size
        self.video_frames_in = video_in_seq_len
        self.video_frames_out = video_out_seq_len
        self.image_embedding = Embedding(num_image_tokens, dim, frac_gradient=embed_gradient_frac)
        # cycle dilation for sparse 3d-nearby attention 
        sparse_3dna_dilations = tuple(range(1, sparse_3dna_dilation + 1)) if not isinstance(sparse_3dna_dilation, (list, tuple)) else sparse_3dna_dilation  # ???
        
        video_shape_in = (video_in_seq_len, fmap_size, fmap_size)
        video_shape_out = (video_out_seq_len, fmap_size, fmap_size)
        self.video_in_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape_in)
        self.video_out_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape_out)
        self.video_bos = nn.Parameter(torch.randn(dim))
        
        # 3DNA Encoder
        video_shape_input = (video_in_seq_len, fmap_size, fmap_size)
    
        # 3DNA Decoder
        #self.video_bos = nn.Parameter(torch.randn(dim))#
        #self.image_embedding = Embedding(num_image_tokens, dim, frac_gradient = embed_gradient_frac)#???
        # self.video_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape)#???

        video_shape_output = (video_out_seq_len, fmap_size, fmap_size)
        self.video_decode_transformer = Transformer(
            dim = dim,
            depth = video_dec_depth,
            heads = video_dec_heads,
            dim_head = video_dec_dim_head,
            causal = True,
            cross_attend = True,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_chunk_size = ff_chunk_size,
            cross_3dna_attn = True,
            cross_3dna_image_size = video_shape_input,
            cross_3dna_kernel_size = (5,3,3),
            shift_video_tokens = shift_video_tokens,
            sparse_3dna_video_shape = video_shape_output,
            sparse_3dna_attn = True,
            sparse_3dna_kernel_size = sparse_3dna_kernel_size,
            sparse_3dna_dilations = sparse_3dna_dilations,
            sparse_3dna_query_num_frames_chunk = sparse_3dna_query_num_frames_chunk
        )

        self.to_logits = nn.Linear(dim, num_image_tokens)
    
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        *,
        video_input_latent,
        video_output_latent = None,
        filter_thres = 0.9,
        temperature = 1.,
        num_frames = 6,
        prior_weight = 1,
        alpha = 1
    ):
        #video_indice_last_frame = video_input_latent[0,:,512:768]
        device = video_input_latent.device
        batch = video_input_latent.shape[0]
        
        frame_indices_input = video_input_latent.squeeze(1)
        frame_embeddings_input = self.image_embedding(frame_indices_input)
        frame_embeddings_input = self.video_in_pos_emb().repeat(batch,1,1) + frame_embeddings_input
        
        bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
        video_indices = torch.empty((batch, 0), device = device, dtype = torch.long)
        num_tokens_per_frame = self.video_fmap_size ** 2
        
        num_frames = default(num_frames, self.video_frames_out)
        total_video_tokens =  num_tokens_per_frame * num_frames
        
        pos_emb = self.video_out_pos_emb()
        
        prior = frame_indices_input[:, 512:].repeat(1,num_frames)
        for ind in range(total_video_tokens):
            print(ind, '/', total_video_tokens, end = '\r')
            
            video_indices_input = video_indices
            num_video_tokens = video_indices.shape[1]
            #video_indices_input = video_output_latent[:,: num_video_tokens]
            
            frame_embeddings = self.image_embedding(video_indices_input)
            frame_embeddings = pos_emb[:frame_embeddings.shape[1]] + frame_embeddings
            frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1)
            
            frame_embeddings = self.video_decode_transformer(
                frame_embeddings,
                context = frame_embeddings_input
            )
            
        
            logits = self.to_logits(frame_embeddings)
            logits = logits[:, -1, :]
            logits[:,int(prior[:,ind])] *= prior_weight
            filtered_logits = top_k(logits, thres = filter_thres)
            
            #probs = F.softmax(filtered_logits, dim=-1)
            #sample = torch.multinomial(probs, 1).squeeze(-1)
            
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
            sample = rearrange(sample, 'b -> b 1')
            video_indices = torch.cat((video_indices, sample), dim = 1)
            
            if (ind+1) % num_tokens_per_frame == 0:
                n =  int((ind+1) / num_tokens_per_frame)
                prior = video_indices[:, (n-1)*int(num_tokens_per_frame):].repeat(1,num_frames)
                prior_weight = max(prior_weight*alpha, 1)
        
        return video_indices

    def forward(
        self,
        *,
        #video_input, # raw observation video here
        #video_output, # raw prediction video here
        video_input_latent,
        video_output_latent,
        return_loss = False,
        cond_dropout_prob = 0
    ):
        device = video_input_latent.device
        batch = video_input_latent.shape[0]
        seq_out_len = self.video_frames_out
        seq_in_len = self.video_frames_in
        
        #get indices
        frame_indices_input = video_input_latent.squeeze(1)
        frame_indices_output = video_output_latent.squeeze(1)
        frame_indices_output = frame_indices_output[:, :-1] if return_loss else frame_indices_output
        
        #indices to embedding
        frame_embeddings_input = self.image_embedding(frame_indices_input)
        frame_embeddings_prediction = self.image_embedding(frame_indices_output)

        #position encoding
        frame_embeddings_input = self.video_in_pos_emb().repeat(batch,1,1) + frame_embeddings_input
        if return_loss:
            frame_embeddings_prediction = self.video_out_pos_emb()[:-1].repeat(batch,1,1) + frame_embeddings_prediction
            bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
            frame_embeddings_prediction = torch.cat((bos, frame_embeddings_prediction), dim = 1)
        else:
            frame_embeddings_prediction = self.video_out_pos_emb().repeat(batch,1,1) + frame_embeddings_prediction
        
        #transformer 
        frame_embeddings_prediction = self.video_decode_transformer(
            frame_embeddings_prediction,
            context = frame_embeddings_input
        )
        #print(frame_embeddings_prediction.shape)
        logits = self.to_logits(frame_embeddings_prediction)
        
        #print(t2-t1, t3-t2, t4-t3, t5-t4)
        if not return_loss:
            return logits
        #weight = torch.load('weight_clip_10.pt').to(device)
        loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), video_output_latent.squeeze(1))
        
        #dice_loss = DiceLoss(smooth = 1, with_logits = True, ohem_ratio = 0, 
        #                     alpha = 0.01, reduction = "mean",  index_label_position=True,  square_denominator = True)
        #print(rearrange(logits, 'b n c -> (b n) c').shape, rearrange(video_output_latent.squeeze(1), 'b n -> (b n)').shape)
        #loss = dice_loss(rearrange(logits, 'b n c -> (b n) c'), rearrange(video_output_latent.squeeze(1), 'b n -> (b n)'))
        
        return loss, logits

In [9]:
from torchsummary import summary
#torch.cuda.empty_cache()
nuwa = nuwa_v2v(dim=512, vae=None,
                image_size = 256,
                codebook_size = 1024,
                compression_rate = 16,
                video_out_seq_len = 6,
                video_in_seq_len = 3,
                video_dec_depth = 12,
                video_dec_dim_head = 64,
                video_dec_heads = 4,
                sparse_3dna_kernel_size = (7,3,3),
                attn_dropout = 0,
                ff_dropout = 0).to(DEVICE)
#summary(nuwa)
#print('')

In [10]:
from pysteps.verification.detcatscores import det_cat_fct
from pysteps.verification.detcontscores import det_cont_fct
from pysteps.verification.spatialscores import intensity_scale
from pysteps.visualization import plot_precip_field
from nuwa_pytorch.nuwa_pytorch import top_k, gumbel_sample
import random
device = DEVICE
nuwa.eval()
loaders = { 'test' :DataLoader(dataset_test_ext, batch_size=1, shuffle=True, num_workers=8), 
           'valid' :DataLoader(dataset_vali, batch_size=1, shuffle=False, num_workers=0),
            'ext' :DataLoader(dataset_ext, batch_size=1, shuffle=False, num_workers=8),
           'train_ext' :DataLoader(dataset_train_ext, batch_size=1, shuffle=True, num_workers=8),
           'train_aa5' :DataLoader(dataset_train_aa, batch_size=1, shuffle=False, num_workers=0),
           'train_del5' :DataLoader(dataset_train, batch_size=1, shuffle=True, num_workers=0),
           'train_dw5' :DataLoader(dataset_train_dw, batch_size=1, shuffle=False, num_workers=0),
           'train_re5' :DataLoader(dataset_train_re, batch_size=1, shuffle=False, num_workers=0)}
#torch.cuda.empty_cache()
from collections import Counter
import time
from pysteps.verification.detcatscores import det_cat_fct
from pysteps.verification.detcontscores import det_cont_fct
from pysteps.verification.spatialscores import intensity_scale
from pysteps.visualization import plot_precip_field
from tqdm import tqdm
import math
import random

Pysteps configuration file found at: /users/junzheyin/anaconda3/envs/myenv/lib/python3.8/site-packages/pysteps/pystepsrc



In [13]:
pcc_sum = [0]*6
mse_sum = [0]*6
mae_sum = [0]*6
csi1_sum = [0]*6
csi2_sum = [0]*6
csi8_sum = [0]*6
far1_sum = [0]*6
far2_sum = [0]*6
far8_sum = [0]*6
fss1_sum = [0]*6
fss10_sum = [0]*6
fss20_sum = [0]*6
fss30_sum = [0]*6
record = {}
#name = '/home/hbi/predictions_new_ce_1/{}.npy'
#checkpoint = torch.load( '/home/hbi/nuwa_newvae_sparse7_epoch14',map_location="cpu")
name = '/home/hbi/predictions_new_ce_6/{}.npy'
checkpoint = torch.load( '/bulk/junzheyin/nuwa_newvae_sparse7_evl_75_epoch12',map_location="cpu")
nuwa.load_state_dict(checkpoint, strict=True)
#for name in ['/home/hbi/predictions_evl100/{}.npy', '/home/hbi/predictions_evl75/{}.npy', '/home/hbi/predictions_evl50/{}.npy', '/home/hbi/predictions_ce/{}.npy', '/home/hbi/predictions_wce/{}.npy']:
#    print(name)
count = 0
count_nan = 0
j = 0
for i, (images,time) in enumerate(loaders['ext']):
    print(i, time[0])
    image = images[0].unsqueeze(1).unsqueeze(0)
    a = Variable(image).to(DEVICE)   # batch x
    a1 = a.squeeze(0)
    # Generate
    count+=1
    n = 6

    '''
    plt.figure(figsize=(12, 4))
    plt.subplot(131)
    plot_precip_field(a1[0,0,:,:].to('cpu').detach().numpy()*40, title="t-60", axis='off')
    plt.subplot(132)
    plot_precip_field(a1[1,0,:,:].to('cpu').detach().numpy()*40, title="t-30", axis='off')
    plt.subplot(133)
    plot_precip_field(a1[2,0,:,:].to('cpu').detach().numpy()*40, title="t", axis='off')
    plt.tight_layout()
    plt.show()'''
    if not os.path.exists(name.format(time[0])):
        indice = vae.get_video_indices(a)
        indice_obs = indice[:,:3, :, :]
        indice_obs = torch.flatten(indice_obs).unsqueeze(0)
        indice_pre = indice[:,3:, :, :]
        indice_pre = torch.flatten(indice_pre).unsqueeze(0)
        indice = torch.flatten(indice).unsqueeze(0)
        video_generate1 = nuwa.generate(video_input_latent = indice_obs, 
                                        video_output_latent = indice_pre,
                                        filter_thres = 0.98,
                                        temperature = 0.5,
                                        num_frames = n,
                                        prior_weight = 2,
                                        alpha = 0.8)
    #_, video_generate1 = nuwa(video_input_latent= indice_obs,video_output_latent= indice_pre, return_loss = True)
    #video_generate1 = torch.argmax(video_generate1,dim=-1)
    #accuracy = (indice_pre==video_generate1).sum().item()*100/indice_pre.shape[1]

        video_predict1 = vae.codebook_indices_to_video(video_generate1).squeeze(0)
    #video_predict2 = vae.codebook_indices_to_video(video_generate2).squeeze(0)
    #indice_obs = torch.cat((indice_obs[:,256:], video_generate1), dim = 1)
        vp = video_predict1[:,0,:,:].to('cpu').detach().numpy()*40
        np.save(name.format(time[0]), vp)
    else:
        vp = np.load(name.format(time[0]))

    vp_catchment = vp[:,168:193,54:80]
    vg_catchment = a1[:,0,168:193,54:80].to('cpu').detach().numpy()*40

    #print("MFBS average: {:.3f}, Radar average: {:.3f}, Prediction Average:{:.3f}"
    #      .format(dic_mfbs[int(time[0])], float(np.mean(vg_catchment.flatten())), float(np.mean(vp_catchment.flatten()))))
    flag = 0
    for t in range(n):
        a1_display = a1[t+3,0,:,:].to('cpu').detach().numpy()*40
        #a_r1_display = a_r1[j+t+3,0,:,:].to('cpu').detach().numpy()*40
        a_p1_display = vp[t,:,:]

        scores_cont1 = det_cont_fct(a_p1_display, a1_display, thr=0.1)
        scores_cat1 = det_cat_fct(a_p1_display, a1_display, 1)
        scores_cat2 = det_cat_fct(a_p1_display, a1_display, 2)
        scores_cat8 = det_cat_fct(a_p1_display, a1_display, 8)
        scores_spatial = intensity_scale(a_p1_display, a1_display, 'FSS', 1, [1,10,20,30])

        print("Start Time:", time[0], "Lead Time: {} mins".format(-90+(t+j+4)*30),'\n', 
              'MSE:', np.around(scores_cont1['MSE'],3), 
              'MAE:', np.around(scores_cont1['MAE'],3), 
              'PCC:', np.around(scores_cont1['corr_p'],3),'\n', 
              'CSI(1mm):', np.around(scores_cat1['CSI'],3), # CSI: TP/(TP+FP+FN)
              'CSI(2mm):', np.around(scores_cat2['CSI'],3),
              'CSI(8mm):', np.around(scores_cat8['CSI'],3),'\n',
              'FSS(1km):', np.around(scores_spatial[3][0],3),
              'FSS(10km):', np.around(scores_spatial[2][0],3),
              'FSS(20km):', np.around(scores_spatial[1][0],3),
              'FSS(30km):', np.around(scores_spatial[0][0],3)
             )
        pcc_sum[t] += scores_cont1['corr_p']
        mse_sum[t] += scores_cont1['MSE']
        mae_sum[t] += scores_cont1['MAE']
        csi1_sum[t] += scores_cat1['CSI']
        far1_sum[t] += scores_cat1['FAR']
        if not (math.isnan(scores_cat2['CSI']) or math.isnan(scores_cat2['FAR'])):
            csi2_sum[t] += scores_cat2['CSI']
            far2_sum[t] += scores_cat2['FAR']
        else: flag = 1
        if not (math.isnan(scores_cat8['CSI']) or math.isnan(scores_cat8['FAR'])):
            csi8_sum[t] += scores_cat8['CSI']
            far8_sum[t] += scores_cat8['FAR']
        else: flag = 1
        fss1_sum[t] += scores_spatial[3][0]
        fss10_sum[t] += scores_spatial[2][0]
        fss20_sum[t] += scores_spatial[1][0]
        fss30_sum[t] += scores_spatial[0][0]
        plt.figure(figsize=(12, 4))
        plt.subplot(131)
        plot_precip_field(a1_display, title="t+{}".format((t+1)*30), axis='off')
        #plt.subplot(132)
        #plot_precip_field(a_r1_display, title="Reconstruction")
        plt.subplot(132)
        plot_precip_field(a_p1_display, title="Prediction")
        plt.tight_layout()
        plt.show()
    if flag == 1:count_nan+=1
for t in range(6):
    pcc_sum[t] = round(pcc_sum[t]/count,5)
    mse_sum[t] = round(mse_sum[t]/count,5)
    mae_sum[t] = round(mae_sum[t]/count,5)
    csi1_sum[t] = round(csi1_sum[t]/count,5)
    csi2_sum[t] = round(csi2_sum[t]/(count-count_nan),5)
    csi8_sum[t] = round(csi8_sum[t]/(count-count_nan),5)
    far1_sum[t] = round(far1_sum[t]/count,5)
    far2_sum[t] = round(far2_sum[t]/count,5)
    far8_sum[t] = round(far8_sum[t]/count,5)
    fss1_sum[t] = round(fss1_sum[t]/count,5)
    fss10_sum[t] = round(fss10_sum[t]/count,5)
    fss20_sum[t] = round(fss20_sum[t]/count,5)
    fss30_sum[t] = round(fss30_sum[t]/count,5)
print('PCC average:', pcc_sum)
print('MSE average:', mse_sum)
print('MAE average:', mae_sum)
print('CSI_1 average:', csi1_sum)
print('CSI_2 average:', csi2_sum)
print('CSI_8 average:', csi8_sum)
print('FAR_1 average:', far1_sum)
print('FAR_2 average:', far2_sum)
print('FAR_8 average:', far8_sum)
print('FSS(1km) average:', fss1_sum)
print('FSS(10km) average:', fss10_sum)
print('FSS(20km) average:', fss20_sum)
print('FSS(30km) average:', fss30_sum)

        