In [1]:
import sys
sys.path.append(".")

# also disable grad to save memory
import torch
torch.set_grad_enabled(False)

DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

import yaml
import torch
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ

def load_config(config_path, display=False):
    config = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config)))
    return config

def load_vqgan(config, ckpt_path=None, is_gumbel=False):
    if is_gumbel:
        model = GumbelVQ(**config.model.params)
    else:
        model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()

def reconstruct_with_vqgan(x, model):
    # could also use model(x) for reconstruction but use explicit encoding and decoding here
    z, _, [_, _, indices], _ = model.encode(x)
    #print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
    xrec = model.decode(z)
    return xrec

In [2]:
config1024 = load_config("logs/vqgan_imagenet_f16_1024/configs/model.yaml", display=False)
#config16384 = load_config("logs/vqgan_imagenet_f16_16384/configs/model.yaml", display=True)

model1024 = load_vqgan(config1024, ckpt_path="logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt").to(DEVICE)
#model16384 = load_vqgan(config16384, ckpt_path="logs/vqgan_imagenet_f16_16384/checkpoints/last.ckpt").to(DEVICE)

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.


In [3]:
import io
import os, sys
import requests
import PIL
from PIL import Image
from PIL import ImageDraw, ImageFont
import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

In [4]:
# Define dataset
import torch
import sys
from nuwa_pytorch import VQGanVAE
import h5py
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
def eventGeneration(start_time, obs_time = 3 ,lead_time = 6, time_interval = 30):
    # Generate event based on starting time point, return a list: [[t-4,...,t-1,t], [t+1,...,t+72]]
    # Get the start year, month, day, hour, minute
    year = int(start_time[0:4])
    month = int(start_time[4:6])
    day = int(start_time[6:8])
    hour = int(start_time[8:10])
    minute = int(start_time[10:12])
    #print(datetime(year=year, month=month, day=day, hour=hour, minute=minute))
    times = [(datetime(year, month, day, hour, minute) + timedelta(minutes=time_interval * (x+1))) for x in range(lead_time)]
    lead = [dt.strftime('%Y%m%d%H%M') for dt in times]
    times = [(datetime(year, month, day, hour, minute) - timedelta(minutes=time_interval * x)) for x in range(obs_time)]
    obs = [dt.strftime('%Y%m%d%H%M') for dt in times]
    obs.reverse()
    return lead, obs

from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
from torchvision.transforms import ToTensor, Compose, CenterCrop
class radarDataset(Dataset):
    def __init__(self, root_dir, event_times, obs_number = 3, pred_number = 6, transform=None):
        # event_times is an array of starting time t(string)
        # transform is the preprocessing functions
        self.root_dir = root_dir
        self.transform = transform
        self.event_times = event_times
        self.obs_number = obs_number
        self.pred_number = pred_number
    def __len__(self):
        return len(self.event_times)
    def __getitem__(self, idx):
        start_time = str(self.event_times[idx])
        time_list_pre, time_list_obs = eventGeneration(start_time, self.obs_number, self.pred_number)
        output = []
        time_list = time_list_obs + time_list_pre
        #print(time_list)
        for time in time_list:
            year = time[0:4]
            month = time[4:6]
            #path = self.root_dir + year + '/' + month + '/' + 'RAD_NL25_RAC_MFBS_EM_5min_' + time + '_NL.h5'
            path = self.root_dir + year + '/' + month + '/' + 'RAD_NL25_RAP_5min_' + time + '.h5'
            image = np.array(h5py.File(path)['image1']['image_data'])
            #image = np.ma.masked_where(image == 65535, image)

            image = image[264:520,242:498]
            image[image == 65535] = 0
            image = image.astype('float32')
            image = image/100*12
            image = np.clip(image, 0, 128)
            image = image/40
            output.append(image)
        output = torch.permute(torch.tensor(np.array(output)), (1, 2, 0))
        output = self.transform(np.array(output))
        return output, start_time
#root_dir = '/users/hbi/data/RAD_NL25_RAC_MFBS_EM_5min/'
#dataset = radarDataset(root_dir, ["200808031600"], transform = Compose([ToTensor(),CenterCrop(256)]))


In [5]:
# develop dataset
from torch.cuda.amp import autocast
from torch.autograd import Variable
from torch import optim
import pandas as pd
root_dir = '/home/hbi/RAD_NL25_RAP_5min/' 

# M-L, top 5%
df_train = pd.read_csv('/users/hbi/taming-transformers/training_Delfland08-14.csv', header = None)
event_times = df_train[0].to_list()
dataset_train = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))           

df_test = pd.read_csv('/users/hbi/taming-transformers/testing_Delfland18-20.csv', header = None)
event_times = df_test[0].to_list()
dataset_test = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

df_vali = pd.read_csv('/users/hbi/taming-transformers/validation_Delfland15-17.csv', header = None)
event_times = df_vali[0].to_list()
dataset_vali = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

df_train_aa = pd.read_csv('/users/hbi/taming-transformers/training_Aa08-14.csv', header = None)
event_times = df_train_aa[0].to_list()
dataset_train_aa = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))    

df_train_dw = pd.read_csv('/users/hbi/taming-transformers/training_Dwar08-14.csv', header = None)
event_times = df_train_dw[0].to_list()
dataset_train_dw = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))   

df_train_re = pd.read_csv('/users/hbi/taming-transformers/training_Regge08-14.csv', header = None)
event_times = df_train_re[0].to_list()
dataset_train_re = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))    

# S-M, top 20%
df_train_full = pd.read_csv('/users/hbi/taming-transformers/training_Delfland08-14_20.csv', header = None)
event_times = df_train_full[0].to_list()
dataset_train_full = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))           

df_test_full = pd.read_csv('/users/hbi/taming-transformers/testing_Delfland18-20_20.csv', header = None)
event_times = df_test_full[0].to_list()
dataset_test_full = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

df_vali_full = pd.read_csv('/users/hbi/taming-transformers/validation_Delfland15-17_20.csv', header = None)
event_times = df_vali_full[0].to_list()
dataset_vali_full = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

# extreme, top 1%
df_train_ext = pd.read_csv('/users/hbi/taming-transformers/training_Delfland08-14_ext.csv', header = None)
event_times = df_train_ext[0].to_list()
dataset_train_ext = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))           

df_test_ext = pd.read_csv('/users/hbi/taming-transformers/testing_Delfland18-20_ext.csv', header = None)
event_times = df_test_ext[0].to_list()
dataset_test_ext = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

df_vali_ext = pd.read_csv('/users/hbi/taming-transformers/validation_Delfland15-17_ext.csv', header = None)
event_times = df_vali_ext[0].to_list()
dataset_vali_ext = radarDataset(root_dir, event_times, transform = Compose([ToTensor()]))

print("Extreme:", len(dataset_train_ext), len(dataset_test_ext), len(dataset_vali_ext))
print("Full Dataset:", len(dataset_train_full), len(dataset_test_full), len(dataset_vali_full))
loaders = { 'train_de5' :DataLoader(dataset_train, batch_size=1, shuffle=False, num_workers=0),
            'test' :DataLoader(dataset_test, batch_size=1, shuffle=True, num_workers=8), 
           'valid' :DataLoader(dataset_vali, batch_size=1, shuffle=False, num_workers=0),
           'train_full' :DataLoader(dataset_train_full, batch_size=1, shuffle=False, num_workers=0),
            'test_full' :DataLoader(dataset_test_full, batch_size=1, shuffle=False, num_workers=0), 
           'valid_full' :DataLoader(dataset_vali_full, batch_size=1, shuffle=False, num_workers=0),
          'train_ext' :DataLoader(dataset_train_ext, batch_size=2, shuffle=True, num_workers=0),
            'test_ext' :DataLoader(dataset_test_ext, batch_size=2, shuffle=True, num_workers=0), 
           'valid_ext' :DataLoader(dataset_vali_ext, batch_size=2, shuffle=True, num_workers=0), 
           'train_aa5' :DataLoader(dataset_train_aa, batch_size=1, shuffle=False, num_workers=0),
           'train_dw5' :DataLoader(dataset_train_dw, batch_size=1, shuffle=False, num_workers=0),
           'train_re5' :DataLoader(dataset_train_re, batch_size=1, shuffle=False, num_workers=0)}
print(len(loaders['train_de5']),len(loaders['train_aa5']),len(loaders['train_dw5']))
print(len(loaders['test'])*2)

Extreme: 1535 823 627
Full Dataset: 32183 13352 14170
7873 7575 7557
6986


In [7]:
checkpoint = torch.load('/users/hbi/taming-transformers/vae1024_256km_epoch2', map_location="cpu")
model1024.load_state_dict(checkpoint)

<All keys matched successfully>

In [8]:
from torch import nn
from nuwa_pytorch.nuwa_pytorch import Sparse3DNA, Attention, SparseCross2DNA, cast_tuple, mult_reduce, calc_same_padding, unfoldNd, FeedForward, ShiftVideoTokens, SandwichNorm, StableLayerNorm
from nuwa_pytorch.nuwa_pytorch import AxialPositionalEmbedding, default, padding_to_multiple_of, exists, MList, einsum
from einops import rearrange
from nuwa_pytorch.vqgan_vae import stable_softmax
class SparseCross3DNA(nn.Module):
    def __init__(
        self,
        dim,
        video_shape,
        kernel_size = 3,
        dilation = 1,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False,
        query_num_frames_chunk = None,
        rel_pos_bias = False
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.dropout = nn.Dropout(dropout)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)
        self.dilation = cast_tuple(dilation, size = 3)
        self.kernel_size = cast_tuple(kernel_size, size = 3)
        self.kernel_numel = mult_reduce(self.kernel_size)
       # relative positional bias per head, if needed
        self.rel_pos_bias = AxialPositionalEmbedding(heads, shape = self.kernel_size) if rel_pos_bias else None
        # calculate padding
        self.padding_frame = calc_same_padding(self.kernel_size[0], self.dilation[0])
        self.padding_height = calc_same_padding(self.kernel_size[1], self.dilation[1])
        self.padding_width = calc_same_padding(self.kernel_size[2], self.dilation[2])
        self.video_padding = (self.padding_width, self.padding_width, self.padding_height, self.padding_height, self.padding_frame, self.padding_frame)
        # save video shape and calculate max number of tokens
        self.video_shape = video_shape
        max_frames, fmap_size, _ = video_shape
        max_num_tokens = torch.empty(video_shape).numel()
        self.max_num_tokens = max_num_tokens
        # how many query tokens to process at once to limit peak memory usage, by multiple of frame tokens (fmap_size ** 2)
        self.query_num_frames_chunk = default(query_num_frames_chunk, max_frames)
        # precalculate causal mask
        indices = torch.arange(max_num_tokens)
        shaped_indices = rearrange(indices, '(f h w) -> 1 1 f h w', f = max_frames, h = fmap_size, w = fmap_size)
        padded_indices = F.pad(shaped_indices, self.video_padding, value = max_num_tokens) # padding has value of max tokens so to be masked out
        unfolded_indices = unfoldNd(padded_indices, kernel_size = self.kernel_size, dilation = self.dilation)
        unfolded_indices = rearrange(unfolded_indices, '1 k n -> n k')
        # if causal, compare query and key indices and make sure past cannot see future
        # if not causal, just mask out the padding

        if causal:
            mask = rearrange(indices, 'n -> n 1') < unfolded_indices
        else:
            mask = unfolded_indices == max_num_tokens

        #mask = F.pad(mask, (1, 0), value = False) # bos tokens never get masked out
        self.register_buffer('mask', mask)
        #print(self.mask.shape)

    def forward(self, x, context, **kwargs):
        b, n, _, h, device = *x.shape, self.heads, x.device
        n_context = context.shape[1]
        # more variables

        dilation = self.dilation
        kernel_size = self.kernel_size
        video_padding = self.video_padding
        fmap_size = self.video_shape[1]
        tokens_per_frame = fmap_size ** 2

        padding = padding_to_multiple_of(n - 1, tokens_per_frame)
        num_frames = (n + padding) // tokens_per_frame
        num_frames_input = (n_context + padding) // tokens_per_frame
        
        # derive queries / keys / values

        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        
        bos_only = n == 1
        if bos_only:
            return self.to_out(v)
        # split out heads

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
        
       
        # take care of bos

        #q = q[:, 1:]
        #bos_value = q[:, :1]
        
        # scale queries
        q = q * self.scale

        # reshape keys and values to video and add appropriate padding along all dimensions (frames, height, width)
        k, v = map(lambda t: rearrange(t, 'b (f h w) d -> b d f h w',  h = fmap_size, w = fmap_size), (k, v))
        k, v = map(lambda t: F.pad(t, video_padding), (k, v))
        #print(k.shape)
        # axial relative pos bias

        rel_pos_bias = None

        if exists(self.rel_pos_bias):
            rel_pos_bias = rearrange(self.rel_pos_bias(), 'j h -> h 1 j')
            rel_pos_bias = F.pad(rel_pos_bias, (1, 0), value = 0.)

        # put the attention processing code in a function
        # to allow for processing queries in chunks of frames

        out = []

        def attend(q, k, v, mask, kernel_size):
            chunk_length = q.shape[1]
            k, v = map(lambda t: unfoldNd(t, kernel_size = kernel_size, dilation = dilation), (k, v))
            k, v = map(lambda t: rearrange(t, 'b (d j) i -> b i j d', j = self.kernel_numel), (k, v))
            k, v = map(lambda t: t[:, :chunk_length], (k, v))

            # calculate sim

            sim = einsum('b i d, b i j d -> b i j', q, k)
            #print(q.shape, k.shape, sim.shape)
            # add rel pos bias, if needed

            if exists(rel_pos_bias):
                sim = sim + rel_pos_bias

            # causal mask

            
            if exists(mask):
                mask_value = -torch.finfo(sim.dtype).max               
                mask = rearrange(mask, 'i j -> 1 i j')
                sim = sim.masked_fill(mask, mask_value)
            
            # attention

            attn = stable_softmax(sim, dim = -1)

            attn = rearrange(attn, '(b h) ... -> b h ...', h = h)
            attn = self.talking_heads(attn)
            attn = rearrange(attn, 'b h ... -> (b h) ...')

            attn = self.dropout(attn)

            # aggregate values

            return einsum('b i j, b i j d -> b i d', attn, v)

        # process queries in chunks

        frames_per_chunk = num_frames_input
        chunk_size = frames_per_chunk * tokens_per_frame
        q_chunks = q.split(chunk_size, dim = 1)
        mask = self.mask
        #print(q.shape, chunk_size)
        for ind, q_chunk in enumerate(q_chunks):
            #print(ind, q_chunks[ind].shape, mask.shape)
            q_chunk = q_chunks[ind]
            length = q_chunk.shape[1]
            q_chunk = F.pad(input=q_chunk, pad=(0, 0, 0, chunk_size-length), mode='constant', value=float("-Inf"))
            mask_chunk = mask

            # slice the keys and values to the appropriate frames, accounting for padding along frames dimension

            k_slice, v_slice = k,v
            # calculate output chunk
            out_chunk = attend(
                q = q_chunk,
                k = k_slice,
                v = v_slice,
                mask = mask_chunk,
                kernel_size = kernel_size,
            )

            out_chunk = out_chunk[:,:length,:]
            out.append(out_chunk)
            
        # combine all chunks
        out = torch.cat(out, dim = 1)
        # append bos value

        #out = torch.cat((bos_value, out), dim = 1)  # bos will always adopt its own value, since it pays attention only to itself
        # merge heads

        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)
    
class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        causal = False,
        heads = 8,
        dim_head = 64,
        ff_mult = 4,
        cross_attend = False,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_chunk_size = None,
        cross_2dna_attn = False,
        cross_2dna_image_size = None,
        cross_2dna_kernel_size = 3,
        cross_2dna_dilations = (1,),
        cross_3dna_attn = False,
        cross_3dna_image_size = None,
        cross_3dna_kernel_size = 3,
        cross_3dna_dilations = (1,),
        sparse_3dna_attn = False,
        sparse_3dna_kernel_size = 3,
        sparse_3dna_video_shape = None,
        sparse_3dna_query_num_frames_chunk = None,
        sparse_3dna_dilations = (1,),
        sparse_3dna_rel_pos_bias = False,
        shift_video_tokens = False,
        rotary_pos_emb = False
    ):
        super().__init__()

        self.layers = MList([])

        for ind in range(depth):
            if sparse_3dna_attn:
                dilation = sparse_3dna_dilations[ind % len(sparse_3dna_dilations)]

                self_attn = Sparse3DNA(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    kernel_size = sparse_3dna_kernel_size,
                    dilation = dilation,
                    video_shape = sparse_3dna_video_shape,
                    query_num_frames_chunk = sparse_3dna_query_num_frames_chunk,
                    rel_pos_bias = sparse_3dna_rel_pos_bias,
                )
            else:
                self_attn = Attention(
                    dim = dim,
                    heads = heads,
                    dim_head = dim_head,
                    causal = causal,
                    dropout = attn_dropout
                )

            cross_attn = None
            
            if cross_attend:
                if cross_2dna_attn:
                    dilation = cross_2dna_dilations[ind % len(cross_2dna_dilations)]

                    cross_attn = SparseCross2DNA(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout,
                        image_size = cross_2dna_image_size,
                        kernel_size = cross_2dna_kernel_size,
                        dilation = dilation
                    )
                
                elif cross_3dna_attn:
                    
                    cross_attn = SparseCross3DNA(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout,
                        video_shape = cross_3dna_image_size,
                        kernel_size = cross_3dna_kernel_size,
                    )
                
                else:
                    cross_attn = Attention(
                        dim = dim,
                        heads = heads,
                        dim_head = dim_head,
                        dropout = attn_dropout
                    )

            ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, chunk_size = ff_chunk_size)

            if sparse_3dna_attn and shift_video_tokens:
                fmap_size = sparse_3dna_video_shape[-1]
                self_attn = ShiftVideoTokens(self_attn, image_size = fmap_size)
                ff        = ShiftVideoTokens(ff, image_size = fmap_size)

            self.layers.append(MList([
                SandwichNorm(dim = dim, fn = self_attn),
                SandwichNorm(dim = dim, fn = cross_attn) if cross_attend else None,
                SandwichNorm(dim = dim, fn = ff)
            ]))

        self.norm = StableLayerNorm(dim)

    def forward(
        self,
        x,
        mask = None,
        context = None,
        context_mask = None
    ):
        for attn, cross_attn, ff in self.layers:
            x = attn(x, mask = mask) + x

            if exists(cross_attn):
                x = cross_attn(x, context = context, mask = mask, context_mask = context_mask) + x

            x = ff(x) + x

        return self.norm(x)

In [9]:
# transformer for second stage
# transformer for second stage
from torch import nn
import nuwa_pytorch.nuwa_pytorch
from nuwa_pytorch.nuwa_pytorch import MList, Embedding, AxialPositionalEmbedding, ReversibleTransformer, default, eval_decorator, prob_mask_like, exists
from nuwa_pytorch.nuwa_pytorch import padding_to_multiple_of, einsum
from nuwa_pytorch.vqgan_vae import stable_softmax
from einops import rearrange, repeat
import torch.nn.functional as F
from main import instantiate_from_config
#from dice_loss import DiceLoss
class nuwa_v2v(nn.Module):
    def __init__(
        self,
        *,
        dim,
        vae = None,
        image_size = 256,
        codebook_size = 1024,
        compression_rate = 16,
        video_out_seq_len = 5,
        video_in_seq_len = 5,
        video_dec_depth = 6,
        video_dec_dim_head = 64,
        video_dec_heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_chunk_size = None,
        embed_gradient_frac = 0.2,
        shift_video_tokens = True,
        sparse_3dna_kernel_size = 3,
        sparse_3dna_query_num_frames_chunk = None,
        sparse_3dna_dilation = 1,
    ):
        super().__init__()
        
        # VAE
        #self.first_stage_model = vae.copy_for_eval()
        self.image_size = image_size
        num_image_tokens = codebook_size
        fmap_size = image_size // compression_rate
        self.video_fmap_size = fmap_size
        self.video_frames_in = video_in_seq_len
        self.video_frames_out = video_out_seq_len
        self.image_embedding = Embedding(num_image_tokens, dim, frac_gradient=embed_gradient_frac)
        # cycle dilation for sparse 3d-nearby attention 
        sparse_3dna_dilations = tuple(range(1, sparse_3dna_dilation + 1)) if not isinstance(sparse_3dna_dilation, (list, tuple)) else sparse_3dna_dilation  # ???
        
        video_shape_in = (video_in_seq_len, fmap_size, fmap_size)
        video_shape_out = (video_out_seq_len, fmap_size, fmap_size)
        self.video_in_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape_in)
        self.video_out_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape_out)
        self.video_bos = nn.Parameter(torch.randn(dim))
        
        # 3DNA Encoder
        video_shape_input = (video_in_seq_len, fmap_size, fmap_size)
    
        # 3DNA Decoder
        #self.video_bos = nn.Parameter(torch.randn(dim))#
        #self.image_embedding = Embedding(num_image_tokens, dim, frac_gradient = embed_gradient_frac)#???
        # self.video_pos_emb = AxialPositionalEmbedding(dim, shape = video_shape)#???

        video_shape_output = (video_out_seq_len, fmap_size, fmap_size)
        self.video_decode_transformer = Transformer(
            dim = dim,
            depth = video_dec_depth,
            heads = video_dec_heads,
            dim_head = video_dec_dim_head,
            causal = True,
            cross_attend = True,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_chunk_size = ff_chunk_size,
            cross_3dna_attn = True,
            cross_3dna_image_size = video_shape_input ,
            cross_3dna_kernel_size = (5,3,3),
            shift_video_tokens = shift_video_tokens,
            sparse_3dna_video_shape = video_shape_output,
            sparse_3dna_attn = True,
            sparse_3dna_kernel_size = sparse_3dna_kernel_size,
            sparse_3dna_dilations = sparse_3dna_dilations,
            sparse_3dna_query_num_frames_chunk = sparse_3dna_query_num_frames_chunk
        )

        self.to_logits = nn.Linear(dim, num_image_tokens)
    
    @torch.no_grad()
    @eval_decorator
    def generate(
        self,
        *,
        video_input_latent,
        video_output_latent = None,
        filter_thres = 0.9,
        temperature = 1.,
        num_frames = 6,
        prior_weight = 1,
        alpha = 1
    ):
        #video_indice_last_frame = video_input_latent[0,:,512:768]
        device = video_input_latent.device
        batch = video_input_latent.shape[0]
        
        frame_indices_input = video_input_latent.squeeze(1)
        frame_embeddings_input = self.image_embedding(frame_indices_input)
        frame_embeddings_input = self.video_in_pos_emb().repeat(batch,1,1) + frame_embeddings_input
        
        bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
        video_indices = torch.empty((batch, 0), device = device, dtype = torch.long)
        num_tokens_per_frame = self.video_fmap_size ** 2
        
        num_frames = default(num_frames, self.video_frames_out)
        total_video_tokens =  num_tokens_per_frame * num_frames
        
        pos_emb = self.video_out_pos_emb()
        
        prior = frame_indices_input[:, 512:].repeat(1,num_frames)
        
        for ind in range(total_video_tokens):
            print(ind, '/', total_video_tokens, end = '\r')
            
            #if video_output_latent != None and ind<256:
            #    #print(video_indice_last_frame.shape, video_indice_last_frame[:,ind])
            #    video_indices = torch.cat((video_indices, video_output_latent[:,:,ind]), dim = 1)
            #    continue
            
            video_indices_input = video_indices
            num_video_tokens = video_indices.shape[1]
            #video_indices_input = video_output_latent[:,: num_video_tokens]
            
            frame_embeddings = self.image_embedding(video_indices_input)
            frame_embeddings = pos_emb[:frame_embeddings.shape[1]] + frame_embeddings
            frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1)
            
            frame_embeddings = self.video_decode_transformer(
                frame_embeddings,
                context = frame_embeddings_input
            )
        
            logits = self.to_logits(frame_embeddings)
            logits = logits[:, -1, :]
            logits[:,int(prior[:,ind])] *= prior_weight
            
            #w = torch.load('weight_post.pt').to(DEVICE) 
            #logits[logits<0] = 0
            #filtered_logits = top_k(logits, thres = filter_thres)
            #filtered_logits = filtered_logits*(w.unsqueeze(0))
            
            filtered_logits = top_k(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits, dim=-1)
            sample = torch.multinomial(probs, 1).squeeze(-1)
            
            #sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
            sample = rearrange(sample, 'b -> b 1')
            video_indices = torch.cat((video_indices, sample), dim = 1)
            
            if (ind+1) % num_tokens_per_frame == 0:
                n =  int((ind+1) / num_tokens_per_frame)
                prior = video_indices[:, (n-1)*int(num_tokens_per_frame):].repeat(1,num_frames)
                prior_weight = max(prior_weight*alpha, 1)
        
        return video_indices

    def forward(
        self,
        *,
        #video_input, # raw observation video here
        #video_output, # raw prediction video here
        video_input_latent,
        video_output_latent,
        return_loss = False,
        cond_dropout_prob = 0
    ):
        device = video_input_latent.device
        batch = video_input_latent.shape[0]
        seq_out_len = self.video_frames_out
        seq_in_len = self.video_frames_in
        
        #get indices
        frame_indices_input = video_input_latent.squeeze(1)
        frame_indices_output = video_output_latent.squeeze(1)
        frame_indices_output = frame_indices_output[:, :-1] if return_loss else frame_indices_output
        
        #indices to embedding
        frame_embeddings_input = self.image_embedding(frame_indices_input)
        frame_embeddings_prediction = self.image_embedding(frame_indices_output)

        #position encoding
        frame_embeddings_input = self.video_in_pos_emb().repeat(batch,1,1) + frame_embeddings_input
        if return_loss:
            frame_embeddings_prediction = self.video_out_pos_emb()[:-1].repeat(batch,1,1) + frame_embeddings_prediction
            # shift right
            bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
            frame_embeddings_prediction = torch.cat((bos, frame_embeddings_prediction), dim = 1)
        else:
            frame_embeddings_prediction = self.video_out_pos_emb().repeat(batch,1,1) + frame_embeddings_prediction
        
        #transformer 
        frame_embeddings_prediction = self.video_decode_transformer(
            frame_embeddings_prediction,
            context = frame_embeddings_input
        )
        #print(frame_embeddings_prediction.shape)
        logits = self.to_logits(frame_embeddings_prediction)
        
        #print(t2-t1, t3-t2, t4-t3, t5-t4)
        if not return_loss:
            return logits
        #weight = torch.load('weight_clip_10.pt').to(device)
        loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), video_output_latent.squeeze(1))
        
        #dice_loss = DiceLoss(smooth = 1, with_logits = True, ohem_ratio = 0, 
        #                     alpha = 0.01, reduction = "mean",  index_label_position=True,  square_denominator = True)
        #print(rearrange(logits, 'b n c -> (b n) c').shape, rearrange(video_output_latent.squeeze(1), 'b n -> (b n)').shape)
        #loss = dice_loss(rearrange(logits, 'b n c -> (b n) c'), rearrange(video_output_latent.squeeze(1), 'b n -> (b n)'))
        
        return loss, logits

In [10]:
from torchsummary import summary
#torch.cuda.empty_cache()
nuwa = nuwa_v2v(dim=512 , vae=None,
                image_size = 256,
                codebook_size = 1024,
                compression_rate = 16,
                video_out_seq_len = 6,
                video_in_seq_len = 3,
                video_dec_depth = 12,
                video_dec_dim_head = 64,
                video_dec_heads = 4,
                sparse_3dna_kernel_size = (7,3,3),
                attn_dropout = 0,
                ff_dropout = 0).to(DEVICE)

for layer in nuwa.video_decode_transformer.layers[:12]:
    for param in layer.parameters():
        param.requires_grad = True


In [11]:
from pysteps.verification.detcatscores import det_cat_fct
from pysteps.verification.detcontscores import det_cont_fct
from pysteps.verification.spatialscores import intensity_scale
from pysteps.visualization import plot_precip_field
from nuwa_pytorch.nuwa_pytorch import top_k, gumbel_sample
import random

device = DEVICE
def decode_to_img(index, num_frame = 6):
        model1024.to(device)
        b = index.shape[0]
        #index = self.permuter(index, reverse=True)
        for t in range(num_frame):
            index_frame = index[:,t*nuwa.video_fmap_size*nuwa.video_fmap_size:(t+1)*nuwa.video_fmap_size*nuwa.video_fmap_size]
            bhwc = (b,nuwa.video_fmap_size,nuwa.video_fmap_size,3)
            quant_z = model1024.quantize.get_codebook_entry(
                index_frame.reshape(-1), shape=bhwc)
            quant_z = rearrange(quant_z, '(b h w) d -> b d h w',b = b, h = nuwa.video_fmap_size, w = nuwa.video_fmap_size)
            #print(quant_z.shape)
            if t==0:
                x = model1024.decode(quant_z)
                x = x.unsqueeze(1)
            else:
                x = torch.cat((x,model1024.decode(quant_z).unsqueeze(1)),1)
        return x

Pysteps configuration file found at: /users/junzheyin/anaconda3/envs/myenv/lib/python3.8/site-packages/pysteps/pystepsrc



In [13]:
nuwa.eval()
checkpoint = torch.load( '/bulk/junzheyin/nuwa_newvae_sparse7_evl_75_epoch12',map_location="cpu")
nuwa.load_state_dict(checkpoint, strict=True)

<All keys matched successfully>

In [18]:
import torch
from nuwa_pytorch import VQGanVAE
from nuwa_pytorch.optimizer import get_optimizer
vae = VQGanVAE(
    dim = 256,
    channels = 1,               # default is 3, but can be changed to any value for the training of the segmentation masks (sketches)
    image_size = 256,           # image size
    num_layers = 4,             # number of downsampling layers
    num_resnet_blocks = 2,      # number of resnet blocks
    vq_codebook_size = 1024,    # codebook size
    vq_decay = 0.8 ,             # codebook exponential decay
    use_hinge_loss = True,
    use_vgg_and_gan = True
).to(DEVICE)
checkpoint = torch.load('/bulk/junzheyin/aaa/vae_epoch80', map_location = 'cpu')
vae.load_state_dict(checkpoint, strict=False)

<All keys matched successfully>

In [19]:
#torch.cuda.empty_cache()
from collections import Counter
import time
from pysteps.verification.detcatscores import det_cat_fct
from pysteps.verification.detcontscores import det_cont_fct
from pysteps.verification.spatialscores import intensity_scale
from pysteps.visualization import plot_precip_field
from tqdm import tqdm
pcc_average = 0
counter = Counter()
counter2 = Counter()
import random
index = 0
time_list = ['201808101910','202002091900','201906120945','201801022335','201808101805']
time_list = ['201808101910']
for k in range(1):
    for i, (images,time) in enumerate(loaders['test']):
        #print("{}/{}".format(i, len(loaders['test'])), end='\r')
        #if i<index:continue
        #if i>=index+80:break
        if time[0] not in time_list:continue
        image1 = images[0].unsqueeze(1)
        #a1 = Variable(image1.repeat(1,3,1,1)).to(DEVICE)   # batch x
        #a_r1 = reconstruct_with_vqgan(a1, model1024.to(DEVICE))
        #quant1, emb_loss1, info1, h1 = model1024.encode(a1)
        a1 = Variable(image1.repeat(1, 3, 1, 1)).to(DEVICE)
        quant1, emb_loss1, info1 = model1024.encode(a1)
        #indice1 = info1[2].view(-1, 16, 16)
        indice_obs = info1[2][:768].unsqueeze(0)
        indice_pre = info1[2][768:].unsqueeze(0)
        #print(indice_obs.shape, indice_pre.shape)
        #print(video_predict1.shape)
        # Generate
        n = 6
        video_generate1 = nuwa.generate(video_input_latent = indice_obs, 
                                        video_output_latent = indice_pre,
                                        filter_thres = 0.98,
                                        temperature = 0.5,
                                        num_frames = n,
                                        prior_weight = 2,
                                        alpha = 0.8)
        #video_generate1 = nuwa(video_input_latent= indice_obs,video_output_latent= indice_pre)
        #video_generate1 = torch.argmax(video_generate1, dim = 2)
        video_predict1 = decode_to_img(video_generate1, num_frame = n).squeeze(0)
        video_predict1 = vae.codebook_indices_to_video(video_generate1).squeeze(0)
        for t in range(n):
            a1_display = a1[t+3,0,:,:].to('cpu').detach().numpy()*40
            #a_r1_display = a_r1[t+3,0,:,:].to('cpu').detach().numpy()*40
            a_p1_display = video_predict1[t,0,:,:].to('cpu').detach().numpy()*40
            
            #np.save('prediction_vqgan2_{} min'.format(-90+(t+4)*30),a_p1_display)
            scores_cont = det_cont_fct(a_p1_display, a1_display, thr=0.1)
            if True:
                print("Start Time:", time[0], "Lead Time: {} mins".format(-90+(t+4)*30), "PCC:", np.around(scores_cont['corr_p'],3))
                plt.figure(figsize=(12, 4))
                plt.subplot(131)
                plot_precip_field(a1_display, title="Original")
                plt.subplot(132)
                #plot_precip_field(a_r1_display, title="Reconstruction")
                #plt.subplot(133)
                plot_precip_field(a_p1_display, title="Prediction")
                plt.tight_layout()
                plt.show()

        

1535 / 1536

RuntimeError: shape '[1, 16, 16, 3]' is invalid for input of size 65536