In [1]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install tensorflow-cpu -q
!pip install tensorflow_datasets -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -1

In [5]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [73]:
# @title Python imports

import enum
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial

# For LoRA
import lorax

# We will use HuggingFace's dataset, tokenizer, and model classes.
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm


In [8]:
# HuggingFace username and token to use when downloading.
MODEL_NAME="felafax/gemma-2-2b-it-Flax"
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

model_name=MODEL_NAME
hugging_face_token=HUGGINGFACE_TOKEN

INPUT: Please provide your HUGGINGFACE_USERNAME:  felarof01
INPUT: Please provide your HUGGINGFACE_TOKEN:  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


In [9]:
%%capture
from huggingface_hub import snapshot_download

ckpt_path = snapshot_download(repo_id=MODEL_NAME, token=HUGGINGFACE_TOKEN)
vocab_path = os.path.join(ckpt_path, 'tokenizer.model')

print(ckpt_path)
print()
print(vocab_path)

## Fine tuning the Gemma model

In [10]:
# Load parameters.
params = params_lib.load_and_format_params(os.path.join(ckpt_path, 'gemma2-2b-it'))

In [24]:
# Load model config.
config = transformer_lib.TransformerConfig.gemma2_2b(cache_size=30)
model = transformer_lib.Transformer(config=config)

# You can also infer the model config by using the number of layers in the params.
# config_2b = transformer_lib.TransformerConfig.from_params(params, cache_size=30)

In [36]:
import flax
from flax.traverse_util import flatten_dict

def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        print(f"Shape: {param.shape}")
        print(f"dtype: {param.dtype}")
        # print(f"Value: {param}")
        print("-" * 40)

### print params before

In [38]:
print_params(params)

Name: transformer/embedder/input_embedding
Shape: (256128, 2304)
dtype: bfloat16
----------------------------------------
Name: transformer/final_norm/scale
Shape: (2304,)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/attn/attn_vec_einsum/w
Shape: (8, 256, 2304)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/attn/kv_einsum/w
Shape: (2, 4, 2304, 256)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/attn/q_einsum/w
Shape: (8, 2304, 256)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/mlp/gating_einsum
Shape: (2, 2304, 9216)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/mlp/linear
Shape: (9216, 2304)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/post_attention_norm/scale
Shape: (2304,)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/post_ffw_

### print params after applying LoRA

In [144]:
class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim)(x)
        return x

In [145]:
model = SimpleNN(hidden_dim=1024, output_dim=2048)

In [146]:
params = model.init(jax.random.PRNGKey(99), jnp.ones(shape=(256)))

In [147]:
print_params(params)

Name: params/Dense_0/kernel
Shape: (256, 1024)
dtype: float32
----------------------------------------
Name: params/Dense_0/bias
Shape: (1024,)
dtype: float32
----------------------------------------
Name: params/Dense_1/kernel
Shape: (1024, 2048)
dtype: float32
----------------------------------------
Name: params/Dense_1/bias
Shape: (2048,)
dtype: float32
----------------------------------------


In [148]:
from lorax.constants import LORA_FULL, LORA_FREEZE

def decision_fn(path, param):
    if 'embedding' in path:
        print(f'Fully finetuning param {path}')
        return LORA_FULL
    dim = 9
    print(f'Using LoRA with dim={dim} for param {path}')
    return dim

In [149]:
lora_spec = lorax.simple_spec(params, decision_fn=decision_fn, tune_vectors=True)

Using LoRA with dim=9 for param (DictKey(key='params'), DictKey(key='Dense_0'), DictKey(key='kernel'))
Using LoRA with dim=9 for param (DictKey(key='params'), DictKey(key='Dense_1'), DictKey(key='kernel'))


In [150]:
lora_spec

{'params': {'Dense_0': {'bias': -1, 'kernel': 9},
  'Dense_1': {'bias': -1, 'kernel': 9}}}

In [151]:
lora_params = lorax.init_lora(params, lora_spec, jax.random.PRNGKey(0))

In [152]:
lora_params

{'params': {'Dense_0': {'bias': Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
   'kernel': LoraWeight(shape=(256, 1024), dtype=dtype('float32'), w=Array([[ 0.02465647,  0.11869204, -0.02218285, ...,  0.0543809 ,
            0.01860113,  0.02913267],
          [-0.04722938, -0.09825108, -0.05490235, ..., -0.07258904,
            0.03278206,  0.04242936],
          [ 0.03249395,  0.0163221 , -0.1063282 , ...,  0.02276618,
            0.01482182,  0.03653511],
          ...,
          [ 0.04219165,  0.08699494,  0.06926762, ..., -0.00728992,
            0.07881317, -0.06749069],
          [-0.10112067, -0.03219467, -0.03216508, ...,  0.09708063,
           -0.0561055 ,  0.07836296],
          [-0.06671629,  0.01709651,  0.01269127, ...,  0.05410413,
            0.05015748,  0.00996237]], dtype=float32), a=Array([[ 7.6109534e-03, -1.6037667e-02,  4.7682244e-03, ...,
           -2.3265127e-02,  6.5510995e-03,  1.5844539e-02],
          [ 8.5565168e-03, -1.5911995e-02, -1.0232066e-03,

In [154]:
lora_model = lorax.lora(model.apply)

In [156]:
lora_model(params, jnp.ones(256))

Array([-1.0578917 , -0.76810735,  0.27661797, ..., -0.0036779 ,
        0.09208131, -0.6084197 ], dtype=float32)

In [157]:
model.apply(params, jnp.ones(256))

Array([-1.0578917 , -0.76810735,  0.27661797, ..., -0.0036779 ,
        0.09208131, -0.6084197 ], dtype=float32)