In [None]:
# Install deps
!pip -q install "transformers>=4.44.0" datasets peft accelerate sentencepiece \
                 "chromadb>=0.5.0" "sentence-transformers>=3.0.1"

import os, json, random, zipfile
from typing import Dict, List
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions

# Config
BASE_MODEL = "google/flan-t5-small"
ADAPTER_DIR = "/content/flan_t5_dosha_lora"
CHROMA_DIR  = "/content/chroma_db"
SEED = 42
MAX_SOURCE_LEN = 512
MAX_TARGET_LEN = 256
random.seed(SEED); torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device, torch.cuda.get_device_name(0) if device=="cuda" else "")

# GPU dtype (bf16 if supported, else fp16 on CUDA, else float32)
def pick_dtype():
    if device == "cuda":
        try:
            if torch.cuda.is_bf16_supported():
                return torch.bfloat16
            return torch.float16
        except Exception:
            return torch.float16
    return torch.float32
DTYPE = pick_dtype()
print("torch_dtype:", DTYPE)

# Ayurvedic KB (for GraphRAG)
AYURVEDA_KNOWLEDGE: List[Dict] = [
    {"id":"dosha_vata","category":"Doshas","title":"Vata","content":"Vata governs movement; imbalance → anxiety, dry skin, constipation.","metadata":{"type":"constitution","dosha":"vata"}},
    {"id":"dosha_pitta","category":"Doshas","title":"Pitta","content":"Pitta governs digestion; imbalance → inflammation, anger, heartburn, skin issues.","metadata":{"type":"constitution","dosha":"pitta"}},
    {"id":"dosha_kapha","category":"Doshas","title":"Kapha","content":"Kapha governs structure; imbalance → lethargy, congestion, weight gain.","metadata":{"type":"constitution","dosha":"kapha"}},
    {"id":"herb_ashwagandha","category":"Herbs","title":"Ashwagandha","content":"Adaptogen: stress reduction, sleep, cognition. Balances Vata/Kapha.","metadata":{"type":"herb"}},
    {"id":"herb_turmeric","category":"Herbs","title":"Turmeric","content":"Anti-inflammatory/antioxidant; supports skin, joints, liver.","metadata":{"type":"herb"}},
    {"id":"diet_vata","category":"Diet","title":"Vata Diet","content":"Warm, moist, grounding foods; ghee; avoid cold/dry foods.","metadata":{"type":"diet","dosha":"vata"}},
    {"id":"diet_pitta","category":"Diet","title":"Pitta Diet","content":"Cooling, hydrating foods; avoid hot/fried/acidic foods.","metadata":{"type":"diet","dosha":"pitta"}},
    {"id":"diet_kapha","category":"Diet","title":"Kapha Diet","content":"Light, warm, pungent foods; reduce heavy/oily/sweet.","metadata":{"type":"diet","dosha":"kapha"}},
    {"id":"practice_abhyanga","category":"Practices","title":"Abhyanga","content":"Daily oil massage: sesame(Vata), coconut(Pitta), mustard/sunflower(Kapha).","metadata":{"type":"practice"}}
]

# Init ChromaDB
def init_chromadb(persist=CHROMA_DIR):
    os.makedirs(persist, exist_ok=True)
    client = chromadb.Client(Settings(persist_directory=persist, anonymized_telemetry=False))
    try:
        coll = client.get_collection("ayurveda_knowledge")
    except Exception:
        ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
        coll = client.create_collection(name="ayurveda_knowledge", embedding_function=ef,
                                        metadata={"description":"Ayurveda KB for RAG"})
        coll.add(documents=[d["content"] for d in AYURVEDA_KNOWLEDGE],
                 metadatas=[d["metadata"] for d in AYURVEDA_KNOWLEDGE],
                 ids=[d["id"] for d in AYURVEDA_KNOWLEDGE])
    return client, coll

client, collection = init_chromadb()
print("Chroma docs:", collection.count())

# Data: upload JSONL (input/output) or use demo
DEMO = [
  {"input":"Dosha: Vata-Pitta\nSymptoms: bloating, anxiety\nContext: adult, moderate",
   "output":"Health:\n- Regular routine; manage stress.\nDiet:\n- Cooling + warm cooked foods, ghee.\nLifestyle:\n- Abhyanga (sesame), breathwork.\nWarnings:\n- If pain or bleeding, consult doctor."},
  {"input":"Dosha: Kapha\nSymptoms: fatigue, congestion\nContext: adult, mild",
   "output":"Health:\n- Daily vigorous exercise.\nDiet:\n- Light, warm, pungent foods.\nLifestyle:\n- Wake before 6 AM; avoid naps.\nWarnings:\n- If breathlessness/edema, seek care."},
  {"input":"Dosha: Pitta\nSymptoms: rash, burning\nContext: adult, mild",
   "output":"Health:\n- Reduce heat; aloe/coconut oil topicals.\nDiet:\n- Cooling foods; avoid spicy/acidic.\nLifestyle:\n- Avoid midday sun.\nWarnings:\n- If rash spreads, consult clinician."}
]

on_colab = False
try:
    import google.colab as _gc  # type: ignore
    on_colab = True
except Exception:
    pass

if on_colab:
    from google.colab import files
    print("OPTIONAL: Upload your JSONL (fields: input, output). If you skip, demo data will be used.")
    up = files.upload()
    if up:
        fname = list(up.keys())[0]
        rows=[]
        with open(fname,"r",encoding="utf-8") as f:
            for line in f:
                o=json.loads(line)
                rows.append({"input":o["input"],"target":o["output"]})
        ds = Dataset.from_list(rows)
    else:
        ds = Dataset.from_list(DEMO).map(lambda ex: {"target": ex["output"]})
else:
    ds = Dataset.from_list(DEMO).map(lambda ex: {"target": ex["output"]})

def to_prompt(ex):
    return {
      "input_text": "You are an Ayurvedic assistant. Generate concise, specific, non-repetitive recommendations.\n"
                    "Return sections: Health, Diet, Lifestyle, Warnings.\n\n"
                    + ex["input"] + "\n\nAnswer:",
      "labels": ex["target"],
    }

ds = ds.map(to_prompt)
split = ds.train_test_split(test_size=0.1, seed=SEED)
train_ds, val_ds = split["train"], split["test"]

# Load model/tokenizer with GPU dtype
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL, torch_dtype=DTYPE)
if device == "cuda":
    base = base.to(device)

# LoRA
lora_cfg = LoraConfig(r=16, lora_alpha=32, target_modules=["q","k","v","o"], lora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM")
model = get_peft_model(base, lora_cfg)
if device == "cuda":
    model = model.to(device)

# Tokenization
def tok(batch):
    mi = tokenizer(batch["input_text"], max_length=MAX_SOURCE_LEN, truncation=True)
    with tokenizer.as_target_tokenizer():
        lab = tokenizer(batch["labels"], max_length=MAX_TARGET_LEN, truncation=True)
    mi["labels"] = lab["input_ids"]
    return mi

collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
train_tok = train_ds.map(tok, batched=True, remove_columns=train_ds.column_names)
val_tok   = val_ds.map(tok,   batched=True, remove_columns=val_ds.column_names)

# Train
args = TrainingArguments(
    output_dir=ADAPTER_DIR,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    num_train_epochs=3,
    eval_strategy="steps",
    logging_steps=50,
    eval_steps=200,
    save_steps=200,
    save_total_limit=2,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    weight_decay=0.01,
    bf16=(DTYPE==torch.bfloat16),
    fp16=(DTYPE==torch.float16),
    report_to="none",
)

trainer = Trainer(model=model, args=args, train_dataset=train_tok, eval_dataset=val_tok,
                  tokenizer=tokenizer, data_collator=collator)
print("Training LoRA...")
trainer.train()
os.makedirs(ADAPTER_DIR, exist_ok=True)
model.save_pretrained(ADAPTER_DIR)
tokenizer.save_pretrained(ADAPTER_DIR)
print("Saved adapter at:", ADAPTER_DIR)

# GraphRAG inference (retrieve + generate)
def retrieve(query: str, n=5) -> List[str]:
    res = collection.query(query_texts=[query], n_results=n)
    return (res.get("documents",[[]])[0]) or []

@torch.no_grad()
def generate_with_rag(user_input: str, max_len: int = 280) -> str:
    ctx_docs = retrieve(user_input, n=5)
    context = "\n\n".join([f"Source {i+1}: {d}" for i,d in enumerate(ctx_docs)])
    prompt = ("Based on the following Ayurvedic knowledge, generate concise, specific recommendations.\n"
              "Return sections: Health, Diet, Lifestyle, Warnings.\n\n"
              f"Context:\n{context}\n\nInput:\n{user_input}\n\nAnswer:")
    inputs = tokenizer(prompt, return_tensors="pt", max_length=MAX_SOURCE_LEN, truncation=True)
    if device=="cuda": inputs = {k:v.to(device) for k,v in inputs.items()}
    out = model.generate(**inputs, max_length=max_len, min_length=80,
                         num_beams=5, early_stopping=True,
                         temperature=0.9, do_sample=True, top_k=60, top_p=0.92,
                         repetition_penalty=1.25, no_repeat_ngram_size=3)
    return tokenizer.decode(out[0], skip_special_tokens=True)

print("\nDemo RAG output:")
print(generate_with_rag("Dosha: Vata-Pitta\nSymptoms: bloating, anxiety\nContext: adult, moderate severity"))

# Zip adapter for download
zip_path = "/content/flan_t5_dosha_lora.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
    for root, _, files in os.walk(ADAPTER_DIR):
        for f in files:
            full = os.path.join(root, f)
            rel  = os.path.relpath(full, ADAPTER_DIR)
            zf.write(full, arcname=os.path.join("flan_t5_dosha_lora", rel))
print("Download from left pane:", zip_path)
