# Recurent text gen

In [2]:
import tqdm
import kagglehub
import pandas as pd
from pathlib import Path

import torch
from torch import nn

path = kagglehub.dataset_download("Cornell-University/arxiv/versions/205")
path = Path(path)/"arxiv-metadata-oai-snapshot.json"

In [3]:
data_path = Path("recurent_layer_files")
data_path.mkdir(exist_ok=True)
data_path = data_path/"arxiv_small.json"

if not data_path.exists():
    
    lines = []
    with open(path, "r") as f:
        for i, one_line in enumerate(tqdm.tqdm(f.readlines())):
            if i % 10 == 0:
                lines.append(one_line)

    with open(data_path, mode="w") as f:
        f.writelines(lines)

data = pd.read_json(data_path, lines=True)

100%|██████████| 2601564/2601564 [00:02<00:00, 1125297.77it/s]


In [4]:
BOS, EOS = " ", "\n"
lines = (
    data
    .apply(lambda row: (row["title"] + " ; " + row["abstract"])[:512], axis=1)
    .apply(lambda line: BOS + line.replace(EOS, " ") + EOS)
    .tolist()
)

lines[:3]

[' Calculation of prompt diphoton production cross sections at Tevatron and   LHC energies ;   A fully differential calculation in perturbative quantum chromodynamics is presented for the production of massive photon pairs at hadron colliders. All next-to-leading order perturbative contributions from quark-antiquark, gluon-(anti)quark, and gluon-gluon subprocesses are included, as well as all-orders resummation of initial-state gluon radiation valid at next-to-next-to-leading logarithmic accuracy. The region o\n',
 ' Computing genus 2 Hilbert-Siegel modular forms over $\\Q(\\sqrt{5})$ via   the Jacquet-Langlands correspondence ;   In this paper we present an algorithm for computing Hecke eigensystems of Hilbert-Siegel cusp forms over real quadratic fields of narrow class number one. We give some illustrative examples using the quadratic field $\\Q(\\sqrt{5})$. In those examples, we identify Hilbert-Siegel eigenforms that are possible lifts from Hilbert eigenforms. \n',
 ' Molecular Syn

In [5]:
TOKENS = {one_char for one_line in lines for one_char in one_line}

TOKENS = sorted(TOKENS)
"".join(TOKENS)

'\n !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x99â'

In [6]:
token_to_id = {x: i for i, x in enumerate(TOKENS)}

def to_tensor(
    lines: list[str],
    max_len: int | None = None,
    pad: str = token_to_id[EOS],
    dtype=torch.int64,
):
    max_len = max_len or max(map(len, lines))
    lines_ix = torch.full([len(lines), max_len], pad, dtype=dtype)
    for i in range(len(lines)):
        line_ix = [token_to_id[x] for x in lines[i][:max_len]]
        lines_ix[i, : len(line_ix)] = torch.tensor(line_ix)
    return lines_ix


print(to_tensor([" abc\n", " abacaba\n", " abc1234567890\n"]))

tensor([[ 1, 66, 67, 68,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 66, 67, 66, 68, 66, 67, 66,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 66, 67, 68, 18, 19, 20, 21, 22, 23, 24, 25, 26, 17,  0]])


## Model form

In [6]:
inp = to_tensor(["sentence 1 hi", "sentence 2 wow"])
emb_size = 16
hidden_size = 10

In [8]:
emb = nn.Embedding(num_embeddings=len(TOKENS), embedding_dim=emb_size)
embeded = emb(inp)
embeded.shape

torch.Size([2, 14, 16])

In [9]:
rnn = nn.RNN(input_size=emb_size, hidden_size=hidden_size)
states = rnn(embeded)[0]
states.shape

torch.Size([2, 14, 10])

In [10]:
linear = nn.Linear(in_features=hidden_size, out_features=len(TOKENS))
linear(states).shape

torch.Size([2, 14, 100])

In [9]:
class RNNLanguageModel(nn.Module):
    def __init__(
        self, 
        n_tokens: int, 
        emb_size: int = 16, 
        hid_size: int = 256
    ):
        super().__init__()
        self.emb = nn.Embedding(num_embeddings=n_tokens, embedding_dim=emb_size)
        self.rnn = nn.RNN(emb_size, hid_size, batch_first=True)
        self.linear = nn.Linear(in_features=hid_size, out_features=n_tokens)

    def forward(self, input_ix):
        rv: torch.Tensor = self.emb(input_ix)
        rv: torch.Tensor = self.rnn(rv)[0]
        rv: torch.Tensor = self.linear(rv)
        return rv
    
model = RNNLanguageModel(len(TOKENS))
model(inp).shape

torch.Size([2, 14, 100])

## Loss

The loss formula is as follows:

$$
L = - \cfrac{1}{N} \sum_{i=1}^N \ln p(x_t^{(i)} | x_{t-1}^{(i)}, \dots, x_1^{(i)})
$$

We'll pass to our network some sequence of tokens, so if we want to predict $t$-th token we have to pass all previous $t-1$ tokens. The result would be probabilities for each token to be $t$-th.


$p(x_t^{(i)} | x_{t-1}^{(i)}, \dots, x_1^{(i)})$: is a probability that $t$-th token follows previou

In [26]:
inp = to_tensor(["Some long input to the model", "short"])
model = RNNLanguageModel(len(TOKENS))

In [34]:
logits = model(inp[:, :-1])
probas = torch.softmax(logits, 2)
probas.shape

torch.Size([2, 27, 100])

In [42]:
reference_answers = inp[:, 1:]

In [44]:
torch.gather(probas, 2, reference_answers[:, :, None])

tensor([[[0.0102],
         [0.0114],
         [0.0103],
         [0.0101],
         [0.0096],
         [0.0096],
         [0.0104],
         [0.0099],
         [0.0115],
         [0.0096],
         [0.0097],
         [0.0098],
         [0.0100],
         [0.0097],
         [0.0101],
         [0.0086],
         [0.0095],
         [0.0105],
         [0.0081],
         [0.0109],
         [0.0093],
         [0.0098],
         [0.0099],
         [0.0101],
         [0.0106],
         [0.0107],
         [0.0127]],

        [[0.0118],
         [0.0099],
         [0.0089],
         [0.0084],
         [0.0092],
         [0.0089],
         [0.0092],
         [0.0087],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [0.0090],
         [

tensor([[80, 78, 70,  1, 77, 80, 79, 72,  1, 74, 79, 81, 86, 85,  1, 85, 80,  1,
         85, 73, 70,  1, 78, 80, 69, 70, 77],
        [73, 80, 83, 85,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0]])