In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from collections import Counter
import os

In [16]:
batch_size = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
class PreProcessing():
    
    def get_data_from_file(self,train_file, batch_size, seq_size):
        with open(train_file, 'r', encoding='utf-8') as f:
            text = f.read()
        text = text.split()

        word_counts = Counter(text)
        sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
        int_to_vocab = {k: w for k, w in enumerate(sorted_vocab)}
        vocab_to_int = {w: k for k, w in int_to_vocab.items()}
        n_vocab = len(int_to_vocab)

        print('Vocabulary size', n_vocab)

        int_text = [vocab_to_int[w] for w in text]
        num_batches = int(len(int_text) / (seq_size * batch_size))
        in_text = int_text[:num_batches * batch_size * seq_size]
        out_text = np.zeros_like(in_text)
        out_text[:-1] = in_text[1:]
        out_text[-1] = in_text[0]
        in_text = np.reshape(in_text, (batch_size, -1))
        out_text = np.reshape(out_text, (batch_size, -1))
        return int_to_vocab, vocab_to_int, n_vocab, in_text, out_text


    def get_batches(self,in_text, out_text, batch_size, seq_size):
        num_batches = np.prod(in_text.shape) // (seq_size * batch_size)
        for i in range(0, num_batches * seq_size, seq_size):
            yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size]
            
            
preprocess_obj = PreProcessing()
int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = preprocess_obj.get_data_from_file("data.txt",4,4)

Vocabulary size 22


In [18]:
in_text

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12,  0, 13, 14]])

In [19]:
out_text

array([[ 1,  2,  3,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12],
       [ 0, 13, 14,  0]])

In [22]:
class RNNModule(nn.Module):
    def __init__(self, n_vocab, seq_size, embedding_size, lstm_size):
        super(RNNModule, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding(n_vocab, embedding_size)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=True)
        self.dense = nn.Linear(lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.dense(output)

        return logits, state

    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_size),
                torch.zeros(1, batch_size, self.lstm_size))
    
    def get_loss_and_train_op(self, net, lr=0.001):
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(net.parameters(), lr=lr)

        return criterion, optimizer
seq_size = 4
embedding_size = 22
lstm_size = 64
net = RNNModule(n_vocab, seq_size,embedding_size, lstm_size)
net = net.to(device)
criterion, optimizer = net.get_loss_and_train_op(net, 0.01)