In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import TrajectoryDataset

import numpy as np
import matplotlib.pyplot as plt

### 1. Getting the Dataset and Dataloader

In [25]:
dataset = TrajectoryDataset(
    data_dir = "../dataset/",
    states_filename = "states.npy",
    actions_filename = "actions.npy",
    s_transform = None,
    a_transform = None,
)

# TODO: create two dataset for train and test

dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

first_datapoint = next(iter(dataloader))
state, action = first_datapoint

print(f"Number of data_points {len(dataloader)}")
print(f"Shape of state: {state.shape}")
print(f"Shape of action: {action.shape}")

Number of data_points 15000
Shape of state: torch.Size([1, 17, 2, 65, 65])
Shape of action: torch.Size([1, 16, 2])


### 2. Defining the Model

In [26]:
class Predictor(nn.Module):
    def __init__(self, input_size, hidden_size) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)

        self.h = None
        self.c = None

    def set_h(self, h):
        self.h = h
    
    def reset_hc(self, h, c):
        self.h = self.h.zero_() 
        self.c = self.c.zero_()
    """
    Forward Method
    Input: 
        action tensor (a) of shape: (B, input_size)
    Output:   
        predicted repr. tensor (s_yhat) of shape: (B, hidden_size)
    """
    def forward(self, action):
        self.h, self.c = self.lstm_cell(action, (self.h, self.c))
        return self.h

### 3. Training 

In [27]:
# NOTE: assume no batch processing.
def train(pred, enc, dataloader, criterion, optimizer, device):
    # keeping encoder in eval mode
    enc.eval()
    for batch in dataloader:
        a, s = batch
        a, s = a.to(device), s.to(device)

        ## initial observation
        pred.set_h(enc(s[0]))

        L = a.shape[0]
        # TODO: check initialization
        loss = 0

        for i in range(L):
            action = a[i]
            sy_hat = pred(action=action)
            sy = enc(s[i+1])

            loss += criterion(sy_hat, sy)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ## clearing the hidden state and cell state
        pred.reset_hc()

In [30]:
input_size = 2
hidden_size = 2048

encoder = None
predictor = Predictor(input_size=input_size, hidden_size=hidden_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(predictor.parameters(), lr=0.001)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [31]:
train(predictor, encoder, dataloader, criterion, optimizer, device)

AttributeError: 'NoneType' object has no attribute 'eval'