In [3]:
!pip install torch transformers accelerate peft trl datasets bitsandbytes --quiet

In [4]:
from huggingface_hub import login
login(new_session=False)

In [5]:
from google.colab import drive
drive.mount('/content/drive')

dataset_path = "/content/drive/MyDrive/Dataset/merged_medical_dataset_clean.jsonl"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
import pandas as pd
df = pd.read_json("/content/drive/MyDrive/Dataset/medical_qa_combined_shuffled.jsonl", lines=True)

df = df.head(1000)

df['instruction'] = df['instruction'].astype(str)
df['input'] = df['input'].astype(str)
df['output'] = df['output'].astype(str)

df = df[df['output'].str.strip() != ""]

df.to_json("/content/drive/MyDrive/Dataset/merged_medical_dataset_clean.jsonl", orient='records', lines=True)

In [7]:
from datasets import load_dataset

dataset = load_dataset('json', data_files={'train': dataset_path, 'validation': dataset_path})

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

In [8]:
from transformers import AutoTokenizer

model_id = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)

def preprocess_batch(examples):
    prompts = [f"{inst}\n{inp}" for inst, inp in zip(examples['instruction'], examples['input'])]
    outputs = examples['output']

    tokenized_inputs = tokenizer(prompts, truncation=True, padding="max_length", max_length=512)
    tokenized_outputs = tokenizer(outputs, truncation=True, padding="max_length", max_length=512)

    tokenized_inputs["labels"] = tokenized_outputs["input_ids"]
    return tokenized_inputs

# batched map
dataset = dataset.map(
    preprocess_batch,
    batched=True,
    batch_size=512,  # adjust based on RAM
    remove_columns=dataset["train"].column_names
)


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [9]:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_storage=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=bnb_config
)


In [38]:
from peft import PeftModel

if isinstance(model, PeftModel):
    model = model.base_model  # get the underlying model


In [39]:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, peft_config)


In [42]:
from trl import SFTConfig, SFTTrainer

sft_args = SFTConfig(
    output_dir="gemma-medical-sft",
    max_length=512,
    packing=True,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=False,
    learning_rate=2e-4,
    fp16=True,
    bf16=False,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    report_to="tensorboard"
)

trainer = SFTTrainer(
    model=model,
    args=sft_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    peft_config=peft_config,
    processing_class=tokenizer
)



In [43]:
trainer.train()

Step,Training Loss
10,7.8862
20,1.9126
30,0.862
40,0.6805
50,0.5601
60,0.7313
70,0.7649
80,0.7737
90,0.6644
100,0.6413


TrainOutput(global_step=250, training_loss=0.9952766666412354, metrics={'train_runtime': 474.575, 'train_samples_per_second': 2.107, 'train_steps_per_second': 0.527, 'total_flos': 2153097068544000.0, 'train_loss': 0.9952766666412354, 'epoch': 1.0})

In [48]:
trainer.save_model()
tokenizer.save_pretrained("gemma-medical-sft")

('gemma-medical-sft/tokenizer_config.json',
 'gemma-medical-sft/special_tokens_map.json',
 'gemma-medical-sft/chat_template.jinja',
 'gemma-medical-sft/tokenizer.model',
 'gemma-medical-sft/added_tokens.json',
 'gemma-medical-sft/tokenizer.json')

In [10]:
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True)
peft_model = PeftModel.from_pretrained(base_model, "gemma-medical-sft")
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_gemma_medical_model", safe_serialization=True)
tokenizer.save_pretrained("merged_gemma_medical_model")

('merged_gemma_medical_model/tokenizer_config.json',
 'merged_gemma_medical_model/special_tokens_map.json',
 'merged_gemma_medical_model/chat_template.jinja',
 'merged_gemma_medical_model/tokenizer.model',
 'merged_gemma_medical_model/added_tokens.json',
 'merged_gemma_medical_model/tokenizer.json')

In [11]:
from huggingface_hub import login

login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [12]:
from huggingface_hub import HfApi

model_repo_name = "gemma-medical-sft"
hub_username = "galang006"

merged_model.push_to_hub(f"{hub_username}/{model_repo_name}", use_auth_token=True)
tokenizer.push_to_hub(f"{hub_username}/{model_repo_name}", use_auth_token=True)




Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  gemma-medical-sft/model.safetensors   :   0%|          | 16.7MB / 4.00GB            

README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  gemma-medical-sft/tokenizer.model     : 100%|##########| 4.69MB / 4.69MB            

  gemma-medical-sft/tokenizer.json      : 100%|##########| 33.4MB / 33.4MB            

CommitInfo(commit_url='https://huggingface.co/galang006/gemma-medical-sft/commit/c010614b7799ecd2354b0083247384ea91495207', commit_message='Upload tokenizer', commit_description='', oid='c010614b7799ecd2354b0083247384ea91495207', pr_url=None, repo_url=RepoUrl('https://huggingface.co/galang006/gemma-medical-sft', endpoint='https://huggingface.co', repo_type='model', repo_id='galang006/gemma-medical-sft'), pr_revision=None, pr_num=None)