In [None]:
!pip install keras-nlp==0.10.0
!pip install keras==2.15.0
!pip install tensorflow==2.15.0
!pip install faiss-cpu==1.10.0

In [None]:
import tensorflow as tf
from tensorflow import keras
from keras.layers import *
import keras_nlp
from keras import backend

import re
import requests
import numpy as np
import random
import math
import string
import nltk
import json

In [None]:
from transformers import AutoTokenizer
from tokenizers import AddedToken

tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-base')
tokenizer.add_tokens(AddedToken("\n", normalized=False))
tokenizer.add_tokens(AddedToken("<s>", normalized=False))
vocab_size = len(tokenizer.get_vocab().keys())
print("vocab_size:", vocab_size)
print("pad token id:", tokenizer.pad_token)

In [None]:
import spacy

nlp = spacy.load("en_core_web_lg")
nlp.max_length = 2000000

all_pos = {'PART', 'INTJ', 'SPACE', 'AUX', 'PUNCT', 'SYM', 'X', 'SCONJ', 'NUM', 'NOUN', 'ADP', 'ADJ', 'ADV', 'PRON', 'DET', 'CCONJ', 'PROPN', 'VERB'}
#selected = {'NUM', 'NOUN', 'ADJ', 'PROPN'}  # For training
selected = {'NUM', 'PROPN'}                  # For inference

alltoks = sorted(list(tokenizer.get_vocab().items()), key=lambda x:x[1])
all_toks_text = "\n".join([t[0].replace("▁", "") for t in alltoks])

doc = nlp(all_toks_text)

carry_toks = set()

print(len(doc), len(alltoks))

i = 0
for ii, token in enumerate(doc):
    if str(token) in alltoks[i][0]: pass
    else: i += 1
    if str(token) in alltoks[i][0] and token.pos_ in selected and i > 100:
        if (token.pos_ != "PROPN" or alltoks[i][0].replace("▁", "")[0].isupper()):
            carry_toks.add(alltoks[i][1])
print(len(carry_toks))

In [None]:
file = open("dataset_rpc.json", "r")
data = json.loads(file.read())
file.close()

dataset = {}
for subset in data:
    dataset[subset] = {}
    for subsubset in data[subset]:
        dataset[subset][subsubset] = []
        for text in data[subset][subsubset]:
            text = "".join(text)
            text = tokenizer.encode("<s>" + text, add_special_tokens=False)
            dataset[subset][subsubset].append(text)

train = [text for data in dataset.values() for text in data["train"]]
test  = [text for data in dataset.values() for text in data["test"]]
print(len(train), len(test))

In [None]:
input_size  = 320 #512
embed_dim   = 128
not_carry_w = 0.5

In [None]:
train = [text[:input_size+1] for text in train]
train_padded = [text + ([tokenizer.pad_token_id] * (input_size+1 - len(text))) for text in train]

test = [text[:input_size+1] for text in test]
test_padded = [text + ([tokenizer.pad_token_id] * (input_size+1 - len(text))) for text in test]

Weights help the model suring training to focus on tokens like names, numbers and nouns that should be transported from the past

In [None]:
weights = []

for text in train:
    in_past = set()
    w = []
    for i, t in enumerate(text):
        if t in carry_toks:
            if t in in_past:
                w.append(1.0)
            else:
                in_past.add(t)
                w.append(not_carry_w)
        elif t != tokenizer.pad_token_id:
            w.append(not_carry_w)
        else: break
    w += [0.0] * (input_size+1 - len(w))
    weights.append(w)

In [None]:
X = tf.constant(train_padded, shape=(len(train_padded), input_size+1), dtype=tf.int32)
T = tf.constant(test_padded,  shape=(len(test_padded),  input_size+1), dtype=tf.int32)
W = tf.constant(weights,      shape=(len(weights),      input_size+1), dtype=tf.float32)

## Create Model
Defining the embedding layer, differential attention layer and transformer model architecture 

In [None]:
def masked_accuracy(y_true, y_pred, padding_token=tokenizer.pad_token_id):
    y_true = tf.cast(y_true, tf.int32)
    y_pred = tf.cast(tf.argmax(y_pred, axis=-1), tf.int32)
    mask = tf.cast(tf.not_equal(y_true, padding_token), tf.float32)
    matches = tf.cast(tf.equal(y_true, y_pred), tf.float32)
    accuracy = tf.reduce_sum(matches * mask) / tf.reduce_sum(mask)
    return accuracy

In [None]:
class SharedEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, embed_dim, **kwargs):
        super(SharedEmbedding, self).__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        
    def build(self, input_shape):
        self.shared_weights = self.add_weight(
            shape=(self.vocab_size, self.embed_dim),
            initializer='random_normal',
            trainable=True,
            name='shared_weights'
        )
        super(SharedEmbedding, self).build(input_shape)
    
    def call(self, inputs, mode='embedding'):
        if mode == 'embedding':
            return tf.nn.embedding_lookup(self.shared_weights, inputs)
        elif mode == 'classify':
            return tf.nn.softmax(tf.matmul(inputs, self.shared_weights, transpose_b=True), axis=-1)

In [None]:
class DiffAttention(keras.layers.Layer):
    def __init__(self, depth, **kwargs):
        super(DiffAttention, self).__init__(**kwargs)
        self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)

    def build(self, input_shape):
        self.embed_dim = input_shape[-1]
        self.input_size = input_shape[-2]
        self.mask = tf.where(tf.linalg.band_part(tf.ones((input_shape[-2], input_shape[-2])), -1, 0) == 1.0, 0.0, float("-inf"))
        self.range_do = -tf.range(input_shape[-2])-1
        self.range_undo = tf.range(input_shape[-2])+1
        self.Q = self.add_weight(name='kernelQ',
                                      shape=(input_shape[-1], input_shape[-1]),
                                      initializer='uniform',
                                      trainable=True)
        self.K = self.add_weight(name='kernelK',
                                      shape=(input_shape[-1], input_shape[-1]),
                                      initializer='uniform',
                                      trainable=True)
        self.V = self.add_weight(name='kernelV',
                                      shape=(input_shape[-1], input_shape[-1]),
                                      initializer='uniform',
                                      trainable=True)

        initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.1)
        self.lambda_q1 = self.add_weight(
            shape=(input_shape[-1],), initializer=initializer, trainable=True, name="lambda_q1"
        )
        self.lambda_k1 = self.add_weight(
            shape=(input_shape[-1],), initializer=initializer, trainable=True, name="lambda_k1"
        )
        self.lambda_q2 = self.add_weight(
            shape=(input_shape[-1],), initializer=initializer, trainable=True, name="lambda_q2"
        )
        self.lambda_k2 = self.add_weight(
            shape=(input_shape[-1],), initializer=initializer, trainable=True, name="lambda_k2"
        )
        
        super(DiffAttention, self).build(input_shape)

    def roll_embeddings(self, tensor, shift_values):
        batch_size, time_size, embed_dim = tensor.shape
        if batch_size is None: return tensor
        shift_matrix   = tf.reshape(shift_values, (1, -1, 1))
        shift_matrix   = tf.tile(shift_matrix, [batch_size, 1, embed_dim])
        indices        = tf.range(embed_dim)
        indices_matrix = tf.tile(indices, [batch_size * time_size])
        indices_matrix = tf.reshape(indices_matrix, (batch_size, time_size, embed_dim))
        new_indices    = (indices_matrix + shift_matrix) % embed_dim     
        rolled_tensor  = tf.gather(tensor, new_indices, batch_dims=2)
        return rolled_tensor

    def call(self, x, pos):
        v    = x @ self.V
        q    = tf.transpose(tf.reshape(x @ self.Q, (-1, self.input_size, 2, self.embed_dim//2)), perm=[0, 2, 1, 3])
        k    = tf.transpose(tf.reshape(x @ self.K, (-1, self.input_size, 2, self.embed_dim//2)), perm=[0, 2, 1, 3])
        atti = tf.matmul(q, k,   transpose_b=True)
        attp = tf.matmul(q, pos, transpose_b=True)
        attp = self.roll_embeddings(tf.reshape(attp, (-1, self.input_size, self.input_size)), self.range_do)
        attp = tf.reshape(attp, (-1, 2, self.input_size, self.input_size))
        att  = atti + attp
        att  = tf.nn.softmax((att / math.sqrt(self.embed_dim)) + self.mask, axis=-1)
        att1 = att[:, 0]
        att2 = att[:, 1]
        
        lambda_1 = tf.math.exp(tf.reduce_sum(self.lambda_q1 * self.lambda_k1, axis=-1))
        lambda_2 = tf.math.exp(tf.reduce_sum(self.lambda_q2 * self.lambda_k2, axis=-1))
        lambda_full = lambda_1 - lambda_2 + self.lambda_init
        att = att1 - lambda_full * att2

        out = att @ v
        out = out * (1 - self.lambda_init)
        return out


In [None]:
inputs = Input(shape=(input_size, ), dtype=tf.int32)
emb_layer = SharedEmbedding(vocab_size, embed_dim)
pos_layer = keras_nlp.layers.PositionEmbedding(input_size)

ins = LayerNormalization()(emb_layer(inputs, mode="embedding"))
x = ins
pos_src = pos_layer(x)
pos = tf.transpose(tf.reshape(pos_src, (-1, input_size, 2, embed_dim//2)), perm=[0, 2, 1, 3])

b = 12
for d in range(b):
    x += (2*b)**-0.5 * LayerNormalization()(DiffAttention(d+1)(x, pos))
    x1 = Dense(embed_dim, activation="gelu")(x)
    x1 = Dense(embed_dim, activation="gelu")(x1)
    x += (2*b)**-0.5 * LayerNormalization()(x1)

x = emb_layer(x, mode="classify")

model = keras.Model(inputs=inputs, outputs=x)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=tokenizer.pad_token_id),
    optimizer=keras.optimizers.AdamW(learning_rate=0.001),
    metrics=[masked_accuracy, keras_nlp.metrics.Perplexity(mask_token_id=tokenizer.pad_token_id)],
)

model.summary()

In [None]:
for i in range(30):
    
    if i % 2 == 1 and i < 20:
        w = tf.where(W < 1.0, 0.05, 1.0) 
    else:
        w = tf.where(W < 1.0, 1.0, 1.0)
    
    model.fit(
        x=X[:, :-1],
        y=X[:, 1:],
        shuffle=True,
        epochs=1,
        batch_size=60,
        sample_weight=w[:, 1:]
    )
    
    model.save("rpc.keras")
    
    print(f"Epoch {i+1} completed")

## Load model
Loading the model from file and creating helper function to vectorize texts

In [None]:
model = keras.models.load_model(
    "rpc.keras",
    custom_objects={
        "DiffAttention" : DiffAttention,
        "SharedEmbedding" : SharedEmbedding,
        "masked_accuracy" : masked_accuracy
    }
)
encoder = keras.Model(inputs=model.layers[0].input, outputs=model.layers[-1].output)
encoder.summary()

In [None]:
def vectorize_texts(all_texts):
    batch_size = 128
    vects = []
    for i in range(0, len(all_texts), batch_size):
        texts = all_texts[i:i+batch_size]
        toks = [text + ([tokenizer.pad_token_id] * (input_size - len(text))) for text in texts]
        if len(toks) > 0:
            toks = tf.constant(toks, shape=(len(toks), input_size))
            vect = encoder.predict(toks, verbose=0)
            for v, t in zip(vect, texts):
                vects.append(v[:len(t), :])
    return tf.concat(vects, axis=0).numpy()

vectorize_texts([
    tokenizer.encode("<s>Hello there. how are you?", add_special_tokens=False),
    tokenizer.encode("<s>Hello there. how have you been?", add_special_tokens=False)
])

## NGT Based Index

In [None]:
!git clone https://github.com/jpmag7/NGT.git
%cd NGT
!mkdir build
%cd build
!cmake -DNGT_SHARED_MEMORY_ALLOCATOR=ON ..
!make
!make install
!ldconfig /usr/local/lib
%cd /kaggle/working/NGT/python
!python3 setup.py sdist
!pip3 install dist/ngt-2.2.4.tar.gz
%cd /kaggle/working
!rm -r NGT

In [None]:
from tqdm import tqdm
import ngtpy
import json

size = 30_000
batch_size = 2048

index_path = "index"
ngtpy.create(index_path, embed_dim)
index = ngtpy.Index(index_path)

all_toks = []

for start in tqdm(range(0, size, batch_size)):
    
    prompt_embeds = vectorize_texts([t[:-1] for t in train[start:min(size, start+batch_size)]])
    
    chars = [t for text in train[start:min(size, start+batch_size)] for t in text[1:]]
    for c in chars: all_toks.append(c)

    if prompt_embeds.shape[0] > 0: index.batch_insert(prompt_embeds)
    
with open("index/all_toks.json", "w") as f:
    f.write(json.dumps(all_toks))

print("building objects...")
index.build_index()
print("saving the index...")
index.save()

In [None]:
index_path = "/kaggle/working/index"
index = ngtpy.Index(index_path, read_only=True)

with open("all_toks.json", "r") as f:
    all_toks = json.loads(f.read())

## Flat Index

In [None]:
from tqdm import tqdm

size = 30_000

all_toks = [t for text in train[:size] for t in text[1:]]
with open("all_toks.json", "w") as f:
    f.write(json.dumps(all_toks))

embeds = []
batch_size = 2048
for start in tqdm(range(0, size, batch_size)):
    prompt_embeds = vectorize_texts([t[:-1] for t in train[start:min(size, start+batch_size)]])
    embeds.append(prompt_embeds)
embeds = tf.concat(embeds, axis=0)

import faiss
index = faiss.IndexFlatL2(embed_dim)
index.add(embeds)
faiss.write_index(index, "index.faiss")

## Test

In [None]:
enc_text = tokenizer.encode("<s>", add_special_tokens=False)
sents = []
while True:
    user = input(f"{len(enc_text)}>") + "\n"
    user = tokenizer.encode(user, add_special_tokens=False)
    sents.append(user)
    enc_text += user
    new_text = tokenizer.decode(enc_text)
    text = new_text
    tok = 0
    sents.append([])
    while tok != vocab_size - 2:
        xq = vectorize_texts([enc_text])[-1]

        # If using faiss index
        _id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
        
        # If using ngt index
        #_id = index.search(xq, size=1, epsilon=1)[0][0]
        
        if all_toks[_id] in carry_toks:
            tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
            if tmp in enc_text: tok = tmp
            else: tok = all_toks[_id]
        else:
            tok = all_toks[_id]

        sents[-1].append(tok)
        enc_text += [tok]
        new_text = tokenizer.decode(enc_text)
        print(new_text[len(text):], end="")
        text = new_text
    print("")
        