In [1]:
from ConvGRU import ConvGRU
from Encoder_Decoder import EF, Encoder, Forecaster
import numpy as np
import pickle
import torch
from torch.optim import lr_scheduler
from tqdm.notebook import tqdm
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from torchsummary import summary
%load_ext autoreload
%autoreload 2
from torchviz import make_dot, make_dot_from_trace
from collections import OrderedDict
import os

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Using device: cuda
Tesla P100-PCIE-16GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [3]:
data = np.load('./rainy-nexrad-normed.npz')
x_data = data['x_data']
x_mask = data['x_mask']
x = np.ma.MaskedArray(x_data, x_mask)

In [4]:
with open("files.pkl", "rb") as f:
    files = pickle.load(f)
with open("dts.pkl", "rb") as f:
    dts = pickle.load(f)
with open("lost_mark.pkl", "rb") as f:
    lost_mark = pickle.load(f)

In [5]:
window_size = 12

In [6]:
time_delta = np.vectorize(lambda x: x.seconds//60)(np.array(dts[1:]) - np.array(dts[:-1]))
mark = np.argwhere(time_delta>15).reshape(-1) + 1
mark = np.append(mark, len(files))
mark = np.array(sorted(np.unique(mark.tolist() + lost_mark)))
sliding_idx = np.arange(x.shape[0] - window_size + 1).astype(np.int)
remove_idx = np.array([]).astype(np.int)
for i in range(mark.shape[0]):
    remove_idx = np.append(remove_idx, np.arange(window_size - 1) + mark[i] - window_size + 1)
use_idx = np.setdiff1d(sliding_idx, remove_idx)

In [7]:
len(use_idx)

4404

In [8]:
# ConvGRU (input_size, hidden_size, kernel_size, b_t_c_h_w)
batch_size = 16
encoder_params = [
    [
        OrderedDict({'conv1_leaky_1': [1, 8, 7, 5, 1]}),
        OrderedDict({'conv2_leaky_1': [64, 192, 5, 3, 1]}),
        OrderedDict({'conv3_leaky_1': [192, 192, 3, 2, 1]}),
    ],

    [
        ConvGRU(input_channel=8, num_filter=64, b_h_w=(batch_size, 96, 96),
                 kernel_size=3, stride=1, padding=1),
        ConvGRU(input_channel=192, num_filter=192, b_h_w=(batch_size, 32, 32),
                 kernel_size=3, stride=1, padding=1),
        ConvGRU(input_channel=192, num_filter=192, b_h_w=(batch_size, 16, 16),
                 kernel_size=3, stride=1, padding=1),
    ]
]
decoder_params = [
    [
        OrderedDict({'deconv1_leaky_1': [192, 192, 4, 2, 1]}),
        OrderedDict({'deconv2_leaky_1': [192, 64, 5, 3, 1]}),
        OrderedDict({
            'deconv3_leaky_1': [64, 8, 7, 5, 1],
            'conv3_leaky_2': [8, 8, 3, 1, 1],
            'conv3_3': [8, 1, 1, 1, 0]
        }),
    ],

    [
        ConvGRU(input_channel=192, num_filter=192, b_h_w=(batch_size, 16, 16),
                 kernel_size=3, stride=1, padding=1),
        ConvGRU(input_channel=192, num_filter=192, b_h_w=(batch_size, 32, 32),
                 kernel_size=3, stride=1, padding=1),
        ConvGRU(input_channel=64, num_filter=64, b_h_w=(batch_size, 96, 96),
                 kernel_size=3, stride=1, padding=1),
    ]
]

In [9]:
# encoder = Encoder(encoder_params[0], encoder_params[1]).to(device)
# decoder = Forecaster(decoder_params[0], decoder_params[1]).to(device)
# encoder_decoder = EF(encoder, decoder).to(device)
# LR_step_size = 10
# gamma = 0.8
# max_iterations = 20
# LR = 1e-3
# mse_loss = torch.nn.MSELoss()
# optimizer = torch.optim.Adam(encoder_decoder.parameters(), lr=LR)
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_step_size, gamma=gamma)


In [10]:
# summary(encoder_decoder, (window_size, 1, 480, 480), batch_size, device='cuda')

In [11]:
# make_dot(encoder_decoder(torch.randn(1, window_size, 1, 480, 480).cuda()),
#          params=None)

In [12]:
t_end = 4000
t_each = 327
eval_every = 1

In [13]:
dsize = x_data.dtype.itemsize
x_data_slided = np.lib.stride_tricks.as_strided(x_data, 
        (x_data.shape[0]-window_size+1,window_size, x_data.shape[1], x_data.shape[2]), 
        (x_data.shape[1]*x_data.shape[2]*dsize, x_data.shape[1]*x_data.shape[2]*dsize, x_data.shape[2]*dsize, dsize))
x_data_slided = x_data_slided.swapaxes(0,1)[:,:,None]

In [14]:
x_data_slided.shape

(12, 4483, 1, 480, 480)

In [15]:
dsize = x_mask.dtype.itemsize
x_mask_slided = np.lib.stride_tricks.as_strided(x_mask, 
        (x_mask.shape[0]-window_size+1,window_size, x_mask.shape[1], x_mask.shape[2]), 
        (x_mask.shape[1]*x_mask.shape[2]*dsize, x_mask.shape[1]*x_mask.shape[2]*dsize, x_mask.shape[2]*dsize, dsize))
x_mask_slided = x_mask_slided.swapaxes(0,1)[:,:,None]

In [16]:
train_loss = 0.0
writer = SummaryWriter('./convGRU_tb', flush_secs=10)
all_itera = 1
history = []

In [17]:
for t_train in range(3):
    encoder = Encoder(encoder_params[0], encoder_params[1]).to(device)
    decoder = Forecaster(decoder_params[0], decoder_params[1]).to(device)
    encoder_decoder = EF(encoder, decoder).to(device)
    LR_step_size = 10
    gamma = 0.8
    max_iterations = 20
    LR = 1e-4
    mse_loss = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(encoder_decoder.parameters(), lr=LR)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=LR_step_size, gamma=gamma)

    x_train_idx = use_idx[:t_end]
    x_val_idx = use_idx[t_end:t_end+t_each]
    t_end += t_each
    idx = np.arange(x_train_idx.shape[0])
    for itera in tqdm(range(1, max_iterations+1), desc='epoch: '):
        np.random.shuffle(idx)
        for b in tqdm(range(int(np.floor(idx.shape[0] / batch_size))), desc='batch: '):
            cur_idx = x_train_idx[idx[b*batch_size:min((b+1)*batch_size, idx.shape[0])]]
            train_batch = torch.from_numpy(x_data_slided[:,cur_idx,:].astype(np.float32)).to(device)
            train_data = train_batch[:6, ...]
            train_label = train_batch[6:, ...]
            encoder_decoder.train()
            optimizer.zero_grad()
            output = encoder_decoder(train_data)
            loss = mse_loss(output, train_label)
            loss.backward()
            torch.nn.utils.clip_grad_value_(encoder_decoder.parameters(), clip_value=50.0)
            optimizer.step()
            train_loss += loss.item()
            exp_lr_scheduler.step()
            del train_data
            del train_label
            del train_batch
            torch.cuda.empty_cache()
        
        valid_loss = 0.0
        valid_time = 0
        with torch.no_grad():
            encoder_decoder.eval()
            for bb in range(int(np.floor(x_val_idx.shape[0]/batch_size))):
                val_batch = torch.from_numpy(x_data_slided[:,x_val_idx[bb*batch_size:min((bb+1)*batch_size, idx.shape[0])]].astype(np.float32)).to(device)
                val_data = val_batch[:6, ...]
                val_label = val_batch[6:, ...]
                output = encoder_decoder(val_data)
                loss = mse_loss(output, val_label)
                valid_loss += loss.item()
                valid_time += 1
                del val_data
                del val_label
                del val_batch
                torch.cuda.empty_cache()

        writer.add_scalars("mse", {
            "train": train_loss/idx.shape[0],
            "valid": valid_loss/x_val_idx.shape[0],
        }, all_itera)
        history.append([all_itera, train_loss/idx.shape[0], valid_loss/x_val_idx.shape[0]])
        train_loss = 0.0

        all_itera += 1
    torch.save(encoder_decoder.state_dict(), os.path.join(model_save_dir, 'convGRU_{}_{}.pth'.format(t_train,all_itera)))
writer.close()
with open('convGRU_train.pickle', 'wb') as handle:
    pickle.dump(history, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    

HBox(children=(IntProgress(value=0, description='epoch: ', max=20, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='batch: ', max=250, style=ProgressStyle(description_width='ini…




HBox(children=(IntProgress(value=0, description='batch: ', max=250, style=ProgressStyle(description_width='ini…




HBox(children=(IntProgress(value=0, description='batch: ', max=250, style=ProgressStyle(description_width='ini…




HBox(children=(IntProgress(value=0, description='batch: ', max=250, style=ProgressStyle(description_width='ini…

KeyboardInterrupt: 