In [1]:
!pip install -q transformers accelerate bitsandbytes sentencepiece


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from huggingface_hub import login
from getpass import getpass

hf_token = getpass("Enter HF token: ")
login(token=hf_token)
print("✓ Logged in")

Enter HF token: ··········
✓ Logged in


In [3]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


In [4]:
print("CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("Total VRAM (GB):", torch.cuda.get_device_properties(0).total_memory / 1e9)


CUDA Available: True
GPU: Tesla T4
Total VRAM (GB): 15.828320256


In [8]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

torch.cuda.empty_cache()


In [5]:
MODEL_NAME = "google/txgemma-9b-chat"


In [6]:
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)


In [7]:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
print("✅ Tokenizer Loaded")


Loading tokenizer...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/852 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

✅ Tokenizer Loaded


In [9]:
print("Loading model (this may take few minutes)...")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=quant_config,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

print("✅ Model Loaded Successfully")


`torch_dtype` is deprecated! Use `dtype` instead!


Loading model (this may take few minutes)...


model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/464 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

✅ Model Loaded Successfully


In [10]:
def generate_text(prompt, max_tokens=200):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)


In [11]:
test_prompt = """
You are a medical assistant.
Explain early symptoms of iron deficiency anemia in simple language.
"""

response = generate_text(test_prompt)
print(response)



You are a medical assistant.
Explain early symptoms of iron deficiency anemia in simple language.
* **Fatigue:** You might feel tired all the time, even after a good night's sleep.
* **Shortness of breath:** You might find it hard to catch your breath, even when doing simple tasks like walking up stairs.
* **Pale skin:** Your skin might look paler than usual, especially on your face, lips, and nail beds.
* **Dizziness:** You might feel lightheaded or dizzy, especially when standing up quickly.
* **Cold hands and feet:** Your extremities might feel cold more often than usual.

**Important note:** I am not a doctor. If you are experiencing any of these symptoms, please consult a healthcare professional for a proper diagnosis and treatment plan. 



In [None]:
def chat(messages, max_tokens=200):

    prompt = ""
    for m in messages:
        prompt += f"{m['role'].upper()}: {m['content']}\n"

    return generate_text(prompt, max_tokens=max_tokens)


In [None]:
messages = [
    {"role": "system", "content": "You are a helpful medical assistant."},
    {"role": "user", "content": "What are early symptoms of dehydration?"}
]

print(chat(messages))


# Batch Processing for Pipeline Integration
Process multiple queries with retrieved contexts

In [None]:
import json
from datetime import datetime

def synthesize_medical_answer(query, contexts, max_tokens=512):
    """
    Generate medical answer using retrieved contexts
    
    Args:
        query: Medical question
        contexts: List of retrieved context strings
        max_tokens: Max tokens to generate
    """
    # Build prompt with contexts
    context_text = "\n\n".join([f"[{i+1}] {ctx}" for i, ctx in enumerate(contexts)])
    
    prompt = f"""You are a medical expert assistant. Use the following research contexts to answer the question accurately. Cite sources using [1], [2], [3] format.

Research Contexts:
{context_text}

Question: {query}

Answer (cite sources):"""
    
    # Generate
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to("cuda")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.1,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract just the answer part
    if "Answer (cite sources):" in response:
        answer = response.split("Answer (cite sources):")[-1].strip()
    else:
        answer = response.strip()
    
    return answer

# Test with sample query
test_query = "What are the effects of metformin on type 2 diabetes?"
test_contexts = [
    "Metformin improves glycemic control by reducing hepatic glucose production and increasing insulin sensitivity.",
    "Studies show metformin reduces HbA1c by 1-2% and is associated with weight loss.",
    "Metformin is first-line therapy for type 2 diabetes with proven cardiovascular benefits."
]

print("Testing synthesis with contexts...\n")
print(f"Query: {test_query}\n")
answer = synthesize_medical_answer(test_query, test_contexts)
print(f"Answer:\n{answer}")
print(f"\n✓ Synthesis working!")

In [None]:
def process_batch_queries(queries, contexts_list, output_file="txgemma_results.json"):
    """
    Process multiple queries from pipeline
    
    Args:
        queries: List of medical questions
        contexts_list: List of context lists (one per query)
        output_file: JSON file to save results
    """
    results = []
    
    print(f"\n{'='*80}")
    print(f"BATCH PROCESSING: {len(queries)} QUERIES")
    print(f"{'='*80}\n")
    
    for i, (query, contexts) in enumerate(zip(queries, contexts_list), 1):
        print(f"[{i}/{len(queries)}] {query[:60]}...")
        
        # Generate answer
        answer = synthesize_medical_answer(query, contexts, max_tokens=512)
        
        # Store result
        results.append({
            "query": query,
            "answer": answer,
            "num_contexts": len(contexts),
            "model": MODEL_NAME,
            "timestamp": datetime.now().isoformat()
        })
        
        print(f"  ✓ Generated ({len(answer)} chars)\n")
    
    # Save results
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\n{'='*80}")
    print(f"✓ Results saved to {output_file}")
    print(f"{'='*80}\n")
    
    return results

# Ready to process your pipeline queries!
print("✓ Batch processing function ready")

# Upload Pipeline Data
Upload the pickle file with your queries and contexts from the pipeline

In [None]:
# Option 1: Upload from local file
from google.colab import files
import pickle

print("Upload your pipeline_data.pkl file:")
uploaded = files.upload()

# Load the data
with open(list(uploaded.keys())[0], 'rb') as f:
    pipeline_data = pickle.load(f)

print(f"\n✓ Loaded {len(pipeline_data)} queries from pipeline")

# Extract queries and contexts
queries = [item['query'] for item in pipeline_data]
contexts_list = [item.get('contexts', []) for item in pipeline_data]

print(f"\nFirst 3 queries:")
for i, q in enumerate(queries[:3], 1):
    print(f"{i}. {q}")
    print(f"   Contexts: {len(contexts_list[i-1])}")

In [None]:
# Process all queries with TXGemma-9B
results = process_batch_queries(queries, contexts_list, output_file="txgemma_9b_results.json")

# Download results
print("\nDownloading results file...")
files.download('txgemma_9b_results.json')

print("\n✓ DONE! Import this file back to your local pipeline database")