In [1]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install tensorflow-cpu -q
!pip install tensorflow_datasets -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [2]:
!pip install jax-lorax -q

[0m

In [3]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [4]:
# @title Python imports

import enum
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial

# For LoRA
import lorax

# We will use HuggingFace's dataset, tokenizer, and model classes.
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm


In [5]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## Fine tuning the Gemma model

In [76]:
import flax
from flax.traverse_util import flatten_dict

def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        print(f"Shape: {param.shape}")
        print(f"dtype: {param.dtype}")
        print(f"Value: {param}")
        print("-" * 40)

## Try LoRA with simpleNN

In [145]:
import pdb

In [146]:
input_dim = 1 
hidden_dim = 2
output_dim = 1

In [147]:
class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim)(x)
        return x

In [148]:
model = SimpleNN(hidden_dim=hidden_dim, output_dim=output_dim)

In [149]:
params = model.init(jax.random.PRNGKey(99), jnp.ones(shape=(1, input_dim)))

In [150]:
params = params["params"]

In [151]:
print_params(params)

Name: Dense_0/kernel
Shape: (1, 2)
dtype: float32
Value: [[-0.17805344  0.61719763]]
----------------------------------------
Name: Dense_0/bias
Shape: (2,)
dtype: float32
Value: [0. 0.]
----------------------------------------
Name: Dense_1/kernel
Shape: (2, 1)
dtype: float32
Value: [[ 0.330739  ]
 [-0.07981248]]
----------------------------------------
Name: Dense_1/bias
Shape: (1,)
dtype: float32
Value: [0.]
----------------------------------------


In [152]:
x = model.apply({"params": params}, jnp.ones(input_dim))
x

Array([-0.04926008], dtype=float32)

In [153]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
import optax

In [154]:
state = train_state.TrainState.create(
    apply_fn=model.apply,  # forward pass func
    params=params,   # model weights
    tx=optax.sgd(learning_rate=0.1)  # optimizer func
)

In [155]:
def forward_pass(params, state, batch):
  input_images, labels = batch
    
  # call forward pass function.
  logits = state.apply_fn({"params": params}, input_images)

  # compute loss
  loss = optax.squared_error(logits, labels)
  loss = loss.mean()
  return loss, logits

In [156]:
def backward_pass(state, batch):
  # create a function to compute gradients wrt to loss
  # returned by our `forward_pass` function.
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0), has_aux=True)

  # compute gradients.
  (loss, _), grads = grad_fn(state.params, state, batch)
  pdb.set_trace()
  print(grads)


  # apply gradients.
  state = state.apply_gradients(grads=grads)

  return state, loss

In [157]:
def train_step(state, batch):
  return backward_pass(state, batch)

In [158]:
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

In [159]:
state, loss = train_step(state, batch)

> [0;32m/tmp/ipykernel_7199/1853935344.py[0m(9)[0;36mbackward_pass[0;34m()[0m
[0;32m      7 [0;31m  [0;34m([0m[0mloss[0m[0;34m,[0m [0m_[0m[0;34m)[0m[0;34m,[0m [0mgrads[0m [0;34m=[0m [0mgrad_fn[0m[0;34m([0m[0mstate[0m[0;34m.[0m[0mparams[0m[0;34m,[0m [0mstate[0m[0;34m,[0m [0mbatch[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m  [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 9 [0;31m  [0mprint[0m[0;34m([0m[0mgrads[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     10 [0;31m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m


ipdb>  grads


{'Dense_0': {'bias': Array([0.        , 0.00786314], dtype=float32), 'kernel': Array([[0.        , 0.00786314]], dtype=float32)}, 'Dense_1': {'bias': Array([-0.09852015], dtype=float32), 'kernel': Array([[ 0.       ],
       [-0.0608064]], dtype=float32)}}


ipdb>  c


{'Dense_0': {'bias': Array([0.        , 0.00786314], dtype=float32), 'kernel': Array([[0.        , 0.00786314]], dtype=float32)}, 'Dense_1': {'bias': Array([-0.09852015], dtype=float32), 'kernel': Array([[ 0.       ],
       [-0.0608064]], dtype=float32)}}


In [160]:
print_params(state.params)

Name: Dense_0/bias
Shape: (2,)
dtype: float32
Value: [ 0.         -0.00078631]
----------------------------------------
Name: Dense_0/kernel
Shape: (1, 2)
dtype: float32
Value: [[-0.17805344  0.6164113 ]]
----------------------------------------
Name: Dense_1/bias
Shape: (1,)
dtype: float32
Value: [0.00985202]
----------------------------------------
Name: Dense_1/kernel
Shape: (2, 1)
dtype: float32
Value: [[ 0.330739  ]
 [-0.07373184]]
----------------------------------------


## LoRA

In [163]:
from lorax.constants import LORA_FULL, LORA_FREEZE

def decision_fn(path, param):
    if 'embedding' in path:
        print(f'Fully finetuning param {path}')
        return LORA_FULL
    dim = 3
    print(f'Using LoRA with dim={dim} for param {path}')
    return dim

In [164]:
lora_spec = lorax.simple_spec(params, decision_fn=decision_fn, tune_vectors=True)

Using LoRA with dim=3 for param (DictKey(key='Dense_0'), DictKey(key='kernel'))
Using LoRA with dim=3 for param (DictKey(key='Dense_1'), DictKey(key='kernel'))


In [165]:
lora_spec

{'Dense_0': {'bias': -1, 'kernel': 3}, 'Dense_1': {'bias': -1, 'kernel': 3}}

In [172]:
lora_params = lorax.init_lora(params, lora_spec, jax.random.PRNGKey(1))

In [173]:
print_params(lora_params)

Name: Dense_0/bias
Shape: (2,)
dtype: float32
Value: [0. 0.]
----------------------------------------
Name: Dense_0/kernel
Shape: (1, 2)
dtype: float32
Value: LoraWeight(shape=(1, 2), dtype=dtype('float32'), w=Array([[-0.17805344,  0.61719763]], dtype=float32), a=Array([[ 0.0021586 , -0.00557593],
       [ 0.00328551,  0.00336264],
       [ 0.02129472, -0.000558  ]], dtype=float32), b=Array([[0., 0., 0.]], dtype=float32), alpha=1.0)
----------------------------------------
Name: Dense_1/bias
Shape: (1,)
dtype: float32
Value: [0.]
----------------------------------------
Name: Dense_1/kernel
Shape: (2, 1)
dtype: float32
Value: LoraWeight(shape=(2, 1), dtype=dtype('float32'), w=Array([[ 0.330739  ],
       [-0.07981248]], dtype=float32), a=Array([[ 0.00025438],
       [-0.00417447],
       [-0.01784344]], dtype=float32), b=Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32), alpha=1.0)
----------------------------------------


In [174]:
lora_model = lorax.lora(model.apply)

In [175]:
lora_model({"params": lora_params}, jnp.ones(shape=(1, input_dim)))

Array([[-0.04926008]], dtype=float32)

In [176]:
state = train_state.TrainState.create(
    apply_fn=lora_model,  # forward pass func
    params=lora_params,   # model weights
    tx=lorax.wrap_optimizer(optax.sgd(learning_rate=0.1), lora_spec)  # optimizer func
)

In [177]:
def forward_pass(params, state, batch):
  input_images, labels = batch
    
  # call forward pass function.
  logits = state.apply_fn({"params": params}, input_images)

  # compute loss
  loss = optax.squared_error(logits, labels)
  loss = loss.mean()
  return loss, logits

In [183]:
def backward_pass(state, batch):
  # create a function to compute gradients wrt to loss
  # returned by our `forward_pass` function.
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0), has_aux=True)

  # compute gradients.
  (loss, _), grads = grad_fn(state.params, state, batch)
  print(grads)

  # apply gradients.
  state = state.apply_gradients(grads=grads)

  return state, loss

In [179]:
def train_step(state, batch):
  return backward_pass(state, batch)

In [180]:
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

In [197]:
%%capture
for _ in range(1000):
    state, loss = train_step(state, batch)

In [198]:
print_params(state.params)

Name: Dense_0/bias
Shape: (2,)
dtype: float32
Value: [ 0.        -0.0039066]
----------------------------------------
Name: Dense_0/kernel
Shape: (1, 2)
dtype: float32
Value: LoraWeight(shape=(1, 2), dtype=dtype('float32'), w=Array([[-0.17805344,  0.61719763]], dtype=float32), a=Array([[ 0.0021586 , -0.00557593],
       [ 0.00328551,  0.00336265],
       [ 0.02129472, -0.000558  ]], dtype=float32), b=Array([[ 7.2609791e-06, -4.3788386e-06,  7.2662425e-07]], dtype=float32), alpha=1.0)
----------------------------------------
Name: Dense_1/bias
Shape: (1,)
dtype: float32
Value: [0.04894758]
----------------------------------------
Name: Dense_1/kernel
Shape: (2, 1)
dtype: float32
Value: LoraWeight(shape=(2, 1), dtype=dtype('float32'), w=Array([[ 0.330739  ],
       [-0.07981248]], dtype=float32), a=Array([[ 0.00025439],
       [-0.00417465],
       [-0.01784424]], dtype=float32), b=Array([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [ 2.5544762e-06, -4.1919662e-05, -1.7918246

In [132]:
print_params(state.params)

Name: Dense_0/bias
Shape: (2,)
dtype: float32
Value: [ 0.         -0.00141436]
----------------------------------------
Name: Dense_0/kernel
Shape: (1, 2)
dtype: float32
Value: LoraWeight(shape=(1, 2), dtype=dtype('float32'), w=Array([[-0.17805344,  0.61719763]], dtype=float32), a=Array([[-0.01458195, -0.0204706 ],
       [-0.01424288,  0.011684  ],
       [-0.00975838, -0.01271841]], dtype=float32), b=Array([[ 9.6509393e-06, -5.5084647e-06,  5.9961412e-06]], dtype=float32), alpha=1.0)
----------------------------------------
Name: Dense_1/bias
Shape: (1,)
dtype: float32
Value: [0.01772106]
----------------------------------------
Name: Dense_1/kernel
Shape: (2, 1)
dtype: float32
Value: LoraWeight(shape=(2, 1), dtype=dtype('float32'), w=Array([[ 0.330739  ],
       [-0.07981248]], dtype=float32), a=Array([[-0.00066073],
       [ 0.00166766],
       [ 0.01177997]], dtype=float32), b=Array([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       [-2.4075107e-06,  6.0765069e-06,  4.29229

In [138]:
merged_params = lorax.merge_params(lora_params)


ValueError: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 0 to replica 0: INVALID_ARGUMENT: Donation requested for invalid buffer

In [None]:
orig_model_output = model.apply({"params": merged_params}, jnp.ones(shape=(1, input_dim)))