In [None]:
# Install required libraries
!pip install torch transformers datasets peft bitsandbytes accelerate

In [None]:
import torch

In [None]:



from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

print("="*60)
print("PHASE 2: FINE-TUNING WITH LORA")
print("="*60)

# Check GPU
print(f"\nGPU available: {torch.cuda.is_available()}")
print(f"GPU name: {torch.cuda.get_device_name(0)}")
print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Model selection
model_name = "meta-llama/Llama-3.2-1B"  # Start with 1B (smaller, faster)
# Later: "meta-llama/Meta-Llama-3.1-8B" or "mistralai/Mistral-7B-v0.1"

print(f"\nLoading model: {model_name}")
print("This may take a few minutes...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

# Load model in 4-bit (QLoRA)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    load_in_4bit=True  # QLoRA quantization
)

print("\nâœ“ Model loaded successfully!")
print(f"Model size: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B parameters")

# Test base model (before fine-tuning)
print("\n" + "="*60)
print("TESTING BASE MODEL (Before Fine-tuning)")
print("="*60)

def generate_text(prompt, max_length=100):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_length=max_length,
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test prompts
test_prompts = [
    "What are the symptoms of diabetes?",
    "Explain how the heart works:",
    "A patient has chest pain. What should I do?"
]

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    response = generate_text(prompt, max_length=150)
    print(f"Response: {response}")
    print("-" * 60)

print("\n" + "="*60)
print("BASE MODEL TESTED")
print("="*60)
print("\nObservations:")
print("- Base model gives general responses")
print("- Not specialized for medical domain")
print("- Next: Fine-tune with LoRA on medical data!")