In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from nltk.corpus import reuters
import nltk
from tqdm import tqdm
nltk.download('reuters')
nltk.download('punkt')

In [None]:
# Preparing the data
sentences = reuters.sents()
sentences = [["<BOS>"] + [word.lower() for word in sentence] + ["<EOS>"] for sentence in sentences]
# print(sentences[1]) ['<BOS>', 'they', 'told', 'reuter', 'correspondents', 'in', 'asian', 'capitals', 'a', 'u', '.', 's', '.', 'move', 'against', 'japan', 'might', 'boost', 'protectionist', 'sentiment', 'in', 'the', 'u', '.', 's', '.', 'and', 'lead', 'to', 'curbs', 'on', 'american', 'imports', 'of', 'their', 'products', '.', '<EOS>']
# print(sentences[2]) ['<BOS>', 'but', 'some', 'exporters', 'said', 'that', 'while', 'the', 'conflict', 'would', 'hurt', 'them', 'in', 'the', 'long', '-', 'run', ',', 'in', 'the', 'short', '-', 'term', 'tokyo', "'", 's', 'loss', 'might', 'be', 'their', 'gain', '.', '<EOS>']
# print(sentences[3]) ['<BOS>', 'the', 'u', '.', 's', '.', 'has', 'said', 'it', 'will', 'impose', '300', 'mln', 'dlrs', 'of', 'tariffs', 'on', 'imports', 'of', 'japanese', 'electronics', 'goods', 'on', 'april', '17', ',', 'in', 'retaliation', 'for', 'japan', "'", 's', 'alleged', 'failure', 'to', 'stick', 'to', 'a', 'pact', 'not', 'to', 'sell', 'semiconductors', 'on', 'world', 'markets', 'at', 'below', 'cost', '.', '<EOS>']

In [None]:
# Creating the vocabulary based on the whole text
PAD_TOKEN = "_PAD"
UNK_TOKEN = "_UNK"
word2idx = {PAD_TOKEN: 0, UNK_TOKEN: 1}
for sentence in sentences:
    for word in sentence:
        if word not in word2idx:
            word2idx[word] = len(word2idx)
vocab = word2idx
vocab_size = len(vocab)
idx2word = {idx: word for word, idx in vocab.items()}
# count = 0
# for key, value in idx2word.items():
#     print(key, value)
#     count += 1
#     # Break the loop after printing the first 10 elements
#     if count == 20:
#         break


In [None]:
# Prepare the datasets
data = [[vocab.get(word, vocab[UNK_TOKEN]) for word in sentence[:-1]] for sentence in sentences]
targets = [[vocab.get(word, vocab[UNK_TOKEN]) for word in sentence[1:]] for sentence in sentences]

data_train, data_test, targets_train, targets_test = train_test_split(data, targets, test_size=0.2)

In [None]:
# Define Dataset
class SentenceDataset(Dataset):
    def __init__(self, sentences, targets):
        self.sentences = sentences
        self.targets = targets

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return torch.tensor(self.sentences[idx]), torch.tensor(self.targets[idx])

In [None]:
# Prepare data loaders
def pad_collate(batch):
    (xx, yy) = zip(*batch)
    xx_pad = pad_sequence(xx, batch_first=True, padding_value=vocab[PAD_TOKEN])
    yy_pad = pad_sequence(yy, batch_first=True, padding_value=vocab[PAD_TOKEN])
    return xx_pad, yy_pad

In [None]:
train_data = SentenceDataset(data_train, targets_train)
test_data = SentenceDataset(data_test, targets_test)
batch_size = 24
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=pad_collate)
test_loader = DataLoader(test_data, batch_size=batch_size, collate_fn=pad_collate)

In [None]:
# Define the model
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=vocab[PAD_TOKEN])
        self.rnn = nn.RNN(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        output, _ = self.rnn(x)
        x = self.fc(output)
        return x

# Hyperparameters
embedding_dim = 64
hidden_size = 64

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = RNNModel(vocab_size, embedding_dim, hidden_size).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab[PAD_TOKEN])
optimizer = optim.Adam(model.parameters(), lr=0.01)

### Train The Model ###
You can skip this part and move for the other cell if you want to use the already trained weights

In [None]:
# Training
for epoch in range(5):  # For simplicity, we just use 5 epochs
    model.train()
    for data, targets in tqdm(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        output = model(data)
        output = output.view(-1, vocab_size)
        targets = targets.view(-1)
        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
    
# Save the model weights
torch.save(model.state_dict(), 'model_weights.pth')


In [36]:
# Load the model weights
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))


<All keys matched successfully>

In [37]:
model.eval()
sentence = "<BOS> You".split()
x = torch.tensor([[vocab.get(word, vocab[UNK_TOKEN]) for word in sentence]], device=device)
output = model(x)
prediction = output[0, -1].argmax(dim=0).item()  # Get the predicted next word for the last word in the sentence
predicted_word = idx2word.get(prediction, UNK_TOKEN)
print(f"Predicted next word: {predicted_word}")

Predicted next word: ,


In [48]:
model.eval()
length = 30
sentence = "<BOS> I thought about it a lot and I got the right decision regarding to the".split()
generated_sentence = sentence.copy()  # This will be used to print the generated sentence
for _ in range(length):  # Generate 30 words
    x = torch.tensor([[vocab.get(word, vocab[UNK_TOKEN]) for word in sentence]], device=device)
    output = model(x)
    prediction = output[0, -1].argmax(dim=0).item()  # Get the predicted next word for the last word in the sentence
    predicted_word = idx2word.get(prediction, UNK_TOKEN)
    sentence.append(predicted_word)  # Always add the predicted word to the sentence
    if predicted_word != "<EOS>":
        generated_sentence.append(predicted_word)  # Only add the predicted word to the generated sentence if it's not "<EOS>"

print(f"Generated sentence: {' '.join(generated_sentence)}")


Generated sentence: <BOS> I thought about it a lot and I got the right decision regarding to the u . s . , which has been expected to be reached by the market , which bodes , dashing by the u . s . the u
