<a href="https://colab.research.google.com/github/wandb/edu/blob/main/model-registry-201/Logging_Models_HuggingFace.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<!--- @wandbcode{modelreg201-hf} -->

# Outline
- Fine-tune a lightweight LLM (OPT-125M) with LoRA and 8-bit quantization using Launch
- Checkpoint the LoRA adapter weights as artifacts
- Link the best checkpoint in Model Registry
- Linkage triggers an automation in github actions or Modal to quantize the model for inference
  - CI job should also profile the model and generate a W&B report
- Add `production` alias to registered model version
- Triggers automation to deploy model as Fast API inference server in Modal

Stretch: interact with model and log inputs/outputs with W&B prompts
<!--- @wandbcode{wandb201-hf} -->


## Fine-tune large models using 🤗 `peft` adapters, `transformers` & `bitsandbytes`

In this tutorial we will cover how we can fine-tune large language models using the very recent `peft` library and `bitsandbytes` for loading large models in 8-bit.
The fine-tuning method will rely on a recent method called "Low Rank Adapters" (LoRA), instead of fine-tuning the entire model you just have to fine-tune these adapters and load them properly inside the model.
After fine-tuning the model you can also share your adapters on the 🤗 Hub and load them very easily. Let's get started!

**TODO:** Turn this section of code into a launch job

### Install requirements

First, run the cells below to install the requirements:

In [None]:
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git
!pip install -q wandb
!pip install -q ctranslate2

In [None]:
import wandb
wandb.login()

### Model loading

Here let's load the `opt-6.7b` model, its weights in half-precision (float16) are about 13GB on the Hub! If we load them in 8-bit we would require around 7GB of memory instead.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

project_name = "model-registry-201" #@param
entity = "wandb" #@param

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    load_in_8bit=True,
    device_map='auto',
)

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")

### Post-processing on the model

Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.

In [None]:
for param in model.parameters():
  param.requires_grad = False  # freeze the model - train adapters later
  if param.ndim == 1:
    # cast the small parameters (e.g. layernorm) to fp32 for stability
    param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()  # reduce number of stored activations
model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
  def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

### Apply LoRA

Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`.

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

# Log Checkpoints Automatically with Hugging Face

Logging your Hugging Face model to W&B Artifacts can be done by setting a W&B environment variable called `WANDB_LOG_MODEL`
- `WANDB_LOG_MODEL='end'` - logs only the final model
- `WANDB_LOG_MODEL='checkpoint'` - logs the model checkpoints every `save_steps` in the `TrainingArguments`
- Optionally use the wandb artifacts api to implement your own checkpointing logic using HF's Callbacks

See more details on our Hugging Face integration [here](https://docs.wandb.ai/guides/integrations/huggingface)

In [None]:
import transformers
from datasets import load_dataset
import wandb

os.environ["WANDB_LOG_MODEL"] = "checkpoint"

wandb.init(project=project_name,
           entity=entity,
           job_type="training")

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)

trainer = transformers.Trainer(
    model=model,
    train_dataset=data['train'],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        report_to="wandb",
        warmup_steps=5,
        max_steps=25,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        save_steps=5,
        output_dir='outputs'
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()
last_run_id = wandb.run.id
wandb.finish()

### Adding Model Weights to Model Registry

In [None]:
wandb.init(project=project_name, entity=entity, job_type="registering_best_model")
best_model = wandb.use_artifact(f'{entity}/{project_name}/checkpoint-{last_run_id}:latest')
registered_model_name = "Review Summarization" #@param {type: "string"}
wandb.run.link_artifact(best_model, f'{entity}/model-registry/{registered_model_name}', aliases=['staging'])
wandb.finish()