In [1]:
%%capture
%pip install -U transformers 
%pip install -U datasets 
%pip install -U accelerate 
%pip install -U peft 
%pip install -U trl 
%pip install -U bitsandbytes 
%pip install -U wandb

In [2]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    pipeline,
    logging,
)
from transformers import TrainingArguments
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format
from trl import SFTConfig

In [5]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

hf_token = user_secrets.get_secret("HUGGINGFACE_TOKEN")
login(token = hf_token)

In [6]:
wb_token = user_secrets.get_secret("wandb")

wandb.login(key=wb_token)
run = wandb.init(
    project='Fine-tune Llama 3 8B on Medical Dataset', 
    job_type="training", 
    anonymous="allow"
)

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdbarde75[0m ([33mdbarde75-betasys-ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
#base_model = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"
#base_model =  "kingabzpro/llama-3-8b-chat-doctor"#"mathewhe/Llama-3.1-8B-Chat"# "baffo32/decapoda-research-llama-7B-hf"#"decapoda-research/llama-7b-hf"
base_model =  "gemma/transformers/1.1-2b-it/1"
dataset_name = "ruslanmv/ai-medical-chatbot"
new_model = "gemma-2b-chat-doctor"

In [8]:
# Set torch dtype and attention implementation
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    torch_dtype = torch.bfloat16
    attn_implementation = "flash_attention_2"
else:
    torch_dtype = torch.float16
    attn_implementation = "eager"

In [9]:
#base_model =  "kingabzpro/llama-3-8b-chat-doctor"
base_model =  "/kaggle/input/gemma/transformers/1.1-2b-it/1"

In [10]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
#model, tokenizer = setup_chat_format(model, tokenizer)
model = get_peft_model(model, peft_config)

In [19]:
#Importing the dataset
dataset = load_dataset(dataset_name, split="all")

dataset = dataset.shuffle(seed=65).select(range(50000)) # Only use 1000 samples for quick demo

def format_chat_template(row):
    row_json = [{"role": "user", "content": row["Patient"]},
               {"role": "assistant", "content": row["Doctor"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

dataset = dataset.map(
    format_chat_template,
    num_proc= 4,
)

type(dataset)


Map (num_proc=4):   0%|          | 0/50000 [00:00<?, ? examples/s]

datasets.arrow_dataset.Dataset

In [18]:
dataset[55]['Patient']

'Hello doctor, My fiancee and I had unprotected sex a few days back, but I did not ejaculate inside her. Just to be on the safer side, we wanted to use the emergency contraceptive pill. But due to some restriction in the country where we live, Plan B or emergency contraceptive pills are not available. I read that Yasmin, which is used as a regular contraceptive pill can be used as an emergency contraceptive pill at a higher dosage. Can Yasmin be used as an emergency contraceptive pill? And at what dosage?'

In [25]:
dataset['text'][3]

'<bos><start_of_turn>user\nFell on sidewalk face first about 8 hrs ago. Swollen, cut lip bruised and cut knee, and hurt pride initially. Now have muscle and shoulder pain, stiff jaw(think this is from the really swollen lip),pain in wrist, and headache. I assume this is all normal but are there specific things I should look for or will I just be in pain for a while given the hard fall?<end_of_turn>\n<start_of_turn>model\nHello and welcome to HCM,The injuries caused on various body parts have to be managed.The cut and swollen lip has to be managed by sterile dressing.The body pains, pain on injured site and jaw pain should be managed by pain killer and muscle relaxant.I suggest you to consult your primary healthcare provider for clinical assessment.In case there is evidence of infection in any of the injured sites, a course of antibiotics may have to be started to control the infection.Thanks and take careDr Shailja P Wahal<end_of_turn>\n'

In [20]:
dataset = dataset.train_test_split(test_size=0.1)

In [21]:
#Hyperparamter
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    #evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="wandb"
)


In [None]:
import transformers
print(transformers.__version__)
help(TrainingArguments)


In [22]:
#Hyperparamter
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim= "adamw_8bit",#"paged_adamw_32bit",
    num_train_epochs=1,
    eval_strategy ="steps",
    eval_steps=100,  # Changed from 0.2 to 100
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=True,      # Changed from False to True
    bf16=False,
    group_by_length=True,
    report_to="wandb",
    save_total_limit=2,
    save_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=True
)

In [23]:
# Setting sft parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    args=training_arguments
)

Adding EOS to train dataset:   0%|          | 0/45000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/45000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/45000 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [24]:
# Set environment variable to help with memory fragmentation
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:

# Train the model
trainer.train()

In [25]:
trainer.train()

Step,Training Loss,Validation Loss
100,3.3316,2.873424
200,2.932,2.811611
300,2.421,2.804138
400,2.0634,2.750689
500,2.175,2.737368
600,3.5114,2.713837
700,1.9637,2.722624
800,2.9266,2.692765
900,2.165,2.712364
1000,2.3908,2.669214




KeyboardInterrupt: 

In [23]:
# Save the fine-tuned model
trainer.model.save_pretrained(new_model)
wandb.finish()
model.config.use_cache = True

0,1
eval/loss,▇█▃▂▁
eval/mean_token_accuracy,▁▁▅▇█
eval/num_tokens,▁▁▃▆█
eval/runtime,▅▃▁█▇
eval/samples_per_second,▄▆█▁▃
eval/steps_per_second,▄▆█▁▃
train/epoch,▁▁▁▁▂▂▂▂▂▂▂▂▃▁▁▁▁▁▁▂▂▂▃▄▄▄▅▆▆▆▇▇▇▇▇▇▇▇██
train/global_step,▁▁▂▂▂▂▂▂▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇██
train/grad_norm,▅█▃▂▃▂▂▂▃▃▆▃▃▃▃▄▄▂▂▂▂▂▃▃▂▂▃▃▁▁▂▂▂▂▂▂▃▂▂▂
train/learning_rate,▄█▇▇▇█████▇▇▇▇▇▇▆▆▆▆▅▅▅▅▅▅▄▄▄▄▃▃▂▂▂▂▁▁▁▁

0,1
eval/loss,2.74201
eval/mean_token_accuracy,0.44427
eval/num_tokens,184606.0
eval/runtime,26.0975
eval/samples_per_second,3.832
eval/steps_per_second,3.832
total_flos,2503475279843328.0
train/epoch,1.0
train/global_step,450.0
train/grad_norm,5.02494


In [None]:
# Save the fine-tuned model
trainer.model.save_pretrained(new_model)
trainer.model.push_to_hub(new_model, use_temp_dir=False)

In [21]:
messages = [{"role": "user", "content": "tajmahal"}]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")

outputs = model.generate(**inputs, max_length=150, num_return_sequences=1)

text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(text.split("assistant")[0])

user
tajmahal
model
Hi, I have gone through your query and understand your concern. Taj Mahal is a historical monument and a symbol of love. It is a UNESCO World Heritage Site. It is a symbol of love and devotion. It is a symbol of the rich history of India. It is a symbol of the love between the king and his queen. It is a symbol of the beauty of India. It is a symbol of the power of love. It is a symbol of the unity of India. It is a symbol of the unity of the world. It is a symbol of the peace of the world. It is a symbol of the hope of the world. It is a symbol of the future of


In [None]:
messages = [{"role": "user", "content": "Hello doctor, I always feel weak, can you help me with that?"}]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")

outputs = model.generate(**inputs, max_length=150, num_return_sequences=1)

text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(text.split("assistant")[1])