In [3]:
import os
from pickle import FALSE

import pytorch_lightning as pl
from pytorch_lightning.accelerators import accelerator
import torch
from torch.utils.data import DataLoader, random_split
import torchvision 
from tqdm import tqdm 
from pytorch_lightning.loggers.neptune import NeptuneLogger
from torchvision.transforms.transforms import ToTensor
from pytorch_lightning.callbacks import ModelCheckpoint

from MovingMNIST import MovingMNIST

##################################################################
######################  Hyperparameters ########################
##################################################################

HPARAMS = {
    'device': torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    'max_epochs': 10,
    'hidden_dim': 512,
    'batch_size': 128,
    'lr': 0.001,
    'num_gpus': 1
}

In [4]:
##################################################################
######################  Prepare dataset ########################
##################################################################

root = './data'
if not os.path.exists(root):
    os.mkdir(root)

data_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.ConvertImageDtype(torch.float)
])

dataset = MovingMNIST(root='.data/mnist', 
                      train=True, 
                      transform=data_transforms,
                      download=True) # 9000
train_set, val_set = random_split(dataset, [8000, 1000])
test_set = MovingMNIST(root='.data/mnist', 
                       train=False, 
                       transform=data_transforms, 
                       download=True) # 1000

# 8000
train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=HPARAMS['batch_size'],
    num_workers = 20,
    shuffle=True)

# 1000
val_loader = torch.utils.data.DataLoader(
    dataset=val_set,
    batch_size=HPARAMS['batch_size'],
    num_workers = 20,
    shuffle=False)

# 1000
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=HPARAMS['batch_size'],
    num_workers = 20,
    shuffle=False)

In [5]:


##################################################################
######################  Models ########################
##################################################################


# vanilla RNN many-to-many model
class RNN(pl.LightningModule):
    
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()

        self.h_act = torch.nn.Tanh()

        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim

        # Manual RNN
        self.Wb_xh = torch.nn.Linear(self.in_dim, self.hidden_dim, bias=True)
        self.Wb_hh = torch.nn.Linear(self.hidden_dim, self.hidden_dim, bias=True)
        self.Wb_hy = torch.nn.Linear(self.hidden_dim, self.out_dim, bias=True)
        
        # Report epoch loss 
        self.epoch_train_loss = []
        self.criterion = torch.nn.BCEWithLogitsLoss()


    def forward(self, x):
        # function for inference 
        # x: first ten frames, Tensor of size [batch_sz, 10, 64, 64], dtype=torch.float
        x = torch.transpose(x, 0, 1)/255.
        in_len, batch_sz, h, w = x.shape
        out_len = in_len
        assert(h * w == self.in_dim)
       
        h_ = torch.zeros(batch_sz, self.hidden_dim, device=self.device)
        
        for in_idx in range(in_len-1):
            x_ = x[in_idx].view(batch_sz, -1) # [batch_sz, self.in_dim]
            h_ = self.h_act(self.Wb_xh(x_) + self.Wb_hh(h_)) 
        
        x_ = x[-1]
        Ys = []
        
        for out_idx in range(out_len):
            x_ = x_.view(batch_sz, -1)
            h_ = self.h_act(self.Wb_xh(x_) + self.Wb_hh(h_)) 
            Ys.append(self.Wb_hy(h_).view(batch_sz, h, w)) 
        
#         assert (len(Ys) == out_len)
        return torch.transpose(torch.stack(Ys, 0), 0, 1) # [batch_sz, seq_len, h, w], dtype=torch.float
        
        
    def training_step(self, batch, batch_idx):
#         print(batch_idx)
        # training_step defined the train loop. It is independent of forward
        
        x = torch.transpose(torch.cat(batch, 1), 0, 1)/255 # [20, batch_sz, h, w, dtype=torch.float]
        print("max of x", x.max())
        print("dtype ", x.dtype)
        total_len, batch_sz, h, w = x.shape
        h_ = torch.zeros(batch_sz, self.hidden_dim, device=self.device)
        p_out = []
        
        for idx in range(total_len-1):
            x_ = x[idx].view(batch_sz, h*w) # [batch_sz, self.in_dim]
            h_ = self.h_act(self.Wb_xh(x_) + self.Wb_hh(h_)) 
            p_ = self.Wb_hy(h_).view(batch_sz, h, w)
            p_out.append(p_)
            
        y_logits = torch.stack(p_out, 0)
        gt = x[1:]
        loss = self.criterion(torch.transpose(y_logits, 0, 1), torch.transpose(gt, 0, 1)) # [19, batch_sz, h, w] 
        self.log('train_loss', loss)
        return loss 
    
    
    def validation_step(self, batch, batch_idx):
         # training_step defined the train loop. It is independent of forward
        x = torch.transpose(torch.cat(batch, 1), 0, 1)/255 # [20, batch_sz, h, w, dtype=torch.float]
        total_len, batch_sz, h, w = x.shape
        h_ = torch.zeros(batch_sz, self.hidden_dim, device=self.device)
        p_out = []
        
        for idx in range(total_len-1):
            x_ = x[idx].view(batch_sz, h*w) # [batch_sz, self.in_dim]
            h_ = self.h_act(self.Wb_xh(x_) + self.Wb_hh(h_)) 
            p_ = self.Wb_hy(h_).view(batch_sz, h, w)
            p_out.append(p_)
            
        y_logits = torch.stack(p_out, 0)
        gt = x[1:]
        loss = self.criterion(torch.transpose(y_logits, 0, 1), torch.transpose(gt, 0, 1)) # [19, batch_sz, h, w] 
        self.log('val_loss', loss)
        return loss 
    
    
    def test_step(self, batch, batch_idx):
        x, y = batch # integer values need to tansform to float
        pred_frames = self.forward(x) # [batch_sz, seq_len, h, w] in [0, 1]
        _, _, h, w = y.shape
        assert (h == 64 and w == 64)
        pixel_loss = self.criterion(pred_frames, y/255.)
        frame_loss = pixel_loss * h * w
        self.log('test_bce_loss', frame_loss)
        return frame_loss 
        
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=HPARAMS['lr'])
        return optimizer
        
    
        



# ##################################################################
# ###################### Setting Experiment ########################
# ##################################################################


from pytorch_lightning.loggers.neptune import NeptuneLogger

neptune_logger = NeptuneLogger(
    project_name="peterpdai/MovingMNIST-RNN",
    params=HPARAMS)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=os.getcwd(),
    filename='sample-mnist-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min'
)


model = RNN(64*64, HPARAMS['hidden_dim'], 64*64)




trainer = pl.Trainer(max_epochs=HPARAMS['max_epochs'],
                     gpus=HPARAMS['num_gpus'],
                     accelerator="ddp",
                     logger=neptune_logger,
                     default_root_dir= os.getcwd(),# saves checkpoints to '/your/path/to/save/checkpoints' at every epoch end
                     callbacks=[checkpoint_callback],
                     fast_dev_run=False) # for debugging

# ##################################################################
# ###################### Training ########################
# ##################################################################



# 



https://ui.neptune.ai/peterpdai/MovingMNIST-RNN/e/MOVRNN-57


NeptuneLogger will work in online mode
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


In [6]:
trainer.fit(model, train_loader, val_loader)

initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1

  | Name      | Type              | Params
------------------------------------------------
0 | h_act     | Tanh              | 0     
1 | Wb_xh     | Linear            | 2 M   
2 | Wb_hh     | Linear            | 262 K 
3 | Wb_hy     | Linear            | 2 M   
4 | criterion | BCEWithLogitsLoss | 0     


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tensor(1., device='cuda:0')
dtype  torch.float32
max of x tenso

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [None]:
checkpoint_callback.best_model_path

In [None]:
trainer.test(test_dataloaders=test_loader)