## Imports

In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

## Data utils

In [2]:
data = "AB#"
vocab_size = len(set(data))
char2int = {char: i for i, char in enumerate(data)}
int2char = {i: char for i, char in enumerate(data)}
PATTERN = "ABBA"

In [3]:
def make_seq():
    ix_pattern = [char2int[c] for c in PATTERN]
    seq = []
    while seq[-len(PATTERN):] != ix_pattern:
        seq.append(np.random.randint(2))
    return seq + [char2int["#"]]

for _ in range(5):
    print(make_seq())

[0, 1, 0, 1, 0, 1, 1, 0, 2]
[0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 2]
[0, 0, 1, 1, 0, 2]
[1, 0, 1, 1, 0, 2]
[0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 2]


In [4]:
def one_hot(seq):
    return torch.stack([torch.eye(vocab_size)[i] for i in seq])

def token2char(token):
    if token.numpy().item(0) == 1.:
        return "A"
    if token.numpy().item(1) == 1.:
        return "B"
    if token.numpy().item(2) == 1.:
        return "#"

for row in one_hot(make_seq()):
    print(f"{row} -> {token2char(row)}")

tensor([0., 1., 0.]) -> B
tensor([0., 1., 0.]) -> B
tensor([0., 1., 0.]) -> B
tensor([0., 1., 0.]) -> B
tensor([1., 0., 0.]) -> A
tensor([1., 0., 0.]) -> A
tensor([0., 1., 0.]) -> B
tensor([0., 1., 0.]) -> B
tensor([0., 1., 0.]) -> B
tensor([1., 0., 0.]) -> A
tensor([1., 0., 0.]) -> A
tensor([1., 0., 0.]) -> A
tensor([1., 0., 0.]) -> A
tensor([0., 1., 0.]) -> B
tensor([0., 1., 0.]) -> B
tensor([1., 0., 0.]) -> A
tensor([0., 0., 1.]) -> #


## Visualization utils

In [5]:
def viz(hidden_states, states, hidden_size, path):
    last_hidden_states = hidden_states[-1000:]
    last_states = states[-1000:]
    states_np = np.asarray(last_hidden_states).reshape(-1, hidden_size)

    pca = PCA(n_components=2)
    X_embedded = pca.fit_transform(states_np)


    fig, ax = plt.subplots()
    scatter = ax.scatter(X_embedded[:,0], X_embedded[:,1], c=last_states)
    legend1 = ax.legend(handles=scatter.legend_elements()[0],
                        labels=[PATTERN[:i] for i in range(len(PATTERN+"#"))],
                        title="STATES", loc="best")

    ax.add_artist(legend1)

    plt.savefig(path)
    plt.show()
    plt.close()

In [6]:
class DFA:
    def __init__(self, pattern):
        self.pattern = pattern + "#"
        self.state = ""

    def transition(self, char):
        tentative_state = self.state + char
        N = len(tentative_state)
        if tentative_state == self.pattern[:N]:
            self.state = tentative_state
        else:
            for i in range(N+1):
                if tentative_state[i:] == self.pattern[:N-i]:
                    self.state = tentative_state[i:]
                    break
        return self.state

    def hsh(self):
        return len(self.state)

## Language model

In [7]:
# network parameters
hidden_size = 6

Wxh = torch.randn(vocab_size, hidden_size, requires_grad=True)
Whh = torch.randn(hidden_size, hidden_size, requires_grad=True)
bh = torch.zeros(hidden_size, requires_grad=True)

Why = torch.randn(hidden_size, vocab_size, requires_grad=True)
by = torch.zeros(vocab_size, requires_grad=True)
optimizer = torch.optim.Adam([Wxh, Whh, bh, Why, by], lr=0.005)
loss_fn = torch.nn.CrossEntropyLoss()

# containers for viz
acc_hidden_states = list()
states = list()
automata = DFA(pattern="ABBA#")

In [12]:
def min_rnn(x, h):
    def rnn_step(xt, ht):
        return torch.tanh(xt @ Wxh + ht @ Whh + bh)

    def linear(x):
        return x @ Why + by

    outputs = []
    for t in range(x.shape[0]):
        h = rnn_step(x[t,:], h)
        acc_hidden_states.append(h.detach().clone().numpy())
        automata.transition(token2char(x[t,:]))
        states.append(automata.hsh())
        outputs.append(linear(h))

    return torch.stack(outputs, dim=1).squeeze(0), h


In [13]:
def sample():
    seq = ""
    indices = list(char2int.values())
    with torch.no_grad():
        i = np.random.randint(2)
        h = torch.zeros(1, hidden_size)
        char = int2char[i]
        seq += char
        while char != "#":
            logits, h = min_rnn(one_hot([i]), h)
            ps = F.softmax(logits, dim=-1)
            i = np.random.choice(indices, p=ps.squeeze(0).numpy())
            char = int2char[i]
            seq += char
    return seq

In [14]:
# training
total_loss = 0
for epoch in tqdm(range(30000)):
    h = torch.zeros(1, hidden_size)
    seq = make_seq()
    inputs, targets = seq[:-1], seq[1:]
    preds, _ = min_rnn(one_hot(inputs), h)
    loss = loss_fn(preds, one_hot(targets))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

    if (epoch) % 2000 == 50:
        print(f"Epoch: {epoch} Loss:{(total_loss/epoch):.4f}")
        print(sample()[-5:])
        viz(acc_hidden_states, states, hidden_size, f"{epoch}.png")

  0%|          | 0/30000 [00:00<?, ?it/s]


TypeError: min_rnn.<locals>.rnn_step() missing 1 required positional argument: 'ht'

In [11]:
sum(sample()[-5:] == "ABBA#" for _ in range(100)) # accuracy

100