In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm import tqdm

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
class WeatherRNN(nn.Module):
    def __init__(self, tokens, n_hidden=256, n_layers=2, drop_prob=0.5):
        """
        Basic implementation of a multi-layer RNN with LSTM cells and Dropout.
        """
        super().__init__()
        self.drop_prob = drop_prob
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        
        self.tokens = tokens
        self.dropout = nn.Dropout(drop_prob)
        self.lstm = nn.LSTM(len(self.tokens), n_hidden, n_layers, dropout=drop_prob, batch_first=True)
        self.fc = nn.Linear(n_hidden, len(self.tokens))
        
    def forward(self, x, hidden):
        """
        Forward pass through the network
        """
        x, hidden = self.lstm(x, hidden)
        x = self.dropout(x)
        x = self.fc(x)

        return x, hidden

In [None]:
def save_checkpoint(net, opt, filename, train_history={}):
    """
    Save trained model to file.
    """
    checkpoint = {'n_hidden': net.n_hidden,
                  'n_layers': net.n_layers,
                  'state_dict': net.state_dict(),
                  'optimizer': opt.state_dict(),
                  'tokens': net.tokens,
                  'train_history': train_history}

    with open(filename, 'wb') as f:
        torch.save(checkpoint, f)