# Character-level RNNs for generating contract clauses

# To Do & Questions

- [ ] Mechanics of model training
    - Inspect weights to make sure they're changing.
    - Does the prediction get more correct on each iteration?

- [ ] Expand from dev vocab to full ASCII alphabet
- [ ] From single examples to batch
- [ ] How reset the hidden state on a new "run" (whatever that means), without resetting the learned weights too?
- [ ] How to process a single call to `model.forward` as a sequence, instead of one character at a time?
- [ ] How is the PyTorch RNN layer implemented?

# Notes

* The model seems to get stuck in absorbing states. Repeating the call to `generate_text` yields the same output - why would that be? First guess is a bug. If not, though, then what?
    - Suggests a deeper thing about UI: model should have a public `reset_state` method that sets state back to 0 (but doesn't change the weights at all).

# Intro

**Goal:** reproduce Karpathy's blog post on character RNNs for Paul Graham essays and Shakespeare, to really grok 
how RNNs work deeply.

https://karpathy.github.io/2015/05/21/rnn-effectiveness/


## Pieces from the blog post

* Train vs. evaluate
* Use validation set during training
* Watch how quality evolves over training epochs
* RNN state updates vs. LSTM state updates
* Stacking RNNs
* Dropout
* Backprop through time (BPTT)?
* Visualize next-char distribution, given input sequence
* Visualize key active neurons, given an input sequence. Any that are immediately interpretable.
* Optimizer: RMSProp or Adam
* Temperature of sampling
* Sampling logic (beam search, vs. one letter at a time?)

# Data 

In [1]:
from datasets import load_dataset

ds = load_dataset("lex_glue", "ledgar")
ds

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
})

In [2]:
def count_chars(corpus: list[str]):
    return sum(map(len, corpus))

def pprint(x: int):
    print(f"{x:_}")

In [3]:
pprint(count_chars(ds['train']['text']))

42_460_532


In [4]:
pprint(count_chars(ds['validation']['text']))

6_995_963


In [5]:
pprint(count_chars(ds['test']['text']))

6_698_900


# Tiny data for dev

In [6]:
data = ds['train']['text'][:3]
data

['Except as otherwise set forth in this Debenture, the Company, for itself and its legal representatives, successors and assigns, expressly waives presentment, protest, demand, notice of dishonor, notice of nonpayment, notice of maturity, notice of protest, presentment for the purpose of accelerating maturity, and diligence in collection.',
 'No ERISA Event has occurred or is reasonably expected to occur that, when taken together with all other such ERISA Events for which liability is reasonably expected to occur, could reasonably be expected to result in a Material Adverse Effect. Neither Borrower nor any ERISA Affiliate maintains or contributes to or has any obligation to maintain or contribute to any Multiemployer Plan or Plan, nor otherwise has any liability under Title IV of ERISA.',
 'This Amendment may be executed by one or more of the parties hereto on any number of separate counterparts, and all of said counterparts taken together shall be deemed to constitute one and the same

In [7]:
pprint(count_chars(data))

1_118


# Vocabulary

In [8]:
# import string

In [9]:
vocab = "rstlnea" + "." + " "
vocab = list(vocab) + ['<UNK>']
vocab

['r', 's', 't', 'l', 'n', 'e', 'a', '.', ' ', '<UNK>']

In [10]:
char2ix = {c: i for i, c in enumerate(vocab)}
# char2ix['<UNK>'] = len(char2ix)
char2ix

{'r': 0,
 's': 1,
 't': 2,
 'l': 3,
 'n': 4,
 'e': 5,
 'a': 6,
 '.': 7,
 ' ': 8,
 '<UNK>': 9}

# Model

In [11]:
import torch
import torch.nn as nn

In [12]:
class MyRNN(nn.Module):
    def __init__(self, vocab_length, hidden_units):
        super().__init__()
        self.embedding_dim = vocab_length
        self.embedding = nn.Embedding.from_pretrained(torch.eye(vocab_length))  # frozen by default
        self.h = torch.zeros(hidden_units)
        self.W_hh = nn.Parameter(torch.randn(hidden_units, hidden_units))
        self.W_hx = nn.Parameter(torch.randn(hidden_units, self.embedding_dim))
        self.W_out = nn.Parameter(torch.randn(vocab_length, hidden_units))
    
    def forward(self, x):
        self.h = torch.tanh(self.W_hh @ self.h + self.W_hx @ self.embedding(x)[0])
        y = self.W_out @ self.h
        return y

In [13]:
model = MyRNN(vocab_length=len(vocab), hidden_units=4)
model

MyRNN(
  (embedding): Embedding(10, 10)
)

In [14]:
list(model.parameters())

[Parameter containing:
 tensor([[-1.1697,  0.1895, -0.3283,  1.8748],
         [-0.7445, -0.6746, -1.1363, -0.3717],
         [ 0.7787, -1.6314, -0.4668, -0.9429],
         [ 1.1840, -2.0334,  0.5015, -1.1130]], requires_grad=True),
 Parameter containing:
 tensor([[ 1.5139,  0.8695,  1.1172, -0.4834,  2.2298,  1.6107, -0.1146,  1.0463,
           3.0711,  0.4728],
         [ 0.3356,  1.0531, -0.3292,  0.6214, -0.3036,  0.3481, -1.5734, -2.1714,
          -1.7093, -0.6690],
         [ 0.4458, -1.8069, -1.1456, -1.1808,  0.5067, -0.1222,  0.2413, -0.7415,
          -0.7576, -0.8985],
         [-1.2563, -0.9955,  1.3200, -0.1624, -1.9447, -0.4463,  1.8253,  0.4621,
           0.1607,  0.6085]], requires_grad=True),
 Parameter containing:
 tensor([[-0.6856, -0.2899,  0.6719,  0.0141],
         [-0.0730, -0.0859, -1.0558,  0.3905],
         [ 0.2735, -1.0138,  0.6939,  0.6311],
         [-0.8066, -0.7271, -0.0401,  0.3221],
         [ 1.4747, -0.5686,  0.3175, -0.1307],
         [-3.2276,  

# Demo model `step` interface with changing internal state

In [None]:
x2 = torch.tensor([5])
x2

In [None]:
model(x2)

In [None]:
model.h

In [None]:
model(x2)

In [None]:
model.h

# Cross-entropy loss

In [None]:
objective = nn.CrossEntropyLoss()

In [None]:
y = model(x2)
y

In [None]:
len(y)

In [None]:
len(char2ix)

In [None]:
objective(input=y, target=torch.tensor(9))

# Sample from the (untrained) model

For now, just take the most probable output from the model. Later can work on sampling, temperature, beam search, etc. 

In [41]:
def set_prompt(model, prompt: str):
    """Update the model's hidden state based on a user prompt. The model is mutated, not returned."""
    model.eval()
    
    for c in prompt:
        x = torch.tensor([char2ix[c]])
        with torch.no_grad():
            y = model(x)
        ix = torch.argmax(y, keepdim=True)

In [43]:
model = MyRNN(vocab_length=len(vocab), hidden_units=4)
model

MyRNN(
  (embedding): Embedding(10, 10)
)

In [44]:
model.h

tensor([0., 0., 0., 0.])

In [45]:
set_prompt(model, "nearest ")

In [46]:
model.h

tensor([-0.6279,  0.5244, -0.4616,  0.4919])

In [51]:
def autoregress(model, start_char: str, num_chars: int) -> str:
    model.eval()
    out = ''
    
    ix = torch.tensor([char2ix[start_char]])
    
    for _ in range(num_chars):
        with torch.no_grad():
            y = model(ix)
        ix = torch.argmax(y, keepdim=True)
        out += vocab[ix]
        
    return out

In [72]:
def generate_text(model, prompt, num_chars):
    # Run model forward on the prompt to set hidden state.
    set_prompt(model, prompt)
    
    # Generate new text, auto-regressive style.
    out = autoregress(model, start_char=prompt[-1], num_chars=num_chars)
    
    return out

In [92]:
model = MyRNN(vocab_length=len(vocab), hidden_units=4)
model

MyRNN(
  (embedding): Embedding(10, 10)
)

In [95]:
generate_text(model, prompt="nearest ", num_chars=50)

'sllenelenelenelenelenelenelenelenelenelenelenelene'