Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
See the License for the specific language governing permissions and
limitations under the License.

---

# Fine-tuning the 2B Gemma model with flax

In this tutorial you will learn how to fine-tune the 2B Gemma model for a simple translation task. To run this colab, you will need to use a TPU v4 runtime.

## Setup

In [1]:
# @title Installation
! pip install git+https://github.com/google-deepmind/gemma.git
! pip install --user kaggle
! pip install --user kagglehub

Collecting git+https://github.com/google-deepmind/gemma.git
  Cloning https://github.com/google-deepmind/gemma.git to /tmp/pip-req-build-hxirukrg
  Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/gemma.git /tmp/pip-req-build-hxirukrg
  Resolved https://github.com/google-deepmind/gemma.git to commit 2ea41628173cd88de9ab6963e628889faec86ff5
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone


## Downloading the checkpoint

"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.

Then run the cell below.

In [2]:
import kagglehub
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
```

Now select and download the checkpoint you want to try. On a single host, only the 2b model can fit in memory for fine-tuning.

In [3]:
import os

VARIANT = '2b' # @param ['2b', '2b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = os.path.join(weights_dir, VARIANT)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

In [4]:
# @title Python imports

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
# os.environ["XLA_FLAGS"] = "--xla_gpu_enable_command_buffer="

import enum
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import optax

import numpy as np

# We will use tensorflow to handle the dataset
import tensorflow as tf
import tensorflow_datasets as tfds

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

2024-09-30 16:23:51.119108: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-30 16:23:51.137025: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-30 16:23:51.142515: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Step 1: prepare the dataset
Each sample in the dataset contains two entries:
- 'src': The original English sentence.
- 'dst': The corresponding French translation.

### Tokenizer

Let's start by loading our vocabulary base tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.

In [5]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

True

Let's customize `SentencePieceProcessor` for our English-to-French translation task. Since we're fine-tuning the English-only Gemma 2B model, we need a few adjustments:

- **Input Prefix**: Adding a common prefix to each input signals the translation task. For example we could go with a prompt like `Translate this into French: [INPUT_SENTENCE]`.

- **Translation Start suffix**: We add a suffix at the end of each prompt tells the model exactly when to begin the translation process. A new line should do the job.

- **LM Tokens**: Gemma models expect a *beginning of sequence* token at the beginning of each sequence. Similarly, we need to add an *end of sequence* token at the end of each training example.

In [6]:
class GemmaTokenizer:
  """Custom wrapper around a SentencePieceProcessor for tensorflow."""

  def __init__(self,
               spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad id."""
    return self._spm_processor.pad_id()

  def tokenize(self,
               example: str | bytes,
               prefix: str = '',
               suffix: str = '',
               add_bos: bool = True,
               add_eos: bool = True) -> jax.Array:
    """
    Tokenization function.

    Args:
      example: input string to tokenize.
      prefix:  prefix to add to the input string.
      suffix:  suffix to add to the input string.
      add_bos: if True, add a beginning of sequence token at the start of the
               tokenized sequence.
      add_eos: if True, add an end of sequence token at the end of the tokenized
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = []
    if add_bos:
      int_list.append(self._spm_processor.bos_id())
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(self,
                     str_tensor: tf.Tensor,
                     prefix: str = '',
                     suffix: str = '',
                     add_bos: bool = True,
                     add_eos: bool = True) -> tf.Tensor:
    """Tensforflow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_bos, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.DecodeIds(tokens.tolist())

Now let's try our custom tokenizer on the MTNT dataset

### Data loader

We can now wrap everything a build our data loader.

In [7]:
@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens given to the model
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'

src_map = [
    "Cxx1h",
    "Vzbiik",
    "Mmojkr",
    "Trrrrqe",
    "Benjamin",
    "Liz",
    "Kaitlyn",
    "Wiesa",
]

dst_map = [
    "Lkkl",
    "Plooqujhd",
    "Nwops",
    "Qtbnaaa",
    "Buenos",
    "London",
    "Kingston",
    "Warsaw",
]

def seq_generator():
    rng = np.random.RandomState(42)
    while True:
        idx = rng.randint(len(src_map))
        yield {
            "src": idx,
            "dst": idx,
        }


class FTDatasetBuilder:
  """Data loader for the FT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 10,
             DatasetSplit.VALIDATION: 10}

  BUFFER_SIZE_SHUFFLE = 10
  Q_PREFIX = 'Where does '
  Q_SUFFIX = ' live? Only give the name of the city.'

  def __init__(self,
               tokenizer : GemmaTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tf.data.Dataset.from_generator(
            seq_generator,
            output_signature={
                "src": tf.TensorSpec(
                    shape=(),
                    dtype=tf.dtypes.int32,
                ),
                "dst": tf.TensorSpec(
                    shape=(),
                    dtype=tf.dtypes.int32,
                ),
            },
        ),
        DatasetSplit.VALIDATION: tf.data.Dataset.from_generator(
            seq_generator,
            output_signature={
                "src": tf.TensorSpec(
                    shape=(),
                    dtype=tf.dtypes.int32,
                ),
                "dst": tf.TensorSpec(
                    shape=(),
                    dtype=tf.dtypes.int32,
                ),
            },
        ),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    # We add <BOS> as these tokens are the start of our sequence.
    return self._tokenizer.tokenize_tf_op(tf.gather(src_map, example),
                                          prefix=self.Q_PREFIX,
                                          suffix=self.Q_SUFFIX,
                                          add_bos=True,
                                          add_eos=False)

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    # We do not add <BOS> as these tokens get appended to the source tokens.
    return self._tokenizer.tokenize_tf_op(tf.gather(dst_map, example),
                                          add_bos=False,
                                          add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(input_tensor,
                  [[0, to_pad]],
                  mode='CONSTANT',
                  constant_values=pad_value,
                  )

  def _to_training_input(self,
                         src_tokens: jax.Array,
                         dst_tokens: jax.Array,
                         ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # We want to prevent the model from updating based on the source (input)
    # tokens. To achieve this, we add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then we pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # We don't want to perform the backward on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample
    ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),
                                                             self._tokenize_destination(x['dst'])))

    # Convert them to training inputs
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples which are too long
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary
    ds = ds.repeat(num_epochs)

    # Build batches
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same as the training dataset, but no shuffling and no repetition
    ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),
                                                                  self._tokenize_destination(x['dst'])))
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

## Fine tuning the Gemma model

### Getting started

First let's load the model

In [8]:
ckpt_path

'/home/bryanpu1/.cache/kagglehub/models/google/gemma/Flax/2b/2/2b'

In [9]:
# Load parameters

# TODO: change once the downloading url is known
params = params_lib.load_and_format_params(ckpt_path)

# We use the `transformer_lib.TransformerConfig.from_params` function to
# automatically load the correct configuration from a checkpoint. Note that the
# vocabulary size is smaller than the number of input embeddings due to unused
# tokens in this release.
config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=30  # Number of time steps in the transformer's cache
)
model_2b = transformer_lib.Transformer(config=config_2b)

2024-09-30 16:23:59.525232: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.0 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [10]:
Q_PREFIX = 'Where does '
Q_SUFFIX = ' live? Only give the name of the city.'

src_map = [
    "Cxx1h",
    "Vzbiik",
    "Mmojkr",
    "Trrrrqe",
    "Benjamin",
    "Liz",
    "Kaitlyn",
    "Wiesa",
]

dst_map = [
    "Lkkl",
    "Plooqujhd",
    "Nwops",
    "Qtbnaaa",
    "Buenos",
    "London",
    "Kingston",
    "Warsaw",
]

idx = 6

In [11]:
sampler_old = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['transformer'],
)


In [12]:
test_input = "{}{}{}".format(Q_PREFIX, src_map[idx], Q_SUFFIX)
print(test_input)
print(sampler_old(
    [test_input],
    # number of steps performed when generating
    total_generation_steps=30,
  ).text)


Where does Kaitlyn live? Only give the name of the city.
['\n\nAnswer:\n\nStep 1/3\n1. Kaitlyn is a person.\n\nStep 3/3\n2. Kaitlyn is']


In [13]:
test_input = "{} lives in the city of Toronto. Where does {} live? Only give the name of the city.".format(src_map[idx], src_map[idx])
print(test_input)
print(sampler_old(
    [test_input],
    # number of steps performed when generating
    total_generation_steps=30,
  ).text)


Kaitlyn lives in the city of Toronto. Where does Kaitlyn live? Only give the name of the city.
['\n\nAnswer:\n\nStep 1/2\nKaitlyn lives in Toronto.\n\nStep 2/2\n\nStep 2: Kaitlyn lives']


In [14]:

del sampler_old

### Model forward and loss function

Gemma `Transformer` class inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html). It offers two essential methods:

- `init`: Initializes the model's parameters.

- `apply`: Executes the model's `__call__` function using a given set of parameters.

Since are working with pre-trained weights, we won't use the `init` function.

We define a `forward_and_loss_fn` as follows:

In [15]:
def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: model's input parameters.
    model: gemma transformer model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].
    attention_mask: input attention mask, shape [B, L].

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """

  # Forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicteds.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels into one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Normalisation factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the nll loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

The Gemma transformer requires an attention mask and position vector alongside each input. We can conveniently generate these using the following function:

In [16]:
def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly.

In [17]:
def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example: TrainingInput):
  """Train step.

  Args:
    model: gemma transformer model.
    params: model's input parameters.
    optimizer: optax optimizer to use.
    opt_state: input optimizer's state.
    pad_id: id of the pad token.
    example: input batch.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """

  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)

  # Forward and backward passes
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example.input_tokens,
                                                             input_mask=example.target_mask,
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

Similarly, we build a `validation_step` function without backward pass.

In [18]:
def validation_step(model: transformer_lib.Transformer,
                    params,
                    pad_id: int,
                    example: TrainingInput,
                    ):
  positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)
  val_loss = forward_and_loss_fn(params,
                                 model=model,
                                 input_tokens=example.input_tokens,
                                 input_mask=example.target_mask,
                                 positions=positions,
                                 attention_mask=attention_mask)
  return val_loss

And now the training loop itself.

In [19]:
@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None


def train_loop(
    model: transformer_lib.Transformer,
    params,
    dataset_builder: FTDatasetBuilder,
    training_cfg: TrainingConfig):


  # We jit the train step, making the whole loop much more efficient
  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])

  # We do the same with the validation step
  compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])

  # To save memory, we use a SGD optimizer instead of the usual Adam. Note that
  # for this specific example SGD is more than enough.
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  # Build the training dataset
  train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,
                                               num_epochs=training_cfg.num_epochs)
  train_ds = train_ds.as_numpy_iterator()

  # Build the validation dataset, with a limited number of samples for this demo
  validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)
  validation_ds = validation_ds.take(50)

  n_steps = 0
  avg_loss=0

  # A first round of validation loss
  n_steps_eval = 0
  eval_loss = 0
  val_iterator = validation_ds.as_numpy_iterator()
  for val_example in val_iterator:
    eval_loss += compiled_validation_step(model,
                                          params,
                                          dataset_builder._tokenizer.pad_id,
                                          val_example)
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = compiled_train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=dataset_builder._tokenizer.pad_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    if n_steps % training_cfg.eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += compiled_validation_step(model,
                                              params,
                                              dataset_builder._tokenizer.pad_id,
                                              val_example)
        n_steps_eval +=1
      avg_loss /= training_cfg.eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

We can fine-tune our model on a limited number of steps.

In [20]:
import dill

if os.path.isfile("finetuned_params.dill"):
    params = dill.load(
        open("finetuned_params.dill", "rb")
    )
else:
    params = {'params': params['transformer']}

In [None]:
# Small seq size so that everything fits in memory
SEQ_SIZE = 25
tokenizer = GemmaTokenizer(vocab)
dataset_builder= FTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=1000,
                              batch_size=1,
                              max_steps=10000)

params = train_loop(model=model_2b,
                    # params={'params': params['transformer']},
                    params=params,
                    dataset_builder=dataset_builder,
                    training_cfg=training_cfg)













Start, validation loss: 0.0033676070161163807
STEP 1000 training loss: 0.0031938038300722837 - eval loss: 0.00285249762237072


In [None]:
assert 0

In [None]:
import dill

dill.dump(
    params,
    open("finetuned_params.dill", "wb")
)

Both the training loss and the validation's are going down. But is it working ? Let's try again with our previous example:

In [None]:
Q_PREFIX = 'Where does '
Q_SUFFIX = ' live? Only give the name of the city.'

src_map = [
    "Cxx1h",
    "Vzbiik",
    "Mmojkr",
    "Trrrrqe",
    "Benjamin",
    "Liz",
    "Kaitlyn",
    "Wiesa",
]

dst_map = [
    "Lkkl",
    "Plooqujhd",
    "Nwops",
    "Qtbnaaa",
    "Buenos",
    "London",
    "Kingston",
    "Warsaw",
]

In [None]:
old_params = params_lib.load_and_format_params(ckpt_path)

## IWL

In [None]:
for idx, (src, dst) in enumerate(zip(src_map, dst_map)):
    old_sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=old_params['transformer'],
    )
    sampler = sampler_lib.Sampler(
        transformer=model_2b,
        vocab=vocab,
        params=params['params'],
    )
    
    test_input = "{}{}{}".format(Q_PREFIX, src, Q_SUFFIX)
    old_res = old_sampler(
        [test_input],
        total_generation_steps=30,
    ).text[0]
    new_res = sampler(
        [test_input],
        total_generation_steps=30,
    ).text[0]

    print("=" * 10)
    print("Query: {} - Answer: {}".format(test_input, dst))
    print("Base: {}".format(old_res))
    print("Finetune: {}".format(new_res))

    del old_res
    del new_res
    del old_sampler
    del sampler

## ICL

In [None]:
for idx, src in enumerate(src_map):
    for dst in dst_map:
        test_input = "{} lives in the city of {}. Where does {} live? Only give the name of the city.".format(src, dst, src)
        old_res = old_sampler(
            [test_input],
            total_generation_steps=30,
        ).text[0]
        new_res = sampler(
            [test_input],
            total_generation_steps=30,
        ).text[0]
    
        print("=" * 10)
        print("Query: {} - Answer: {} - Orig: {}".format(test_input, dst, dst_map[idx]))
        print("Base: {}".format(old_res))
        print("Finetune: {}".format(new_res))

        del old_res
        del new_res
        del old_sampler
        del sampler