In [1]:
import sys
sys.path.append(".")
# also disable grad to save memory
import torch
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import yaml
import torch
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ
import io
import os, sys
import requests
import PIL
from PIL import Image
from PIL import ImageDraw, ImageFont
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

In [2]:
# Define dataset
import torch
import sys
from nuwa_pytorch import VQGanVAE
import h5py
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
def eventGeneration(start_time, obs_time = 3 ,lead_time = 6, time_interval = 30):
    # Generate event based on starting time point, return a list: [[t-4,...,t-1,t], [t+1,...,t+72]]
    # Get the start year, month, day, hour, minute
    year = int(start_time[0:4])
    month = int(start_time[4:6])
    day = int(start_time[6:8])
    hour = int(start_time[8:10])
    minute = int(start_time[10:12])
    #print(datetime(year=year, month=month, day=day, hour=hour, minute=minute))
    times = [(datetime(year, month, day, hour, minute) + timedelta(minutes=time_interval * (x+1))) for x in range(lead_time)]
    lead = [dt.strftime('%Y%m%d%H%M') for dt in times]
    times = [(datetime(year, month, day, hour, minute) - timedelta(minutes=time_interval * x)) for x in range(obs_time)]
    obs = [dt.strftime('%Y%m%d%H%M') for dt in times]
    obs.reverse()
    return lead, obs

from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
from torchvision.transforms import ToTensor, Compose, CenterCrop
class radarDataset(Dataset):
    def __init__(self, root_dir, event_times, obs_number = 3, pred_number = 18, transform=None):
        # event_times is an array of starting time t(string)
        # transform is the preprocessing functions
        self.root_dir = root_dir
        self.transform = transform
        self.event_times = event_times
        self.obs_number = obs_number
        self.pred_number = pred_number
    def __len__(self):
        return len(self.event_times)
    def __getitem__(self, idx):
        start_time = str(self.event_times[idx])
        time_list_pre, time_list_obs = eventGeneration(start_time, self.obs_number, self.pred_number)
        output = []
        time_list = time_list_obs + time_list_pre
        #print(time_list)
        for time in time_list:
            year = time[0:4]
            month = time[4:6]
            #path = self.root_dir + year + '/' + month + '/' + 'RAD_NL25_RAC_MFBS_EM_5min_' + time + '_NL.h5'
            path = self.root_dir + year + '/' + month + '/' + 'RAD_NL25_RAP_5min_' + time + '.h5'
            image = np.array(h5py.File(path)['image1']['image_data'])
            #image = np.ma.masked_where(image == 65535, image)
            image = image[264:520,242:498]
            image[image == 65535] = 0
            image = image.astype('float32')
            image = image/100*12
            image = np.clip(image, 0, 128)
            image = image/128
            output.append(image)
        output = torch.permute(torch.tensor(np.array(output)), (1, 2, 0))
        output = self.transform(np.array(output))
        return output
#root_dir = '/users/hbi/data/RAD_NL25_RAC_MFBS_EM_5min/'
#dataset = radarDataset(root_dir, ["200808031600"], transform = Compose([ToTensor(),CenterCrop(256)]))


In [3]:
# develop dataset
from torch.cuda.amp import autocast
from torch.autograd import Variable
import pandas as pd
root_dir = '/home/hbi/RAD_NL25_RAP_5min/' 

df_train = pd.read_csv('training_Delfland08-14_20.csv', header = None)
event_times = df_train[0].to_list()
dataset_train = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))  

df_train_s = pd.read_csv('training_Delfland08-14.csv', header = None)
event_times = df_train_s[0].to_list()
dataset_train_del = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))  

df_test = pd.read_csv('testing_Delfland18-20.csv', header = None)
event_times = df_test[0].to_list()
dataset_test = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

df_vali = pd.read_csv('validation_Delfland15-17.csv', header = None)
event_times = df_vali[0].to_list()
dataset_vali = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

df_train_aa = pd.read_csv('training_Aa08-14.csv', header = None)
event_times = df_train_aa[0].to_list()
dataset_train_aa = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))  

df_train_dw = pd.read_csv('training_Dwar08-14.csv', header = None)
event_times = df_train_dw[0].to_list()
dataset_train_dw = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))    

df_train_re = pd.read_csv('training_Regge08-14.csv', header = None)
event_times = df_train_re[0].to_list()
dataset_train_re = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))   

print(len(dataset_train), len(dataset_test), len(dataset_vali))
loaders = { 'train' :DataLoader(dataset_train, batch_size=1, shuffle=True, num_workers=8),
            'test' :DataLoader(dataset_test, batch_size=1, shuffle=True, num_workers=8), 
           'valid' :DataLoader(dataset_vali, batch_size=1, shuffle=False, num_workers=8),
          
          'train_aa5' :DataLoader(dataset_train_aa, batch_size=1, shuffle=False, num_workers=8),
          'train_dw5' :DataLoader(dataset_train_dw, batch_size=1, shuffle=False, num_workers=8),
          'train_del5' :DataLoader(dataset_train_del, batch_size=1, shuffle=False, num_workers=8),
          'train_re5' :DataLoader(dataset_train_re, batch_size=1, shuffle=False, num_workers=8),
          }

32183 3493 3560


In [4]:
from torch import nn
from nuwa_pytorch.nuwa_pytorch import Sparse3DNA, Attention, SparseCross2DNA, cast_tuple, mult_reduce, calc_same_padding, unfoldNd, FeedForward, ShiftVideoTokens, SandwichNorm, StableLayerNorm
class SparseCross3DNA(nn.Module):
    def __init__(
        self,
        dim,
        video_shape,
        kernel_size = 3,
        dilation = 1,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False,
        query_num_frames_chunk = None,
        rel_pos_bias = False
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.dropout = nn.Dropout(dropout)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)
        self.dilation = cast_tuple(dilation, size = 3)
        self.kernel_size = cast_tuple(kernel_size, size = 3)
        self.kernel_numel = mult_reduce(self.kernel_size)
       # relative positional bias per head, if needed
        self.rel_pos_bias = AxialPositionalEmbedding(heads, shape = self.kernel_size) if rel_pos_bias else None
        # calculate padding
        self.padding_frame = calc_same_padding(self.kernel_size[0], self.dilation[0])
        self.padding_height = calc_same_padding(self.kernel_size[1], self.dilation[1])
        self.padding_width = calc_same_padding(self.kernel_size[2], self.dilation[2])
        self.video_padding = (self.padding_width, self.padding_width, self.padding_height, self.padding_height, self.padding_frame, self.padding_frame)
        # save video shape and calculate max number of tokens
        self.video_shape = video_shape
        max_frames, fmap_size, _ = video_shape
        max_num_tokens = torch.empty(video_shape).numel()
        self.max_num_tokens = max_num_tokens
        # how many query tokens to process at once to limit peak memory usage, by multiple of frame tokens (fmap_size ** 2)
        self.query_num_frames_chunk = default(query_num_frames_chunk, max_frames)
        # precalculate causal mask
        indices = torch.arange(max_num_tokens)
        shaped_indices = rearrange(indices, '(f h w) -> 1 1 f h w', f = max_frames, h = fmap_size, w = fmap_size)
        padded_indices = F.pad(shaped_indices, self.video_padding, value = max_num_tokens) # padding has value of max tokens so to be masked out
        unfolded_indices = unfoldNd(padded_indices, kernel_size = self.kernel_size, dilation = self.dilation)
        unfolded_indices = rearrange(unfolded_indices, '1 k n -> n k')
        # if causal, compare query and key indices and make sure past cannot see future
        # if not causal, just mask out the padding

        if causal:
            mask = rearrange(indices, 'n -> n 1') < unfolded_indices
        else:
            mask = unfolded_indices == max_num_tokens

        #mask = F.pad(mask, (1, 0), value = False) # bos tokens never get masked out
        self.register_buffer('mask', mask)

    def forward(self, x, context, **kwargs):
        b, n, _, h, device = *x.shape, self.heads, x.device
        n_context = context.shape[1]
        # more variables

        dilation = self.dilation
        kernel_size = self.kernel_size
        video_padding = self.video_padding
        fmap_size = self.video_shape[1]
        tokens_per_frame = fmap_size ** 2

        padding = padding_to_multiple_of(n - 1, tokens_per_frame)
        num_frames = (n + padding) // tokens_per_frame
        num_frames_input = (n_context + padding) // tokens_per_frame
        
        # derive queries / keys / values

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        # split out heads

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
        
        # take care of bos

        #q = q[:, 1:]
        #bos_value = q[:, :1]
        
        # scale queries
        q = q * self.scale

        # reshape keys and values to video and add appropriate padding along all dimensions (frames, height, width)
        k, v = map(lambda t: rearrange(t, 'b (f h w) d -> b d f h w',  h = fmap_size, w = fmap_size), (k, v))
        k, v = map(lambda t: F.pad(t, video_padding), (k, v))
        #print(k.shape)
        # axial relative pos bias

        rel_pos_bias = None

        if exists(self.rel_pos_bias):
            rel_pos_bias = rearrange(self.rel_pos_bias(), 'j h -> h 1 j')
            rel_pos_bias = F.pad(rel_pos_bias, (1, 0), value = 0.)

        # put the attention processing code in a function
        # to allow for processing queries in chunks of frames

        out = []

        def attend(q, k, v, mask, kernel_size):
            chunk_length = q.shape[1]

            k, v = map(lambda t: unfoldNd(t, kernel_size = kernel_size, dilation = dilation), (k, v))
            k, v = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = self.kernel_numel), (k, v))
            k, v = map(lambda t: t[:, :chunk_length], (k, v))

            # calculate sim

            sim = einsum('b i d, b i j d -> b i j', q, k)

            # add rel pos bias, if needed

            if exists(rel_pos_bias):
                sim = sim + rel_pos_bias

            # causal mask

            if exists(mask):
                mask_value = -torch.finfo(sim.dtype).max
                mask = rearrange(mask, 'i j -> 1 i j')
                sim = sim.masked_fill(mask, mask_value)

            # attention

            attn = stable_softmax(sim, dim = -1)

            attn = rearrange(attn, '(b h) ... -> b h ...', h = h)
            attn = self.talking_heads(attn)
            attn = rearrange(attn, 'b h ... -> (b h) ...')

            attn = self.dropout(attn)

            # aggregate values

            return einsum('b i j, b i j d -> b i d', attn, v)

        # process queries in chunks

        frames_per_chunk = min(self.query_num_frames_chunk, num_frames)
        chunk_size = frames_per_chunk * tokens_per_frame
        q_chunks = q.split(chunk_size, dim = 1)
        
        mask = self.mask

        for ind, q_chunk in enumerate(q_chunks):
            #print(ind, q_chunks[ind].shape, mask.shape)
            q_chunk = q_chunks[ind]
            mask_chunk = mask
            #print(q_chunk.shape, mask.shape, k.shape, v.shape)
            # slice the keys and values to the appropriate frames, accounting for padding along frames dimension

            kv_start_pos = ind * frames_per_chunk
            kv_end_pos = kv_start_pos + (ind + frames_per_chunk + self.padding_frame * 2)
            kv_frame_range = slice(kv_start_pos, kv_end_pos)

            k_slice, v_slice = map(lambda t: t[:, :, kv_frame_range], (k, v))
            k_slice, v_slice = k,v
            # calculate output chunk
            out_chunk = attend(
                q = q_chunk,
                k = k_slice,
                v = v_slice,
                mask = mask_chunk,
                kernel_size = kernel_size,
            )
            out.append(out_chunk)

        # combine all chunks
        out = torch.cat(out, dim = 1)
        # append bos value

        #out = torch.cat((bos_value, out), dim = 1)  # bos will always adopt its own value, since it pays attention only to itself
        # merge heads

        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)
    
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        causal = False,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        cross_attend = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_chunk_size = None,
        cross_2dna_attn = False,
        cross_2dna_image_size = None,
        cross_2dna_kernel_size = 3,
        cross_2dna_dilations = (1,),
        cross_3dna_attn = False,
        cross_3dna_image_size = None,
        cross_3dna_kernel_size = 5,
        cross_3dna_dilations = (1,),
        sparse_3dna_attn = False,
        sparse_3dna_kernel_size = 3,
        sparse_3dna_video_shape = None,
        sparse_3dna_query_num_frames_chunk = None,
        sparse_3dna_dilations = (1,),
        sparse_3dna_rel_pos_bias = False,
        shift_video_tokens = False,
        rotary_pos_emb = False
    ):
        super().__init__()

        self.layers = MList([])

        for ind in range(depth):
            if sparse_3dna_attn:
                dilation = sparse_3dna_dilations[ind % len(sparse_3dna_dilations)]

                self_attn = Sparse3DNA(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    kernel_size = sparse_3dna_kernel_size,
                    dilation = dilation,
                    video_shape = sparse_3dna_video_shape,
                    query_num_frames_chunk = sparse_3dna_query_num_frames_chunk,
                    rel_pos_bias = sparse_3dna_rel_pos_bias,
                )
            else:
                self_attn = Attention(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    dropout = attn_dropout
                )

            cross_attn = None

            if cross_attend:
                if cross_2dna_attn:
                    dilation = cross_2dna_dilations[ind % len(cross_2dna_dilations)]

                    cross_attn = SparseCross2DNA(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout,
                        image_size = cross_2dna_image_size,
                        kernel_size = cross_2dna_kernel_size,
                        dilation = dilation
                    )
                
                elif cross_3dna_attn:
                    
                    cross_attn = SparseCross3DNA(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout,
                        video_shape = cross_3dna_image_size,
                        kernel_size = cross_3dna_kernel_size,
                    )
                
                else:
                    cross_attn = Attention(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout
                    )

            ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, chunk_size = ff_chunk_size)

            if sparse_3dna_attn and shift_video_tokens:
                fmap_size = sparse_3dna_video_shape[-1]
                self_attn = ShiftVideoTokens(self_attn, image_size = fmap_size)
                ff        = ShiftVideoTokens(ff, image_size = fmap_size)

            self.layers.append(MList([
                SandwichNorm(dim = dim, fn = self_attn),
                SandwichNorm(dim = dim, fn = cross_attn) if cross_attend else None,
                SandwichNorm(dim = dim, fn = ff)
            ]))

        self.norm = StableLayerNorm(dim)

    def forward(
        self,
        x,
        mask = None,
        context = None,
        context_mask = None
    ):
        for attn, cross_attn, ff in self.layers:
            x = attn(x, mask = mask) + x

            if exists(cross_attn):
                x = cross_attn(x, context = context, mask = mask, context_mask = context_mask) + x

            x = ff(x) + x

        return self.norm(x)


In [5]:
weight = {676: 1286569, 674: 706515, 129: 367104, 178: 289731, 267: 252592, 33: 244407, 105: 186288, 657: 148341, 232: 127132, 148: 125003, 409: 106906, 577: 105078, 197: 103030, 408: 98176, 115: 94823, 302: 90022, 896: 85156, 752: 81477, 23: 74606, 450: 74520, 553: 72955, 82: 72498, 584: 71854, 484: 69222, 855: 60575, 45: 58398, 135: 56177, 708: 53631, 406: 52498, 124: 52310, 503: 51673, 652: 50737, 564: 50727, 654: 50581, 798: 50264, 587: 49564, 184: 48388, 505: 47029, 880: 46506, 474: 46405, 970: 46203, 428: 46016, 159: 46003, 380: 44474, 715: 44128, 369: 43548, 728: 42507, 303: 42143, 25: 41416, 994: 41078, 261: 41035, 336: 40827, 663: 40308, 968: 40199, 181: 39830, 283: 39522, 128: 39213, 354: 38859, 819: 37764, 44: 37715, 417: 37550, 540: 37264, 595: 37020, 392: 37009, 582: 36752, 222: 36533, 501: 36164, 217: 35185, 101: 34958, 800: 34712, 653: 34639, 114: 34391, 956: 33259, 607: 32589, 216: 32503, 27: 32237, 1007: 32015, 744: 31891, 109: 31707, 499: 31690, 1013: 31499, 658: 31273, 358: 31153, 784: 30926, 103: 29705, 659: 29129, 562: 28409, 988: 27714, 198: 27196, 457: 27012, 231: 26399, 816: 25858, 603: 23635, 452: 23135, 160: 22420, 171: 21585, 891: 20388, 723: 19264, 844: 17305, 610: 17243, 797: 16665, 460: 16509, 290: 14949, 704: 95, 13: 2}
total = 0
for key in weight:
    total += weight[key]
for key in weight:
    weight[key] =(weight[key]/total*100)
w = [weight[min(weight)]]*1024
for key in weight:
    w[key] = weight[key]
for i in range(1024):
    w[i] = min(1/w[i], 1000)
w = torch.tensor(w).unsqueeze(0).to(DEVICE)
print(w.shape)

torch.Size([1, 1024])


In [6]:
# evt modules
from torch import nn
from nuwa_pytorch.nuwa_pytorch import Embedding
from einops import rearrange
# gt_indicator: 
# indicator: [n, 1024] logits + [1, 1024] masks 
# gamma: hyperparameter
# beta0: propotion of non-extreme tokens
# beta1: propotion of extreme tokens 

def cal_evt_loss(indicator, gt_indicator, gamma = 1, beta0 = 0.95, beta1 = 0.05):
    loss1 = -1 * beta0 * torch.pow((1-indicator/gamma),gamma) * gt_indicator * torch.log(indicator)
    loss2 = -1 * beta1 * torch.pow((1-(1-indicator)/gamma),gamma) * (1-gt_indicator) * torch.log(1-indicator)
    loss = loss1 + loss2 
    return loss

# logits b * 1024 * 1536
# pre_token b * 1 * 1536
# ext_tokens 1*15
def evt_loss(logits, gt_tokens, ext_tokens, gamma = 2, beta0 = 0.95, beta1 = 0.05):
    batch = logits.shape[0]
    channel = logits.shape[1]
    number = logits.shape[2]
    device = logits.device
    softmax = nn.Softmax(dim = -1)
    loss = 0
    count = 0
    for batch_id in range(batch):
        for n in range(number):
            gt = gt_tokens[batch_id,n]
            if gt not in ext_tokens: continue
            gt_indicator = 1
            prob = softmax(logits[batch_id,:,n])
            indicator = 0
            for token in ext_tokens:
                indicator += prob[token]
            loss += cal_evt_loss(indicator, gt_indicator, gamma, beta0, beta1)
            count += 1
    return loss/count

In [7]:
# transformer for second stage
# transformer for second stage
from torch import nn
import nuwa_pytorch.nuwa_pytorch
from nuwa_pytorch.nuwa_pytorch import MList, Embedding, AxialPositionalEmbedding, ReversibleTransformer, default, eval_decorator, prob_mask_like, exists
from nuwa_pytorch.nuwa_pytorch import padding_to_multiple_of, einsum, stable_softmax
from einops import rearrange, repeat
import torch.nn.functional as F
from main import instantiate_from_config
from dice_loss import DiceLoss
class nuwa_v2v(nn.Module):
    def __init__(
        self,
        *,
        dim,
        vae = None,
        image_size = 256,
        codebook_size = 1024,
        compression_rate = 16,
        video_out_seq_len = 5,
        video_in_seq_len = 5,
        video_dec_depth = 6,
        video_dec_dim_head = 64,
        video_dec_heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_chunk_size = None,
        embed_gradient_frac = 0.2,
        shift_video_tokens = True,
        sparse_3dna_kernel_size = 3,
        sparse_3dna_query_num_frames_chunk = None,
        sparse_3dna_dilation = 1,
    ):
        super().__init__()
        
        # VAE
        #self.first_stage_model = vae.copy_for_eval()
        self.image_size = image_size
        num_image_tokens = codebook_size
        fmap_size = image_size // compression_rate
        self.video_fmap_size = fmap_size
        self.video_frames_in = video_in_seq_len
        self.video_frames_out = video_out_seq_len
        self.image_embedding = Embedding(num_image_tokens, dim, frac_gradient=embed_gradient_frac)
        # cycle dilation for sparse 3d-nearby attention 
        sparse_3dna_dilations = tuple(range(1, sparse_3dna_dilation + 1)) if not isinstance(sparse_3dna_dilation, (list, tuple)) else sparse_3dna_dilation  # ???
        
        video_shape_in = (video_in_seq_len, fmap_size, fmap_size)
        video_shape_out = (video_out_seq_len, fmap_size, fmap_size)
        self.video_in_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape_in)
        self.video_out_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape_out)
        self.video_bos = nn.Parameter(torch.randn(dim))
        
        # 3DNA Encoder
        video_shape_input = (video_in_seq_len, fmap_size, fmap_size)
    
        # 3DNA Decoder
        #self.video_bos = nn.Parameter(torch.randn(dim))#
        #self.image_embedding = Embedding(num_image_tokens, dim, frac_gradient = embed_gradient_frac)#???
        # self.video_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape)#???

        video_shape_output = (video_out_seq_len, fmap_size, fmap_size)
        self.video_decode_transformer = Transformer(
            dim = dim,
            depth = video_dec_depth,
            heads = video_dec_heads,
            dim_head = video_dec_dim_head,
            causal = True,
            cross_attend = True,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_chunk_size = ff_chunk_size,
            cross_3dna_attn = True,
            cross_3dna_image_size = video_shape_input ,
            cross_3dna_kernel_size = (5,3,3),
            shift_video_tokens = shift_video_tokens,
            sparse_3dna_video_shape = video_shape_output,
            sparse_3dna_attn = True,
            sparse_3dna_kernel_size = sparse_3dna_kernel_size,
            sparse_3dna_dilations = sparse_3dna_dilations,
            sparse_3dna_query_num_frames_chunk = sparse_3dna_query_num_frames_chunk
        )

        self.to_logits = nn.Linear(dim, num_image_tokens)
    
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        *,
        video_input_latent,
        filter_thres = 0.9,
        temperature = 1.,
        num_frames = 6
    ):
        device = video_input_latent.device
        batch = video_input_latent.shape[0]
        
        frame_indices_input = video_input_latent.squeeze(1)
        frame_embeddings_input = self.image_embedding(frame_indices_input)
        frame_embeddings_input = self.video_in_pos_emb().repeat(batch,1,1) + frame_embeddings_input
        
        bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
        video_indices = torch.empty((batch, 0), device = device, dtype = torch.long)
        num_tokens_per_frame = self.video_fmap_size ** 2
        
        num_frames = default(num_frames, self.video_frames_out)
        total_video_tokens =  num_tokens_per_frame * num_frames

        pos_emb = self.video_out_pos_emb()

        for ind in range(total_video_tokens):
            print(ind, '/', total_video_tokens, end = '\r')
            video_indices_input = video_indices
            num_video_tokens = video_indices.shape[1]
            frame_embeddings = self.image_embedding(video_indices_input)
            frame_embeddings = pos_emb[:frame_embeddings.shape[1]] + frame_embeddings
            frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1)

            frame_embeddings = self.video_decode_transformer(
                frame_embeddings,
                context = frame_embeddings_input
            )

            logits = self.to_logits(frame_embeddings)
            logits = logits[:, -1, :]
            filtered_logits = top_k(logits, thres = filter_thres)
            sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
            sample = rearrange(sample, 'b -> b 1')
            video_indices = torch.cat((video_indices, sample), dim = 1)

        return video_indices

    def forward(
        self,
        *,
        #video_input, # raw observation video here
        #video_output, # raw prediction video here
        video_input_latent,
        video_output_latent,
        return_loss = False,
        cond_dropout_prob = 0,
        use_evt_loss = False,
        evt_loss_weight = 0.5
    ):
        device = video_input_latent.device
        batch = video_input_latent.shape[0]
        seq_out_len = self.video_frames_out
        seq_in_len = self.video_frames_in
        
        #get indices
        frame_indices_input = video_input_latent.squeeze(1)
        frame_indices_output = video_output_latent.squeeze(1)
        frame_indices_output = frame_indices_output[:, :-1] if return_loss else frame_indices_output
        
        #indices to embedding
        frame_embeddings_input = self.image_embedding(frame_indices_input)
        frame_embeddings_prediction = self.image_embedding(frame_indices_output)

        #position encoding
        frame_embeddings_input = self.video_in_pos_emb().repeat(batch,1,1) + frame_embeddings_input
        if return_loss:
            frame_embeddings_prediction = self.video_out_pos_emb()[:-1].repeat(batch,1,1) + frame_embeddings_prediction
            # shift right
            bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
            frame_embeddings_prediction = torch.cat((bos, frame_embeddings_prediction), dim = 1)
        else:
            frame_embeddings_prediction = self.video_out_pos_emb().repeat(batch,1,1) + frame_embeddings_prediction
        
        #transformer 
        frame_embeddings_prediction = self.video_decode_transformer(
            frame_embeddings_prediction,
            context = frame_embeddings_input
        )
        logits = self.to_logits(frame_embeddings_prediction)
        
        if not return_loss:
            return logits
        loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), video_output_latent.squeeze(1))
        total_loss = loss
        evt_loss_term = 0 
        if use_evt_loss:
            ext_tokens = torch.tensor([452, 303, 171, 653, 891, 603, 499, 160, 216, 109, 290, 797, 956, 607, 816]).to(DEVICE)
            evt_loss_term = evt_loss_weight * evt_loss(rearrange(logits, 'b n c -> b c n'),  
                                               video_output_latent.squeeze(1), 
                                               ext_tokens)  
            if not torch.isnan(evt_loss_term):
                total_loss = loss + evt_loss_weight *  evt_loss_term    
        
        return total_loss, loss, logits, evt_loss_term

In [8]:
class latentDataset(Dataset):
    def __init__(self, root_dir, radarset):
        # event_times is an array of starting time t(string)
        # transform is the preprocessing functions
        self.root_dir = root_dir
        self.radarset = radarset
    def __len__(self):
        if self.radarset: return len(self.radarset)
        else: return 30000
    def __getitem__(self, idx):
        dir_file = self.root_dir + str(idx) + '.pt'
        if os.path.exists(dir_file):
            video_in_latent, video_out_latent = torch.load(dir_file, map_location='cpu')
            return video_in_latent, video_out_latent
        else:
            print("File not found")
            return None

latent_valid = latentDataset('/home/hbi/vali1517_newvae/validset', dataset_vali)
latent_test = latentDataset('/home/hbi/test1820/testset', dataset_test)
latent_train_de5  = latentDataset('/home/hbi/train0814_Delf_newvae/trainset', dataset_train_del)
latent_train_aa5  = latentDataset('/home/hbi/train0814_Aa_newvae/trainset', dataset_train_aa)
latent_train_dw5  = latentDataset('/home/hbi/train0814_Dwar_newvae/trainset', dataset_train_dw)
latent_train_re5 = latentDataset('/home/hbi/train0814_Regge_newvae/trainset', dataset_train_re)
latent_list = [latent_train_de5, latent_train_aa5, latent_train_dw5, latent_train_re5]
latent_train_aadedwre = torch.utils.data.ConcatDataset(latent_list)

loaders_latent = { 'test' :DataLoader(latent_test, batch_size=1, shuffle=True, num_workers=0),
                  'train_aadedwre' :DataLoader(latent_train_aadedwre, batch_size=1, shuffle=True, num_workers=0),
                  #'train_del5' :DataLoader(latent_train_de5, batch_size=1, shuffle=True, num_workers=0),
                  'train_aa5' :DataLoader(latent_train_aa5, batch_size=1, shuffle=True, num_workers=0)}
print(len(loaders_latent['train_aa5']))

7575


In [9]:
from torchsummary import summary
from torch import optim
#torch.cuda.empty_cache()
nuwa = nuwa_v2v(dim=512 , vae=None,
                image_size = 256,
                codebook_size = 1024,
                compression_rate = 16,
                video_out_seq_len = 6,
                video_in_seq_len = 3,
                video_dec_depth = 12,
                video_dec_dim_head = 64,
                video_dec_heads = 4,
                sparse_3dna_kernel_size = (7,3,3),
                attn_dropout = 0.2,
                ff_dropout = 0.2).to(DEVICE)
#optimizer = optim.Adam(nuwa.parameters(),lr = 1e-3, weight_decay = 1e-6)
optimizer = optim.AdamW(nuwa.parameters(), lr = 1e-3, weight_decay = 0.1)
summary(nuwa)
print('')

Layer (type:depth-idx)                   Param #
├─Embedding: 1-1                         --
|    └─Embedding: 2-1                    524,288
├─AxialPositionalEmbedding: 1-2          17,920
├─AxialPositionalEmbedding: 1-3          19,456
├─Transformer: 1-4                       --
|    └─ModuleList: 2-2                   --
|    |    └─ModuleList: 3-1              3,155,658
|    |    └─ModuleList: 3-2              3,155,658
|    |    └─ModuleList: 3-3              3,155,658
|    |    └─ModuleList: 3-4              3,155,658
|    |    └─ModuleList: 3-5              3,155,658
|    |    └─ModuleList: 3-6              3,155,658
|    |    └─ModuleList: 3-7              3,155,658
|    |    └─ModuleList: 3-8              3,155,658
|    |    └─ModuleList: 3-9              3,155,658
|    |    └─ModuleList: 3-10             3,155,658
|    |    └─ModuleList: 3-11             3,155,658
|    |    └─ModuleList: 3-12             3,155,658
|    └─StableLayerNorm: 2-3              --
|    |    └─LayerN

In [15]:
import time
from tqdm import tqdm
#from pysteps.visualization import plot_precip_field
nuwa.to(DEVICE)
#torch.cuda.empty_cache()
num_epochs = 10
total_step = len(loaders_latent['train_aadedwre'])
loss_sum = 0
evl_sum = 0
count = 0
ga = 64
number = 12 

for g in optimizer.param_groups:
    g['lr'] = 5e-4
    g['weight_decay'] = 0.1
for epoch in range(num_epochs):
    for i, latent in enumerate(loaders_latent['train_aadedwre']):
        nuwa.train()
        torch.set_grad_enabled(True)
        #if i<0:continue
        #if i>=1:break
        latent_in = latent[0].to(DEVICE)
        latent_out = latent[1].to(DEVICE)
        
        total_loss, loss, _, evt_loss_term = nuwa(video_input_latent = latent_in,
            video_output_latent = latent_out, 
            return_loss = True,
            cond_dropout_prob = 0.1,
            use_evt_loss = True,
            evt_loss_weight = 0.75)
        (total_loss / ga).backward()
        
        if (i+1)%ga == 0:
            torch.nn.utils.clip_grad_norm_(nuwa.parameters(), 0.5)
            optimizer.step()
            optimizer.zero_grad()
        
        loss_sum += float(total_loss.item())
        evl_sum += float(evt_loss_term)
        count += 1
        
        # backpropagation, compute gradients   
        print( (i+1) % total_step, '/{}, loss: {:.4f}, evl: {:.4f}'.format(total_step, loss_sum/count, evt_loss_term), end='\r')
        
        if (i+1) % total_step == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Average Loss: {:.4f}' 
                   .format(epoch + 1, num_epochs, i + 1, total_step, loss_sum/count))
            loss_sum = 0
            count = 0 
            # save check point
            torch.save(nuwa.state_dict(),'nuwa_newvae_sparse7_evl_75_epoch{}'.format(number))
            torch.save(optimizer.state_dict(), 'nuwa_newvae_sparse7_evl_75_epoch{}_optim'.format(number))
            for g in optimizer.param_groups:
                g['lr'] = g['lr']*0.98
            number += 1

1012 /37685, loss: 1.5290, evl: 0.4733


KeyboardInterrupt

