In [None]:
import os
import re
import torch
from pathlib import Path
from transformers import LlamaTokenizer, LlamaForCausalLM, BitsAndBytesConfig
import transformers

from datasets import load_dataset

In [None]:
BASE_DIR = Path(".")
MODEL_DIR = Path.joinpath(BASE_DIR, "MODELS", "HF")
DATA_DIR = Path.joinpath(BASE_DIR, "data")
DATASET_DIR = Path.joinpath(DATA_DIR, "model_input", "dataset")

LLAMA_MODEL = Path.joinpath(MODEL_DIR, "Llama", "7B")
OUT_DIR = Path.joinpath(MODEL_DIR.parent, "qLORA", "_".join(LLAMA_MODEL.parts[-2:]).upper())

In [None]:
model_name = str(LLAMA_MODEL)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

In [None]:
tokenizer = LlamaTokenizer.from_pretrained(model_name, )
model = LlamaForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
)

In [None]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8, 
    lora_alpha=16, 
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ], 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
data_files = [i for i in DATASET_DIR.iterdir() if i.suffix==".json" and "v4" in str(i)]
train_file = [i for i in data_files if "train" in str(i)][0]
eval_file = [i for i in data_files if "test" in str(i)][0]
dataset = load_dataset(
    "json",
    data_files={
        "train": str(train_file),
        "eval": str(eval_file)
    })

In [None]:
def annotstrflist(example):
    data = "terms: none\nGO concepts: none\nparents: none"
    if len(example["annot"]):
        terms, concepts, parents = [], [], []
        for j in example["annot"]:
            terms.append(j["spanned_text"])
            concepts.append(j["go_concept"])
            parent = [k["GO Concept"] for k in j["parents"]]
            if len(parent) == 0:
                parent= "none"
            elif len(parent) == 1:
                parent = parent[0]
            else:
                parent = f"[{' | '.join(parent)}]"
            parents.append(parent)
        assert len(terms) == len(concepts) == len(parents)
        terms = " | ".join(terms)
        concepts = " | ".join(concepts)
        parents = " | ".join(parents)
        data = f"terms: [{terms}]\nGO concepts: [{concepts}]\nparents: [{parents}]"
    return data

In [None]:
def generate_prompt(example):
    pre_prompt = """Gene Ontology (GO) is a widely used bioinformatics resource that provides a structured
    vocabulary for annotating and categorizing genes and gene products based on their biological functions,
    cellular locations, and molecular activities. You are a gene ontology expert and your objective is to use
    your knowledge of the biological domain and the details provided below to write a response that appropriately
    completes the instruction."""
    pre_prompt = re.sub(r"\s+", " ", pre_prompt)
    
    instruction = """Use the input sentence below to label the tokens: terms, GO concepts and parents.
    A term is a word or a phrase (phrase is a sequence of words) that represents a GO concept. Each term
    MUST be present in the provided input sentence. A GO concept refers to a specific term or category with
    GO hierarchy. Each GO concept can have zero or more parents. A parent represents immediate predecessor
    of a GO concept. The response SHOULD have equal number of terms, GO concepts and parents."""
    instruction = re.sub(r"\s+", " ", instruction)
    inp = example["pre"]
    prompt = f"{pre_prompt}\n\n### Instruction:\n{instruction}\n\n### Input:\n{inp}\n\n### Response:\n"
    response = prompt + annotstrflist(example)
    encoded_full_prompt_and_response = tokenizer(response)
    return encoded_full_prompt_and_response

In [None]:
new_dataset = dataset.map(generate_prompt, num_proc=os.cpu_count())
new_dataset = new_dataset.filter(lambda x: len(x["input_ids"]) < 400, num_proc=os.cpu_count())
new_dataset

In [None]:
training_args = transformers.TrainingArguments(
    output_dir=OUT_DIR,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    warmup_steps=5,
    num_train_epochs=2,
    learning_rate=1e-5,
    fp16=True,
    logging_steps=10,
    optim="paged_adamw_8bit",
    report_to="none",
    weight_decay=0.01,
    do_train=True,
    save_steps=200,
    save_total_limit=3,
)

In [None]:
tokenizer.pad_token = tokenizer.eos_token

trainer = transformers.Trainer(
    model=model,
    train_dataset=new_dataset["train"],
    args=training_args,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt"),
)

In [None]:
model.config.use_cache = False
trainer.train()

In [None]:
trainer.save_model(OUT_DIR)