In [2]:
import torch.nn as nn
import torch

In [14]:
class lc_CharModel(nn.Module):
    '''char-based LSTM model
    Input
    n-vocab: constant - number of vocab 
    Output 
    nn model
    '''
    def __init__(self, n_vocab):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=256, num_layers=2, batch_first=True)
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, n_vocab)
    def forward(self, x):
        x, _ = self.lstm(x)
        # take only the last output
        x = x[:, -1, :]
        # produce output
        x = self.linear(self.dropout(x))
        return x

In [4]:
class lc_CharEmbModel(nn.Module):
    '''char-based LSTM model with embedding input vectors: not ok yet
    Input
    n-vocab: constant - number of vocab
    embedding_dim: constant - dimension of embedding vector
    Output
    nn model
    '''
    def __init__(self, n_vocab, embedding_dim = 100):
        super().__init__()
        self.embedding = nn.Embedding(n_vocab, embedding_dim)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=256, num_layers=1, batch_first=True)
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, n_vocab)
    def forward(self, x):
        x = x.long()
        x = x.squeeze(2)
        x = self.embedding(x)
        x, _ = self.lstm(x)
        # take only the last output
        x = x[:, -1, :]
        # produce output
        x = self.linear(self.dropout(x))
        return x

In [15]:
def lc_load_model(model, model_path):
    '''load model and char_to_int dict
    Inputs
    model: obj - nn model
    model_path: string - file path
    Output
    model obj, character to integer dict 
    '''
    model_par, char_to_int = torch.load(model_path) 
    model.load_state_dict(model_par)
    return model, char_to_int

if __name__ == '__main__':
    model = lc_CharModel(47)
    model, char_to_int = lc_load_model(model, model_path = '/Users/danielboda/Text_generation/lc_CharModel_1700225116.pth')
    print(model, char_to_int, sep = '\n')


lc_CharModel(
  (lstm): LSTM(1, 256, num_layers=2, batch_first=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (linear): Linear(in_features=256, out_features=47, bias=True)
)
{'\n': 0, ' ': 1, '!': 2, '(': 3, ')': 4, '*': 5, ',': 6, '-': 7, '.': 8, ':': 9, ';': 10, '?': 11, '[': 12, ']': 13, '_': 14, 'a': 15, 'b': 16, 'c': 17, 'd': 18, 'e': 19, 'f': 20, 'g': 21, 'h': 22, 'i': 23, 'j': 24, 'k': 25, 'l': 26, 'm': 27, 'n': 28, 'o': 29, 'p': 30, 'q': 31, 'r': 32, 's': 33, 't': 34, 'u': 35, 'v': 36, 'w': 37, 'x': 38, 'y': 39, 'z': 40, 'ù': 41, '—': 42, '‘': 43, '’': 44, '“': 45, '”': 46}
