# Recurent text gen

In [None]:
import tqdm
import kagglehub
import pandas as pd
from pathlib import Path

import torch
from torch import nn

path = kagglehub.dataset_download("Cornell-University/arxiv/versions/205")
path = Path(path)/"arxiv-metadata-oai-snapshot.json"

In [None]:
data_path = Path("recurent_layer_files")/"arxiv_small.json"

if not data_path.exists():
    
    lines = []
    with open(path, "r") as f:
        for i, one_line in enumerate(tqdm.tqdm(f.readlines())):
            if i % 10 == 0:
                lines.append(one_line)

    with open(data_path, mode="w") as f:
        f.writelines(lines)

data = pd.read_json(data_path, lines=True)

In [None]:
BOS, EOS = " ", "\n"
lines = (
    data
    .apply(lambda row: (row["title"] + " ; " + row["abstract"])[:512], axis=1)
    .apply(lambda line: BOS + line.replace(EOS, " ") + EOS)
    .tolist()
)

lines[:3]

[' Calculation of prompt diphoton production cross sections at Tevatron and   LHC energies ;   A fully differential calculation in perturbative quantum chromodynamics is presented for the production of massive photon pairs at hadron colliders. All next-to-leading order perturbative contributions from quark-antiquark, gluon-(anti)quark, and gluon-gluon subprocesses are included, as well as all-orders resummation of initial-state gluon radiation valid at next-to-next-to-leading logarithmic accuracy. The region o\n',
 ' Computing genus 2 Hilbert-Siegel modular forms over $\\Q(\\sqrt{5})$ via   the Jacquet-Langlands correspondence ;   In this paper we present an algorithm for computing Hecke eigensystems of Hilbert-Siegel cusp forms over real quadratic fields of narrow class number one. We give some illustrative examples using the quadratic field $\\Q(\\sqrt{5})$. In those examples, we identify Hilbert-Siegel eigenforms that are possible lifts from Hilbert eigenforms. \n',
 ' Molecular Syn

In [None]:
tokens = {one_char for one_line in lines for one_char in one_line}

tokens = sorted(tokens)
"".join(tokens)

'\n !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x99â'

In [None]:
token_to_id = {x: i for i, x in enumerate(tokens)}

def to_tensor(
    lines: list[str],
    max_len: int | None = None,
    pad: str = token_to_id[EOS],
    dtype=torch.int64,
):
    max_len = max_len or max(map(len, lines))
    lines_ix = torch.full([len(lines), max_len], pad, dtype=dtype)
    for i in range(len(lines)):
        line_ix = [token_to_id[x] for x in lines[i][:max_len]]
        lines_ix[i, : len(line_ix)] = torch.tensor(line_ix)
    return lines_ix


print(to_tensor([" abc\n", " abacaba\n", " abc1234567890\n"]))

tensor([[ 1, 66, 67, 68,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 66, 67, 66, 68, 66, 67, 66,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 66, 67, 68, 18, 19, 20, 21, 22, 23, 24, 25, 26, 17,  0]])


In [None]:
my_rnn = MyRnnLayer(10, 32)
test_input = torch.randn((5, 12, 10))
test_output = my_rnn(test_input)
test_output[0].shape, test_output[1].shape

(torch.Size([5, 12, 32]), torch.Size([1, 5, 32]))

In [None]:
inp = to_tensor(["hello world"])

emb = nn.Embedding(num_embeddings=len(tokens), embedding_dim=16)
ans = emb(inp)
ans.shape

torch.Size([1, 11, 16])

In [None]:
rnn = MyRnnLayer(16, 256)
ans, hidden = rnn(ans)

In [None]:
ans.shape

torch.Size([1, 11, 256])

In [None]:
hidden.shape

torch.Size([1, 1, 256])