In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

! pip install -q langdetect
from langdetect import detect
from tqdm import tqdm
tqdm.pandas()

! pip install -q nltk
import nltk
from nltk import word_tokenize, sent_tokenize
from nltk.probability import FreqDist

nltk.download('punkt')

from itertools import chain

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/981.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m972.8/981.5 kB[0m [31m31.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m20.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for langdetect (setup.py) ... [?25l[?25hdone


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [3]:
! pip install -q kaggle
! kaggle datasets download 'smagnan/1-million-reddit-comments-from-40-subreddits'
! unzip -q 1-million-reddit-comments-from-40-subreddits.zip

Dataset URL: https://www.kaggle.com/datasets/smagnan/1-million-reddit-comments-from-40-subreddits
License(s): CC0-1.0
Downloading 1-million-reddit-comments-from-40-subreddits.zip to /content
 91% 65.0M/71.2M [00:01<00:00, 59.9MB/s]
100% 71.2M/71.2M [00:01<00:00, 50.0MB/s]


In [4]:
config = {
    'comments_loaded': 30000,
    'sent_length': 20,
    'truncation_length': 10,
    'padding_token': 0,
    'unknown_token': 1,
    'embedding_dim': 1024,
    'hidden_dim': 1024,
    'vocab_size': 30002,
    'batch_size': 256,
    'learning_rate': 1e-4,
    'num_epochs': 30,
}

In [5]:
df = pd.read_csv('kaggle_RC_2019-05.csv')
df = df[:config['comments_loaded']]
df.drop(columns=['subreddit', 'controversiality', 'score'], inplace=True)

def isEnglish_text(text: str) -> bool:
    try:
        return detect(text) == 'en'
    except:
        return False

df['IsEnglish'] = df['body'].progress_apply(isEnglish_text)
print(df)

100%|██████████| 30000/30000 [02:45<00:00, 181.44it/s]

                                                    body  IsEnglish
0      Your submission has been automatically removed...       True
1      Dont squeeze her with you massive hand, you me...       True
2      It's pretty well known and it was a paid produ...       True
3      You know we have laws against that currently c...       True
4      Yes, there is a difference between gentle supp...       True
...                                                  ...        ...
29995  Is this april fools? Did I get too high? Am I ...       True
29996         Alister Black's disembodied head is spooky      False
29997  I used to eat vienna sausages straight from th...       True
29998  My teachers don’t care about me using that as ...       True
29999  Don’t start with this “kids movie for kids” rh...       True

[30000 rows x 2 columns]





In [6]:
df = df[df['IsEnglish']]
df.drop(columns=['IsEnglish'], inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.drop(columns=['IsEnglish'], inplace=True)


In [7]:
df['body'] = df['body'].progress_apply(lambda s: s.lower())
df['body'] = df['body'].progress_apply(sent_tokenize)

df = df.explode(column='body')

df['body'] = df['body'].progress_apply(word_tokenize)
df['body'] = df['body'].progress_apply(lambda ws: ['<s>'] + ws + ['</s>'])
print(df)

100%|██████████| 28853/28853 [00:00<00:00, 641693.46it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['body'] = df['body'].progress_apply(lambda s: s.lower())
100%|██████████| 28853/28853 [00:02<00:00, 11410.79it/s]
100%|██████████| 73377/73377 [00:15<00:00, 4603.75it/s]
100%|██████████| 73377/73377 [00:00<00:00, 214489.83it/s]

                                                    body
0      [<s>, your, submission, has, been, automatical...
0      [<s>, please, review, the, options, posted, in...
0      [<s>, *, i, am, a, bot, ,, and, this, action, ...
0      [<s>, please, [, contact, the, moderators, of,...
0                                         [<s>, *, </s>]
...                                                  ...
29998    [<s>, i, ’, m, stronger, then, this, boi, </s>]
29999  [<s>, don, ’, t, start, with, this, “, kids, m...
29999  [<s>, if, that, were, the, case, ,, this, woul...
29999                              [<s>, movie, ., </s>]
29999  [<s>, this, is, a, movie, that, is, banking, o...

[73377 rows x 1 columns]





In [8]:
word_freq = nltk.FreqDist(chain(*df['body'].to_list()))
vocab = word_freq.most_common(config['vocab_size'] - 2)

word_to_token_map = {word: i + 2 for i, (word, _) in enumerate(vocab)}
token_to_word_map = {t: w for w, t in word_to_token_map.items()}

word_to_token_map['<pad>'] = 0
word_to_token_map['<unk>'] = 1
token_to_word_map[0] = '<pad>'
token_to_word_map[1] = '<unk>'

print(word_to_token_map)
print(token_to_word_map)



In [9]:
def map_words_to_tokens(unk_token):
    def apply(sent_tok: list[str]) -> list[int]:
        return [word_to_token_map.get(w, unk_token) for w in sent_tok]

    return apply

df['body'] = df['body'].progress_apply(map_words_to_tokens(config['unknown_token']))

def pad_truncate_sent(pad_token, sent_length):
    def apply(sent_tok: list[int]) -> list[int]:
        len_ = len(sent_tok)
        if len_ > sent_length:
            return sent_tok[:sent_length]
        else:
            return sent_tok + [pad_token] * (sent_length - len_)

    return apply

df['body'] = df['body'].progress_apply(pad_truncate_sent(config['padding_token'], config['sent_length']))
df.reset_index(drop=True, inplace=True)
print(df)

100%|██████████| 73377/73377 [00:00<00:00, 102225.89it/s]
100%|██████████| 73377/73377 [00:00<00:00, 667712.63it/s]

                                                    body
0      [2, 39, 290, 79, 90, 180, 198, 71, 53, 119, 45...
1      [2, 102, 1058, 5, 1155, 560, 17, 5, 57, 475, 9...
2      [2, 11, 9, 139, 8, 183, 6, 10, 18, 186, 22, 23...
3      [2, 102, 57, 197, 5, 185, 14, 18, 162, 58, 25,...
4      [2, 11, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
...                                                  ...
73372  [2, 9, 19, 170, 2607, 101, 18, 4878, 3, 0, 0, ...
73373  [2, 168, 19, 63, 369, 32, 18, 144, 346, 310, 2...
73374  [2, 34, 15, 105, 5, 445, 6, 18, 54, 30, 603, 4...
73375  [2, 310, 4, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
73376  [2, 18, 16, 8, 310, 15, 16, 11551, 35, 4947, 6...

[73377 rows x 1 columns]





In [10]:
class CustomDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return torch.tensor(self.df['body'][idx][:-1]), \
               torch.tensor(self.df['body'][idx][1:])

dataset = CustomDataset(df)
train_data, test_data = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_loader = DataLoader(train_data, batch_size=config['batch_size'], shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=config['batch_size'], shuffle=True, drop_last=True)

In [11]:
class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(RNN, self).__init__()

        self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, h):
        x = self.embedding_layer(x)
        x, h = self.rnn(x, h)
        x = self.fc(x)
        return x, h

In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = RNN(config['vocab_size'], config['embedding_dim'], config['hidden_dim']).to(device)

criterion = nn.CrossEntropyLoss()
adam = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(adam, config['num_epochs'])

In [13]:
best_error = float('inf')

for epoch in range(config['num_epochs']):
    model.train()

    losses = []
    for batch_data, batch_target in tqdm(train_loader):
        batch_data = batch_data.to(device)
        batch_target = batch_target.to(device)

        hidden = torch.zeros(1, config['batch_size'], config['hidden_dim']).to(device)

        adam.zero_grad()

        total_loss = 0.
        for i in range(0, config['sent_length'], config['truncation_length']):
            data_chunk = batch_data[:, i:i + config['truncation_length']]
            target_chunk = batch_target[:, i:i + config['truncation_length']]

            output, hidden = model(data_chunk, hidden)
            loss = criterion(output.view(-1, config['vocab_size']), target_chunk.reshape(-1))
            loss.backward()
            total_loss += loss.item()

            hidden = hidden.detach()

        adam.step()
        losses.append(total_loss)

    lr_scheduler.step()
    print(f'Train error on epoch {epoch}: {np.mean(losses)}')

    model.eval()
    with torch.no_grad():
        losses = []
        for batch_data, batch_target in test_loader:
            batch_data = batch_data.to(device)
            batch_target = batch_target.to(device)

            hidden = torch.zeros(1, config['batch_size'], config['hidden_dim']).to(device)

            total_loss = 0.
            for i in range(0, config['sent_length'], config['truncation_length']):
                data_chunk = batch_data[:, i:i + config['truncation_length']]
                target_chunk = batch_target[:, i:i + config['truncation_length']]

                output, hidden = model(data_chunk, hidden)
                loss = criterion(output.view(-1, config['vocab_size']), target_chunk.reshape(-1))
                total_loss += loss.item()

            losses.append(total_loss)

        test_error = np.mean(losses)
        print(f'Test error on epoch {epoch}: {test_error}')

        if test_error < best_error:
            best_error = test_error
            torch.save(model.state_dict(), 'best_model.pt')

100%|██████████| 229/229 [01:06<00:00,  3.43it/s]


Train error on epoch 0: 9.299296881954742
Test error on epoch 0: 7.709931315037242


100%|██████████| 229/229 [01:11<00:00,  3.21it/s]


Train error on epoch 1: 7.35284509929507
Test error on epoch 1: 7.209664599937305


100%|██████████| 229/229 [01:11<00:00,  3.21it/s]


Train error on epoch 2: 6.935429655828851
Test error on epoch 2: 6.967574496018259


100%|██████████| 229/229 [01:11<00:00,  3.22it/s]


Train error on epoch 3: 6.674080877845464
Test error on epoch 3: 6.81246234659563


100%|██████████| 229/229 [01:11<00:00,  3.21it/s]


Train error on epoch 4: 6.476601041040046
Test error on epoch 4: 6.711249600376999


100%|██████████| 229/229 [01:11<00:00,  3.20it/s]


Train error on epoch 5: 6.315819585687729
Test error on epoch 5: 6.630589577189663


100%|██████████| 229/229 [01:11<00:00,  3.21it/s]


Train error on epoch 6: 6.176538642837491
Test error on epoch 6: 6.562792240527639


100%|██████████| 229/229 [01:11<00:00,  3.20it/s]


Train error on epoch 7: 6.053296071993732
Test error on epoch 7: 6.519871711730957


100%|██████████| 229/229 [01:11<00:00,  3.19it/s]


Train error on epoch 8: 5.945725630985077
Test error on epoch 8: 6.477865606023554


100%|██████████| 229/229 [01:11<00:00,  3.19it/s]


Train error on epoch 9: 5.847547202131113
Test error on epoch 9: 6.452772029659204


100%|██████████| 229/229 [01:12<00:00,  3.18it/s]


Train error on epoch 10: 5.7585893848577445
Test error on epoch 10: 6.427966442024498


100%|██████████| 229/229 [01:12<00:00,  3.17it/s]


Train error on epoch 11: 5.678497735069309
Test error on epoch 11: 6.406990885734558


100%|██████████| 229/229 [01:11<00:00,  3.18it/s]


Train error on epoch 12: 5.606426906898032
Test error on epoch 12: 6.394401391347249


100%|██████████| 229/229 [01:12<00:00,  3.18it/s]


Train error on epoch 13: 5.540453440237253
Test error on epoch 13: 6.378845490907368


100%|██████████| 229/229 [01:11<00:00,  3.20it/s]


Train error on epoch 14: 5.480452646334619
Test error on epoch 14: 6.370638602658322


100%|██████████| 229/229 [01:12<00:00,  3.18it/s]


Train error on epoch 15: 5.426909870455879
Test error on epoch 15: 6.364448787873251


100%|██████████| 229/229 [01:12<00:00,  3.18it/s]


Train error on epoch 16: 5.380628707627542
Test error on epoch 16: 6.35625313875968


100%|██████████| 229/229 [01:12<00:00,  3.16it/s]


Train error on epoch 17: 5.337591216033203
Test error on epoch 17: 6.351148561427467


100%|██████████| 229/229 [01:12<00:00,  3.18it/s]


Train error on epoch 18: 5.300375671365896
Test error on epoch 18: 6.347860003772535


100%|██████████| 229/229 [01:12<00:00,  3.17it/s]


Train error on epoch 19: 5.267566264977101
Test error on epoch 19: 6.347182838540328


100%|██████████| 229/229 [01:12<00:00,  3.17it/s]


Train error on epoch 20: 5.239722834924423
Test error on epoch 20: 6.340796234314902


100%|██████████| 229/229 [01:12<00:00,  3.16it/s]


Train error on epoch 21: 5.216145491495924
Test error on epoch 21: 6.341577500627752


100%|██████████| 229/229 [01:11<00:00,  3.19it/s]


Train error on epoch 22: 5.196062451366775
Test error on epoch 22: 6.337120993095532


100%|██████████| 229/229 [01:12<00:00,  3.16it/s]


Train error on epoch 23: 5.181204782823288
Test error on epoch 23: 6.341091827342384


100%|██████████| 229/229 [01:11<00:00,  3.18it/s]


Train error on epoch 24: 5.168183191894965
Test error on epoch 24: 6.338316133147792


100%|██████████| 229/229 [01:11<00:00,  3.19it/s]


Train error on epoch 25: 5.158739235203339
Test error on epoch 25: 6.3390007876513295


100%|██████████| 229/229 [01:11<00:00,  3.19it/s]


Train error on epoch 26: 5.15244205237476
Test error on epoch 26: 6.338727836023297


100%|██████████| 229/229 [01:12<00:00,  3.18it/s]


Train error on epoch 27: 5.14706193567884
Test error on epoch 27: 6.338678468737686


100%|██████████| 229/229 [01:11<00:00,  3.18it/s]


Train error on epoch 28: 5.145146853018015
Test error on epoch 28: 6.3395274145561356


100%|██████████| 229/229 [01:12<00:00,  3.18it/s]


Train error on epoch 29: 5.144154113453028
Test error on epoch 29: 6.339496520527622


In [14]:
checkpoint = torch.load('best_model.pt', weights_only=True)
model.load_state_dict(checkpoint)

prompt = ['<s>', 'there']
tok_prompt = [word_to_token_map[w] for w in prompt]

model.eval()
with torch.no_grad():
    hidden = torch.zeros(1, 1, config['hidden_dim']).to(device)

    for _ in range(config['sent_length']):
        data_chunk = torch.tensor(tok_prompt).unsqueeze(0).to(device)
        output, hidden = model(data_chunk, hidden)

        pred = output[0, -1].argmax().item()
        tok_prompt.append(pred)
        prompt.append(token_to_word_map[pred])

        if token_to_word_map[pred] == '</s>':
            break

print(' '.join(prompt))

<s> there 's a lot of people who are n't going to be the same thing . </s>
