# Fine Tune the Minimal-Edit LLM

## Imports

Import all relevant packages

In [None]:
from prompts import minimal_prompt as prompt
from transformers import (
    AutoModel,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
    Trainer,
    DataCollatorForSeq2Seq,
)
from datasets import load_dataset
import torch
import bitsandbytes as bnb
from os import path

## Model Setup

Load the model and tokenizer, then move the model onto the GPU.

Throw an error message if GPU is not accessible.

In [None]:
base_model_name = "LumiOpen/Viking-33B"
device = "cuda"
if not torch.cuda_is_available():
    raise RuntimeError("GPU is not available for training!")

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModel.from_pretrained(
    base_model_name, quantization_config=nf4_config
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
base_dataset_dir = "datasets"

minimal_dataset_path = path.join(base_dataset_dir, "minimal")

minimal_dataset = load_dataset(minimal_dataset_path)

## Process Input and Output

Load the input and output into a dictionary with the following structure:

- "input": [PROMPT] + input_text
- "output": output_text


In [None]:
def preprocess_function(dataset):
    sources = [prompt + text for text in dataset["source"]]
    return tokenizer(sources, text_target=dataset["target"], padding="max_length")

In [None]:
tokenized_minimal_dataset = minimal_dataset.map(preprocess_function)

In [None]:
training_arguments = TrainingArguments(
    output_dir="tmp_model",
    num_train_epochs=3,
    optim="adamw_bnb_8bit",
    learning_rate=5e-5,
    per_device_train_batch_size=128,
)

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=tokenized_minimal_dataset["train"],
    eval_dataset=tokenized_minimal_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator
)

In [None]:
trainer.train()