In [None]:
import sys
import torch
from nowcasting.config import cfg
from nowcasting.models.forecaster import Forecaster
from nowcasting.models.encoder import Encoder
from nowcasting.models.model import EF
from torch.optim import lr_scheduler
from nowcasting.models.loss import Weighted_mse_mae
import os, shutil
from nowcasting.models.trajGRU import TrajGRU
from experiments.net_params import encoder_params, forecaster_params
import torchvision
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

KeyboardInterrupt: 

In [2]:
data = np.load('./rainy-nexrad-normed.npz')
x_data = data['x_data']
x_mask = data['x_mask']
x_max = data['x_max']
x_min = data['x_min']
x = np.ma.MaskedArray(x_data, x_mask)

In [3]:
x.shape

(6, 1459, 1, 60, 30)

In [4]:
train_loss = 0.0
save_dir = os.path.join(cfg.GLOBAL.MODEL_SAVE_DIR, 'g')
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
model_save_dir = os.path.join(save_dir, 'models')
log_dir = os.path.join(save_dir, 'logs')
all_scalars_file_name = os.path.join(save_dir, "all_scalars.json")
pkl_save_dir = os.path.join(save_dir, 'pkl')
if os.path.exists(all_scalars_file_name):
    os.remove(all_scalars_file_name)
if os.path.exists(log_dir):
    shutil.rmtree(log_dir)
if os.path.exists(model_save_dir):
    shutil.rmtree(model_save_dir)
os.mkdir(model_save_dir)

In [5]:
encoder = Encoder(encoder_params[0], encoder_params[1]).to(cfg.GLOBAL.DEVICE)
forecaster = Forecaster(forecaster_params[0], forecaster_params[1]).to(cfg.GLOBAL.DEVICE)
encoder_forecaster = EF(encoder, forecaster).to(cfg.GLOBAL.DEVICE)

LR_step_size = 20000
gamma = 0.7
max_iterations = 5000
LR = 1e-4
batch_size = cfg.GLOBAL.BATCH_SZIE
mse_loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(encoder_forecaster.parameters(), lr=LR)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_step_size, gamma=gamma)

t_end = 729
t_each = 146
eval_every = 1

In [None]:
train_loss = 0.0
writer = SummaryWriter()
all_itera = 1
dataset = torch.from_numpy(x.astype(np.float32)).to(cfg.GLOBAL.DEVICE)
for t_train in range(5):
    x_train = dataset[:,:t_end]
    x_val = dataset[:,t_end:t_end+t_each]
    t_end += t_each
    
    for itera in tqdm(range(1, max_iterations+1)):
        
        idx = np.arange(x_train.shape[0])
        np.random.shuffle(idx)
        for b in range(int(np.ceil(x_train.shape[0] / batch_size))):
            cur_idx = idx[b*batch_size:(b+1)*batch_size]
            train_batch = x_train[:,cur_idx,:]
            train_data = train_batch[:5, ...]
            train_label = train_batch[5:6, ...]

            encoder_forecaster.train()
            optimizer.zero_grad()
            output = encoder_forecaster(train_data)
            loss = mse_loss(output, train_label)
            loss.backward()
            torch.nn.utils.clip_grad_value_(encoder_forecaster.parameters(), clip_value=50.0)
            optimizer.step()
            train_loss += loss.item()
            exp_lr_scheduler.step()

        valid_loss = 0.0
        valid_time = 0
        with torch.no_grad():
            encoder_forecaster.eval()
            for b in range(int(np.ceil(x_val.shape[0]/batch_size))):
                val_batch = x_val[:,b*batch_size: (b+1)*batch_size]
                val_data = train_batch[:5, ...]
                val_label = train_batch[5:6, ...]
                output = encoder_forecaster(val_data)
                loss = mse_loss(output, val_label)
                valid_loss += loss.item()
                valid_time += 1

        writer.add_scalars("mse", {
            "train": train_loss/eval_every,
            "valid": valid_loss/valid_time,
        }, all_itera)
        train_loss = 0.0

        all_itera += 1

    torch.save(encoder_forecaster.state_dict(), os.path.join(model_save_dir, 'traj_{}_{}.pth'.format(t_train,itera)))
writer.close()

100%|██████████| 5000/5000 [29:03<00:00,  2.87it/s]
 51%|█████     | 2540/5000 [14:42<14:12,  2.89it/s]