In [1]:
import sys
import torch
from nowcasting.config import cfg
from nowcasting.models.forecaster import Forecaster
from nowcasting.models.encoder import Encoder
from nowcasting.models.model import EF
from torch.optim import lr_scheduler
from nowcasting.models.loss import Weighted_mse_mae
import os, shutil
from experiments.net_params import convlstm_encoder_params, convlstm_forecaster_params
import torchvision
import numpy as np
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import pickle

In [2]:
data = np.load('./rainy-nexrad-normed.npz')
x_data = data['x_data']
x_mask = data['x_mask']
x_max = data['x_max']
x_min = data['x_min']
x = np.ma.MaskedArray(x_data, x_mask)

In [3]:
x.shape

(4494, 480, 480)

In [4]:
with open("files.pkl", "rb") as f:
    files = pickle.load(f)
with open("dts.pkl", "rb") as f:
    dts = pickle.load(f)
with open("lost_mark.pkl", "rb") as f:
    lost_mark = pickle.load(f)

In [5]:
window_size = 12

In [6]:
time_delta = np.vectorize(lambda x: x.seconds//60)(np.array(dts[1:]) - np.array(dts[:-1]))
mark = np.argwhere(time_delta>15).reshape(-1) + 1
mark = np.append(mark, len(files))
mark = np.array(sorted(np.unique(mark.tolist() + lost_mark)))
sliding_idx = np.arange(x.shape[0] - window_size + 1).astype(np.int)
remove_idx = np.array([]).astype(np.int)
for i in range(mark.shape[0]):
    remove_idx = np.append(remove_idx, np.arange(window_size - 1) + mark[i] - window_size + 1)
use_idx = np.setdiff1d(sliding_idx, remove_idx)

In [7]:
len(use_idx)

4404

In [8]:
train_loss = 0.0
save_dir = os.path.join(cfg.GLOBAL.MODEL_SAVE_DIR, 'f')
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
model_save_dir = os.path.join(save_dir, 'models')
log_dir = os.path.join(save_dir, 'logs')
all_scalars_file_name = os.path.join(save_dir, "all_scalars.json")
pkl_save_dir = os.path.join(save_dir, 'pkl')
if os.path.exists(all_scalars_file_name):
    os.remove(all_scalars_file_name)
if os.path.exists(log_dir):
    shutil.rmtree(log_dir)
if os.path.exists(model_save_dir):
    shutil.rmtree(model_save_dir)
os.mkdir(model_save_dir)

In [9]:
encoder = Encoder(convlstm_encoder_params[0], convlstm_encoder_params[1]).to(cfg.GLOBAL.DEVICE)
forecaster = Forecaster(convlstm_forecaster_params[0], convlstm_forecaster_params[1]).to(cfg.GLOBAL.DEVICE)
encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

LR_step_size = 50
gamma = 0.8
max_iterations = 20
LR = 1e-4
batch_size = 8
mse_loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_step_size, gamma=gamma)

t_end = 4000
t_each = 327
eval_every = 1
n_window = 12

In [10]:
train_loss = 0.0
writer = SummaryWriter()
all_itera = 1
dataset = torch.from_numpy(x.astype(np.float32)).to(cfg.GLOBAL.DEVICE)
dataset = dataset.unfold(0,n_window,1).permute(3,0,1,2)[:,:,None,:]

In [11]:
dataset.shape

torch.Size([12, 4483, 1, 480, 480])

In [12]:
for t_train in range(3):
    x_train_idx = use_idx[:t_end]
    x_val_idx = use_idx[t_end:t_end+t_each]
    x_val = dataset[:,x_val_idx]
    t_end += t_each
    idx = np.arange(x_train_idx.shape[0])
    for itera in tqdm(range(1, max_iterations+1)):
        np.random.shuffle(idx)
        for b in range(int(np.ceil(idx.shape[0] / batch_size))):
            cur_idx = x_train_idx[idx[b*batch_size:(b+1)*batch_size]]
            train_batch = dataset[:,cur_idx,:]
            train_data = train_batch[:6, ...]
            train_label = train_batch[6:, ...]
            encoder_forecaster.train()
            optimizer.zero_grad()
            output = encoder_forecaster(train_data)
            loss = mse_loss(output, train_label)
            loss.backward()
            torch.nn.utils.clip_grad_value_(encoder_forecaster.parameters(), clip_value=50.0)
            optimizer.step()
            train_loss += loss.item()
            exp_lr_scheduler.step()
            del train_batch
            del train_data
            del train_label
            torch.cuda.empty_cache()
        
        if all_itera % eval_every == 0:
            valid_loss = 0.0
            valid_time = 1
            with torch.no_grad():
                encoder_forecaster.eval()
                for bb in range(int(np.ceil(x_val_idx.shape[0]/batch_size))):
                    val_batch = x_val[:, bb*batch_size:(bb+1)*batch_size]
                    val_data = val_batch[:6, ...]
                    val_label = val_batch[6:, ...]
                    output = encoder_forecaster(val_data)
                    loss = mse_loss(output, val_label)
                    valid_loss += loss.item()
                    valid_time += 1
                    del val_batch
                    del val_data
                    del val_label
                    torch.cuda.empty_cache()

            writer.add_scalars("mse", {
                "train": train_loss/eval_every,
                "valid": valid_loss/valid_time,
            }, all_itera)
            train_loss = 0.0

        all_itera += 1
    del x_val
    torch.cuda.empty_cache()
    torch.save(encoder_forecaster.state_dict(), os.path.join(model_save_dir, 'conv_{}_{}.pth'.format(t_train,itera)))
writer.close()

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

KeyboardInterrupt: 