In [4]:
# %cd ..
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp

is_dark = False
theme, cs = rp.mpl_setup(is_dark)

In [243]:
from critic.kbd_layout import QWERTY
from critic.kbd_model import KbdModel
from critic.string_alignment import all_paths
from critic.corrector import Corrections
from dataclasses import dataclass
from keras import ops
import pickle

@dataclass
class CorrectionResult:
    true: str
    typed: str
    corrs: Corrections
    context: str
    time: float

class AdjustedKbdModel(KbdModel):
    def call(self, inputs):
        X, same_i, true_i = inputs
        probs = self.log_prob(X[:, 2], X[:, 3:5], X[:, 5:7])
        path_log_probs = ops.segment_sum(probs, X[:, 1], num_segments=1000)
        path_corr_is = ops.segment_max(X[:, 0], X[:, 1], num_segments=1000)
        
        yhat = ops.segment_max(path_log_probs, path_corr_is, num_segments=100)
        yhat = ops.log_softmax(yhat.at[same_i].set(0).at[-1].set(-np.inf))

        return -yhat[true_i]

In [244]:
layout = QWERTY

with open('results/lm_probs.pkl', 'rb') as f:
    lms = pickle.load(f)

lms[0]

CorrectionResult(true='million', typed='million', corrs=Corrections(words=['million', 'mullion', 'millions', 'gillion', 'pillion', 'billion', 'zillion', 'mil lion', 'mil-lion', 'mill ion', 'mill-ion', 'milling', 'milliner'], probs=array([1.47925184e-01, 4.46878511e-07, 1.07528916e-01, 2.13066060e-07,
       8.39868496e-07, 1.81012676e-01, 1.13113114e-05, 3.54274769e-07,
       2.83182178e-08, 1.15190670e-04, 1.26088814e-06, 5.63057794e-01,
       3.45786135e-04])), context='', time=0.2963411490054568)

In [245]:
from critic.string_alignment import align, all_paths


inputs = []
for res in lms[::5]:
    if res.true not in res.corrs.words:
        continue

    corrs = res.corrs.as_series().sort_values(ascending=False)
    true_i = np.argmax(corrs.index == res.true)
    same_i = np.argmax(corrs.index == res.typed) if res.typed in corrs.index else -1


    X = []
    i = 0
    for corr_i, word in enumerate(corrs.index):
        edit_paths = align(res.typed, word)    
        paths = all_paths(max(edit_paths.keys()), edit_paths)    

        x0s = []
        x1s = []
        x2s = []
        xis = []
        xcs = []
        for path_i, path in enumerate(paths):
            x0 = []
            x1 = []
            x2 = []
            for edit in path:
                if edit is not None:
                    kind, wrong, right = edit.as_numerical(layout)
                    x0.append(kind)
                    x1.append(list(wrong))
                    x2.append(list(right))

            if sum(ops.array(x0).shape) > 0:
                x0s.append(ops.array(x0))
                x1s.append(ops.array(x1))
                x2s.append(ops.array(x2))
                xcs.append(ops.zeros_like(x0s[-1]) + corr_i) 
                xis.append(ops.zeros_like(x0s[-1]) + i)
                i += 1


        if ops.array(xis).nbytes > 0:        
            X0, X1, X2, Xi, Xc = list(map(np.concat, (x0s, x1s, x2s, xis, xcs)))
            row = ops.concatenate((Xc[:, None], Xi[:, None], X0[:, None], X1, X2), axis=1)
            X.append(row)

    X = ops.concatenate(X, axis=0)        
    inputs.append([X, ops.array([same_i]), ops.array([true_i])])

In [249]:
import keras
from pprint import pprint

from keras.optimizers.schedules import PolynomialDecay

keras.config.disable_traceback_filtering()


def dl():
    while True:
        for x in inputs:
            yield (x, ops.array([0.0]))

def fit(epochs=25):
    mod = AdjustedKbdModel()    
    mod(next(dl())[0])

    steps_in_epoch = len(inputs)

    decay_steps = steps_in_epoch * epochs

    def log_prob_loss(y_true, y_pred):
        return y_pred

    mod.compile(
        optimizer=keras.optimizers.Adam(
            learning_rate=PolynomialDecay(
                1e-3, decay_steps, end_learning_rate=1e-6
            ),
            global_clipnorm=3.0,
        ),
        loss=log_prob_loss,
    )

    history = mod.fit(
        dl(),
        steps_per_epoch=steps_in_epoch,
        epochs=epochs,
    )

    print(pd.DataFrame(history.history))

    return mod

mod = fit()
pprint(mod.get_state_tree()['trainable_variables'], indent=2)

Epoch 1/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 79ms/step - loss: 0.7591
Epoch 2/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 306us/step - loss: 0.6869 
Epoch 3/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 307us/step - loss: 0.6394
Epoch 4/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 302us/step - loss: 0.6071
Epoch 5/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 313us/step - loss: 0.5844
Epoch 6/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 321us/step - loss: 0.5680
Epoch 7/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 313us/step - loss: 0.5558
Epoch 8/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 326us/step - loss: 0.5465
Epoch 9/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 378us/step - loss: 0.5393
Epoch 10/25
[1m236/236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0

In [250]:
mod.save_weights("models/kbd_model_new.weights.h5")