In [1]:
import numpy as np
import torch
from src.utilities import vocabulary, CONTEXT_LEN

In [151]:
def get_words(file_path: str='data/names.txt') -> list:
    """Function that reads the raw data and outputs list of words"""
    return open(file_path, 'r').read().splitlines()

def word2vec(word: str, vocabulary: list=vocabulary) -> list:
    """Function that transforms passed word into a vector of indicies using input vocabulary"""
    return [vocabulary.index(let) for let in word]

def sample_train(words: list):
    """Function beaks down a random word from passed list into train and target samples"""
    word = word2vec(words[np.random.choice(range(len(words)))] + '.')
    X = []; y = []
    for n, ch in enumerate(word[:-1]):
        X.append(ch); y.append(word[n+1])
    return torch.tensor(X), torch.tensor(y).float()

In [155]:
class RNN(torch.nn.Module):
    def __init__(self, vocab_len: int, embedding_dim: int, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.emb = torch.nn.Embedding(num_embeddings=vocab_len, embedding_dim=embedding_dim)
        self.lstm = torch.nn.LSTMCell(embedding_dim, self.hidden_size)
        self.lin = torch.nn.Linear(hidden_size, vocab_len)

    def forward(self, char: torch.Tensor, hidden_state: torch.Tensor, cell_state: torch.Tensor):
        """Applies all the network layers to the passed character encoded as a number"""
        embedding = self.emb(char)
        hidden_state, cell_state = self.lstm(embedding, (hidden_state, cell_state))
        output = self.lin(hidden_state)
        return output, hidden_state, cell_state
    
    def init_zero_state(self) -> tuple[torch.Tensor, torch.Tensor]:
        """Initiates dummy hidden and cell states for an lstm cell"""
        zero_hidden_state = torch.zeros(self.hidden_size)
        zero_cell_state = torch.zeros(self.hidden_size)
        return zero_hidden_state, zero_cell_state

In [156]:
rnn = RNN(vocab_len=len(vocabulary),
          embedding_dim=20,
          hidden_size=128)

In [278]:
words = get_words()
rnn_input, target = sample_train(words)

def train_epoch (rnn_input, target, rnn, optimiser):
    loss = 0
    hidden_state, cell_state = rnn.init_zero_state()
    for char, tar in zip(rnn_input, target):
        output, hidden_state, cell_state = rnn.forward(char=char,
                                                       hidden_state=hidden_state,
                                                       cell_state=cell_state)
        loss += torch.nn.functional.cross_entropy(output, tar.long())
    
    epoch_loss = loss / target.shape[0] 
    optimiser.zero_grad()
    epoch_loss.backward()
    optimiser.step()
    return rnn, epoch_loss

In [279]:
opt = torch.optim.Adam(rnn.parameters())
train_epoch(rnn_input, target, rnn, opt)

(RNN(
   (emb): Embedding(27, 20)
   (lstm): LSTMCell(20, 128)
   (lin): Linear(in_features=128, out_features=27, bias=True)
 ),
 tensor(2.6938, grad_fn=<DivBackward0>))

In [178]:
opt

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [None]:
# def build_train(words: list):
#     """Function beaks down every word from passed list into train and target samples"""
#     X = []; y=[]
#     for word in words:
#         context = '.' * CONTEXT_LEN
#         for ch in word + '.':
#             X.append(word2vec(context)); y.append(word2vec(ch))
#             context = context[1:] + ch
#     return torch.tensor(X), torch.tensor(y).float()