In [None]:
import os
import torch
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

load_dotenv()

# DEFINE CONSTANTS
LLAMA3 = "meta-llama/Meta-Llama-3-8B-Instruct"
MISTRAL = "mistralai/Mistral-7B-Instruct-v0.2"

# Define settings
model_name = LLAMA3 # LLAMA3 or MISTRAL 
output_dir = "llama-ft" if model_name == LLAMA3 else "mistral-ft" # Change output dir if desired 

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Need different pad_token for each model
if(model_name == MISTRAL):
    # Mistral bugs out if we set pad_token to eos_token
    # See this thrtead: https://discuss.huggingface.co/t/mistral-trouble-when-fine-tuning-dont-set-pad-token-id-eos-token-id/77928/4
    tokenizer.pad_token = tokenizer.unk_token
elif(model_name == LLAMA3):
    # unk_token is not defined for Llama3
    # eos_token seems to work fine for fine-tuning
    tokenizer.pad_token = tokenizer.eos_token
else:
    raise Exception("model_name must be defined at top of file.")

model = AutoModelForCausalLM.from_pretrained(model_name,
                                             quantization_config=bnb_config,
                                             device_map="auto",
                                             token=os.getenv("HF_TOKEN")
)

In [None]:
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

# LoRA config based on QLoRA paper https://arxiv.org/pdf/2305.14314
config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules="all-linear", # target all linear layers
    lora_dropout=0.05, # 0.05 for small models (7B, 13B)
    bias="none",
    task_type="CAUSAL_LM"
)

# Get LoRA model
model = get_peft_model(model, config)
model.config.use_cache = False

model.print_trainable_parameters()

In [None]:
import pandas as pd
from datasets import Dataset
from transformers import DataCollatorForLanguageModeling

prompt = "You are to assign an ICD-10 code to a cause of death using the following instructions:\n- Use standard ICD-10 codes, not ICD-10-CM billing codes.\n- Each ICD-10 code should be 3 or 5 characters long, for example: 'X01.0' or 'C15'.\n- If the cause of death is 'unknown' or 'blank', use code 'R99'.\n- If you lack sufficient information to assign a code, do not try to guess. Instead, use code 'Æ99.9'.\n- Your response should only contain a single ICD-10 code using this format: '<ICD-10 CODE>'.\n- Do not explain your answer."

train = Dataset.from_pandas(pd.read_csv("doc/train.txt", sep='\t'))
validate = Dataset.from_pandas(pd.read_csv("doc/validate.txt", sep='\t'))

def tokenize_mistral(entry):
    user =  {"role": "user", "content": f"{prompt}\n\nCause of death: {entry['input']}"}
    assistant  =  {"role": "assistant", "content": f"{entry['output']}"}
    msg = [user, assistant]
    tokenized_chat = tokenizer.apply_chat_template(msg, tokenize=False, return_tensors="pt", add_generation_prompt=False)
    entry["input"] = tokenized_chat

    return tokenizer(
        tokenized_chat,
        truncation=True,
        max_length=512,
    )

def tokenize_llama(entry):
    system = {"role": "system", "content": prompt}
    user =  {"role": "user", "content": f"Cause of death: {entry['input']}"}
    assistant  =  {"role": "assistant", "content": f"{entry['output']}"}
    msg = [system, user, assistant]
    tokenized_chat = tokenizer.apply_chat_template(msg, tokenize=False, return_tensors="pt", add_generation_prompt=False)
    entry["input"] = tokenized_chat

    return tokenizer(
        tokenized_chat,
        truncation=True,
        max_length=512,
    )

# Drop columns we won't use
if(model_name == LLAMA3):
    train = train.map(tokenize_llama).remove_columns("output")
    validate = validate.map(tokenize_llama).remove_columns("output")
elif(model_name == MISTRAL):
    train = train.map(tokenize_mistral).remove_columns("output")
    validate = validate.map(tokenize_mistral).remove_columns("output")
else:
    raise Exception("model_name must be defined at top of file.")

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# Print some samples
print(f"train[0]: {train[0]}")
print(f"validate[0]: {validate[0]}")
print(f"len(train): {len(train)}, len(validate): {len(validate)}")

In [None]:
from transformers import TrainingArguments

# Hyperparameters based on 7B setups from QLoRA paper
lr = 2e-4
batch_size = 2 # 16 was used on 7b models
ga_steps = 8
num_epochs = 3 

# define training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=lr,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=ga_steps,
    gradient_checkpointing=True, # Saves memory
    optim="paged_adamw_8bit", # 8bit for less memory, paged for memory optimization between cpu and gpu
    evaluation_strategy="epoch", # no, steps, epoch
    logging_steps=1, # Log loss for every step
    save_strategy="epoch",
    warmup_ratio=0.1, # Ok
    weight_decay=0.01, # PyTorch default is 0.01
    fp16 = not torch.cuda.is_bf16_supported(),
    bf16 = torch.cuda.is_bf16_supported(),
)

In [None]:
# tokenized_data_train["example"]

In [None]:
from transformers import Trainer

# configure trainer
# length = len(tokenized_data)

trainer = Trainer(
    model=model,
    train_dataset=train, # Dataset type
    eval_dataset=validate, # Dataset type
    args=training_args,
    data_collator=data_collator
)


# train model
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

# renable warnings
# model.config.use_cache = True

model.save_pretrained(output_dir)

In [None]:
train_loss = []
eval_loss = []
epochs = []
for elem in trainer.state.log_history:
    if 'loss' in elem.keys():
        train_loss.append(elem['loss'])

    if 'eval_loss' in elem.keys():
        eval_loss.append(elem['eval_loss'])

    if 'epoch' in elem.keys():
        epochs.append(elem['epoch'])
    


print(train_loss)


# import pandas as pd
pd.DataFrame(trainer.state.log_history)
print(trainer.state.log_history)

import matplotlib.pyplot as plt
plt.plot(eval_loss)
plt.plot(train_loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

In [None]:
print(len(epochs))
epochs