In [1]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install tensorflow-cpu -q
!pip install tensorflow_datasets -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [3]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [7]:
# @title Python imports

import enum
import re
import string
import pdb

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial

# For LoRA
import lorax

# We will use HuggingFace's dataset, tokenizer, and model classes.
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm


In [5]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

## Fine tuning the Gemma model

In [180]:
import flax
from flax.traverse_util import flatten_dict
from flax.core.meta import unbox


def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        # print(f"Shape: {param.shape}")
        # print(f"dtype: {param.dtype}")
        # print(f"Value: {param}")
        if isinstance(param, flax.core.meta.Partitioned):
            array = unbox(param)
        else:
            array = param
        print(jax.debug.visualize_array_sharding(array))
        print("-" * 40)

## Try LoRA with simpleNN

In [127]:
input_dim = 32 
hidden_dim = 8
output_dim = 1

In [128]:

# Helper function for creating NamedSharding
def create_sharding(pspec):
    return NamedSharding(mesh, pspec)

In [129]:
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
  return NamedSharding(mesh, pspec)

In [163]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from flax import linen as nn
from flax.training import train_state
import optax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
import functools
from functools import partial
import jax
import jax.numpy as jnp

class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = with_sharding_constraint(x, mesh_sharding(PartitionSpec('data', 'model')))

        x = nn.Dense(features=self.hidden_dim, 
                     kernel_init=nn.with_partitioning(nn.initializers.xavier_normal(), ('data', 'model')),
                     use_bias=False)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim, 
                     use_bias=False)(x)
        return x

# Set up the device mesh
devices = jax.devices()
device_mesh = mesh_utils.create_device_mesh((1, 4))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

print(mesh)

Mesh('data': 1, 'model': 4)


In [164]:
sample_batch = jnp.ones(shape=(1, input_dim))
x_sharding = mesh_sharding(PartitionSpec('data', 'model')) # dimensions: (batch, length)
sample_batch = jax.device_put(sample_batch, x_sharding)
jax.debug.visualize_array_sharding(x)

In [165]:
# Initialize function
def init_fn(key, x, model, optimizer):
    params = model.init(key, x)  # Initialize the model
    state = train_state.TrainState.create(  # Create a `TrainState`
        apply_fn=model.apply,
        params=params['params'],
        tx=optimizer)
    return state

In [166]:
model = SimpleNN(hidden_dim, output_dim)
optimizer = optax.adam(learning_rate=0.001)

In [167]:
abstract_variables = jax.eval_shape(
    functools.partial(
        init_fn, 
        model=model, 
        optimizer=optimizer
    ),
    jax.random.PRNGKey(99),
    sample_batch
)


In [135]:
abstract_variables

TrainState(step=ShapeDtypeStruct(shape=(), dtype=int32), apply_fn=<bound method Module.apply of SimpleNN(
    # attributes
    hidden_dim = 8
    output_dim = 1
)>, params={'Dense_0': {'kernel': Partitioned(value=ShapeDtypeStruct(shape=(32, 8), dtype=float32), names=('data', 'model'), mesh=None)}, 'Dense_1': {'bias': ShapeDtypeStruct(shape=(1,), dtype=float32), 'kernel': ShapeDtypeStruct(shape=(8, 1), dtype=float32)}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f5e0c4291b0>, update=<function chain.<locals>.update_fn at 0x7f5e0c4292d0>), opt_state=(ScaleByAdamState(count=ShapeDtypeStruct(shape=(), dtype=int32), mu={'Dense_0': {'kernel': Partitioned(value=ShapeDtypeStruct(shape=(32, 8), dtype=float32), names=('data', 'model'), mesh=None)}, 'Dense_1': {'bias': ShapeDtypeStruct(shape=(1,), dtype=float32), 'kernel': ShapeDtypeStruct(shape=(8, 1), dtype=float32)}}, nu={'Dense_0': {'kernel': Partitioned(value=ShapeDtypeStruct(shape=(32, 8), dtype=float32), 

In [168]:
state_sharding = nn.get_sharding(abstract_variables, mesh)

In [138]:
state_sharding

TrainState(step=NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec()), apply_fn=<bound method Module.apply of SimpleNN(
    # attributes
    hidden_dim = 8
    output_dim = 1
)>, params={'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec('data', 'model'))}, 'Dense_1': {'bias': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec()), 'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec())}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f5e0c4291b0>, update=<function chain.<locals>.update_fn at 0x7f5e0c4292d0>), opt_state=(ScaleByAdamState(count=NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec()), mu={'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec('data', 'model'))}, 'Dense_1': {'bias': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec()), 'kernel': NamedSharding(mesh=Mesh('data': 1, 'mod

In [169]:
jit_init_fn = jax.jit(init_fn, 
                          static_argnums=(2, 3),
                          in_shardings=(mesh_sharding(pspec=()), x_sharding),  # PRNG key and x
                          out_shardings=state_sharding)

In [170]:
initialized_state = jit_init_fn(jax.random.PRNGKey(99), sample_batch, model, optimizer)

In [171]:
initialized_state.params['Dense_1']['kernel']

Array([[-0.14665894],
       [-0.22765416],
       [-0.41170552],
       [-0.47610608],
       [ 0.09983327],
       [-0.55441546],
       [-0.32387885],
       [ 0.44042173]], dtype=float32)

In [181]:
print_params(initialized_state.params)

Name: Dense_0/kernel


None
----------------------------------------
Name: Dense_1/kernel


None
----------------------------------------


In [182]:
def forward_pass(params, state, batch):
  input_images, labels = batch
    
  # call forward pass function.
  logits = state.apply_fn({"params": params}, input_images)

  # compute loss
  loss = optax.squared_error(logits, labels)
  loss = loss.mean()
  return loss, logits

In [183]:
def backward_pass(state, batch):
  # create a function to compute gradients wrt to loss
  # returned by our `forward_pass` function.
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0), has_aux=True)

  # compute gradients.
  (loss, _), grads = grad_fn(state.params, state, batch)

  # apply gradients.
  state = state.apply_gradients(grads=grads)

  return state

In [207]:
@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding, None),
                   out_shardings=state_sharding)
def train_step(state, inputs, targets):
  return backward_pass(state, (inputs, targets))

In [208]:
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

In [209]:
batch0 = jax.device_put(batch[0], x_sharding)

In [210]:
batch1 = jax.device_put(batch[1], mesh_sharding(PartitionSpec('data', None)))

In [211]:
with mesh:
    new_state = train_step(initialized_state, batch0, batch1)

In [212]:
new_state

TrainState(step=Array(1, dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of SimpleNN(
    # attributes
    hidden_dim = 8
    output_dim = 1
)>, params={'Dense_0': {'kernel': Partitioned(value=Array([[ 0.29087412, -0.05503109, -0.20790817, -0.29750723,  0.07447191,
         0.0784103 ,  0.01066407,  0.49764064],
       [ 0.3668062 , -0.34450766,  0.1546367 ,  0.32535422,  0.12399623,
        -0.29323256,  0.0124924 ,  0.04483453],
       [ 0.03978511,  0.13564567, -0.21708795, -0.20757158, -0.11483077,
         0.06651969,  0.30071765,  0.21988122],
       [-0.14387387,  0.13240269, -0.21598163,  0.18907267,  0.15003328,
         0.383029  ,  0.27404705,  0.34015822],
       [-0.14029394, -0.42128232, -0.0436018 ,  0.3238122 , -0.07844601,
        -0.01554115,  0.06825367, -0.01932324],
       [-0.23955524, -0.01933642,  0.0330421 , -0.0531561 ,  0.1021446 ,
         0.14570239, -0.33155596,  0.1250879 ],
       [ 0.02744705, -0.21270357,  0.10063621, -0.37119505,  0.