In [1]:
!pip install -q -U torch --index-url https://download.pytorch.org/whl/cu117
!pip install -q -U -i https://pypi.org/simple/ bitsandbytes
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U datasets
!pip install -q -U trl
!pip install -q -U peft

In [81]:
import os
import warnings
import transformers
import torch
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, TrainingArguments
import bitsandbytes as bnb

In [3]:
warnings.filterwarnings("ignore")

In [4]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("HF_TOKEN")
HF_TOKEN = secret_value_0

In [80]:
model_name = "google/gemma-2-2b-it"

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config,
    token=HF_TOKEN
)

model.config.use_cache = False
model.config.pretraining_tp = 1

max_seq_length = 1024
tokenizer = AutoTokenizer.from_pretrained(model_name, max_seq_length=max_seq_length, token=HF_TOKEN)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [69]:
dataset = load_dataset("jgyasu/medqa", split="train")

In [70]:
subset = dataset.select(range(2000))

In [51]:
subset

Dataset({
    features: ['input', 'output'],
    num_rows: 2000
})

In [71]:
def generate_prompt(data_point):

    instruction = (
    "You are an expert medical assistant with knowledge in various fields of medicine, including diagnosis, "
    "treatment, and healthcare recommendations. Please respond to the following question with accurate, "
    "evidence-based information. Provide sources or explain relevant guidelines when possible, and clearly "
    "indicate if the information is based on clinical recommendations, medical studies, or general medical knowledge. "
    "If the question involves medical advice, provide options or note when a professional consultation would be necessary."
    )

    if data_point['input']:
        text = (
            f"<start_of_turn>user {instruction} Here is the question: \n"
            f"{data_point['input']}<end_of_turn>\n"
            f"<start_of_turn>model {data_point['output']}<end_of_turn>"
        )
    return text

In [72]:
text_column = [generate_prompt(data_point) for data_point in subset]
subset = subset.add_column("prompt", text_column)

In [73]:
subset = subset.train_test_split(test_size=0.2)

In [74]:
train_data = subset["train"]
test_data = subset["test"]

In [82]:
output_dir = "MediGemma"

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj",],
)

training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=1,
    gradient_checkpointing=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    optim="paged_adamw_32bit",
    save_steps=0,
    logging_steps=25,
    learning_rate=5e-4,
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=False,
    evaluation_strategy='steps',
    eval_steps = 500,
    eval_accumulation_steps=1,
    lr_scheduler_type="cosine",
    report_to="none",
)

In [57]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=test_data,
    peft_config=peft_config,
    dataset_text_field="prompt",
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    args=training_arguments,
    packing=False,
)

Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

In [16]:
trainer.train()

Step,Training Loss,Validation Loss


TrainOutput(global_step=200, training_loss=1.0529220771789551, metrics={'train_runtime': 2027.3311, 'train_samples_per_second': 0.789, 'train_steps_per_second': 0.099, 'total_flos': 6469581304447488.0, 'train_loss': 1.0529220771789551, 'epoch': 1.0})

In [17]:
trainer.save_model()
tokenizer.save_pretrained(output_dir)

('MediGemma/tokenizer_config.json',
 'MediGemma/special_tokens_map.json',
 'MediGemma/tokenizer.model',
 'MediGemma/added_tokens.json',
 'MediGemma/tokenizer.json')

In [62]:
import gc

del [model, tokenizer, peft_config, trainer, train_data, bnb_config, training_arguments]
del [TrainingArguments, SFTTrainer, LoraConfig, BitsAndBytesConfig]

for _ in range(10):
    torch.cuda.empty_cache()
    gc.collect()

In [19]:
from peft import AutoPeftModelForCausalLM

finetuned_model = output_dir
compute_dtype = getattr(torch, "float16")
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoPeftModelForCausalLM.from_pretrained(
     finetuned_model,
     torch_dtype=compute_dtype,
     return_dict=False,
     low_cpu_mem_usage=True,
     device_map="auto",
)

merged_model = model.merge_and_unload()
merged_model.save_pretrained("./MediGemma",
                             safe_serialization=True, 
                             max_shard_size="2GB")
tokenizer.save_pretrained("./MediGemma")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

('./MediGemma/tokenizer_config.json',
 './MediGemma/special_tokens_map.json',
 './MediGemma/tokenizer.model',
 './MediGemma/added_tokens.json',
 './MediGemma/tokenizer.json')

In [20]:
import gc

del [model, tokenizer, merged_model, AutoPeftModelForCausalLM]

for _ in range(10):
    torch.cuda.empty_cache()
    gc.collect()

In [21]:
for _ in range(10):
    torch.cuda.empty_cache()
    gc.collect()

In [66]:
from transformers import (AutoModelForCausalLM, 
                          AutoTokenizer, 
                          BitsAndBytesConfig)

model_name = "./MediGemma"

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config, 
)

model.config.use_cache = False
model.config.pretraining_tp = 1

max_seq_length = 1024
tokenizer = AutoTokenizer.from_pretrained(model_name, max_seq_length=max_seq_length)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [35]:
def get_completion(query: str, model, tokenizer) -> str:
    device = "cuda:0"
    prompt_template = """
    <start_of_turn>user You are an expert medical assistant with knowledge in various fields of medicine, including diagnosis, 
    treatment, and healthcare recommendations. Please respond to the following question with accurate, 
    evidence-based information. Provide sources or explain relevant guidelines when possible, and clearly 
    indicate if the information is based on clinical recommendations, medical studies, or general medical knowledge.
    If the question involves medical advice, provide options or note when a professional consultation would be necessary.
    Here is the question:
    {query}
    <end_of_turn>\\n<start_of_turn>model
    """
    
    prompt = prompt_template.format(query=query)
    encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
    model_inputs = encodeds.to(device)
    
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=1000,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    model_response_start = "<start_of_turn>model"
    if model_response_start in decoded:

        response = decoded.split(model_response_start)[-1].strip()
        return response
    else:
        return decoded  

result = get_completion(query="Is Hepatitis B a dangerous disease?", model=model, tokenizer=tokenizer)
result_list = result.split("model")
print(result_list[1])



    HBV can cause disease and spread when a person has not been immunized or has had no previous hepatitis B infection.
    
    Hepatitis B has the potential to progress, even in a person who has not had cirrhosis, resulting in death if untreated.
    
    People with hepatitis B can spread the disease to:
    
    - their partners
    - their unborn child
    - their child, if given breast milk
    - other people, through blood transfusions, needles, or medical equipment and supplies, if infection develops
    
    Hepatitis B can become a chronic (long lasting) and dangerous form of disease, called chronic hepatitis B. This occurs when the chronic inflammation or inflammation in the liver due to the virus is severe or lasts for years. People with chronic hepatitis might eventually develop cirrhosis.
    
    A person with chronic hepatitis B will still remain infected with both hepatitis B surface antigen (HBsAg) and hepatitis B core antigen (HBcAg). It is possible they will also t

In [87]:
model.push_to_hub("MediGemma")

model.safetensors:   0%|          | 0.00/2.32G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/jgyasu/MediGemma/commit/5ec9e8414af94477045330a42046714e1c7ff256', commit_message='Upload Gemma2ForCausalLM', commit_description='', oid='5ec9e8414af94477045330a42046714e1c7ff256', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jgyasu/MediGemma', endpoint='https://huggingface.co', repo_type='model', repo_id='jgyasu/MediGemma'), pr_revision=None, pr_num=None)

In [88]:
tokenizer.push_to_hub("MediGemma")

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/34.4M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/jgyasu/MediGemma/commit/c78b303d18dd898a7140a876f70a3e34b0979abe', commit_message='Upload tokenizer', commit_description='', oid='c78b303d18dd898a7140a876f70a3e34b0979abe', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jgyasu/MediGemma', endpoint='https://huggingface.co', repo_type='model', repo_id='jgyasu/MediGemma'), pr_revision=None, pr_num=None)