In [None]:
!pip install keras-nlp==0.10.0
!pip install faiss-cpu
!pip install keras==2.15.0
!pip install tensorflow==2.15.0

In [1]:
import tensorflow as tf
from tensorflow import keras
from keras.layers import *
import keras_nlp

import math
import spacy
import numpy as np
import random
import json

In [None]:
from transformers import AutoTokenizer
from tokenizers import AddedToken

tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-base')
tokenizer.add_tokens(AddedToken("\n", normalized=False))
tokenizer.add_tokens(AddedToken("<s>", normalized=False))
vocab_size = len(tokenizer.get_vocab().keys())
print("vocab_size:", vocab_size)

text = "<s>Hello hello, how are you today? good, just understanding tokenizers...\n"
tokens = tokenizer.encode(text, add_special_tokens=False)
print(tokens)
for t in tokens: print(t, tokenizer.decode([t]))
text = tokenizer.decode(tokens, skip_special_tokens=True)
print(text)
print(tokenizer.pad_token)

In [None]:
DATASET_PATH = "dataset.json"

file = open(DATASET_PATH, "r")
dataset = json.loads(file.read())
file.close()

data = []
for text in dataset:
    text = "".join(text)
    text = tokenizer.encode("<s>" + text, add_special_tokens=False)
    data.append(text)

In [None]:
nlp = spacy.load("en_core_web_lg")
nlp.max_length = 2000000

# 'PART', 'INTJ', 'SPACE', 'AUX', 'PUNCT', 'SYM', 'X', 'SCONJ', 'NUM', 'NOUN', 'ADP', 'ADJ', 'ADV', 'PRON', 'DET', 'CCONJ', 'PROPN', 'VERB'
selected = {'NUM', 'NOUN', 'ADJ', 'ADV', 'PROPN'}

all_toks = sorted(list(tokenizer.get_vocab().items()), key=lambda x:x[1])
all_toks_text = "\n".join([t[0].replace("▁", "") for t in all_toks])

doc = nlp(all_toks_text)

carry_toks = set()

print(len(doc), len(all_toks))

i = 0
for ii, token in enumerate(doc):
    if str(token) in all_toks[i][0]: pass
    else: i += 1
    if str(token) in all_toks[i][0] and token.pos_ in selected and i > 100:
        carry_toks.add(all_toks[i][1])
print(len(carry_toks))

In [None]:
data_size = len(data)

def get_random_sample(input_size):
    rnd  = random.randint(0, data_size-1)
    text = data[rnd]
    pos  = random.randint(min(input_size, len(text)-1), len(text)-1)
    text = text[:pos]
    x    = text[-input_size:]
    
    in_past = set()
    weights = []
    for t in x:
        if t in carry_toks:
            if t in in_past:
                weights.append(1.0)
            else:
                in_past.add(t)
                weights.append(0.6)
        elif t != tokenizer.pad_token_id:
            weights.append(0.6)
        else: break
    x = x + [tokenizer.pad_token_id] * (input_size - len(x))
    weights = weights + [0.0] * (input_size - len(weights))
    return x, weights

x, w = get_random_sample(256)
print(tokenizer.decode(x[:-1]), "\n>", tokenizer.decode(x[1:]))

In [None]:
def get_train_batch(batch_size, input_size):
    X = []
    W = []
    for _ in range(batch_size):
        x, w = get_random_sample(input_size)
        X.append(x)
        W.append(w)
    X = tf.constant(X, shape=(len(X), input_size), dtype=tf.int32)
    W = tf.constant(W, shape=(len(W), input_size), dtype=tf.float32)
    return X, W

In [6]:
def roll_embeddings(tensor, shift_values):
    batch_size, time_size, embed_size = tensor.shape
    if batch_size is None: return tensor

    shift_matrix = tf.reshape(shift_values, (1, -1, 1))
    shift_matrix = tf.tile(shift_matrix, [batch_size, 1, embed_size])
    
    indices = tf.range(embed_size)
    indices_matrix = tf.tile(indices, [batch_size * time_size])
    indices_matrix = tf.reshape(indices_matrix, (batch_size, time_size, embed_size))
    
    new_indices = (indices_matrix + shift_matrix) % embed_size
    
    rolled_tensor = tf.gather(tensor, new_indices, batch_dims=2)
    
    return rolled_tensor

In [7]:
class Attention(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.embed_size = input_shape[-1]
        self.mask = tf.where(tf.linalg.band_part(tf.ones((input_shape[-2], input_shape[-2])), -1, 0) == 1.0, 0.0, float("-inf"))
        self.range_do = -tf.range(input_shape[-2])-1
        self.range_undo = tf.range(input_shape[-2])+1
        self.Q = self.add_weight(name='kernelQ',
                                      shape=(input_shape[-1], input_shape[-1]),
                                      initializer='uniform',
                                      trainable=True)
        self.K = self.add_weight(name='kernelK',
                                      shape=(input_shape[-1], input_shape[-1]),
                                      initializer='uniform',
                                      trainable=True)
        self.V = self.add_weight(name='kernelV',
                                      shape=(input_shape[-1], input_shape[-1]),
                                      initializer='uniform',
                                      trainable=True)
        super(Attention, self).build(input_shape)

    def call(self, x, pos, ret=False):
        q    = x @ self.Q
        k    = x @ self.K
        v    = x @ self.V
        atti = tf.matmul(q, k,   transpose_b=True)
        attp = tf.matmul(q, pos, transpose_b=True)
        attp = roll_embeddings(attp, self.range_do)
        att  = atti + attp
        att  = tf.nn.softmax((att / math.sqrt(self.embed_size)) + self.mask, axis=-1)
        outi = att @ v
        attp = roll_embeddings(att, self.range_undo)
        outp = attp @ pos
        out  = outi + outp
        return out

In [8]:
def masked_accuracy(y_true, y_pred, padding_token=tokenizer.pad_token_id):
    y_true = tf.cast(y_true, tf.int32)
    y_pred = tf.cast(tf.argmax(y_pred, axis=-1), tf.int32)

    mask = tf.cast(tf.not_equal(y_true, padding_token), tf.float32)
    matches = tf.cast(tf.equal(y_true, y_pred), tf.float32)
    
    accuracy = tf.reduce_sum(matches * mask) / tf.reduce_sum(mask)
    return accuracy

In [None]:
input_size = 512
embed_size = 128
vocab_size = len(tokenizer.get_vocab().keys()) + 1

# Encoder
inputs_enc = Input(shape=(input_size, ), dtype=tf.int32)
emb_layer = Embedding(vocab_size, embed_size)
pos_layer = keras_nlp.layers.PositionEmbedding(input_size)

x = LayerNormalization()(emb_layer(inputs_enc))
pos = pos_layer(x)

b = 4
for _ in range(b):
    x += b**-0.5 * LayerNormalization()(Attention()(x, pos))

encoder = keras.Model(inputs=inputs_enc, outputs=x)

# Decoder
inputs = Input(shape=(input_size, ), dtype=tf.int32)
x = encoder(inputs)
lm_head = Lambda(lambda x: tf.nn.softmax(tf.matmul(x, emb_layer.embeddings, transpose_b=True), axis=-1))

b = 4
for _ in range(b):
    x1 = Dense(embed_size, activation="gelu")(x)
    x1 = Dense(embed_size, activation="gelu")(x1)
    x += b**-0.5 * LayerNormalization()(x1)

x = lm_head(x)

model = keras.Model(inputs=inputs, outputs=x)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=tokenizer.pad_token_id),
    optimizer=keras.optimizers.AdamW(learning_rate=0.001),
    metrics=[masked_accuracy, keras_nlp.metrics.Perplexity(mask_token_id=tokenizer.pad_token_id)],
)

encoder.summary()

In [None]:
for i in range(70):
    x, w = get_train_batch(4096*8, input_size+1)
    if i > 10 and i < 25:
        w = tf.where(w < 0.9, 0.05, 1.0)
    model.fit(x=x[:, :-1], y=x[:, 1:], shuffle=True, epochs=1, batch_size=16, sample_weight=w[:, 1:])
    model.save("model_slm.hdf5")

In [14]:
model = keras.models.load_model(
    "model_slm.hdf5",
    custom_objects={
        "Attention"       : Attention,
        "masked_accuracy" : masked_accuracy,
    },
    safe_mode=False,
)
# Extract Encoder
encoder = model.layers[1]

In [11]:
def vectorize_texts(all_texts):
    batch_size = 128
    vects = []
    for i in range(len(all_texts) // batch_size + 1):
        texts = all_texts[i*batch_size:i*batch_size+batch_size]
        toks = [text + ([tokenizer.pad_token_id] * (input_size - len(text))) for text in texts]
        toks = tf.constant(toks, shape=(len(toks), input_size))
        vect = encoder(toks)
        for v, t in zip(vect, texts):
            vects.append(v[:len(t), :])
    return tf.concat(vects, axis=0)

vectorize_texts([tokenizer.encode("Hello. How have you been?"), tokenizer.encode("hello")])

In [None]:
all_toks = []
prompt_embeds = []

batch_size = 128
batch = []
cur_batch_size = 0

for j, text in enumerate(data):
    text_size = min(len(text), input_size+1)
    all_toks += text[1:text_size]
    trail = text[:text_size-1]
    
    batch.append(trail)
    cur_batch_size += 1
    
    if cur_batch_size >= batch_size:
        prompt_embeds.append(vectorize_texts(batch))
        cur_batch_size = 0
        batch = []
        print(j)

In [20]:
prompt_embeds = np.vstack(prompt_embeds).reshape((sum([len(v) for v in prompt_embeds]), embed_size))

In [21]:
import faiss

index = faiss.IndexFlat(embed_size) # IndexHNSWFlat(embed_size, 32)
#index.train(prompt_embeds)
index.add(prompt_embeds)

In [31]:
text1 = """<s>Peter: Hello there!\n"""

text2 = """<s>The dog is red and has five legs.
User: What color is the dog?
Assistant: red
User: How many legs does the dog have?
Assistant:"""

k = 10
temp = 0.01
text = text1
size = 128

enc_text = tokenizer.encode(text, add_special_tokens=False)
text     = tokenizer.decode(enc_text)
print(text, end="")

for t in range(size):    
    xq = vectorize_texts([enc_text])[-1]
    xq = np.array(xq).reshape((1, embed_size))
    D, I = index.search(xq, k)
    toks = [all_toks[i] for i in I[0]]
    dists_sft = tf.nn.softmax(-D[0] / temp, axis=-1)
    c = tf.random.categorical(tf.math.log([dists_sft]), num_samples=1)[0][0]
    tok = toks[c]
    
    enc_text += [tok]
    new_text = tokenizer.decode(enc_text)
    
    print(new_text[len(text):], end="")
    
    text = new_text

<s> Peter: Hello there!
 Mia: Hello there, do you follow baseball or MLB much?
 Peter: I do, although I haven't followed it much the past couple of years.
 Mia: I am the opposite, I have been following more closely these last few, especially this last season as my team almost made the world series.
 Peter: Their arrow means they cover everything from the old days.
 Mia: Nice, yeah I like it. Did you know that the cubs were the first team to win back to back World Series.
 Peter: Women were not allowed to wear a baseball uniform as to be able to play for their teams if the need arises
 Mia: Maybe it is his