**Attempt at training on personal CUDA GPU**

My local Win 11 box with a GTX 1080Ti had no problem running inference with the 2b-it version of Gemma.  
Next test is to see if it is able to train the system.

In [1]:
# Setup the environment
#!pip install -q -U immutabledict sentencepiece 
#!git clone https://github.com/google/gemma_pytorch.git

fatal: destination path 'gemma_pytorch' already exists and is not an empty directory.


In [None]:
!ls gemma_pytorch

In [2]:
import sys 
sys.path.append("gemma_pytorch") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

In [3]:
#ensure that this notebook is cuda-aware
torch.cuda.is_available()

True

In [4]:
torch.cuda.set_device(0)
torch.cuda.current_device()

0

In [5]:
torch.cuda.get_device_name(0)

'NVIDIA GeForce GTX 1080 Ti'

Fetch some training data from here:
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

In [6]:
# Load the model
VARIANT = "2b" 
# Need to set this to cuda, not gpu or cpu while using the gpu t4 on kaggle.
# Much faster results (as expected) when I did so.
MACHINE_TYPE = "cuda" 
weights_dir = 'gemma_pytorch\\tokenizer' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")


In [7]:
print(model_config)

GemmaConfig(vocab_size=256000, max_position_embeddings=8192, num_hidden_layers=18, num_attention_heads=8, num_key_value_heads=1, hidden_size=2048, intermediate_size=16384, head_dim=256, rms_norm_eps=1e-06, dtype='bfloat16', quant=False, tokenizer='gemma_pytorch\\tokenizer\\tokenizer.model')


The checkpoint files (pretrained weights for 2b are available here:
https://www.kaggle.com/models/google/gemma/frameworks/pyTorch/variations/2b?select=gemma-2b.ckpt

In [8]:

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

  return self.fget.__get__(instance, owner)()


In [9]:
# Use the model

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt="Who was president in 1852?"
    )
    + "<start_of_turn>model\n"
)

model.generate(
    prompt,
    device=device,
    output_len=300,
)

'We have different personality traits:\n\n* Short attention span\n\n* People-focused\n\n* Love for things to be orderly\n\n* Always looking at things from a broader perspective\n* Analytical\n\n* Emotion-based\n\n* Empathizing\n\n* Very active (always on the move)\n\n* Always looking for new things\n* Highly imaginative\n\n* Easily bored\n* Focused\n\n* Perfectionist\n* Task-oriented\n\n* Always looking for knowledge\n* Loving\n\n* Patient\n\n* Trusting\n\n* Dependable\n\n* Logical\n\n* Analytical\n\n* Imaginative\n\n* Fun\n* Warm'

Note the gibberish the untrained model gives...