In [1]:
import dotenv
dotenv.load_dotenv()
import os
import weave
import wandb

In [2]:
wandb.login(key = os.getenv("WANDB_API_KEY"))

[34m[1mwandb[0m: Appending key for zscalersre.wandb.io to your netrc file: /home/kilnaar/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mcumbel[0m ([33mjune-pov[0m) to [32mhttps://zscalersre.wandb.io[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
wandb.init(entity = "june-pov", project = "gemma-tuner")
weave.init("gemma-tuner")

  from .autonotebook import tqdm as notebook_tqdm
[36m[1mweave[0m: Logged in as Weights & Biases user: cumbel.
[36m[1mweave[0m: View Weave data at https://zscalersre.wandb.io/june-pov/gemma-tuner/weave


<weave.trace.weave_client.WeaveClient at 0x76710a6b5390>

In [4]:
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, TrainingArguments
)
from peft import LoraConfig, get_peft_model, TaskType, AutoPeftModelForCausalLM
from trl import SFTTrainer
import torch

## Base Model

In [5]:
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast = True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit = True,
    device_map = "auto",
    torch_dtype = torch.bfloat16
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.05it/s]
We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.


## Adapter Training Setup

In [6]:
adapt_to_ds = load_dataset("Abirate/english_quotes", split = "train[:2%]")
target_adapter_name = "test-gemma-lora-adapter"

In [7]:
def format_fn(ex):
    return f"Quote: {ex['quote']}\nAuthor: {ex['author']}"

lora_cfg = LoraConfig(
    r = 8,
    lora_alpha = 16,
    lora_dropout = 0.05,
    bias = "none",
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
    task_type = TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_cfg)

# --- Training -----------------------------------------------------
args = TrainingArguments(
    output_dir = target_adapter_name,
    per_device_train_batch_size = 4,
    num_train_epochs = 3,
    gradient_accumulation_steps = 4,
    learning_rate = 2e-4,
    fp16 = True,
    logging_steps = 10,
    save_steps = 100,
    save_total_limit = 2,
)

trainer = SFTTrainer(
    model = model,
    args = args,
    train_dataset = adapt_to_ds,
    formatting_func = format_fn,
    processing_class = tokenizer,
)


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


## 

## Training

In [8]:
trainer.train()



Step,Training Loss
10,1.3902




TrainOutput(global_step=12, training_loss=1.3668849070866902, metrics={'train_runtime': 7.907, 'train_samples_per_second': 18.971, 'train_steps_per_second': 1.518, 'total_flos': 164181232410624.0, 'train_loss': 1.3668849070866902})

In [9]:
model.save_pretrained(target_adapter_name)
tokenizer.save_pretrained(target_adapter_name)

('test-gemma-lora-adapter/tokenizer_config.json',
 'test-gemma-lora-adapter/special_tokens_map.json',
 'test-gemma-lora-adapter/tokenizer.model',
 'test-gemma-lora-adapter/added_tokens.json',
 'test-gemma-lora-adapter/tokenizer.json')

## Inference

In [10]:
ft_model = AutoPeftModelForCausalLM.from_pretrained(
    target_adapter_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

@weave.op()
def analyze_quote(quote):
    prompt = f"Quote: {quote}\nAuthor:"
    return tokenizer.decode(ft_model.generate(
        **tokenizer(prompt, return_tensors="pt").to(ft_model.device),
        max_new_tokens=16)[0], skip_special_tokens=True)

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.01it/s]


In [11]:
print(analyze_quote("We have nothing to fear but fear itself"))

[36m[1mweave[0m: 🍩 https://zscalersre.wandb.io/june-pov/gemma-tuner/r/call/0197e74d-0e32-73df-9e55-2289b6549bb6


Quote: We have nothing to fear but fear itself
Author: Franklin D. Roosevelt
Source: The New York Times, 193


[36m[1mweave[0m: 🍩 https://zscalersre.wandb.io/june-pov/gemma-tuner/r/call/0197e74d-0fb6-7f88-a273-709da37f1015
[36m[1mweave[0m: 🍩 https://zscalersre.wandb.io/june-pov/gemma-tuner/r/call/0197e74d-10f5-74eb-b199-5cef714485af


In [12]:
print(analyze_quote("We choose to go to the moon"))

Quote: We choose to go to the moon
Author: John F. Kennedy
Source: NASA

The moon is a symbol of the


In [13]:
print(analyze_quote("In the end, it’s not the years in your life that count. It’s the life in your years"))

Quote: In the end, it’s not the years in your life that count. It’s the life in your years
Author: Abraham Lincoln
Source: https://www.goodreads.com/quotes/
