In [3]:
# %cd ..
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import rho_plus as rp

is_dark = False
theme, cs = rp.mpl_setup(is_dark)

In [4]:
from critic.kbd_layout import QWERTY
from critic.kbd_model import KbdModel
from critic.string_alignment import all_paths
from critic.corrector import Corrections
from dataclasses import dataclass
from keras import ops
import pickle

@dataclass
class CorrectionResult:
    true: str
    typed: str
    corrs: Corrections
    context: str
    time: float

class AdjustedKbdModel(KbdModel):
    def call(self, inputs):
        X, same_i, true_i, lm_probs = inputs
        probs = self.log_prob(X[:, 2], X[:, 3:5], X[:, 5:7])
        path_log_probs = ops.segment_sum(probs, X[:, 1], num_segments=1000)
        path_corr_is = ops.segment_max(X[:, 0], X[:, 1], num_segments=1000)
        
        yhat = ops.segment_max(path_log_probs, path_corr_is, num_segments=100)
        yhat = ops.log_softmax(yhat.at[same_i].set(0).at[-1].set(-np.inf) + ops.log(lm_probs))

        return -yhat[true_i]

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [5]:
layout = QWERTY

with open('results/lm_probs.pkl', 'rb') as f:
    lms = pickle.load(f)

lms[0]

CorrectionResult(true='million', typed='million', corrs=Corrections(words=['million', 'mullion', 'millions', 'gillion', 'pillion', 'billion', 'zillion', 'mil lion', 'mil-lion', 'mill ion', 'mill-ion', 'milling', 'milliner'], probs=array([8.77737983e-01, 5.09121721e-07, 1.06132202e-03, 1.50707586e-06,
       1.10553320e-05, 1.20752210e-01, 2.17134659e-04, 1.20400661e-07,
       1.74372439e-08, 1.29610905e-06, 1.37476171e-07, 2.05762160e-04,
       1.09456924e-05])), context='1.5 ', time=0.5459755449555814)

In [6]:
from critic.string_alignment import align, all_paths


inputs = []
for res in lms[::10]:
    if res.true not in res.corrs.words:
        continue

    corrs = res.corrs.as_series().sort_values(ascending=False)[:100]
    true_i = np.argmax(corrs.index == res.true)
    same_i = np.argmax(corrs.index == res.typed) if res.typed in corrs.index else -1


    X = []
    i = 0
    for corr_i, word in enumerate(corrs.index):
        edit_paths = align(res.typed, word)    
        paths = all_paths(max(edit_paths.keys()), edit_paths)    

        x0s = []
        x1s = []
        x2s = []
        xis = []
        xcs = []
        for path_i, path in enumerate(paths):
            x0 = []
            x1 = []
            x2 = []
            for edit in path:
                if edit is not None:
                    kind, wrong, right = edit.as_numerical(layout)
                    x0.append(kind)
                    x1.append(list(wrong))
                    x2.append(list(right))

            if sum(ops.array(x0).shape) > 0:
                x0s.append(ops.array(x0))
                x1s.append(ops.array(x1))
                x2s.append(ops.array(x2))
                xcs.append(ops.zeros_like(x0s[-1]) + corr_i) 
                xis.append(ops.zeros_like(x0s[-1]) + i)                
                i += 1


        if ops.array(xis).nbytes > 0:        
            X0, X1, X2, Xi, Xc = list(map(np.concat, (x0s, x1s, x2s, xis, xcs)))
            row = ops.concatenate((Xc[:, None], Xi[:, None], X0[:, None], X1, X2), axis=1)
            X.append(row)

    X = ops.concatenate(X, axis=0)        
    inputs.append([X, ops.array([same_i]), ops.array([true_i]), ops.array(np.pad(corrs.values, (0, 100 - len(corrs))))])

In [7]:
import keras
from pprint import pprint

from keras.optimizers.schedules import PolynomialDecay

keras.config.disable_traceback_filtering()


def dl():
    while True:
        for x in inputs:
            yield (x, ops.array([0.0]))

def fit(epochs=50):
    mod = AdjustedKbdModel()    
    mod(next(dl())[0])

    steps_in_epoch = len(inputs)

    decay_steps = steps_in_epoch * epochs

    def log_prob_loss(y_true, y_pred):
        return y_pred

    mod.compile(
        optimizer=keras.optimizers.RMSprop(
            learning_rate=PolynomialDecay(
                4e-4, decay_steps, end_learning_rate=1e-6
            ),
            global_clipnorm=1.0,
        ),
        loss=log_prob_loss,
    )

    history = mod.fit(
        dl(),
        steps_per_epoch=steps_in_epoch,
        epochs=epochs,
    )

    print(pd.DataFrame(history.history))

    return mod

mod = fit()
pprint(mod.get_state_tree()['trainable_variables'], indent=2)

Epoch 1/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 87ms/step - loss: 0.2544
Epoch 2/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 310us/step - loss: 0.2474    
Epoch 3/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 316us/step - loss: 0.2409  
Epoch 4/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 307us/step - loss: 0.2347  
Epoch 5/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 318us/step - loss: 0.2290  
Epoch 6/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 323us/step - loss: 0.2236  
Epoch 7/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 322us/step - loss: 0.2185   
Epoch 8/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 339us/step - loss: 0.2138  
Epoch 9/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 342us/step - loss: 0.2094   
Epoch 10/50
[1m243/243[0m [32m━━━━━━━━━━━━━━━━━━━━

In [11]:
mod.save_weights("models/kbd_model_new.weights.h5")

In [9]:
ops.softmax(ops.array([-0.66441685,  0.63354456,  0.50821155,  0.526713  ]))

Array([0.08942068, 0.32744282, 0.28887108, 0.29426536], dtype=float32)

In [10]:
ops.sigmoid(0.1083)

Array(0.5270485, dtype=float32, weak_type=True)