In [1]:
import cv2
import sys
import torch
from nowcasting.config import cfg
from nowcasting.models.forecaster import Forecaster
from nowcasting.models.encoder import Encoder
from nowcasting.models.model import EF
from torch.optim import lr_scheduler
from nowcasting.models.loss import Weighted_mse_mae
import os, shutil
from experiments.net_params import convlstm_encoder_params, convlstm_forecaster_params
from nowcasting.models.trajGRU import TrajGRU
from experiments.net_params import encoder_params, forecaster_params
import torchvision
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import pickle

In [2]:
data = np.load('./rainy-nexrad-normed.npz')
x_data = data['x_data']
x_mask = data['x_mask']
x_max = data['x_max']
x_min = data['x_min']
x = np.ma.MaskedArray(x_data, x_mask)

In [3]:
x.shape

(4494, 480, 480)

In [4]:
encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(cfg.GLOBAL.DEVICE)
forecaster = Forecaster(convlstm_forecaster_params[0], convlstm_forecaster_params[1]).to(cfg.GLOBAL.DEVICE)
conv = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)
conv.load_state_dict(torch.load('../model_test/f/models/conv_0_20.pth'))


<All keys matched successfully>

In [5]:
with open("files.pkl", "rb") as f:
    files = pickle.load(f)
with open("dts.pkl", "rb") as f:
    dts = pickle.load(f)
with open("lost_mark.pkl", "rb") as f:
    lost_mark = pickle.load(f)

In [6]:
window_size = 12

In [7]:
time_delta = np.vectorize(lambda x: x.seconds//60)(np.array(dts[1:]) - np.array(dts[:-1]))
mark = np.argwhere(time_delta>15).reshape(-1) + 1
mark = np.append(mark, len(files))
mark = np.array(sorted(np.unique(mark.tolist() + lost_mark)))
sliding_idx = np.arange(x.shape[0] - window_size + 1).astype(np.int)
remove_idx = np.array([]).astype(np.int)
for i in range(mark.shape[0]):
    remove_idx = np.append(remove_idx, np.arange(window_size - 1) + mark[i] - window_size + 1)
use_idx = np.setdiff1d(sliding_idx, remove_idx)

In [8]:
len(use_idx)

4404

In [9]:
dataset = torch.from_numpy(x_data[use_idx[4000:4308]].astype(np.float32)).to(cfg.GLOBAL.DEVICE)
x_torch = dataset.unfold(0,window_size,1).permute(3,0,1,2)[:,:,None,:]

In [10]:
x_torch.size()

torch.Size([12, 297, 1, 480, 480])

In [11]:
data = x_torch[:6, ...]
label = x_torch[-6:, ...]

In [12]:
bs=8
with torch.no_grad():
    image_conv = None
    for i in range(int(np.ceil(x_torch.size(1)/bs))):
        output_conv = data[:6,i*bs:min((i+1)*bs, x_torch.size(1))]
        for j in range(3):
            output_conv = conv(output_conv)
        output_conv = output_conv.cpu().numpy()
        if image_conv is not None:
            image_conv = np.concatenate([image_conv, output_conv], axis=1)
        else:
            image_conv = output_conv
image_conv = image_conv * (x_max - x_min) + x_min
print(image_conv.shape)

(6, 285, 1, 480, 480)


In [12]:
bs=8
with torch.no_grad():
    image_conv = None
    for i in range(int(np.ceil(x_torch.size(1)/bs))):
        output_conv = conv(data[:6,i*bs:min((i+1)*bs, x_torch.size(1))])
        output_conv = output_conv.cpu().numpy()
        if image_conv is not None:
            image_conv = np.concatenate([image_conv, output_conv], axis=1)
        else:
            image_conv = output_conv
image_conv = image_conv * (x_max - x_min) + x_min
print(image_conv.shape)

(6, 297, 1, 480, 480)


In [13]:
image_label = label.cpu().numpy() * (x_max - x_min) + x_min
print(image_label.shape)

(6, 297, 1, 480, 480)


In [14]:
w = 480
h = 480

In [16]:
os.mkdir('./3hr_truth')
mn = min(image_conv.min(), image_label.min())
mx = max(image_conv.max(), image_label.max())
for i in range(image_label.shape[1]):
    x = image_label[-1,i,0]
    y = np.array(((x - mn) / (mx - mn)) * 255,dtype=np.uint8)
    f = cv2.cvtColor(y, cv2.COLOR_GRAY2RGB)
    cv2.imwrite('./3hr_truth/3hr_truth_'+str(i).zfill(3)+'.png', f)

In [15]:
out = cv2.VideoWriter('compare_1hr.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 10, (w*2+3*3,h+3*2))
mn = min(image_conv.min(), image_label.min())
mx = max(image_conv.max(), image_label.max())
for i in range(image_label.shape[1]):
    left = np.ones((h,3)) * mn
    f = np.concatenate([left,image_conv[-1,i,0],left,image_label[-1,i,0],left], axis=1)
    f = cv2.cvtColor(np.array((f - mn) / (mx - mn)*255,dtype=np.uint8), cv2.COLOR_GRAY2RGB)
    top = np.zeros((3,w*2+3*3,3),dtype=np.uint8)
    f = np.concatenate([top,f,top], axis=0)
    out.write(f)
out.release()

In [17]:
out = cv2.VideoWriter('compare_3hr.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 10, (w*2+3*3,h+3*2))
mn = min(image_conv.min(), image_label.min())
mx = max(image_conv.max(), image_label.max())
for i in range(image_label.shape[1]):
    left = np.ones((h,3)) * mn
    f = np.concatenate([left,image_conv[-1,i,0],left,image_label[-1,i,0],left], axis=1)
    f = cv2.cvtColor(np.array((f - mn) / (mx - mn)*255,dtype=np.uint8), cv2.COLOR_GRAY2RGB)
    top = np.zeros((3,w*2+3*3,3),dtype=np.uint8)
    f = np.concatenate([top,f,top], axis=0)
    out.write(f)
out.release()

In [12]:
out = cv2.VideoWriter('all_model.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 10, (w*4+50*5,h))
border = np.ones((h,50,3),dtype=np.uint8) * 255
for i in range(1, image_label.shape[0]):
    label = cv2.cvtColor(cv2.resize(np.array(image_label[i] * 255,dtype=np.uint8).T, dsize=(w, h), interpolation=cv2.INTER_CUBIC), cv2.COLOR_GRAY2RGB)
    pred_conv = cv2.cvtColor(cv2.resize(np.array(image_conv[i] * 255,dtype=np.uint8).T, dsize=(w, h), interpolation=cv2.INTER_CUBIC), cv2.COLOR_GRAY2RGB)
    pred_traj = cv2.cvtColor(cv2.resize(np.array(image_traj[i] * 255,dtype=np.uint8).T, dsize=(w, h), interpolation=cv2.INTER_CUBIC), cv2.COLOR_GRAY2RGB)
    pred_t1 = cv2.cvtColor(cv2.resize(np.array(image_label[i - 1] * 255,dtype=np.uint8).T, dsize=(w, h), interpolation=cv2.INTER_CUBIC), cv2.COLOR_GRAY2RGB)
    
    out.write(255 - np.concatenate([border,label,border,pred_t1,border,pred_conv,border,pred_traj,border], axis=1))
    
out.release()

In [16]:
def rmse(x,y):
    return np.sqrt(np.mean(np.square(x-y)))

In [27]:
thres = 0.5

In [28]:
def rmse_rain(x,y):
    mask = y.reshape(-1)
    mask = mask>thres
    return np.sqrt(np.mean(np.square(x.reshape(-1)[mask]-y.reshape(-1)[mask])))

In [29]:
def rmse_not_rain(x,y):
    mask = y.reshape(-1)
    mask = mask<thres
    return np.sqrt(np.mean(np.square(x.reshape(-1)[mask]-y.reshape(-1)[mask])))

In [30]:
def csi(x,y):
    xx = (x.reshape(-1) > thres).astype(np.int) * 2
    yy = (y.reshape(-1) > thres).astype(np.int)
    res = xx+yy
    csi = np.sum(res==3) / (np.sum(res==3) + np.sum(res==1) + np.sum(res==2))
    return csi

In [21]:
rmse(image_conv[-1], image_label[-1])

8.977473

In [22]:
rmse_rain(image_conv[-1], image_label[-1])

28.546831

In [23]:
rmse_not_rain(image_conv[-1], image_label[-1])

8.609939

In [24]:
csi(image_conv[-1], image_label[-1])

0.056461245440165286

In [34]:
def dtr(dBZ):
    return ((10**(dBZ/10))/200)**(5/8)

In [35]:
rmse(dtr(image_conv[-1]), dtr(image_label[-1]))

1.6527108

In [36]:
rmse_rain(dtr(image_conv[-1]), dtr(image_label[-1]))

17.684595

In [37]:
rmse_not_rain(dtr(image_conv[-1]), dtr(image_label[-1]))

0.05471131

In [19]:
t_end = 729
t_each = 146
t1_train_rmse = []
conv_train_rmse = []
traj_train_rmse = []
t1_val_rmse = []
conv_val_rmse = []
traj_val_rmse = []
for t_train in range(5):
    x_train = x[:,:t_end]
    x_val = x[:,t_end:t_end+t_each]
    t_end += t_each
    
    x_torch_train = torch.from_numpy(x_train.astype(np.float32)).to(cfg.GLOBAL.DEVICE)
    train_data = x_torch_train[:5, ...]
    train_label = x_torch_train[5:6, ...].cpu().numpy()
    
    x_torch_val = torch.from_numpy(x_val.astype(np.float32)).to(cfg.GLOBAL.DEVICE)
    val_data = x_torch_val[:5, ...]
    val_label = x_torch_val[5:6, ...].cpu().numpy()
    
    with torch.no_grad():
        train_traj = traj(train_data)
    train_traj = train_traj.cpu().numpy()
    
    with torch.no_grad():
        train_conv = conv(train_data)
    train_conv = train_conv.cpu().numpy()
    
    with torch.no_grad():
        val_traj = traj(val_data)
    val_traj = val_traj.cpu().numpy()
    
    with torch.no_grad():
        val_conv = conv(val_data)
    val_conv = val_conv.cpu().numpy()
    
    t1_train_rmse.append(rmse(train_label[:,:-1], train_label[:,1:]))
    conv_train_rmse.append(rmse(train_conv, train_label))
    traj_train_rmse.append(rmse(train_traj, train_label))
    t1_val_rmse.append(rmse(val_label[:,:-1], val_label[:,1:]))
    conv_val_rmse.append(rmse(val_conv, val_label))
    traj_val_rmse.append(rmse(val_traj, val_label))
    

In [20]:
print(t1_train_rmse)
print(conv_train_rmse)
print(traj_train_rmse)
print(t1_val_rmse)
print(conv_val_rmse)
print(traj_val_rmse)

[0.00023013682, 0.00023665791, 0.00025148623, 0.00025377452, 0.00024126697]
[0.00057227106, 0.0005798239, 0.00059455057, 0.00059644185, 0.00056816125]
[0.0005609737, 0.0005636566, 0.0005862547, 0.0005862842, 0.00055647606]
[0.00026763562, 0.0003275521, 0.00026833592, 9.369576e-05, 0.00016235112]
[0.0006161524, 0.0006761207, 0.0006095043, 0.00024402342, 0.00039725448]
[0.0005768662, 0.0007067046, 0.00058649, 0.00019336567, 0.0004099274]


In [36]:
print(t1_train_rmse)
print(conv_train_rmse)
print(traj_train_rmse)
print(t1_val_rmse)
print(conv_val_rmse)
print(traj_val_rmse)

[0.000225227, 0.0002308476, 0.0002459908, 0.00024819077, 0.00023555584]
[0.0005405624, 0.0005410925, 0.0005606859, 0.0005631597, 0.0005365473]
[0.0005112838, 0.00051146996, 0.00053044566, 0.0005357328, 0.00050764246]
[0.00025787603, 0.0003231838, 0.00026215983, 8.18461e-05, 0.00015611586]
[0.0005437319, 0.0006661426, 0.00058016414, 0.00023228738, 0.0003940223]
[0.0005123988, 0.00063234754, 0.0005713402, 0.00015307272, 0.00037681145]
