In [1]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install tensorflow-cpu -q
!pip install tensorflow_datasets -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [2]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [3]:
# @title Python imports

import enum
import re
import string
import pdb

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial

# For LoRA
import lorax

# We will use HuggingFace's dataset, tokenizer, and model classes.
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm


In [4]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## Fine tuning the Gemma model

In [8]:
import flax
from flax.traverse_util import flatten_dict
from flax.core.meta import unbox


def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        # print(f"Shape: {param.shape}")
        # print(f"dtype: {param.dtype}")
        # print(f"Value: {param}")
        if isinstance(param, flax.core.meta.Partitioned):
            array = unbox(param)
        else:
            array = param
        print(jax.debug.visualize_array_sharding(array))
        print("-" * 40)

In [122]:
import jax
import jax.numpy as jnp
from jax import lax
from flax.linen import Module, compact
from flax.linen.initializers import zeros_init
from flax.linen.dtypes import promote_dtype
from typing import Any, Callable

default_kernel_init = jax.nn.initializers.lecun_normal()

class LoRADense(Module):
    features: int
    use_bias: bool = True
    dtype: Any = None
    param_dtype: Any = jnp.float32
    precision: Any = None
    kernel_init: Callable = default_kernel_init
    bias_init: Callable = zeros_init()
    lora_rank: int = 8
    lora_alpha: float = 16

    @compact
    def __call__(self, inputs: Any) -> Any:
        kernel = self.variable(
            'original_params', 'kernel',
            self.kernel_init,
            jax.random.PRNGKey(0),  # You might want to pass a proper key
            (jnp.shape(inputs)[-1], self.features),
            self.param_dtype
        )
        
        if self.use_bias:
            bias = self.variable(
                'original_params', 'bias',
                self.bias_init,
                jax.random.PRNGKey(1),  # You might want to pass a proper key
                (self.features,),
                self.param_dtype
            )
        else:
            bias = None

        # LoRA weights (these remain as trainable parameters)
        lora_a = self.param(
            'lora_a',
            default_kernel_init,
            (jnp.shape(inputs)[-1], self.lora_rank),
            self.param_dtype
        )
        lora_b = self.param(
            'lora_b',
            zeros_init(),
            (self.lora_rank, self.features),
            self.param_dtype
        )

        inputs, kernel_value, lora_a, lora_b, bias_value = promote_dtype(
            inputs, kernel.value, lora_a, lora_b, 
            None if bias is None else bias.value, 
            dtype=self.dtype
        )

        y = lax.dot_general(
            inputs,
            jax.lax.stop_gradient(kernel_value),
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )

        # LoRA computation
        lora_output = lax.dot_general(
            inputs,
            lora_a,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        lora_output = lax.dot_general(
            lora_output,
            lora_b,
            (((lora_output.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        y += (self.lora_alpha / self.lora_rank) * lora_output

        if bias_value is not None:
            y += jnp.reshape(bias_value, (1,) * (y.ndim - 1) + (-1,))
        return y

In [112]:
input_dim = 4
hidden_dim = 2
output_dim = 1

In [123]:
class Model(Module):
    hidden_dim: int
    output_dim: int

    @compact
    def __call__(self, x):
        x = LoRADense(features=self.hidden_dim, 
                      kernel_init=jax.nn.initializers.xavier_normal(),
                      use_bias=False)(x)
        x = jax.nn.relu(x)
        x = LoRADense(features=self.output_dim, 
                      use_bias=False)(x)
        return x

# model = Model(hidden_dim=hidden_dim, output_dim=output_dim)
# params = model.init({
#     'params': jax.random.PRNGKey(0), 
#     'original_params': jax.random.PRNGKey(1)
# }, jnp.zeros((1, 8)))

In [124]:
# Define dimensions
input_dim = 4
hidden_dim = 2
output_dim = 1

# Initialize model
model = Model(hidden_dim=hidden_dim, output_dim=output_dim)
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.zeros((1, input_dim)))

# Define batch
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

# Define loss function (Mean Squared Error)
def loss_fn(params, batch):
    inputs, targets = batch
    outputs = model.apply(params, inputs)
    return jnp.mean((outputs - targets) ** 2)

# Define forward pass and gradient function
forward_pass = lambda params, batch: loss_fn(params, batch)
grad_fn = jax.grad(forward_pass)

# Compute gradients
grads = grad_fn(params, batch)

# Print gradients
print("Gradients:")
jax.tree_util.tree_map(lambda g: print(g.shape, jnp.mean(g)), grads)

Gradients:
(4, 2) 0.0
(2, 1) 0.0
(4, 8) 0.0
(8, 2) -0.99083626
(2, 8) 0.0
(8, 1) 2.5092876


{'original_params': {'LoRADense_0': {'kernel': None},
  'LoRADense_1': {'kernel': None}},
 'params': {'LoRADense_0': {'lora_a': None, 'lora_b': None},
  'LoRADense_1': {'lora_a': None, 'lora_b': None}}}

In [125]:
print(grads)


{'original_params': {'LoRADense_0': {'kernel': Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)}, 'LoRADense_1': {'kernel': Array([[0.],
       [0.]], dtype=float32)}}, 'params': {'LoRADense_0': {'lora_a': Array([[0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), 'lora_b': Array([[ 0.       ,  4.9100347],
       [-0.       , -5.7454195],
       [ 0.       ,  1.7669095],
       [-0.       , -4.0578656],
       [-0.       , -1.7626989],
       [-0.       , -7.924409 ],
       [ 0.       ,  4.2498674],
       [-0.       , -7.2897983]], dtype=float32)}, 'LoRADense_1': {'lora_a': Array([[0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), 'lora_b': Array([[  8.193271 ],
       [ 16.42541  ],
       [ -3.104372 ],
       [-18.169437 ],
       [ 15.23544  ],
       [  9.356257 ],
       [ -8.72578  

In [114]:
params

{'original_params': {'LoRADense_0': {'kernel': Array([[ 0.04115987,  0.8384647 ],
          [ 0.14326866,  0.53145474],
          [ 0.16196674, -0.39632457],
          [ 0.31242397,  0.4965637 ],
          [-0.3758965 , -0.21959446],
          [ 0.30370873,  0.39126053],
          [-0.15885408, -0.53257084],
          [-0.75438243,  0.13228753]], dtype=float32)},
  'LoRADense_1': {'kernel': Array([[-0.59603935],
          [ 0.6490678 ]], dtype=float32)}},
 'params': {'LoRADense_0': {'lora_a': Array([[-0.77251935, -0.34922582, -0.16781904,  0.45006755,  0.5829724 ,
           -0.20926498,  0.09819026,  0.41416237],
          [-0.15905392, -0.56037694,  0.00484613,  0.04173131, -0.00559456,
           -0.4811671 , -0.03955451,  0.02943526],
          [-0.54997915,  0.07164625, -0.37911326,  0.311233  , -0.47716334,
            0.16550043, -0.4135211 , -0.6538287 ],
          [ 0.17607015,  0.15953197, -0.01787516,  0.01488424, -0.33595267,
           -0.21633081, -0.0459558 ,  0.34950352

In [116]:
model.apply(params, jnp.ones((1, 8)))

Array([[0.8058445]], dtype=float32)

In [117]:
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

In [118]:
# Initialize function
def init_fn(key, x, model, optimizer):
    params = model.init(key, x)  # Initialize the model
    state = train_state.TrainState.create(  # Create a `TrainState`
        apply_fn=model.apply,
        params=params['params'],
        tx=optimizer)
    return state

def forward_pass(params, state, batch):
  input_images, labels = batch
    
  # call forward pass function.
  logits = state.apply_fn({"params": params}, input_images)

  # compute loss
  loss = optax.squared_error(logits, labels)
  loss = loss.mean()
  return loss, logits

In [None]:

grad_fn = jax.value_and_grad(forward_pass, argnums=(0), has_aux=True)

# compute gradients.
(loss, _), grads = grad_fn(state.params, state, batch)


In [66]:
params = model.init(jax.random.PRNGKey(99), jnp.zeros((1, 8)))['params']

In [67]:
params

{'Dense_0': {'kernel': Array([[-0.49519178, -0.2765769 ,  0.29829058,  0.1825179 ,  0.12744386,
           0.01904512,  0.5438577 , -0.5497272 ],
         [-0.2905944 , -0.2480456 , -0.20327911, -0.75905126,  0.14114662,
           0.4145826 , -0.19347848,  0.47221926],
         [ 0.668468  , -0.2039723 ,  0.38946563,  0.13857874, -0.45398667,
          -0.24390857, -0.32775238,  0.3971401 ],
         [-0.29118022,  0.05225108,  0.40288976, -0.13084824, -0.25388727,
          -0.36863217, -0.638132  , -0.15610406],
         [ 0.18570937,  0.26377958, -0.02895478,  0.05723229, -0.39722154,
          -0.2310003 ,  0.7511915 , -0.2714732 ],
         [-0.3982755 ,  0.36451694,  0.20695516, -0.01715776, -0.44657674,
           0.21720676,  0.5451954 ,  0.56244195],
         [-0.2991815 ,  0.4356464 , -0.6718789 ,  0.16581155,  0.05301353,
           0.32873207,  0.16510576, -0.34548888],
         [-0.57636905, -0.15388115, -0.69777644,  0.33658162,  0.36329   ,
          -0.34397233, -0.134

In [20]:
params

{'Dense_0': {'kernel': Array([[-0.49519178, -0.2765769 ,  0.29829058,  0.1825179 ,  0.12744386,
           0.01904512,  0.5438577 , -0.5497272 ],
         [-0.2905944 , -0.2480456 , -0.20327911, -0.75905126,  0.14114662,
           0.4145826 , -0.19347848,  0.47221926],
         [ 0.668468  , -0.2039723 ,  0.38946563,  0.13857874, -0.45398667,
          -0.24390857, -0.32775238,  0.3971401 ],
         [-0.29118022,  0.05225108,  0.40288976, -0.13084824, -0.25388727,
          -0.36863217, -0.638132  , -0.15610406],
         [ 0.18570937,  0.26377958, -0.02895478,  0.05723229, -0.39722154,
          -0.2310003 ,  0.7511915 , -0.2714732 ],
         [-0.3982755 ,  0.36451694,  0.20695516, -0.01715776, -0.44657674,
           0.21720676,  0.5451954 ,  0.56244195],
         [-0.2991815 ,  0.4356464 , -0.6718789 ,  0.16581155,  0.05301353,
           0.32873207,  0.16510576, -0.34548888],
         [-0.57636905, -0.15388115, -0.69777644,  0.33658162,  0.36329   ,
          -0.34397233, -0.134

In [22]:
model.apply({"params": params}, jnp.ones((1, 8)))

Array([[-0.28349656]], dtype=float32)

## lora parms are in different namespace

In [19]:
import jax
import jax.numpy as jnp
from jax import lax
from flax import linen as nn
from flax.core import freeze, unfreeze
from flax.linen.initializers import zeros_init
from flax.linen.dtypes import promote_dtype
from typing import Any, Callable
import optax

default_kernel_init = jax.nn.initializers.lecun_normal()

import jax
import jax.numpy as jnp
from jax import lax
from flax.linen import Module, compact
from flax.linen.initializers import zeros_init
from flax.linen.dtypes import promote_dtype
from typing import Any, Callable

default_kernel_init = jax.nn.initializers.lecun_normal()

class LoRADense(Module):
    features: int
    use_bias: bool = True
    dtype: Any = None
    param_dtype: Any = jnp.float32
    precision: Any = None
    kernel_init: Callable = default_kernel_init
    bias_init: Callable = zeros_init()
    lora_rank: int = 8
    lora_alpha: float = 16

    @compact
    def __call__(self, inputs: Any) -> Any:
        kernel = self.variable(
            'params', 'kernel',
            self.kernel_init,
            jax.random.PRNGKey(0),  # You might want to pass a proper key
            (jnp.shape(inputs)[-1], self.features),
            self.param_dtype
        )
        
        if self.use_bias:
            bias = self.variable(
                'params', 'bias',
                self.bias_init,
                jax.random.PRNGKey(1),  # You might want to pass a proper key
                (self.features,),
                self.param_dtype
            )
        else:
            bias = None

        # LoRA weights (these remain as trainable parameters)

        # LoRA weights are moved to "lora_params" scope but are still trainable
        lora_a = self.variable(
                'lora_params', 'lora_a',
                self.bias_init,
                jax.random.PRNGKey(1),  # You might want to pass a proper key
                (jnp.shape(inputs)[-1], self.lora_rank),
                self.param_dtype
            )
        lora_b = self.variable(
                'lora_params', 'lora_b',
                self.bias_init,
                jax.random.PRNGKey(1),  # You might want to pass a proper key
                (self.lora_rank, self.features),
                self.param_dtype
            )

        inputs, kernel_value, lora_a, lora_b, bias_value = promote_dtype(
            inputs, kernel.value, lora_a.value, lora_b.value, 
            None if bias is None else bias.value, 
            dtype=self.dtype
        )

        y = lax.dot_general(
            inputs,
            jax.lax.stop_gradient(kernel_value),
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )

        # LoRA computation
        lora_output = lax.dot_general(
            inputs,
            lora_a,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        lora_output = lax.dot_general(
            lora_output,
            lora_b,
            (((lora_output.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        y += (self.lora_alpha / self.lora_rank) * lora_output

        if bias_value is not None:
            y += jnp.reshape(bias_value, (1,) * (y.ndim - 1) + (-1,))
        return y
class Model(nn.Module):
    hidden_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x):
        x = LoRADense(features=self.hidden_dim,
                      kernel_init=jax.nn.initializers.xavier_normal(),
                      use_bias=False)(x)
        x = jax.nn.relu(x)
        x = LoRADense(features=self.output_dim,
                      use_bias=False)(x)
        return x

# Define dimensions
input_dim = 4
hidden_dim = 2
output_dim = 1

# Initialize model
model = Model(hidden_dim=hidden_dim, output_dim=output_dim)
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.zeros((1, input_dim)))


In [16]:
print(variables)

{'params': {'LoRADense_0': {'kernel': Array([[ 0.0506583 , -0.24143301],
       [-0.23484214,  0.977799  ],
       [-0.77396363,  1.156682  ],
       [-0.5310172 ,  0.69166714]], dtype=float32)}, 'LoRADense_1': {'kernel': Array([[-0.596048 ],
       [ 0.6490589]], dtype=float32)}}, 'lora_params': {'LoRADense_0': {'lora_a': Array([[0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), 'lora_b': Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)}, 'LoRADense_1': {'lora_a': Array([[0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), 'lora_b': Array([[0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.]], dtype=float32)}}}


In [20]:

# Separate constants and trainable parameters
constants, params = variables.pop('params'), variables.pop('lora_params')
variables = freeze({'lora_params': params, 'params': constants})

# Define batch
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

# Define loss function (Mean Squared Error)
def loss_fn(params, constants, batch):
    inputs, targets = batch
    variables = {'lora_params': params, 'params': constants}
    outputs = model.apply(variables, inputs)
    return jnp.mean((outputs - targets) ** 2)

# Define forward pass and gradient function
forward_pass = lambda params, constants, batch: loss_fn(params, constants, batch)
grad_fn = jax.grad(forward_pass, argnums=0)  # Only compute gradients w.r.t. params

# Compute gradients
grads = grad_fn(variables['lora_params'], variables['params'], batch)

# Print gradients
print("Gradients:", grads)
# jax.tree_util.tree_map(lambda g: print(g.shape, jnp.mean(g)), grads)

# Create optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

# Print optimizer state
print("\nOptimizer State:", opt_state)
# jax.tree_util.tree_map(lambda x: print(type(x), x.shape if hasattr(x, 'shape') else x), opt_state)

Gradients: FrozenDict({
    LoRADense_0: {
        lora_a: Array([[0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
        lora_b: Array([[0., 0.],
               [0., 0.],
               [0., 0.],
               [0., 0.],
               [0., 0.],
               [0., 0.],
               [0., 0.],
               [0., 0.]], dtype=float32),
    },
    LoRADense_1: {
        lora_a: Array([[0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
        lora_b: Array([[0.],
               [0.],
               [0.],
               [0.],
               [0.],
               [0.],
               [0.],
               [0.]], dtype=float32),
    },
})

Optimizer State: (ScaleByAdamState(count=Array(0, dtype=int32), mu={'LoRADense_0': {'lora_a': Array([[0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0.,

## original

In [1]:
import jax
import jax.numpy as jnp
from jax import lax
from flax import linen as nn
from flax.core import freeze, unfreeze
from flax.linen.initializers import zeros_init
from flax.linen.dtypes import promote_dtype
from typing import Any, Callable
import optax

default_kernel_init = jax.nn.initializers.lecun_normal()

import jax
import jax.numpy as jnp
from jax import lax
from flax.linen import Module, compact
from flax.linen.initializers import zeros_init
from flax.linen.dtypes import promote_dtype
from typing import Any, Callable

default_kernel_init = jax.nn.initializers.lecun_normal()

class LoRADense(Module):
    features: int
    use_bias: bool = True
    dtype: Any = None
    param_dtype: Any = jnp.float32
    precision: Any = None
    kernel_init: Callable = default_kernel_init
    bias_init: Callable = zeros_init()
    lora_rank: int = 8
    lora_alpha: float = 16

    @compact
    def __call__(self, inputs: Any) -> Any:
        kernel = self.variable(
            'constants', 'kernel',
            self.kernel_init,
            jax.random.PRNGKey(0),  # You might want to pass a proper key
            (jnp.shape(inputs)[-1], self.features),
            self.param_dtype
        )
        
        if self.use_bias:
            bias = self.variable(
                'constants', 'bias',
                self.bias_init,
                jax.random.PRNGKey(1),  # You might want to pass a proper key
                (self.features,),
                self.param_dtype
            )
        else:
            bias = None

        # LoRA weights (these remain as trainable parameters)
        lora_a = self.param(
            'lora_a',
            default_kernel_init,
            (jnp.shape(inputs)[-1], self.lora_rank),
            self.param_dtype
        )
        lora_b = self.param(
            'lora_b',
            zeros_init(),
            (self.lora_rank, self.features),
            self.param_dtype
        )

        inputs, kernel_value, lora_a, lora_b, bias_value = promote_dtype(
            inputs, kernel.value, lora_a, lora_b, 
            None if bias is None else bias.value, 
            dtype=self.dtype
        )

        y = lax.dot_general(
            inputs,
            jax.lax.stop_gradient(kernel_value),
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )

        # LoRA computation
        lora_output = lax.dot_general(
            inputs,
            lora_a,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        lora_output = lax.dot_general(
            lora_output,
            lora_b,
            (((lora_output.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        y += (self.lora_alpha / self.lora_rank) * lora_output

        if bias_value is not None:
            y += jnp.reshape(bias_value, (1,) * (y.ndim - 1) + (-1,))
        return y
class Model(nn.Module):
    hidden_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x):
        x = LoRADense(features=self.hidden_dim,
                      kernel_init=jax.nn.initializers.xavier_normal(),
                      use_bias=False)(x)
        x = jax.nn.relu(x)
        x = LoRADense(features=self.output_dim,
                      use_bias=False)(x)
        return x

# Define dimensions
input_dim = 4
hidden_dim = 2
output_dim = 1

# Initialize model
model = Model(hidden_dim=hidden_dim, output_dim=output_dim)
key = jax.random.PRNGKey(0)
variables = model.init(key, jnp.zeros((1, input_dim)))

# Separate constants and trainable parameters
constants, params = variables.pop('constants'), variables.pop('params')
variables = freeze({'params': params, 'constants': constants})

# Define batch
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

# Define loss function (Mean Squared Error)
def loss_fn(params, constants, batch):
    inputs, targets = batch
    variables = {'params': params, 'constants': constants}
    outputs = model.apply(variables, inputs)
    return jnp.mean((outputs - targets) ** 2)

# Define forward pass and gradient function
forward_pass = lambda params, constants, batch: loss_fn(params, constants, batch)
grad_fn = jax.grad(forward_pass, argnums=0)  # Only compute gradients w.r.t. params

# Compute gradients
grads = grad_fn(variables['params'], variables['constants'], batch)

# Print gradients
print("Gradients:", grads)
# jax.tree_util.tree_map(lambda g: print(g.shape, jnp.mean(g)), grads)

# Create optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

# Print optimizer state
print("\nOptimizer State:", opt_state)
# jax.tree_util.tree_map(lambda x: print(type(x), x.shape if hasattr(x, 'shape') else x), opt_state)

Gradients: FrozenDict({
    LoRADense_0: {
        lora_a: Array([[0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
        lora_b: Array([[ 0.       ,  4.9099245],
               [-0.       , -5.745265 ],
               [ 0.       ,  1.7668908],
               [-0.       , -4.0577426],
               [-0.       , -1.7626575],
               [-0.       , -7.9242153],
               [ 0.       ,  4.249761 ],
               [-0.       , -7.2896028]], dtype=float32),
    },
    LoRADense_1: {
        lora_a: Array([[0., 0., 0., 0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
        lora_b: Array([[  8.193145  ],
               [ 16.425367  ],
               [ -3.1043384 ],
               [-18.169392  ],
               [ 15.235281  ],
               [  9.35627   ],
               [ -8.725613  ],
         

In [136]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core import freeze, unfreeze
import optax

class Model(nn.Module):
    hidden_dim: int
    output_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim,
                     kernel_init=jax.nn.initializers.xavier_normal(),
                     use_bias=False)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(features=self.output_dim,
                     use_bias=False)(x)
        return x

# Define dimensions
input_dim = 4
hidden_dim = 2
output_dim = 1

# Initialize model
model = Model(hidden_dim=hidden_dim, output_dim=output_dim)
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.zeros((1, input_dim)))

# Define batch
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

# Define loss function (Mean Squared Error)
def loss_fn(params, batch):
    inputs, targets = batch
    outputs = model.apply(params, inputs)
    return jnp.mean((outputs - targets) ** 2)

# Define forward pass and gradient function
forward_pass = lambda params, batch: loss_fn(params, batch)
grad_fn = jax.grad(forward_pass)

# Compute gradients
grads = grad_fn(params, batch)

# Print gradients
print("Gradients:", grads)
# jax.tree_util.tree_map(lambda g: print(g.shape, jnp.mean(g)), grads)

# Create optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

# Print optimizer state
print("\nOptimizer State:", opt_state)
# jax.tree_util.tree_map(lambda x: print(type(x), x.shape if hasattr(x, 'shape') else x), opt_state)

Gradients: {'params': {'Dense_0': {'kernel': Array([[0.       , 0.0030845],
       [0.       , 0.0030845],
       [0.       , 0.0030845],
       [0.       , 0.0030845]], dtype=float32)}, 'Dense_1': {'kernel': Array([[ 0.        ],
       [-0.02614744]], dtype=float32)}}}

Optimizer State: (ScaleByAdamState(count=Array(0, dtype=int32), mu={'params': {'Dense_0': {'kernel': Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)}, 'Dense_1': {'kernel': Array([[0.],
       [0.]], dtype=float32)}}}, nu={'params': {'Dense_0': {'kernel': Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)}, 'Dense_1': {'kernel': Array([[0.],
       [0.]], dtype=float32)}}}), EmptyState())


In [12]:
def train_step(state, inputs, targets):
  return backward_pass(state, (inputs, targets))

## Try LoRA with simpleNN

In [5]:
input_dim = 32 
hidden_dim = 8
output_dim = 1

In [6]:

# Helper function for creating NamedSharding
def create_sharding(pspec):
    return NamedSharding(mesh, pspec)

In [None]:
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
  return NamedSharding(mesh, pspec)

In [None]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from flax import linen as nn
from flax.training import train_state
import optax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
import functools
from functools import partial
import jax
import jax.numpy as jnp

class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = with_sharding_constraint(x, mesh_sharding(PartitionSpec('data', 'model')))

        x = nn.Dense(features=self.hidden_dim, 
                     kernel_init=nn.with_partitioning(nn.initializers.xavier_normal(), ('data', 'model')),
                     use_bias=False)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim, 
                     use_bias=False)(x)
        return x

# Set up the device mesh
devices = jax.devices()
device_mesh = mesh_utils.create_device_mesh((1, 4))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

print(mesh)

In [None]:
sample_batch = jnp.ones(shape=(1, input_dim))
x_sharding = mesh_sharding(PartitionSpec('data', 'model')) # dimensions: (batch, length)
sample_batch = jax.device_put(sample_batch, x_sharding)
jax.debug.visualize_array_sharding(x)

In [None]:
# Initialize function
def init_fn(key, x, model, optimizer):
    params = model.init(key, x)  # Initialize the model
    state = train_state.TrainState.create(  # Create a `TrainState`
        apply_fn=model.apply,
        params=params['params'],
        tx=optimizer)
    return state

In [None]:
model = SimpleNN(hidden_dim, output_dim)
optimizer = optax.adam(learning_rate=0.001)

In [None]:
abstract_variables = jax.eval_shape(
    functools.partial(
        init_fn, 
        model=model, 
        optimizer=optimizer
    ),
    jax.random.PRNGKey(99),
    sample_batch
)


In [None]:
abstract_variables

In [None]:
state_sharding = nn.get_sharding(abstract_variables, mesh)

In [None]:
state_sharding

In [None]:
jit_init_fn = jax.jit(init_fn, 
                          static_argnums=(2, 3),
                          in_shardings=(mesh_sharding(pspec=()), x_sharding),  # PRNG key and x
                          out_shardings=state_sharding)

In [None]:
initialized_state = jit_init_fn(jax.random.PRNGKey(99), sample_batch, model, optimizer)

In [None]:
initialized_state.params['Dense_1']['kernel']

In [None]:
print_params(initialized_state.params)

In [None]:
def forward_pass(params, state, batch):
  input_images, labels = batch
    
  # call forward pass function.
  logits = state.apply_fn({"params": params}, input_images)

  # compute loss
  loss = optax.squared_error(logits, labels)
  loss = loss.mean()
  return loss, logits

In [None]:
def backward_pass(state, batch):
  # create a function to compute gradients wrt to loss
  # returned by our `forward_pass` function.
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0), has_aux=True)

  # compute gradients.
  (loss, _), grads = grad_fn(state.params, state, batch)

  # apply gradients.
  state = state.apply_gradients(grads=grads)

  return state

In [None]:
@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding, None),
                   out_shardings=state_sharding)
def train_step(state, inputs, targets):
  return backward_pass(state, (inputs, targets))

In [None]:
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

In [None]:
batch0 = jax.device_put(batch[0], x_sharding)

In [None]:
batch1 = jax.device_put(batch[1], mesh_sharding(PartitionSpec('data', None)))

In [None]:
with mesh:
    new_state = train_step(initialized_state, batch0, batch1)

In [None]:
new_state

In [None]:
def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        print(f"Shape: {param.shape}")
        print(f"dtype: {param.dtype}")
        print(f"Value: {param}")
        # if isinstance(param, flax.core.meta.Partitioned):
        #     array = unbox(param)
        # else:
        #     array = param
        # print(jax.debug.visualize_array_sharding(array))
        print("-" * 40)

In [9]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from flax import linen as nn
from flax.training import train_state
import optax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
import functools
from functools import partial
import jax
import jax.numpy as jnp

In [10]:
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, PartitionSpec as PS
from flax import linen as nn
from flax.training.train_state import TrainState
import optax
from functools import partial
from jax.experimental import mesh_utils
import re


In [11]:
class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = with_sharding_constraint(x, mesh_sharding(PartitionSpec('data', 'model')))

        x = nn.Dense(features=self.hidden_dim, 
                     kernel_init=nn.with_partitioning(nn.initializers.xavier_normal(), ('data', 'model')),
                     use_bias=False)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim, 
                     use_bias=False)(x)
        return x

In [12]:
# Set up the device mesh
devices = jax.devices()
device_mesh = mesh_utils.create_device_mesh((1, 4))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

In [13]:
model = SimpleNN(hidden_dim, output_dim)
optimizer = optax.adam(learning_rate=0.001)

NameError: name 'hidden_dim' is not defined

In [27]:
# Define the model
class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim, use_bias=False)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim, use_bias=False)(x)
        return x

In [28]:
class JaxRNG(object):
    """ A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
        pure function.
    """

    @classmethod
    def from_seed(cls, seed):
        return cls(jax.random.PRNGKey(seed))

    def __init__(self, rng):
        self.rng = rng

    def __call__(self, keys=None):
        if keys is None:
            self.rng, split_rng = jax.random.split(self.rng)
            return split_rng
        elif isinstance(keys, int):
            split_rngs = jax.random.split(self.rng, num=keys + 1)
            self.rng = split_rngs[0]
            return tuple(split_rngs[1:])
        else:
            split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
            self.rng = split_rngs[0]
            return {key: val for key, val in zip(keys, split_rngs[1:])}


In [29]:
def tree_path_to_string(path, sep=None):
    keys = []
    for key in path:
        if isinstance(key, jax.tree_util.SequenceKey):
            keys.append(str(key.idx))
        elif isinstance(key, jax.tree_util.DictKey):
            keys.append(str(key.key))
        elif isinstance(key, jax.tree_util.GetAttrKey):
            keys.append(str(key.name))
        elif isinstance(key, jax.tree_util.FlattenedIndexKey):
            keys.append(str(key.key))
        else:
            keys.append(str(key))
    if sep is None:
        return tuple(keys)
    return sep.join(keys)


In [30]:
def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
    """ An extended version of jax.tree_util.tree_map, where the mapped function
        f takes both the name (path) and the tree leaf as input.
    """
    return jax.tree_util.tree_map_with_path(
        lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
        tree, *rest,
        is_leaf=is_leaf
    )

In [31]:
import numpy as np
def match_partition_rules(rules, params):
    """ Returns a pytree of PartitionSpec according to rules. Supports handling
        Flax TrainState and Optax optimizer state.
    """
    def get_partition_spec(name, leaf):
        if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
            """ Don't partition scalar values. """
            return PS()
        for rule, ps in rules:
            if re.search(rule, name) is not None:
                return ps
        raise ValueError(f'Partition rule not found for param: {name}')
    return named_tree_map(get_partition_spec, params, sep='/')

In [32]:
def get_jax_mesh(axis_dims, names):
    dims = [int(x) for x in axis_dims.split(',')]
    mesh_shape = jax.numpy.arange(jax.device_count()).reshape(dims).shape
    return 

In [33]:
def with_sharding_constraint(x, partition_specs):
    return jax.lax.with_sharding_constraint(x, partition_specs)

In [34]:
# Configuration
input_dim = 32 
hidden_dim = 8
output_dim = 1

In [47]:
class SimpleNNConfigurator:
    @staticmethod
    def get_partition_rules():
        return (
            # ("params/params/Dense_0/kernel", PS("data", "model")),
            ('.*', PS(None)),
        )

    @staticmethod
    def rng_keys():
        return ('params', 'dropout', 'other')

In [36]:
def create_trainstate_from_params(params):
    return TrainState.create(params=params, tx=optimizer, apply_fn=None)

In [37]:
def init_fn(rng, input_dim=32, hidden_dim=8, output_dim=1):
    rng_generator = JaxRNG(rng)
    model = SimpleNN(hidden_dim=hidden_dim, output_dim=output_dim)
    params = model.init(
        rng_generator(SimpleNNConfigurator.rng_keys()),
        jnp.zeros((4, input_dim)),
    )
    return TrainState.create(params=params, tx=optimizer, apply_fn=None)

In [38]:
def train_step(train_state, rng, batch):
    rng_generator = JaxRNG(rng)
    batch = jax.lax.with_sharding_constraint(batch, PS(('data', 'model')))
    
    def loss_and_accuracy(params):
        pred = model.apply(params, batch['input'],
            rngs=rng_generator(SimpleNNConfigurator.rng_keys()),
        )
        loss = optax.squared_error(pred, batch['target'])
        loss = loss.mean()
        return loss
        
    grad_fn = jax.grad(loss_and_accuracy)
    grads = grad_fn(train_state.params)
    train_state = train_state.apply_gradients(grads=grads)
    return train_state


In [39]:
rng = JaxRNG(0)
model = SimpleNN(hidden_dim=hidden_dim, output_dim=output_dim)

In [40]:
optimizer = optax.adam(1e-3)

In [51]:
# eval_shape(fun: Callable, *args, **kwargs):
# evaluating ``fun(*args, **kwargs)``.

train_state_shapes = jax.eval_shape(init_fn, jax.random.PRNGKey(99))

In [49]:
train_state_partition = match_partition_rules(
    SimpleNNConfigurator.get_partition_rules(), train_state_shapes
)

In [52]:
def create_named_sharding(partition_spec):
    return NamedSharding(mesh, partition_spec)

In [53]:
train_state_named_sharding = jax.tree.map(create_named_sharding, train_state_partition)

In [54]:
train_state_named_sharding.params

{'params': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec(None,))},
  'Dense_1': {'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec(None,))}}}

In [55]:
sharded_init_fn = jax.jit(
    init_fn,
    in_shardings=NamedSharding(mesh, PS()),
    # out_shardings=train_state_partition, out_shardings is optional in jax.jit, GSPMD will figure it out.
    static_argnums=(1, 2, 3)
    
)

In [56]:
sharded_create_trainstate_from_params = jax.jit(
    create_trainstate_from_params,
    in_shardings=(train_state_partition.params, ),
    # out_shardings=train_state_partition,
    donate_argnums=(0, ),
)

In [57]:
sharded_train_step = jax.jit(
    train_step,
    in_shardings=(train_state_named_sharding, NamedSharding(mesh, PS()), NamedSharding(mesh, PS())),
    # out_shardings=(train_state_partition),
    donate_argnums=(0, 1),
)

In [58]:
device_mesh = mesh_utils.create_device_mesh((1, 4))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

In [59]:
with mesh:
    train_state = sharded_init_fn(jax.random.PRNGKey(99))
    batch = {
        'input': jnp.ones((32, input_dim)),
        'target': jnp.zeros((32, 1), dtype=jnp.int32),
    }
    sharded_rng = jax.random.PRNGKey(99)
    train_state = sharded_train_step(
        train_state, sharded_rng, batch
    )
    print(train_state)

TrainState(step=Array(1, dtype=int32, weak_type=True), apply_fn=None, params={'params': {'Dense_0': {'kernel': Array([[ 9.58651379e-02, -5.77238090e-02,  8.71751755e-02,
         2.06664905e-01, -1.54080182e-01,  3.74663025e-01,
         1.21506892e-01, -1.27760872e-01],
       [-1.65383801e-01,  9.92371961e-02, -6.61521703e-02,
         2.75267810e-01,  3.34123582e-01,  1.94328249e-01,
        -2.11523082e-02,  7.46464133e-02],
       [ 6.08712770e-02, -1.41109407e-01, -8.76355544e-02,
         2.75042266e-01, -2.81154457e-02, -1.16207547e-01,
        -2.79863566e-01,  2.71764964e-01],
       [ 1.61056504e-01, -9.62803438e-02,  3.30342233e-01,
        -3.67344916e-01,  4.07360569e-02, -1.63870811e-01,
        -1.81345776e-01,  4.84794006e-03],
       [-9.79540274e-02,  1.76957369e-01,  3.66404593e-01,
        -1.52017111e-02,  5.93522601e-02,  9.57616698e-03,
         9.57029611e-02, -1.66373268e-01],
       [ 7.82506913e-02,  2.19656542e-01,  4.44784090e-02,
         3.93143147e-02, 

In [None]:
train_state