In [113]:
import equinox as eqx 
import jax 
import jax.numpy as jnp
from typing import List

In [142]:
ALPHABET = 'abcdefghijklmnopqrstuvwxyz'
SPECIAL = ['<S>', '<E>']

# helper functions 
def most_freq(dct): 
  return sorted(dct.items(), key=lambda kv: -kv[1])

def unique(s: str): 
  match s:
    case str():
      return sorted(list(set(s)))
    case list():
      return unique(''.join(s))
    case _:
      raise ValueError(f"unsupported type {type(s)}")

In [62]:
words = open('../names.txt', 'r').read().splitlines()
words2 = [line for line in open('../nietzsche_cleaned.txt', 'r', encoding='utf-8').read().splitlines() if line.strip()]
words3 = [line for line in open('../shakespeare.txt', 'r', encoding='utf-8').read().splitlines() if line.strip()]

In [63]:
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [64]:
len(words)

32033

In [65]:
min(len(w) for w in words)

2

In [66]:
max(len(w) for w in words)

15

In [106]:
N = jnp.zeros((len(ALPHABET) + len(SPECIAL), len(ALPHABET) + len(SPECIAL)), dtype=jnp.int32) # N.shape = (28, 28)

In [111]:
sorted(list(set(''.join(words)))

['l',
 'a',
 'x',
 'n',
 'p',
 'i',
 'd',
 'r',
 'h',
 'o',
 'm',
 's',
 'e',
 'u',
 'f',
 'z',
 'v',
 'b',
 'j',
 'q',
 'c',
 't',
 'g',
 'w',
 'y',
 'k']

In [149]:
chars = unique(words)
stoi = {s: i for i, s in enumerate(chars)}
stoi['<S>'] = 26
stoi['<E>'] = 27
stoi

{'a': 0,
 'b': 1,
 'c': 2,
 'd': 3,
 'e': 4,
 'f': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'j': 9,
 'k': 10,
 'l': 11,
 'm': 12,
 'n': 13,
 'o': 14,
 'p': 15,
 'q': 16,
 'r': 17,
 's': 18,
 't': 19,
 'u': 20,
 'v': 21,
 'w': 22,
 'x': 23,
 'y': 24,
 'z': 25,
 '<S>': 26,
 '<E>': 27}

In [154]:
for w in words:
  cs = ['<S>'] + list(w) + ['<E>']
  for c1, c2 in zip(cs, cs[1:]):
    ix1 = stoi[c1]
    ix2 = stoi[c2]
    print(f"{c1} -> {ix1}, {c2} -> {ix2}")
    N = N.at[ix1, ix2].add(1)

<S> -> 26, e -> 4
e -> 4, m -> 12
m -> 12, m -> 12
m -> 12, a -> 0
a -> 0, <E> -> 27
<S> -> 26, o -> 14
o -> 14, l -> 11
l -> 11, i -> 8
i -> 8, v -> 21
v -> 21, i -> 8
i -> 8, a -> 0
a -> 0, <E> -> 27
<S> -> 26, a -> 0
a -> 0, v -> 21
v -> 21, a -> 0
a -> 0, <E> -> 27
<S> -> 26, i -> 8
i -> 8, s -> 18
s -> 18, a -> 0
a -> 0, b -> 1
b -> 1, e -> 4
e -> 4, l -> 11
l -> 11, l -> 11
l -> 11, a -> 0
a -> 0, <E> -> 27
<S> -> 26, s -> 18
s -> 18, o -> 14
o -> 14, p -> 15
p -> 15, h -> 7
h -> 7, i -> 8
i -> 8, a -> 0
a -> 0, <E> -> 27
<S> -> 26, c -> 2
c -> 2, h -> 7
h -> 7, a -> 0
a -> 0, r -> 17
r -> 17, l -> 11
l -> 11, o -> 14
o -> 14, t -> 19
t -> 19, t -> 19
t -> 19, e -> 4
e -> 4, <E> -> 27
<S> -> 26, m -> 12
m -> 12, i -> 8
i -> 8, a -> 0
a -> 0, <E> -> 27
<S> -> 26, a -> 0
a -> 0, m -> 12
m -> 12, e -> 4
e -> 4, l -> 11
l -> 11, i -> 8
i -> 8, a -> 0
a -> 0, <E> -> 27
<S> -> 26, h -> 7
h -> 7, a -> 0
a -> 0, r -> 17
r -> 17, p -> 15
p -> 15, e -> 4
e -> 4, r -> 17
r -> 17, <E> -> 27
