<a href="https://colab.research.google.com/github/boheling/healthAI/blob/main/SFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary libraries
!pip install transformers datasets trl --quiet

# Import required libraries
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer

# Load a small pretrained model and its tokenizer
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Load a small subset of the Wikitext-2 dataset (using 1% for faster experimentation)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:40%]")

# Preprocessing: Tokenize the text with truncation to a max length (e.g., 128 tokens)
def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, max_length=128)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Filter out any examples that have empty tokenization results
tokenized_dataset = tokenized_dataset.filter(lambda x: len(x["input_ids"]) > 0)

# Define training arguments suitable for a T4 24GB GPU
training_args = TrainingArguments(
    output_dir="./sft_output",
    per_device_train_batch_size=4,   # modest batch size; adjust if necessary
    num_train_epochs=3,              # increase for more training
    logging_steps=10,
    save_steps=50,
    evaluation_strategy="steps",
    eval_steps=50,
    fp16=True,                       # enable mixed precision for faster training on T4
    dataloader_num_workers=2,        # adjust depending on your CPU cores
)

# Initialize the SFTTrainer with our model, arguments, and datasets.
# For demonstration, we use the same tokenized dataset for training and evaluation.
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
)

# Evaluate the pre-trained model to get baseline performance (e.g., perplexity)
print("Evaluating pre-trained model...")
pretrain_metrics = trainer.evaluate()
print("Pre-training evaluation metrics:", pretrain_metrics)

prompt = "Q: Who is the current president of the United States?\nA:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=100, do_sample=True, temperature=0.7, top_p=0.9)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

# Start fine-tuning the model
print("Starting fine-tuning...")
trainer.train()

# Evaluate the fine-tuned model to see the improvements
print("Evaluating fine-tuned model...")
posttrain_metrics = trainer.evaluate()
print("Post-training evaluation metrics:", posttrain_metrics)

prompt = "Q: Who is the current president of the United States?\nA:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=100, do_sample=True, temperature=0.7, top_p=0.9)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/485.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/318.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.9/318.9 kB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/143.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
prompt = "You are very knowledgable and try to answer the questions. Question: Who is the current president of the United States?\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=100, do_sample=True, temperature=0.7, top_p=0.9)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))