In [14]:
import numpy as np
import torch
import torch.nn as nn
import pickle
from torch.utils.data import Dataset, DataLoader

In [4]:
class SIRSimulatedData(Dataset):
    
    def __init__(self, path='data/data.pkl', partition='train'):
        super(SIRSimulatedData).__init__()
    
        with open(path, 'rb') as f:
            
            data, labels = pickle.load(f)
            
        if partition == 'train':
            self.data, self.labels = data[:16000], labels[:16000]
        
        elif partition == 'dev':
            self.data, self.labels = data[16000:18000], labels[16000:18000]
        
        elif partition == 'test':
            self.data, self.labels = data[18000:], labels[18000:]
            
        
    def __len__(self):
        
        return len(self.labels)
    
    def __getitem__(self, index):
        
        return self.data[index], self.labels[index]

In [12]:
train_data = SIRSimulatedData(partition='train')
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)


In [13]:
X, y = next(iter(train_loader))
X.shape

torch.Size([10, 366, 3])

In [28]:
class Network(nn.Module):
    
    def __init__(self, num_layers, num_hidden, bidirectional=False):
        super(Network, self).__init__()
        
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        
        if bidirectional:
            self.bidirectional_multiplier = 2
        else:
            self.bidirectional_multiplier = 1
        
        self.rnn = nn.RNN(input_size=3, 
                          hidden_size=self.num_hidden,
                          num_layers=self.num_layers,
                          dropout=0,
                          bidirectional=bidirectional)
        
        self.reg = nn.Sequential(*[nn.Linear(self.num_hidden * self.bidirectional_multiplier, 2048),
                                   nn.Linear(2048, 1)])
    
    def forward(self, X):
        
        out, _ = self.rnn(X)
        out = self.reg(out)
        
        return out

In [32]:
model = Network(num_layers=4, num_hidden=128, bidirectional=True)
out = model(X.float()).squeeze()

In [33]:
out.shape

torch.Size([10, 366])