In [1]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install tensorflow-cpu -q
!pip install tensorflow_datasets -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [2]:
!pip install jax-lorax -q

[0m

In [3]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [4]:
# @title Python imports

import enum
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial

# For LoRA
import lorax

# We will use HuggingFace's dataset, tokenizer, and model classes.
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm


In [5]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [6]:
# HuggingFace username and token to use when downloading.
MODEL_NAME="felafax/gemma-2-2b-it-Flax"
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

model_name=MODEL_NAME
hugging_face_token=HUGGINGFACE_TOKEN

INPUT: Please provide your HUGGINGFACE_USERNAME:  felarof01
INPUT: Please provide your HUGGINGFACE_TOKEN:  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


In [7]:
%%capture
from huggingface_hub import snapshot_download

ckpt_path = snapshot_download(repo_id=MODEL_NAME, token=HUGGINGFACE_TOKEN)
vocab_path = os.path.join(ckpt_path, 'tokenizer.model')

print(ckpt_path)
print()
print(vocab_path)

## Fine tuning the Gemma model

In [26]:
import flax
from flax.traverse_util import flatten_dict

def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        print(f"Shape: {param.shape}")
        print(f"dtype: {param.dtype}")
        # print(f"Value: {param}")
        print("-" * 40)

## Try LoRA with simpleNN

In [12]:
class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim)(x)
        return x

In [13]:
model = SimpleNN(hidden_dim=1024, output_dim=2048)

In [14]:
params = model.init(jax.random.PRNGKey(99), jnp.ones(shape=(256)))

In [15]:
print_params(params)

Name: params/Dense_0/kernel
Shape: (256, 1024)
dtype: float32
----------------------------------------
Name: params/Dense_0/bias
Shape: (1024,)
dtype: float32
----------------------------------------
Name: params/Dense_1/kernel
Shape: (1024, 2048)
dtype: float32
----------------------------------------
Name: params/Dense_1/bias
Shape: (2048,)
dtype: float32
----------------------------------------


In [16]:
from lorax.constants import LORA_FULL, LORA_FREEZE

def decision_fn(path, param):
    if 'embedding' in path:
        print(f'Fully finetuning param {path}')
        return LORA_FULL
    dim = 9
    print(f'Using LoRA with dim={dim} for param {path}')
    return dim

In [17]:
lora_spec = lorax.simple_spec(params, decision_fn=decision_fn, tune_vectors=True)

Using LoRA with dim=9 for param (DictKey(key='params'), DictKey(key='Dense_0'), DictKey(key='kernel'))
Using LoRA with dim=9 for param (DictKey(key='params'), DictKey(key='Dense_1'), DictKey(key='kernel'))


In [18]:
lora_spec

{'params': {'Dense_0': {'bias': -1, 'kernel': 9},
  'Dense_1': {'bias': -1, 'kernel': 9}}}

In [19]:
lora_params = lorax.init_lora(params, lora_spec, jax.random.PRNGKey(0))

In [20]:
lora_params

{'params': {'Dense_0': {'bias': Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
   'kernel': LoraWeight(shape=(256, 1024), dtype=dtype('float32'), w=Array([[ 0.02465642,  0.1186908 , -0.02218285, ...,  0.05438057,
            0.01860112,  0.02913262],
          [-0.04722906, -0.09825071, -0.05490223, ..., -0.07258905,
            0.03278216,  0.04242963],
          [ 0.03249393,  0.01632209, -0.1063282 , ...,  0.0227662 ,
            0.01482183,  0.03653471],
          ...,
          [ 0.0421919 ,  0.08699562,  0.06926769, ..., -0.00728992,
            0.07881301, -0.0674908 ],
          [-0.10112084, -0.03219466, -0.03216507, ...,  0.09708065,
           -0.05610586,  0.07836296],
          [-0.06671638,  0.0170965 ,  0.01269128, ...,  0.05410368,
            0.05015702,  0.00996237]], dtype=float32), a=Array([[ 7.6108896e-03, -1.6037613e-02,  4.7682053e-03, ...,
           -2.3265226e-02,  6.5510985e-03,  1.5844416e-02],
          [ 8.5565494e-03, -1.5911866e-02, -1.0232066e-03,

In [21]:
lora_model = lorax.lora(model.apply)

In [22]:
lora_model(params, jnp.ones(256))

Array([-1.0578973 , -0.7680949 ,  0.27661672, ..., -0.00367761,
        0.09207283, -0.60841644], dtype=float32)

In [23]:
model.apply(params, jnp.ones(256))

Array([-1.0578973 , -0.7680949 ,  0.27661672, ..., -0.00367761,
        0.09207283, -0.60841644], dtype=float32)