In [52]:
import torch
import torch.nn as nn
import torch.optim as optim
from gensim.models import KeyedVectors
import re
from tqdm import tqdm
import numpy as np
from collections import Counter
import random
import torch.nn.functional as F
import os
import json
import math

In [9]:
!curl 'https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M.vec.zip' --output 'wiki-news-300d-1M.vec.zip'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  650M  100  650M    0     0   251M      0  0:00:02  0:00:02 --:--:--  251M


In [10]:
!unzip 'wiki-news-300d-1M.vec.zip'

Archive:  wiki-news-300d-1M.vec.zip
  inflating: wiki-news-300d-1M.vec   


In [11]:
fasttext = KeyedVectors.load_word2vec_format('wiki-news-300d-1M.vec', binary=False)

In [12]:
word_vec = fasttext['spongebob']
punct_vec = fasttext['!']

print("Vector for 'spongebob':", word_vec[:5])
print("Vector for '!':", punct_vec[:5])

Vector for 'spongebob': [-0.0727 -0.0882 -0.2449 -0.0302 -0.1174]
Vector for '!': [-0.1894 -0.002  -0.0817  0.0334  0.1775]


In [13]:
def load_text():
    text = ''
    for dirname, _, filenames in os.walk('/kaggle/input'):
        for filename in filenames:
            with open(os.path.join(dirname, filename)) as f:
                text += f.read()
    return text

In [14]:
def tokenize(text):
    tokens = []
    current = ""

    for char in text:
        if char.isalpha():
            current += char
        else:
            if current:
                tokens.append(current)
                current = ""
            if char == '\n':
                tokens.append('\n')       # newline is its own token
            elif char.strip() == "":
                tokens.append(' ')        # space or tab
            else:
                tokens.append(char)       # punctuation

    if current:
        tokens.append(current)

    return tokens

In [15]:
print(tokenize('Hello, this is a test\'n to see if the \ntokenize function works! buuuut we will have to see... so, i hope it works :) but it\'s going to depend on it for sure! Anyway. We will see, I\'ll see, haha!'))

['Hello', ',', ' ', 'this', ' ', 'is', ' ', 'a', ' ', 'test', "'", 'n', ' ', 'to', ' ', 'see', ' ', 'if', ' ', 'the', ' ', '\n', 'tokenize', ' ', 'function', ' ', 'works', '!', ' ', 'buuuut', ' ', 'we', ' ', 'will', ' ', 'have', ' ', 'to', ' ', 'see', '.', '.', '.', ' ', 'so', ',', ' ', 'i', ' ', 'hope', ' ', 'it', ' ', 'works', ' ', ':', ')', ' ', 'but', ' ', 'it', "'", 's', ' ', 'going', ' ', 'to', ' ', 'depend', ' ', 'on', ' ', 'it', ' ', 'for', ' ', 'sure', '!', ' ', 'Anyway', '.', ' ', 'We', ' ', 'will', ' ', 'see', ',', ' ', 'I', "'", 'll', ' ', 'see', ',', ' ', 'haha', '!']


In [16]:
def build_vocab(tokens, max_vocab_size=1024):
    # Count token frequencies
    token_counts = Counter(tokens)
    # Most common tokens up to max_vocab_size - 1 (reserve 1 for <UNK>)
    most_common = token_counts.most_common(max_vocab_size - 1)
    # Build vocab dict: token -> index, reserve 0 for <UNK>
    vocab = {token: idx + 1 for idx, (token, _) in enumerate(most_common)}
    vocab['<UNK>'] = 0
    return vocab

In [17]:
def tokens_to_indices(tokens, vocab):
    # Map tokens to indices; use 0 (<UNK>) if token not in vocab
    return [vocab.get(token, 0) for token in tokens]

In [18]:
print('Loading Spongebob Transcript...')
text = load_text()
print('Finished Loading!')
print(len(text))

Loading Spongebob Transcript...
Finished Loading!
5082524


In [19]:
def clean_ascii(text):
    return ''.join(c for c in text if ord(c) < 128)

In [20]:
tokens = tokenize(clean_ascii(text.lower()))
print(len(tokens))

2127632


In [21]:
vocab = build_vocab(tokens, max_vocab_size=4096)

In [22]:
with open('vocab.json', 'w') as f:
    json.dump(vocab, f)

In [23]:
indices = tokens_to_indices(tokens, vocab)
print(indices[100:200])

[519, 1, 6, 1, 4040, 2, 4, 7, 382, 1, 27, 1, 314, 1, 3308, 1, 25, 1, 591, 1, 12, 1, 2133, 1, 2084, 1, 27, 1, 589, 1, 13, 1, 0, 1, 857, 1, 14, 1, 118, 1, 3127, 2, 8, 4, 2804, 1, 143, 5, 1, 7, 1076, 1, 6, 1, 4040, 1, 1681, 1, 12, 1, 1149, 1, 56, 1, 13, 1, 389, 8, 1, 658, 31, 650, 10, 1, 7, 230, 1, 28, 1, 6, 1, 4040, 1, 389, 1, 12, 1, 138, 1, 14, 1, 491, 8, 4, 7, 6, 1, 87, 1, 138]


In [24]:
def coverage(tokens, vocab):
    known = sum(1 for t in tokens if t in vocab)
    return known / len(tokens)

In [25]:
coverage(tokens, vocab)

0.9779882047271332

In [41]:
class LSTMTextGen(nn.Module):
    def __init__(self, embedding_matrix, hidden_size=1024, num_layers=4, dropout=0.3):
        super().__init__()
        vocab_size, embedding_dim = embedding_matrix.shape
        
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(embedding_matrix, dtype=torch.float32),
            freeze=False  # allow training embeddings
        )

        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )

        self.layer_norm = nn.LayerNorm(hidden_size)

        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        output, hidden = self.lstm(x, hidden)
        output = self.layer_norm(output)
        logits = self.fc(output)
        return logits, hidden

In [69]:
class LSTMTextGen(nn.Module):
    def __init__(self, embedding_matrix, hidden_size=2048, num_layers=2):
        super().__init__()
        vocab_size, embedding_dim = embedding_matrix.shape
        
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(embedding_matrix, dtype=torch.float32),
            freeze=False  # allow training embeddings
        )

        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )

        self.layer_norm = nn.LayerNorm(hidden_size)

        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        output, hidden = self.lstm(x, hidden)
        output = self.layer_norm(output)
        logits = self.fc(output)
        return logits, hidden


In [70]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [71]:
def build_embedding_matrix(vocab, fasttext, embedding_dim=300):
    matrix = np.zeros((len(vocab), embedding_dim), dtype=np.float32)
    
    for token, idx in vocab.items():
        if token in fasttext:
            matrix[idx] = fasttext[token]
        else:
            # Random init for <UNK> or missing tokens
            matrix[idx] = np.random.normal(scale=0.6, size=embedding_dim)
    
    return matrix

In [57]:
embedding_matrix = build_embedding_matrix(vocab, fasttext, embedding_dim=300)

In [72]:
model = LSTMTextGen(embedding_matrix).to(device)

print(sum([0 if param.requires_grad == False else param.numel() for param in model.parameters()]))

62447616


In [59]:
def get_batch(seq_len=32, batch_size=64):
    inputs = torch.empty((batch_size, seq_len), dtype=torch.long)
    targets = torch.empty((batch_size, seq_len), dtype=torch.long)

    max_start = len(indices) - seq_len - 1
    for i in range(batch_size):
        start = random.randint(0, max_start)
        seq = indices[start:start + seq_len]
        tgt = indices[start + 1:start + seq_len + 1]

        inputs[i] = torch.tensor(seq, dtype=torch.long)
        targets[i] = torch.tensor(tgt, dtype=torch.long)

    return inputs, targets

In [73]:
def train(model, num_iters=1000, seq_len=64, batch_size=256, lr=1e-3):
    model.train()
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    losses = []

    for step in range(1, num_iters + 1):
        x, y = get_batch(seq_len, batch_size)
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits, _ = model(x)

        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        if step % 100 == 0 or step == 1:
            avg_loss = sum(losses) / len(losses)
            losses = []
            print(f"Step {step}/{num_iters} | Loss: {avg_loss:.4f}")

In [None]:
for _ in range(10):
    train(model)

Step 1/1000 | Loss: 8.3848
Step 100/1000 | Loss: 3.7935
Step 200/1000 | Loss: 3.0937
Step 300/1000 | Loss: 2.7840
Step 400/1000 | Loss: 2.5832
Step 500/1000 | Loss: 2.4700
Step 600/1000 | Loss: 2.3865
Step 700/1000 | Loss: 2.3157
Step 800/1000 | Loss: 2.2538
Step 900/1000 | Loss: 2.2042
Step 1000/1000 | Loss: 2.1564
Step 1/1000 | Loss: 2.1574
Step 100/1000 | Loss: 2.1562
Step 200/1000 | Loss: 2.0782
Step 300/1000 | Loss: 2.0414
Step 400/1000 | Loss: 2.0029


In [35]:
index_to_token = {idx: token for token, idx in vocab.items()}

In [36]:
with open('index_to_token.json', 'w') as f:
    json.dump(index_to_token, f)

In [38]:
def generate(model, input_sequence, length, temperature=1.0):
    model.eval()

    sentence = ''

    tokens = tokenize(input_sequence.lower())

    indices = tokens_to_indices(tokens, vocab)

    hidden = None

    for idx in indices:
        inp = torch.tensor([[idx]], dtype=torch.long).to(device)
        _, hidden = model(inp, hidden)

    inp = torch.tensor([[indices[-1]]], dtype=torch.long).to(device)

    for _ in range(length):
        logits, hidden = model(inp, hidden)  # logits: (1, 1, vocab_size)
        logits = logits[:, -1, :] / temperature
        logits[0, vocab['<UNK>']] = -float('inf') # Make <UNK> impossible to sample
        probs = F.softmax(logits, dim=-1)

        next_idx = torch.multinomial(probs, num_samples=1).item()

        sentence += index_to_token[next_idx]
        
        inp = torch.tensor([[next_idx]], dtype=torch.long).to(device)

    return input_sequence + sentence
    

In [44]:
print(generate(model, 'Patrick: Hi Spongebob, How are you doing?\nSpongebob: ', 250))

Patrick: Hi Spongebob, How are you doing?
Spongebob: spongebobwalterknocksatollantennaflashmetpersonstoppingthendecidecoloryupgrease6butterflytrulyimaginesecondssureinstantgroundtowardkisshoplovestightduerushes4ahoyfranticallycarolgangboredchargescarfishshapemassivemmraindoocompletelywhateverhistorycoveredhockeyallmissingupsidemayorslugfishcopseriouslypetersonnewestsetstrainsnancystrugglingscratchesmealssetbillysalesmanweekminiebellgoessamplewhitecrashesmiserygreekspatulassisterestablishmentoutprivateworsedisappointedpoplongtibutterflypaintyboardsshakesomebodypaintedmicrowavecakepipeinterruptsroundrexboltsthankturningpunchuponsnapsimpressedspendinglellyacornspeanutmagirlsbeatentestnowhereshutwrittenkelpshakekelpshakegreen2angrilywonderfulfoolsfoolingcardeelssmashmisterfergusonascranebookstreedomelageniussmittyshuddershandyayearstraightunnamedpreciousbooroyalcongratulationsyewalleatenswabsuburbanslitherfuzzydownwarddrumsimpressedpopprizedrumsugheyeposterspringarrivecominpecosperkinsquit

In [40]:
model = model.cpu()
torch.save(model.state_dict(), 'spongebob_lstm_27M.pth')