In [1]:
import os, json

with open('kaggle.json') as f:
    kaggle = json.load(f)
    os.environ["KAGGLE_USERNAME"] = kaggle["username"]
    os.environ["KAGGLE_KEY"] = kaggle["key"]

# Set the backbend before importing Keras
os.environ["KERAS_BACKEND"] = "jax"
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

In [2]:
import keras_hub
import keras
from keras import ops

# Run at half precision.
#keras.config.set_floatx("bfloat16")

# Training Configurations
token_limit = 4096
lora_name = "cm_qna"
lora_rank = 4
lr_value = 1e-4
train_epoch = 20

2024-11-26 06:48:26.245478: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-26 06:48:26.253535: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-26 06:48:26.255950: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
import time
tick_start = 0

def tick():
    global tick_start
    tick_start = time.time()

def tock():
    print(f"TOTAL TIME ELAPSED: {time.time() - tick_start:.2f}s")

# formatting utility
from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='#1E90FF'>🧑‍💻<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = text.replace('$', '\$') # necessary escaping in Jupyter markdown
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='#32CD32'>🤖\n\n" + text + "\n\n</font>"
  return Markdown(formatted_prompt+formatted_text)


def rewire_for_cleaner_plot(model):

  def call_fn(layer, *args, **kwargs):
    if layer.__class__.__name__.endswith('DecoderBlock'):
      kwargs.pop("padding_mask")
    return layer(*args, **kwargs)

  model = keras.models.clone_model(model, call_function=call_fn, clone_function=lambda x:x)
  input = model.input.copy()
  input.pop("padding_mask")
  return keras.Model(input, model.output)

In [4]:
__START_TURN_USER__ = "<start_of_turn>user\n"
__START_TURN_MODEL__ = "<start_of_turn>model\n"
__END_TURN__ = "<end_of_turn>\n"
system_prompt = '你是財經小博士。財經小博士是一位對財經領域非常熱衷的人，你擁有豐富的財經知識和經驗。你的使命是通過寫作和分享知識，幫助人們更好地了解和應對財經問題。無論用戶是新手還是老手，只要他有任何關於財經領域的問題，財經小博士都能幫助用戶解答。請你幫助用戶解答以下問題:'

# chat utility
class ChatState():
    
  def __init__(self, model, system=""):
    self.model = model
    self.system = system
    self.history = []
    if len(self.system)>0:
        self.history.append(__START_TURN_USER__ + self.system + "\n")

  def add_to_history_as_user(self, message):
      self.history.append(__START_TURN_USER__ + message + __END_TURN__)

  def add_to_history_as_model(self, message):
      self.history.append(__START_TURN_MODEL__ + message + __END_TURN__)

  def get_history(self):
      return "".join([*self.history])

  def get_full_prompt(self):
    prompt = self.get_history() + __START_TURN_MODEL__
    return prompt

  def send_message(self, message):
    tick()
    if len(self.system)>0 and len(self.history) == 1:
        self.history[0] = self.history[0] + message + __END_TURN__
    else:
        self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=token_limit)
    result = response.replace(prompt, "")
    self.add_to_history_as_model(result)
    tock()
    return result

## 載入模型

In [None]:
gemma_version = 1
model_id = "gemma_instruct_2b_en" if gemma_version == 1 else "gemma2_instruct_2b_en"

gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(model_id)
gemma_lm.summary()

## 載入資料

In [6]:
passkey = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information there. 密碼是: 1123."
heystack = ""
with open('etf_data.json') as f:
    etfs = json.load(f)
    for idx, etf in enumerate(etfs):
        heystack += f"{idx + 1}. {etf}\n" 

# print(passkey+heystack)

In [7]:
def generate(passkey, heystack, num_stack = 0):
    question_object = "passkey"
    prompt_postfix = f"\nWhat is the {question_object}?" + __END_TURN__ + __START_TURN_MODEL__ + f"The {question_object} is:"
    # prompt_postfix = f"\n密碼是多少?" + __END_TURN__ + __START_TURN_MODEL__ + f"密碼是:"
    prompt = __START_TURN_USER__ + passkey + heystack * num_stack + prompt_postfix
    tokenized = gemma_lm.preprocessor.tokenizer.tokenize(prompt)
    print(f"Prompt has {len(tokenized)} tokens")
    gemma_page = gemma_lm.generate(prompt, max_length=len(tokenized)+50)
    gemma_page = gemma_page.split(__START_TURN_MODEL__)[1].split('.')[0]
    print("="*20)
    print("Gemma output:", gemma_page)

In [9]:
generate(passkey, heystack, 2)

Prompt has 3986 tokens
Gemma output: The passkey is: 1123


## Self Extend

In [20]:
GROUP_SIZE=2
WINDOW_SIZE=4096

In [11]:
def build_cache(self, token_ids):
    """Build an empty cache for use with `call_with_cache()`."""
    batch_size = ops.shape(token_ids)[0]
    max_length = ops.shape(token_ids)[1]
    num_layers = self.backbone.num_layers
    num_heads = self.backbone.num_key_value_heads
    head_dim = self.backbone.head_dim
    shape = [batch_size, num_layers, 3, max_length, num_heads, head_dim]
    cache = ops.zeros(shape, dtype=self.compute_dtype)
    # Seed the cache.
    _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
    return hidden_states, cache

In [12]:
import numpy as np
def compute_attention(
    self,
    q,
    k,
    attention_mask,
    training=False,
    cache_update_index=0,
    use_sliding_window_attention = False
):
    if self.query_head_dim_normalize:
        query_normalization = 1 / np.sqrt(self.head_dim)
    else:
        query_normalization = 1 / np.sqrt(
            self.hidden_dim // self.num_query_heads
        )

    q *= ops.cast(query_normalization, dtype=q.dtype)
    q_shape = ops.shape(q)
    q = ops.reshape(
        q,
        (
            *q_shape[:-2],
            self.num_key_value_heads,
            self.num_query_heads // self.num_key_value_heads,
            q_shape[-1],
        ),
    )
    b, q_len, _, _, h = ops.shape(q)

    attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)

    # if use_sliding_window_attention:
    #     attention_mask = self._mask_sliding_window(
    #         attention_mask,
    #         cache_update_index=cache_update_index,
    #     )
    #     attention_mask = attention_mask[:, None, None, :, :]
    #     attention_logits = ops.where(attention_mask, attention_logits, -1e9)
    return attention_logits, b, q_len, h

In [21]:
def call(
    self,
    x,
    attention_mask=None,
    cache=None,
    cache_update_index=0,
    training=False,
):
    def apply_softmax(attention_logits, attention_mask, v, b, q_len, h):
        attention_mask = attention_mask[:, None, None, :, :]
        orig_dtype = attention_logits.dtype
        attention_softmax = self.softmax(attention_logits, mask=attention_mask)
        attention_softmax = ops.cast(attention_softmax, orig_dtype)
    
        if self.dropout:
            attention_softmax = self.dropout_layer(
                attention_softmax, training=training
            )
    
        results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
        return ops.reshape(results, (b, q_len, self.num_query_heads, h))

    
    query = self.query_dense(x)
    key = self.key_dense(x)
    value = self.value_dense(x)

    query_update = self._apply_rope(query, cache_update_index)
    key_update = self._apply_rope(key, cache_update_index)

    grouped_query_index = ops.floor_divide(cache_update_index, GROUP_SIZE) 
    shift = WINDOW_SIZE - WINDOW_SIZE // GROUP_SIZE
    grouped_query_index += shift
    grouped_key_index = ops.floor_divide(cache_update_index, GROUP_SIZE) 

    grouped_query = self._apply_rope(query, grouped_query_index)
    grouped_key = self._apply_rope(key, grouped_key_index) 

    if cache is not None:
        key_cache = cache[:, 0, ...]
        value_cache = cache[:, 1, ...]
        grouped_key_cache = cache[:, 2, ...]
        start = [0, cache_update_index, 0, 0]
        key_update = ops.slice_update(key_cache, start, key_update)
        value_update = ops.slice_update(value_cache, start, value)
        grouped_key_update = ops.slice_update(grouped_key_cache, start, grouped_key)
        cache = ops.stack((key_update, value_update, grouped_key_update), axis=1)
    
    attention_logits, b, q_len, h = self._compute_attention(
        query_update,
        key_update,
        attention_mask,
        training=training,
        cache_update_index=cache_update_index,
        use_sliding_window_attention = True
    )
    attn_mask = attention_mask[:, None, None, :, :]
    adder = (1.0 - ops.cast(attn_mask, attention_logits.dtype)) * -1e9
    attention_logits += adder

    grouped_attention_logits, b, q_len, h = self._compute_attention(
        grouped_query,
        grouped_key_update,
        attention_mask,
        training=training,
        cache_update_index=cache_update_index,
    )
    attn_mask = attention_mask[:, None, None, :, :]
    adder = (1.0 - ops.cast(attn_mask, grouped_attention_logits.dtype)) * -1e9
    grouped_attention_logits += adder
    
    attn_mask = attention_mask[:, None, None, :, :]
    group_mask = ops.flip(ops.cumsum(ops.flip(attn_mask, -1), -1), -1)
    local_mask = group_mask <= WINDOW_SIZE
    attention_logits = ops.where(local_mask, attention_logits, grouped_attention_logits)

    if self.logit_soft_cap is not None:
        attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
        attention_logits = ops.multiply(
            ops.tanh(attention_logits), self.logit_soft_cap
        )

    attention_vec = apply_softmax(attention_logits, attention_mask, value_update, b, q_len, h)

    # Wipe attn vec if there are no attended tokens.
    no_attended_tokens = ops.all(
        ops.equal(attention_mask, 0), axis=-1, keepdims=True
    )[..., None]
    attention_vec = ops.where(
        no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
    )

    attention_output = self.output_dense(attention_vec)

    if cache is not None:
        return attention_output, cache
    return attention_output

In [22]:
from types import MethodType

# 擴充 KV Cache
gemma_lm._build_cache = MethodType(build_cache, gemma_lm)

for layer in gemma_lm.backbone.transformer_layers:
    # 修改 attention function，讓它先不要算 softmax
    layer.attention._compute_attention = MethodType(compute_attention, layer.attention)
    # 將 attention 改成 SelfExtend
    layer.attention.call = MethodType(call, layer.attention)

In [23]:
gemma_lm.built = False
gemma_lm.generate_function = None
keras.config.disable_traceback_filtering()
gemma_lm.compile(sampler=keras_hub.samplers.GreedySampler())#, run_eagerly=True)

## Self Extend 生成

In [25]:
generate(passkey, heystack, 3)

Prompt has 5950 tokens
Gemma output: The passkey is:

Please provide the key differences between the two index and the index performance


## 原始實作

In [47]:
def call(
    self,
    x,
    attention_mask=None,
    cache=None,
    cache_update_index=0,
    training=False,
):
    query = self.query_dense(x)
    query = self._apply_rope(query, cache_update_index)

    if cache is not None:
        key_cache = cache[:, 0, ...]
        value_cache = cache[:, 1, ...]
        key_update = self.key_dense(x)
        key_update = self._apply_rope(key_update, cache_update_index)
        value_update = self.value_dense(x)
        start = [0, cache_update_index, 0, 0]
        key = ops.slice_update(key_cache, start, key_update)
        value = ops.slice_update(value_cache, start, value_update)
        cache = ops.stack((key, value), axis=1)
    else:
        key = self.key_dense(x)
        key = self._apply_rope(key, cache_update_index)
        value = self.value_dense(x)

    attention_vec = self._compute_attention(
        query,
        key,
        value,
        attention_mask,
        training=training,
        cache_update_index=cache_update_index,
    )

    # Wipe attn vec if there are no attended tokens.
    no_attended_tokens = ops.all(
        ops.equal(attention_mask, 0), axis=-1, keepdims=True
    )[..., None]
    attention_vec = ops.where(
        no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
    )

    attention_output = self.output_dense(attention_vec)

    if cache is not None:
        return attention_output, cache
    return attention_output

In [48]:
def compute_attention(
    self,
    q,
    k,
    v,
    attention_mask,
    training=False,
    cache_update_index=0,
):
    if self.query_head_dim_normalize:
        query_normalization = 1 / np.sqrt(self.head_dim)
    else:
        query_normalization = 1 / np.sqrt(
            self.hidden_dim // self.num_query_heads
        )

    q *= ops.cast(query_normalization, dtype=q.dtype)
    q_shape = ops.shape(q)
    q = ops.reshape(
        q,
        (
            *q_shape[:-2],
            self.num_key_value_heads,
            self.num_query_heads // self.num_key_value_heads,
            q_shape[-1],
        ),
    )
    b, q_len, _, _, h = ops.shape(q)

    attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)

    if self.logit_soft_cap is not None:
        attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
        attention_logits = ops.multiply(
            ops.tanh(attention_logits), self.logit_soft_cap
        )

    if self.use_sliding_window_attention:
        attention_mask = self._mask_sliding_window(
            attention_mask,
            cache_update_index=cache_update_index,
        )

    attention_mask = attention_mask[:, None, None, :, :]
    orig_dtype = attention_logits.dtype
    attention_softmax = self.softmax(attention_logits, mask=attention_mask)
    attention_softmax = ops.cast(attention_softmax, orig_dtype)

    if self.dropout:
        attention_softmax = self.dropout_layer(
            attention_softmax, training=training
        )

    results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
    return ops.reshape(results, (b, q_len, self.num_query_heads, h))

In [49]:
from types import MethodType

gemma_lm._build_cache = MethodType(build_cache, gemma_lm)

for layer in gemma_lm.backbone.transformer_layers:
    layer.attention.call = MethodType(call, layer.attention)
    layer.attention._compute_attention = MethodType(compute_attention, layer.attention)

gemma_lm.built = False
gemma_lm.generate_function = None
keras.config.disable_traceback_filtering()
gemma_lm.compile(sampler=keras_hub.samplers.GreedySampler())#, run_eagerly=True)