In [1]:
import os
from datetime import datetime

import neptune
import torch
from torch import nn
from torch.nn.modules.loss import BCEWithLogitsLoss

import trainer
import utils

%load_ext autoreload
%autoreload 2

In [2]:
    # train_loader, val_loader, test_loader = utils.prepare_movingmnist(
    #     hparams['batch_size'], data_dir=data_dir)
train_loader, val_loader, test_loader = utils.prepare_mm(128, data_dir='ConvLSTM/data/')

10000
1000
1000


In [45]:
class RNN(nn.Module):

    def __init__(self, in_dim, hidden_dim, out_dim, train_criterion, test_criterion, i_lr=0.001):

        super().__init__()

        # Set hyperparameters
        self.h_act = torch.nn.Tanh()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim

        # Manual RNN
        self.W_xh = torch.nn.Linear(self.in_dim, self.hidden_dim, bias=False)
        self.Wb_hh = torch.nn.Linear(self.hidden_dim, self.hidden_dim, bias=True)
        self.Vc_hh = torch.nn.Linear(self.hidden_dim, self.hidden_dim, bias=True)
        self.Wb_hy = torch.nn.Linear(self.hidden_dim, self.out_dim, bias=True)

        # For managing training process
        self.device = None
        self.training_criterion = train_criterion
        self.test_criterion = test_criterion
        self.hidden_criterion = torch.nn.MSELoss()
        self.i_lr = i_lr # to set local target

    def F_hh(self, x_t, h_tm1):
        return self.h_act(self.W_xh(x_t) + self.Wb_hh(h_tm1))
    def G_hh(self, x_t, h_t):
        return self.h_act(self.W_xh(x_t) + self.Vc_hh(h_t))
        
    def set_device(self, device):
        self.device = device

    def get_device(self):
        return self.device

    def forward(self, x):

        x = torch.transpose(x, 0, 1)/1.  # seq_len first
        in_len, batch_sz, h, w = x.shape
        out_len = in_len
        

        h_ = torch.zeros(batch_sz, self.hidden_dim, device=self.get_device())

        for in_idx in range(in_len-1):
            x_ = x[in_idx].view(batch_sz, -1)  # [batch_sz, self.in_dim]
            h_ = self.h_act(self.W_xh(x_) + self.Wb_hh(h_))

        Ys = []
        y_ = x[in_len-1].view(batch_sz, -1)

        for _ in range(out_len):
            h_ = self.h_act(self.W_xh(y_) + self.Wb_hh(h_))
            y_ = self.Wb_hy(h_)
            Ys.append(y_.view(batch_sz, h, w))
            y_ = y_.sigmoid()

        # [batch_sz, out_len, 64, 64] values [0., 1.]
        return torch.transpose(torch.stack(Ys, 0), 0, 1)

    
# Train epoch
    def train_epoch(self, train_loader, optimizers, logger, epoch, verbose=True):  # might be None
        G_losses_list = []
        F_losses_list = []
        Y_losses_list = []
        opt_g, opt_f, opt_y = optimizers
        for batch_idx, batch in enumerate(train_loader):
            batch = [batch[2].squeeze(), batch[1].squeeze()]
            x = torch.transpose(torch.cat(batch, 1), 0, 1)/1.
            # print(x.max())
            x = x.to(self.get_device())
            total_len, batch_sz, h, w = x.shape
            h_ = torch.zeros(batch_sz, self.hidden_dim,
                             device=self.get_device())
            p_out = []
            h_preactivations = []
            
            x_preactivations = []
            Y_losses = []
            
            h_local_tar = []
            h_acts = []
            
            # Forward pass
            for idx in range(total_len - 1):
                # idx = 0 -> 18
                x_i = x[idx].view(batch_sz, h*w) # [batch_sz, self.in_dim]
                x_ip1 = x[idx+1].view(batch_sz, h*w)
                
                # Forward activations 
                h_preactivation = self.Wb_hh(h_.detach())
                x_preactivation = self.W_xh(x_i.detach())
#                 x_preactivations.append(x_preactivation)
                
                h_activation = self.h_act(h_preactivation + x_preactivation)
                h_acts.append(h_activation)
                
                # step y loss and local target
                h_copy_w_grad = h_activation.detach()
                h_copy_w_grad.requires_grad = True
                
                y_pred = self.Wb_hy(h_copy_w_grad)
                step_loss = self.training_criterion(y_pred, x_ip1)
                step_loss.backward(retain_graph=True) # reuse later
                Y_losses.append(step_loss)
                
                local_target = (h_copy_w_grad - self.i_lr * h_copy_w_grad.grad).detach()
                h_copy_w_grad.grad.zero_() # for later update
                h_copy_w_grad.requries_grad=False 
                self.Wb_hy.weight.grad.zero_() # for later update
                self.Wb_hy.bias.grad.zero_() # for later update
                h_local_tar.append(local_target)
                
                # Recursive step
                h_ = h_activation.detach()
            
            # Backward pass
            G_losses = []
            F_losses = []
            
            for idx in range(total_len-3, -1, -1):
                # idx = 17 -> 0
                arg1 = x[idx+1].view([batch_sz, h*w])
#                 print(arg1.shape)
                arg2 = h_acts[idx+1].detach()
#                 print(arg2.shape)
                G_h = self.G_hh(arg1, arg2)
                G_h_hat = self.G_hh(x[idx+1].view(batch_sz, h*w), h_local_tar[idx+1].detach())
                
                total = (total_len - 1) - idx
                beta = 1./total # weight for current loss
                alpha = 1 - beta
                h_local_tar[idx] = (beta * h_local_tar[idx] + alpha * (h_acts[idx].detach() - G_h + G_h_hat)).detach()
                g_loss = self.hidden_criterion(G_h, h_acts[idx].detach())
                G_losses.append(g_loss)
                
                f_loss = self.hidden_criterion(h_acts[idx+1], h_local_tar[idx+1])
                F_losses.append(f_loss)
            
            G_loss_batch = torch.stack(G_losses).mean()
            F_loss_batch = torch.stack(F_losses).mean()
            Y_loss_batch = torch.stack(Y_losses).mean()
            
            opt_g.zero_grad()
            opt_f.zero_grad()
            opt_y.zero_grad()
            G_loss_batch.backward()
            F_loss_batch.backward()
            Y_loss_batch.backward()
            opt_g.step()
            
        
            if logger:
                epoch_float = epoch + (batch_idx+1)/len(train_loader)
                logger.log_metric('train_g_loss', epoch_float, G_loss_batch)
                logger.log_metric('train_f_loss', epoch_float, F_loss_batch)
                logger.log_metric('train_y_loss', epoch_float, Y_loss_batch)
            F_losses_list.append(F_loss_batch)
            Y_losses_list.append(F_loss_batch)
            G_losses_list.append(G_loss_batch) 
                 
#         for batch_idx, batch in enumerate(train_loader):
#             batch = [batch[2].squeeze(), batch[1].squeeze()]
#             x = torch.transpose(torch.cat(batch, 1), 0, 1)/1.
#             # print(x.max())
#             x = x.to(self.get_device())
#             total_len, batch_sz, h, w = x.shape
#             h_ = torch.zeros(batch_sz, self.hidden_dim,
#                              device=self.get_device())
#             p_out = []
#             h_preactivations = []
            
#             x_preactivations = []
#             Y_losses = []
            
#             h_local_tar = []
#             h_acts = []
            
#             # Forward pass
#             for idx in range(total_len - 1):
#                 # idx = 0 -> 18
#                 x_i = x[idx].view(batch_sz, h*w) # [batch_sz, self.in_dim]
#                 x_ip1 = x[idx+1].view(batch_sz, h*w)
                
#                 # Forward activations 
#                 h_preactivation = self.Wb_hh(h_.detach())
#                 x_preactivation = self.W_xh(x_i.detach())
# #                 x_preactivations.append(x_preactivation)
                
#                 h_activation = self.h_act(h_preactivation + x_preactivation)
#                 h_acts.append(h_activation)
                
#                 # step y loss and local target
#                 h_copy_w_grad = h_activation.detach()
#                 h_copy_w_grad.requires_grad = True
                
#                 y_pred = self.Wb_hy(h_copy_w_grad)
#                 step_loss = self.training_criterion(y_pred, x_ip1)
#                 step_loss.backward(retain_graph=True) # reuse later
#                 Y_losses.append(step_loss)
                
#                 local_target = (h_copy_w_grad - self.i_lr * h_copy_w_grad.grad).detach()
#                 h_copy_w_grad.grad.zero_() # for later update
#                 h_copy_w_grad.requries_grad=False 
#                 self.Wb_hy.weight.grad.zero_() # for later update
#                 self.Wb_hy.bias.grad.zero_() # for later update
#                 h_local_tar.append(local_target)
                
#                 # Recursive step
#                 h_ = h_activation.detach()
            
#             # Backward pass
#             G_losses = []
#             F_losses = []
# #             hidden_criterion
#             for idx in range(total_len-3, -1, -1):
#                 # idx = 17 -> 0
#                 G_h = self.G_hh(x[idx+1].view(batch_sz, h*w), h_acts[idx+1].detach())
#                 G_h_hat = self.G_hh(x[idx+1].view(batch_sz, h*w), h_local_tar[idx+1].detach())
                
#                 total = (total_len - 1) - idx
#                 beta = 1./total # weight for current loss
#                 alpha = 1 - beta
#                 h_local_tar[idx] = (beta * h_local_tar[idx] + alpha * (h_acts[idx].detach() - G_h + G_h_hat)).detach()
#                 g_loss = self.hidden_criterion(G_h, h_acts[idx].detach())
#                 G_losses.append(g_loss)
                
#                 f_loss = self.hidden_criterion(h_acts[idx+1], h_local_tar[idx+1])
#                 F_losses.append(f_loss)
            
#             G_loss_batch = torch.stack(G_losses).mean()
#             F_loss_batch = torch.stack(F_losses).mean()
#             Y_loss_batch = torch.stack(Y_losses).mean()
            
#             optimizer.zero_grad()
#             F_loss_batch.backward()
#             Y_loss_batch.backward()
#             optimizer.step()
            
        
#             if logger:
#                 epoch_float = epoch + (batch_idx+1)/len(train_loader)
#                 logger.log_metric('train_f_loss', epoch_float, F_loss_batch)
#                 logger.log_metric('train_y_loss', epoch_float, Y_loss_batch)
#             F_losses_list.append(F_loss_batch)
#             Y_losses_list.append(F_loss_batch)
                
        return G_losses_list, F_losses_list, Y_losses_list

    def validation(self, val_loader, logger, epoch, verbose=True):

        self.eval()
        with torch.no_grad():
            total_loss = torch.zeros(1, device=self.get_device())
            log_image_batch_idx = torch.randint(0, len(val_loader), [1])
            for batch_idx, batch in enumerate(val_loader):
                # batch = [batch[1], batch[0]]
                batch = [batch[2].squeeze(), batch[1].squeeze()]
                x = torch.transpose(torch.cat(batch, 1), 0, 1)/1.
                x = x.to(self.get_device())
                total_len, batch_sz, h, w = x.shape
                h_ = torch.zeros(batch_sz, self.hidden_dim,
                                 device=self.get_device())
                p_out = []

                for idx in range(total_len-1):
                    x_ = x[idx].view(batch_sz, h*w)  # [batch_sz, self.in_dim]
                    h_ = self.h_act(self.Wb_xh(x_) + self.Wb_hh(h_))
                    p_ = self.Wb_hy(h_).view(batch_sz, h, w)
                    p_out.append(p_)

                y_logits = torch.stack(p_out, 0)
                gt = x[1:]
                loss = self.training_criterion(torch.transpose(
                    y_logits, 0, 1), torch.transpose(gt, 0, 1))  # [19, batch_sz, h, w]
                total_loss += loss

                # log images
                if logger and (batch_idx == log_image_batch_idx):

                    sample_seq_idx = torch.randint(0, batch_sz, [1])
                    
                    x_gt = batch[1][sample_seq_idx, :, :, :].squeeze().to(self.get_device()) # seq, h, w
                    assert(x_gt.max() <= 1.)
                    # print(x_gt.shape)

                    x_in = batch[0][sample_seq_idx:sample_seq_idx+1, :, :, :].to(self.get_device()) # 1, seq, h, w
                    
                    x_pred = self.forward(x_in).squeeze().sigmoid_()
                    assert(x_pred.max() <= 1.)
                    # print(x_pred.shape)

                    x_in_ = x_in.squeeze()/1.
                    # print(x_in_.shape)
                    assert(x_in_.max() <= 1.)
                    

                    utils.log_images_to_meptune(
                        "Predicted frames", [x_in_, x_gt, x_pred], logger)

            mean_loss = total_loss/len(val_loader)
            if logger:
                logger.log_metric('val_loss', epoch, mean_loss)
        self.train()
        if verbose:
            print(f"{epoch:d} \t\t {mean_loss.item():.04f} \t\t")
        return mean_loss

    def test(self, test_loader, logger=None, verbose=True):

        self.eval()
        with torch.no_grad():
            total_loss = torch.zeros(1, device=self.get_device())
            for idx, batch in enumerate(test_loader):
                batch = [batch[2].squeeze(), batch[1].squeeze()]
                x, y = [item.to(self.get_device()) for item in batch]
                # y = y/255.
                y = y/1.
                assert (y.max() <= 1.)
                # [batch_sz, seq_len, h, w] in [0, 1]
                pred_frames = self.forward(x)
                _, _, h, w = y.shape
                # assert (h == 64 and w == 64)
                pixel_loss = self.test_criterion(pred_frames, y)
                frame_loss = pixel_loss * h * w
                total_loss += frame_loss
            mean_loss = total_loss/len(test_loader)
            if logger:
                logger.log_metric(f'test_loss', mean_loss)
        self.train()
        if verbose:
            print(f"Testing loss: {mean_loss.item():.04f}")
        return mean_loss

    def configure_optimizers(self, lr=0.001):
        optimizer_g = torch.optim.Adam(self.parameters(), lr=lr)
        optimizer_f = torch.optim.Adam(self.parameters(), lr=lr)
        optimizer_y = torch.optim.Adam(self.parameters(), lr=lr)
        return [optimizer_g, optimizer_f, optimizer_y]


In [46]:
exp_dir = './exp_vid_rnn/'
data_dir = './data'

hparams = {
    'max_epochs': 100,
    'in_dim': 64*64,
    'hidden_dim': 1000,
    'out_dim': 64*64,
    'batch_size': 20,
    'lr': 0.001,
    'BCE_pos_weight': 1.
}

nep_project = neptune.init(project_qualified_name="peterpdai/test")
nep_experiment = nep_project.create_experiment(
    name='RNN video prediction',
    params=hparams,
    upload_source_files=['*.py'],
    tags=["vanilla-rnn", "video-prediction"],
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_criterion = BCEWithLogitsLoss(torch.tensor(
    [hparams['BCE_pos_weight']], device=device))
test_criterion = BCEWithLogitsLoss()




model = RNN(hparams['in_dim'], hparams['hidden_dim'],
            hparams['out_dim'], train_criterion, test_criterion)


my_trainer = trainer.Trainer(hparams, device)



# Starting training
my_trainer.train(model, train_loader,
              logger=nep_experiment, root_dir=exp_dir)
my_trainer.test(model, test_loader, logger=None)

https://ui.neptune.ai/peterpdai/test/e/TEST-171
Starting a training run...
Model will be saved at  ./exp_vid_rnn/run-starting-at-11-18-06:39:32-TEST-171  after training.
Epoch 		 val_loss 		


KeyboardInterrupt: 