In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from nltk.corpus import reuters
import nltk
from tqdm import tqdm
nltk.download('reuters')
nltk.download('punkt')

[nltk_data] Downloading package reuters to /root/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from nltk.corpus import reuters
import nltk
from tqdm import tqdm
nltk.download('reuters')
nltk.download('punkt')

# Preparing the data
sentences = reuters.sents()
sentences = [["<BOS>"] + [word.lower() for word in sentence] + ["<EOS>"] for sentence in sentences]

# Prepare the vocabulary
PAD_TOKEN = "_PAD"
UNK_TOKEN = "_UNK"
word2idx = {PAD_TOKEN: 0, UNK_TOKEN: 1}
for sentence in sentences:
    for word in sentence:
        if word not in word2idx:
            word2idx[word] = len(word2idx)
vocab = word2idx
vocab_size = len(vocab)

idx2word = {idx: word for word, idx in vocab.items()}

# Prepare the datasets
data = [[vocab.get(word, vocab[UNK_TOKEN]) for word in sentence[:-1]] for sentence in sentences]
targets = [[vocab.get(word, vocab[UNK_TOKEN]) for word in sentence[1:]] for sentence in sentences]

data_train, data_test, targets_train, targets_test = train_test_split(data, targets, test_size=0.2)

# Define Dataset
class SentenceDataset(Dataset):
    def __init__(self, sentences, targets):
        self.sentences = sentences
        self.targets = targets

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return torch.tensor(self.sentences[idx]), torch.tensor(self.targets[idx])

# Prepare data loaders
def pad_collate(batch):
    (xx, yy) = zip(*batch)
    xx_pad = pad_sequence(xx, batch_first=True, padding_value=vocab[PAD_TOKEN])
    yy_pad = pad_sequence(yy, batch_first=True, padding_value=vocab[PAD_TOKEN])
    return xx_pad, yy_pad

train_data = SentenceDataset(data_train, targets_train)
test_data = SentenceDataset(data_test, targets_test)
batch_size = 24
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=pad_collate)
test_loader = DataLoader(test_data, batch_size=batch_size, collate_fn=pad_collate)

# Define the model
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=vocab[PAD_TOKEN])
        self.rnn = nn.RNN(embedding_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        output, _ = self.rnn(x)
        x = self.fc(output)
        return x

# Hyperparameters
embedding_dim = 64
hidden_size = 64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = RNNModel(vocab_size, embedding_dim, hidden_size).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab[PAD_TOKEN])
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training
for epoch in range(5):  # For simplicity, we just use 5 epochs
    model.train()
    for data, targets in tqdm(train_loader):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        output = model(data)
        output = output.view(-1, vocab_size)
        targets = targets.view(-1)
        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

# Save the model weights
torch.save(model.state_dict(), 'model_weights.pth')


[nltk_data] Downloading package reuters to /root/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
100%|██████████| 1824/1824 [00:42<00:00, 43.36it/s]


Epoch: 1, Loss: 4.56048059463501


100%|██████████| 1824/1824 [00:41<00:00, 43.63it/s]


Epoch: 2, Loss: 4.333864688873291


100%|██████████| 1824/1824 [00:41<00:00, 43.46it/s]


Epoch: 3, Loss: 4.291331768035889


100%|██████████| 1824/1824 [00:42<00:00, 43.12it/s]


Epoch: 4, Loss: 4.3650803565979


100%|██████████| 1824/1824 [00:41<00:00, 43.71it/s]

Epoch: 5, Loss: 4.353855133056641





In [None]:
model.eval()
sentence = "<BOS> You".split()
x = torch.tensor([[vocab.get(word, vocab[UNK_TOKEN]) for word in sentence]], device=device)
output = model(x)
prediction = output[0, -1].argmax(dim=0).item()  # Get the predicted next word for the last word in the sentence
predicted_word = idx2word.get(prediction, UNK_TOKEN)
print(f"Predicted next word: {predicted_word}")

Predicted next word: ,


In [None]:
model.eval()
length = 30
sentence = "<BOS> What was".split()
for _ in range(length):  # Generate 10 words
    x = torch.tensor([[vocab.get(word, vocab[UNK_TOKEN]) for word in sentence]], device=device)
    output = model(x)
    prediction = output[0, -1].argmax(dim=0).item()  # Get the predicted next word for the last word in the sentence
    predicted_word = idx2word.get(prediction, UNK_TOKEN)
    if predicted_word != "<EOS>":
        sentence.append(predicted_word)  # Add the predicted word to the sentence if it's not "<EOS>"

print(f"Generated sentence: {' '.join(sentence)}")


Generated sentence: <BOS> What was expected to be a little comment on the agreement , which is expected to be taken .
