In [1]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install tensorflow-cpu -q
!pip install tensorflow_datasets -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [3]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [7]:
# @title Python imports

import enum
import re
import string
import pdb

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial

# For LoRA
import lorax

# We will use HuggingFace's dataset, tokenizer, and model classes.
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm


In [5]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

## Fine tuning the Gemma model

In [46]:
import flax
from flax.traverse_util import flatten_dict

def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        # print(f"Shape: {param.shape}")
        # print(f"dtype: {param.dtype}")
        # print(f"Value: {param}")
        print(jax.debug.visualize_array_sharding(param))
        print("-" * 40)

## Try LoRA with simpleNN

In [8]:
input_dim = 1 
hidden_dim = 2
output_dim = 1

In [13]:

# Helper function for creating NamedSharding
def create_sharding(pspec):
    return NamedSharding(mesh, pspec)

In [15]:
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
  return NamedSharding(mesh, pspec)

In [21]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from flax import linen as nn
from flax.training import train_state
import optax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
import functools
from functools import partial
import jax
import jax.numpy as jnp

class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim)(x)
        return x

# Set up the device mesh
devices = jax.devices()
device_mesh = mesh_utils.create_device_mesh((1, 4))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

print(mesh)

Mesh('data': 1, 'model': 4)


In [16]:
x = jnp.ones(shape=(1, input_dim))
x_sharding = mesh_sharding(PartitionSpec('data', None)) # dimensions: (batch, length)
x = jax.device_put(x, x_sharding)
jax.debug.visualize_array_sharding(x)

In [17]:
# Initialize function
def init_fn(key, x, model, optimizer):
    variables = model.init(key, x)  # Initialize the model
    state = train_state.TrainState.create(  # Create a `TrainState`
        apply_fn=model.apply,
        params=variables['params'],
        tx=optimizer)
    return state

In [34]:
model = SimpleNN(hidden_dim, output_dim)
optimizer = optax.adam(learning_rate=0.001)

In [38]:
abstract_variables = jax.eval_shape(
    functools.partial(
        init_fn, 
        model=model, 
        optimizer=optimizer
    ),
    jax.random.PRNGKey(99),
    jnp.ones(shape=(1,))
)


In [40]:
state_sharding = nn.get_sharding(abstract_variables, mesh)

In [41]:
jit_init_fn = jax.jit(init_fn, 
                          static_argnums=(2, 3),
                          in_shardings=(mesh_sharding(pspec=()), x_sharding),  # PRNG key and x
                          out_shardings=state_sharding)

In [42]:
initialized_state = jit_init_fn(jax.random.PRNGKey(99), jnp.ones(shape=(1, input_dim)), model, optimizer)

In [47]:
print_params(initialized_state.params)

Name: Dense_0/bias


None
----------------------------------------
Name: Dense_0/kernel


None
----------------------------------------
Name: Dense_1/bias


None
----------------------------------------
Name: Dense_1/kernel


None
----------------------------------------


In [53]:
def forward_pass(params, state, batch):
  input_images, labels = batch
    
  # call forward pass function.
  logits = state.apply_fn({"params": params}, input_images)

  # compute loss
  loss = optax.squared_error(logits, labels)
  loss = loss.mean()
  return loss, logits

In [54]:
def backward_pass(state, batch):
  # create a function to compute gradients wrt to loss
  # returned by our `forward_pass` function.
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0), has_aux=True)

  # compute gradients.
  (loss, _), grads = grad_fn(state.params, state, batch)

  # apply gradients.
  state = state.apply_gradients(grads=grads)

  return state

In [50]:
@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),
                   out_shardings=state_sharding)
def train_step(state, batch):
  return backward_pass(state, batch)

In [55]:
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

In [57]:
with mesh:
    new_state = train_step(initialized_state, batch)

In [58]:
new_state

TrainState(step=Array(1, dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of SimpleNN(
    # attributes
    hidden_dim = 2
    output_dim = 1
)>, params={'Dense_0': {'bias': Array([ 0.        , -0.00099999], dtype=float32), 'kernel': Array([[-0.17805348,  0.6162018 ]], dtype=float32)}, 'Dense_1': {'bias': Array([0.00099999], dtype=float32), 'kernel': Array([[ 0.3307397 ],
       [-0.07881248]], dtype=float32)}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f5e94451090>, update=<function chain.<locals>.update_fn at 0x7f5e94451240>), opt_state=(ScaleByAdamState(count=Array(1, dtype=int32), mu={'Dense_0': {'bias': Array([0.        , 0.00078632], dtype=float32), 'kernel': Array([[0.        , 0.00078632]], dtype=float32)}, 'Dense_1': {'bias': Array([-0.00985208], dtype=float32), 'kernel': Array([[ 0.        ],
       [-0.00608072]], dtype=float32)}}, nu={'Dense_0': {'bias': Array([0.0000000e+00, 6.1829745e-08], dtype=float32), 'kernel': Arr