In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn
import torch.nn.utils.rnn
import torch.utils.data
import matplotlib.pyplot as plt
import seaborn as sns
import opencc
import os

In [2]:
char_to_id = {'<pad>': 0,
 '<eos>': 1,
 '(': 2,
 ')': 3,
 '*': 4,
 '+': 5,
 '-': 6,
 '0': 7,
 '1': 8,
 '2': 9,
 '3': 10,
 '4': 11,
 '5': 12,
 '6': 13,
 '7': 14,
 '8': 15,
 '9': 16,
 '=': 17,
 '<unk>': 18}
id_to_char = {0: '<pad>',
 1: '<eos>',
 2: '(',
 3: ')',
 4: '*',
 5: '+',
 6: '-',
 7: '0',
 8: '1',
 9: '2',
 10: '3',
 11: '4',
 12: '5',
 13: '6',
 14: '7',
 15: '8',
 16: '9',
 17: '=',
 18: '<unk>'}

In [3]:
class CharRNN(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(CharRNN, self).__init__()
        
        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size,
                                            embedding_dim=embed_dim,
                                            padding_idx=char_to_id['<pad>'])
        
        self.rnn_layer1 = torch.nn.LSTM(input_size=embed_dim,
                                        hidden_size=hidden_dim,
                                        batch_first=True)
        
        self.rnn_layer2 = torch.nn.LSTM(input_size=hidden_dim,
                                        hidden_size=hidden_dim,
                                        batch_first=True)
        
        self.linear = torch.nn.Sequential(torch.nn.Linear(in_features=hidden_dim,
                                                          out_features=hidden_dim),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(in_features=hidden_dim,
                                                          out_features=vocab_size))
        
    def forward(self, batch_x, batch_x_lens):
        return self.encoder(batch_x, batch_x_lens)
    
    # The forward pass of the model
    def encoder(self, batch_x, batch_x_lens):
        batch_x = self.embedding(batch_x)
        
        batch_x = torch.nn.utils.rnn.pack_padded_sequence(batch_x,
                                                          batch_x_lens,
                                                          batch_first=True,
                                                          enforce_sorted=False)
        
        batch_x, _ = self.rnn_layer1(batch_x)
        batch_x, _ = self.rnn_layer2(batch_x)
        
        batch_x, _ = torch.nn.utils.rnn.pad_packed_sequence(batch_x,
                                                            batch_first=True)
        
        batch_x = self.linear(batch_x)
        
        return batch_x
    
    def generator(self, start_char, max_len=200):
        
        char_list = [char_to_id[c] for c in start_char]
        answer_list = []
        
        next_char = None
        
        while len(char_list) < max_len: 
            # Write your code here 
            # Pack the char_list to tensor
            input_tensor = torch.tensor([char_list], dtype=torch.long, device=self.embedding.weight.device)
            input_lens = torch.tensor([len(char_list)], dtype=torch.long)
            # Input the tensor to the embedding layer, LSTM layers, linear respectively
            logits = self.encoder(input_tensor, input_lens)
            # Obtain the next token prediction y
            y = logits[0, -1, :]
            
            next_char = torch.argmax(y, dim=-1).item() # Use argmax function to get the next token prediction
            
            char_list.append(next_char)
            
            if next_char == char_to_id['<eos>']:
                break
            
            
        return [id_to_char[ch_id] for ch_id in char_list]

In [7]:
batch_size = 64
epochs = 2
embed_dim = 256
hidden_dim = 256
lr = 0.001
grad_clip = 1

vocab_size = 18

In [11]:
model = CharRNN(vocab_size,
                embed_dim,
                hidden_dim)
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('mps')))

<All keys matched successfully>

In [None]:
model.generator('1+1=')

['1', '+', '1', '=', '2', '<eos>']