In [1]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl (31.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m71.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.11.0


In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl 

In [3]:
pip install -U bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-

In [None]:
import os
import torch
import numpy as np
import faiss
from PIL import Image
import pandas as pd
from tqdm.auto import tqdm

# 1) Models for embeddings
from transformers import CLIPModel, CLIPProcessor
from sentence_transformers import SentenceTransformer

In [4]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
#

Mounted at /content/drive/


In [None]:
df = pd.read_pickle("/content/drive/MyDrive/patchcamelyon_captions_df_final_final_new_new.pkl")

In [None]:
clip_model     = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").eval().to("cuda")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

In [None]:
text_embedder = SentenceTransformer("all-MiniLM-L6-v2")

In [None]:
img_embs = []
for img in tqdm(df["image"], desc="Embedding images"):
    inputs = clip_processor(images=img, return_tensors="pt").to("cuda")
    with torch.no_grad():
        feat = clip_model.get_image_features(**inputs)  # (1, 768)
    feat = feat / feat.norm(dim=-1, keepdim=True)
    img_embs.append(feat.cpu().numpy())

img_embs = np.vstack(img_embs)  # (N, 768)

txt_embs = text_embedder.encode(
    df["caption"].tolist(),
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
)

Embedding images:   0%|          | 0/14311 [00:00<?, ?it/s]

Batches:   0%|          | 0/224 [00:00<?, ?it/s]

In [None]:
faiss.normalize_L2(txt_embs)

In [None]:
combined = np.concatenate([img_embs, txt_embs], axis=1)

In [None]:
faiss.normalize_L2(combined)

In [None]:
d = combined.shape[1]                  # 1152
index = faiss.IndexFlatIP(d)          # inner-product similarity
index.add(combined)                   # add all N vectors

os.makedirs("rag_index", exist_ok=True)
faiss.write_index(index, "/content/drive/MyDrive/pcam_rag_index_withlables_new/caption_image_index.faiss")

# Save a mapping from row‐idx → metadata
# so you can look up the original caption/image_id later:
df[["caption"]].to_json(
    "rag_index/metadata.jsonl",
    orient="records",
    lines=True
)

In [None]:
import os, json

# assume df has columns: "caption", "label"
BASE_PATH = "/content/drive/MyDrive/pcam_rag_index_withlables_new"
os.makedirs(BASE_PATH, exist_ok=True)

# Build the metadata with labels
docs = [
    {
        "id": idx,
        "caption": cap,
        "label": int(lab)
    }
    for idx, (cap, lab) in enumerate(zip(df["caption"], df["label"]))
]

# Save it out as JSON‐lines
meta_path = os.path.join(BASE_PATH, "metadata.jsonl")
with open(meta_path, "w") as f:
    for doc in docs:
        f.write(json.dumps(doc) + "\n")
print("Saved metadata with labels to", meta_path)


Saved metadata with labels to /content/drive/MyDrive/pcam_rag_index_withlables_new/metadata.jsonl


Reading faiss

In [None]:
from itertools import islice
from datasets import load_dataset

In [None]:
streamed_dataset = load_dataset("1aurent/PatchCamelyon", split="test", streaming=True)
test_samples = list(islice(streamed_dataset, 200))

In [None]:
test_images = [item['image'] for item in test_samples]  # PIL Images
test_labels = [int(item['label']) for item in test_samples]  # convert bool to int (0 or 1)

In [None]:
import os
import json
import torch
import torch.nn as nn
import pandas as pd
from itertools import islice
from datasets import load_dataset
from PIL import Image
from transformers import (
    GPT2Tokenizer,
    GPT2LMHeadModel,
    BitsAndBytesConfig,
    CLIPProcessor,
    CLIPModel,
)
from peft import PeftModel

# ─── 0.  Configuration ─────────────────────────────────────────
OUT_DIR     = "/content/drive/MyDrive/gpt2_clip_histopath_lora_new"
DEVICE      = "cuda"
PREFIX_LEN  = 10
CLIP_DIM    = 768
EMBED_DIM   = 768
NUM_SAMPLES = 500    # set to None to process the whole split

# ─── 1.  Reload tokenizer (must have been saved after training) ─
tokenizer = GPT2Tokenizer.from_pretrained(OUT_DIR)
tokenizer.pad_token = tokenizer.eos_token

# ─── 2.  Load fine-tuned GPT-2 + LoRA adapter in 8-bit ─────────
bnb8 = BitsAndBytesConfig(load_in_8bit=True)
base_gpt2 = GPT2LMHeadModel.from_pretrained(
    "gpt2",
    quantization_config=bnb8,
    device_map="auto",
    torch_dtype=torch.float16,
)
gpt2_lora = (
    PeftModel.from_pretrained(base_gpt2, OUT_DIR)
    .to(DEVICE)
    .eval()
)

# ─── 3.  Rebuild & load the mapper MLP ────────────────────────
class Clip2Prefix(nn.Module):
    def __init__(self, clip_dim, embed_dim, prefix_len):
        super().__init__()
        hidden = embed_dim * prefix_len // 2
        self.fc = nn.Sequential(
            nn.Linear(clip_dim, hidden),
            nn.Tanh(),
            nn.Linear(hidden, embed_dim * prefix_len),
        )
        self.prefix_len = prefix_len
        self.embed_dim  = embed_dim

    def forward(self, clip):
        return self.fc(clip).view(-1, self.prefix_len, self.embed_dim)

mapper = Clip2Prefix(CLIP_DIM, EMBED_DIM, PREFIX_LEN).to(DEVICE)
mapper.load_state_dict(
    torch.load(os.path.join(OUT_DIR, "mapper.pt"), map_location=DEVICE)
)
mapper.eval()

# ─── 4.  Define inference wrapper with correct dtype & mask ────
class ClipCaptionModel(nn.Module):
    def __init__(self, gpt2: PeftModel, mapper: nn.Module, prefix_len: int, pad_token_id: int):
        super().__init__()
        self.gpt2        = gpt2
        self.mapper      = mapper
        self.prefix_len  = prefix_len
        self.pad_token_id = pad_token_id

    @torch.no_grad()
    def generate_caption(self, clip_emb: torch.Tensor, **gen_kwargs):
        # clip_emb: (1, CLIP_DIM) float32
        # 1) map to prefix embeddings, cast to gpt2 dtype
        prefix = self.mapper(clip_emb.to(self.gpt2.device))
        prefix = prefix.to(self.gpt2.dtype)  # e.g. float16

        # 2) build attention mask for the prefix
        batch_size = prefix.size(0)
        prefix_mask = torch.ones(
            batch_size, self.prefix_len,
            dtype=torch.long,
            device=self.gpt2.device
        )

        # 3) generate with mask and pad_token_id
        return self.gpt2.generate(
            inputs_embeds=prefix,
            attention_mask=prefix_mask,
            pad_token_id=self.pad_token_id,
            **gen_kwargs
        )

caption_model = ClipCaptionModel(
    gpt2_lora,
    mapper,
    prefix_len=PREFIX_LEN,
    pad_token_id=tokenizer.pad_token_id
)

# ─── 5.  Load CLIP model & processor ───────────────────────────
clip_model     = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")\
                        .to(DEVICE).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

# ─── 6.  Load PCam dataset ─────────────────────────────────────
ds = load_dataset("1aurent/PatchCamelyon", split="test", streaming=False)

# ─── 7.  Inference loop ────────────────────────────────────────
results = []
for ex in islice(ds, NUM_SAMPLES):
    img, label = ex["image"], ex["label"]

    # a) CLIP embed
    clip_inputs = clip_processor(images=img, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        emb = clip_model.get_image_features(**clip_inputs)
        emb = emb / emb.norm(p=2, dim=-1, keepdim=True)

    # b) Generate caption
    gen_ids = caption_model.generate_caption(
        emb,
        max_new_tokens=64,
        do_sample=True,
        temperature=0.9,
        top_p=0.9,
    )
    caption = tokenizer.decode(gen_ids[0], skip_special_tokens=True)

    # c) Print & store
    print(f"Label={label} → {caption}")
    results.append({"image": img, "label": label, "caption": caption})

# ─── 8.  Build DataFrame ────────────────────────────────────────
df = pd.DataFrame(results)
print(df.head())


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00000-of-00013-4717c3cf92578c96.parquet:   0%|          | 0.00/471M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00001-of-00013-549914845b4273b1.parquet:   0%|          | 0.00/471M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00002-of-00013-a859720d3cfcebdf.parquet:   0%|          | 0.00/470M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00003-of-00013-a70975735603ee91.parquet:   0%|          | 0.00/470M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00004-of-00013-f3cb3678324a5346.parquet:   0%|          | 0.00/470M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00005-of-00013-959ba247c1881dc0.parquet:   0%|          | 0.00/470M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00006-of-00013-318f5c6d89fc04ef.parquet:   0%|          | 0.00/471M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00007-of-00013-c8a1a9cf7273420c.parquet:   0%|          | 0.00/471M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00008-of-00013-3d4f66c19471ed0a.parquet:   0%|          | 0.00/472M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00009-of-00013-867b6df30133f28e.parquet:   0%|          | 0.00/471M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00010-of-00013-abf99d3df1f77818.parquet:   0%|          | 0.00/471M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00011-of-00013-e929006353f3ae95.parquet:   0%|          | 0.00/471M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00012-of-00013-73b855ce7d233beb.parquet:   0%|          | 0.00/471M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00000-of-00002-0e1a29e0620125c6.parquet:   0%|          | 0.00/383M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00001-of-00002-aad8011eb887c9d9.parquet:   0%|          | 0.00/385M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00000-of-00002-bb04e6313f58efa0.parquet:   0%|          | 0.00/376M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


(…)-00001-of-00002-3bfa172e8818685a.parquet:   0%|          | 0.00/375M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/262144 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/32768 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/32768 [00:00<?, ? examples/s]

Label=False → enoma - Tumor tissue consisting of medium-sized and irregular glandular ducts fused and infiltrated is observed in the superficial epithelium. Tumor cells are highly columnar, with nuclei aligned basolaterally and polarized.
Label=True → enocarcinoma - In the superficial epithelium, tumor tissue consisting of large irregular glandular ducts partially fused and infiltrated is observed. Some tumor cells exhibit marked fusion or clustering. tumor cells exhibit a fusion-like shape and exhibit hyperchromatism.
Label=False → enocarcinoma - Tumor tissue consisting of cord-like or small irregular glandular ducts fused and infiltrated is observed in the superficial epithelium. Tumor cells are highly columnar and show nucleus.
Label=True → enocarcinoma - On the superficial epithelium, tumor tissue consisting of medium-sized and irregular glandular ducts infiltrating is observed. Well differentiated tubular adenocarcinoma
Label=True →  solid type - Tumor tissue consisting of cord-li

In [None]:
df.to_pickle("/content/drive/MyDrive/patchcamelyon_captions_df_testdata500.pkl")

In [None]:
testdf = pd.read_pickle("/content/drive/MyDrive/patchcamelyon_captions_df_testdata500.pkl")

In [None]:
testdf.head()

Unnamed: 0,image,label,caption
0,<PIL.PngImagePlugin.PngImageFile image mode=RG...,False,enoma - Tumor tissue consisting of medium-size...
1,<PIL.PngImagePlugin.PngImageFile image mode=RG...,True,"enocarcinoma - In the superficial epithelium, ..."
2,<PIL.PngImagePlugin.PngImageFile image mode=RG...,False,enocarcinoma - Tumor tissue consisting of cord...
3,<PIL.PngImagePlugin.PngImageFile image mode=RG...,True,"enocarcinoma - On the superficial epithelium, ..."
4,<PIL.PngImagePlugin.PngImageFile image mode=RG...,True,solid type - Tumor tissue consisting of cord-...


In [None]:
from transformers import LlavaForConditionalGeneration, LlavaProcessor, CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration

In [19]:
hf_token = "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"

In [None]:
img_embs = []
for img in tqdm(testdf["image"], desc="Embedding images"):
    inputs = clip_processor(images=img, return_tensors="pt").to("cuda")
    with torch.no_grad():
        feat = clip_model.get_image_features(**inputs)  # (1, 768)
    feat = feat / feat.norm(dim=-1, keepdim=True)
    img_embs.append(feat.cpu().numpy())

img_embs = np.vstack(img_embs)  # (N, 768)

txt_embs = text_embedder.encode(
    testdf["caption"].tolist(),
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
)

Embedding images:   0%|          | 0/500 [00:00<?, ?it/s]

Batches:   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
import pickle

images = testdf["image"].tolist()
# Extract relevant parts from DataFrame
captions = testdf["caption"].tolist()
labels = testdf["label"].tolist()  # Make sure 'label' column exists

# Bundle everything into a dictionary
data_to_save = {
    "images": images,
    "captions": captions,
    "labels": labels,
    "image_embeddings": img_embs,
    "text_embeddings": txt_embs
}

# Save to pickle file
with open("embeddings_and_labels.pkl", "wb") as f:
    pickle.dump(data_to_save, f)


In [None]:

with open("/content/drive/MyDrive/patchcamelyon_embeddings_testdata500.pkl", "wb") as f:
  pickle.dump(data_to_save, f)

### Start from here inference

In [5]:
import os, json, pickle
import numpy as np
import faiss

In [6]:
# 1) Paths for your new RAG store
BASE_PATH  = "/content/drive/MyDrive/pcam_rag_index_withlables_new"
INDEX_PATH = os.path.join(BASE_PATH, "caption_image_index.faiss")
META_PATH  = os.path.join(BASE_PATH, "metadata.jsonl")

In [7]:
index = faiss.read_index(INDEX_PATH)
docs  = [json.loads(line) for line in open(META_PATH, "r")]
print(f"Index: {index.ntotal} vectors; metadata entries: {len(docs)}")

Index: 14311 vectors; metadata entries: 14311


In [8]:
PICKLE_IN = "/content/drive/MyDrive/patchcamelyon_embeddings_testdata500.pkl"
with open(PICKLE_IN, "rb") as f:
    test_data = pickle.load(f)

In [9]:
test_images   = test_data["images"]              # list of PIL.Image
test_captions = test_data["captions"]            # list of str
test_labels   = test_data["labels"]              # list of int
test_img_embs = np.array(test_data["image_embeddings"], dtype=np.float32)
test_txt_embs = np.array(test_data["text_embeddings"], dtype=np.float32)

In [10]:
faiss.normalize_L2(test_img_embs)
faiss.normalize_L2(test_txt_embs)

In [15]:
alpha = 1.0          # weight for image
beta  = 1.0 - alpha  # weight for text
k     = 10

In [16]:
img_dim = test_img_embs.shape[1]

In [17]:
rag_results = []
for i in range(len(test_img_embs)):
    # weighted fusion
    q_img = test_img_embs[i] * alpha
    q_txt = test_txt_embs[i] * beta
    q = np.concatenate([q_img, q_txt], axis=0).reshape(1, -1)
    faiss.normalize_L2(q)

    # search
    D, I = index.search(q, k)
    neighs = []
    for score, idx in zip(D[0], I[0]):
        entry = docs[idx].copy()
        neighs.append({
            "id":      entry.get("id"),
            "caption": entry.get("caption"),
            "label":   entry.get("label"),
            "score":   float(score),
            # If you stored the combined embeddings you could also
            # reconstruct them here with index.reconstruct(idx)
        })
    rag_results.append(neighs)

print(f"Retrieved top-{k} neighbors for each of {len(rag_results)} test samples")

Retrieved top-10 neighbors for each of 500 test samples


In [14]:
import json
import torch
from PIL import Image
from transformers import (
    LlavaForConditionalGeneration,
    LlavaProcessor,
    BitsAndBytesConfig,
)
from sentence_transformers import SentenceTransformer
import faiss, numpy as np, pickle, os

In [20]:
DEVICE   = "cuda"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
)

model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    device_map="auto",
    token=hf_token,
).eval().to(DEVICE)

processor = LlavaProcessor.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    token=hf_token
)


config.json:   0%|          | 0.00/950 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/70.1k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.18G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/505 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.62M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

In [None]:
system_msg = """you are a medical assistant trained to classify histopathologic images as tumor (label 1) or normal (label 0).
Just give 0 or 1 as response. 0 if there is no cancer, 1 if cancer.
Dont give any explanations.
Think step-by-step based on cellular structure and pathology cues.
Then output your final answer in the format:<label>"""
user_step1 = "Step 1: Study the following similar examples and their labels."


In [21]:
def make_prompt(test_caption, neighbors):
    # 1) Instruction
    instr = (
        "<image>\n"
        "### Instruction:\n"
        "You are a medical assistant trained to classify histopathologic images as tumor (1) or normal (0).\n"
        "Only respond with a single digit (0 or 1), no explanations.\n"
        "Think step-by-step based on cellular structure and pathology cues.\n\n"
    )

    # 2) Few‐shot examples
    ex_str = "### Examples:\n"
    for n in neighbors:
        ex_str += (
            f"Caption: {n['caption']}\n"
            f"Answer: {n['label']}\n\n"
        )

    # 3) Query
    query = (
        "### Query:\n"
        f"Caption: {test_caption}\n"
        "Answer:"
    )

    return instr + ex_str + query


In [22]:
@torch.no_grad()
def predict_with_few_shot(idx):
    # retrieve neighbors as before…
    neighbors = rag_results[idx]  # list of top-10 dicts

    prompt = make_prompt(test_captions[idx], neighbors)

    inputs = processor(
        images=test_images[idx],
        text=prompt,
        return_tensors="pt",
        padding=True,
        truncation=False,   # ensure the <image> token is never dropped
    ).to(DEVICE)

    gen_ids = model.generate(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    pixel_values=inputs.pixel_values,
    max_new_tokens=2,                # allow up to “<digit>” + EOS
    num_beams=3,                     # beam search for a cleaner single‐token answer
    do_sample=False,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
)
    input_len  = inputs.input_ids.shape[-1]
    new_tokens = gen_ids[0, input_len:].cpu().tolist()
    out_text   = processor.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    # take the first “word” in case there’s a trailing newline
    pred       = out_text.split()[0] if out_text else None
    return pred


In [None]:
preds = []
for i in range(len(test_images)):
    p = predict_with_few_shot(i)
    print(p)
    preds.append(int(p) if p.isdigit() else None)
    print(f"[{i}] True={test_labels[i]} → Pred={preds[-1]}")

acc = sum(p==t for p,t in zip(preds, test_labels)) / len(test_labels)
print(f"\n10-shot RAG→LLaVA accuracy: {acc:.1%}")


1
[0] True=False → Pred=1
1
[1] True=True → Pred=1
1
[2] True=False → Pred=1
1
[3] True=True → Pred=1
1
[4] True=True → Pred=1
1
[5] True=False → Pred=1
1
[6] True=True → Pred=1
1
[7] True=True → Pred=1
1
[8] True=True → Pred=1
0
[9] True=False → Pred=0
1
[10] True=True → Pred=1
0
[11] True=False → Pred=0
1
[12] True=True → Pred=1
1
[13] True=True → Pred=1
1
[14] True=False → Pred=1
1
[15] True=True → Pred=1
1
[16] True=False → Pred=1
1
[17] True=True → Pred=1
0
[18] True=True → Pred=0
1
[19] True=True → Pred=1
1
[20] True=True → Pred=1
1
[21] True=False → Pred=1
0
[22] True=False → Pred=0
0
[23] True=False → Pred=0
0
[24] True=False → Pred=0
0
[25] True=True → Pred=0
1
[26] True=True → Pred=1
1
[27] True=False → Pred=1
1
[28] True=True → Pred=1
1
[29] True=True → Pred=1
1
[30] True=True → Pred=1
1
[31] True=True → Pred=1
1
[32] True=True → Pred=1
1
[33] True=True → Pred=1
1
[34] True=True → Pred=1
1
[35] True=True → Pred=1
0
[36] True=True → Pred=0
1
[37] True=False → Pred=1
0
[38] Tr

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# If you have any None‐predictions, filter them out
valid_idx = [i for i, p in enumerate(preds) if p is not None]
y_true = [test_labels[i] for i in valid_idx]
y_pred = [preds[i]      for i in valid_idx]

# Compute metrics
acc  = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, zero_division=0)
rec  = recall_score(y_true, y_pred, zero_division=0)
f1   = f1_score(y_true, y_pred, zero_division=0)

print(f"Accuracy : {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall   : {rec:.4f}")
print(f"F1-score : {f1:.4f}")


Accuracy : 0.6940
Precision: 0.6863
Recall   : 0.7865
F1-score : 0.7330


In [23]:
preds = []
for i in range(len(test_images)):
    p = predict_with_few_shot(i)
    print(p)
    preds.append(int(p) if p.isdigit() else None)
    print(f"[{i}] True={test_labels[i]} → Pred={preds[-1]}")

acc = sum(p==t for p,t in zip(preds, test_labels)) / len(test_labels)
print(f"\n10-shot RAG→LLaVA accuracy: {acc:.1%}")


0
[0] True=False → Pred=0
1
[1] True=True → Pred=1
0
[2] True=False → Pred=0
1
[3] True=True → Pred=1
0
[4] True=True → Pred=0
1
[5] True=False → Pred=1
1
[6] True=True → Pred=1
1
[7] True=True → Pred=1
1
[8] True=True → Pred=1
0
[9] True=False → Pred=0
0
[10] True=True → Pred=0
1
[11] True=False → Pred=1
1
[12] True=True → Pred=1
1
[13] True=True → Pred=1
0
[14] True=False → Pred=0
0
[15] True=True → Pred=0
1
[16] True=False → Pred=1
1
[17] True=True → Pred=1
1
[18] True=True → Pred=1
0
[19] True=True → Pred=0
1
[20] True=True → Pred=1
0
[21] True=False → Pred=0
0
[22] True=False → Pred=0
0
[23] True=False → Pred=0
0
[24] True=False → Pred=0
0
[25] True=True → Pred=0
1
[26] True=True → Pred=1
1
[27] True=False → Pred=1
1
[28] True=True → Pred=1
1
[29] True=True → Pred=1
0
[30] True=True → Pred=0
0
[31] True=True → Pred=0
1
[32] True=True → Pred=1
1
[33] True=True → Pred=1
1
[34] True=True → Pred=1
1
[35] True=True → Pred=1
0
[36] True=True → Pred=0
1
[37] True=False → Pred=1
1
[38] Tr

In [24]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# If you have any None‐predictions, filter them out
valid_idx = [i for i, p in enumerate(preds) if p is not None]
y_true = [test_labels[i] for i in valid_idx]
y_pred = [preds[i]      for i in valid_idx]

# Compute metrics
acc  = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, zero_division=0)
rec  = recall_score(y_true, y_pred, zero_division=0)
f1   = f1_score(y_true, y_pred, zero_division=0)

print(f"Accuracy : {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall   : {rec:.4f}")
print(f"F1-score : {f1:.4f}")


Accuracy : 0.7120
Precision: 0.6916
Recall   : 0.8315
F1-score : 0.7551


In [31]:
alpha = 5.0          # weight for image
beta  = 1.0 - alpha  # weight for text
k     = 10

In [32]:
rag_results = []
for i in range(len(test_img_embs)):
    # weighted fusion
    q_img = test_img_embs[i] * alpha
    q_txt = test_txt_embs[i] * beta
    q = np.concatenate([q_img, q_txt], axis=0).reshape(1, -1)
    faiss.normalize_L2(q)

    # search
    D, I = index.search(q, k)
    neighs = []
    for score, idx in zip(D[0], I[0]):
        entry = docs[idx].copy()
        neighs.append({
            "id":      entry.get("id"),
            "caption": entry.get("caption"),
            "label":   entry.get("label"),
            "score":   float(score),
            # If you stored the combined embeddings you could also
            # reconstruct them here with index.reconstruct(idx)
        })
    rag_results.append(neighs)

print(f"Retrieved top-{k} neighbors for each of {len(rag_results)} test samples")

Retrieved top-10 neighbors for each of 500 test samples


In [33]:
preds = []
for i in range(len(test_images)):
    p = predict_with_few_shot(i)
    print(p)
    preds.append(int(p) if p.isdigit() else None)
    print(f"[{i}] True={test_labels[i]} → Pred={preds[-1]}")

acc = sum(p==t for p,t in zip(preds, test_labels)) / len(test_labels)
print(f"\n10-shot RAG→LLaVA accuracy: {acc:.1%}")


1
[0] True=False → Pred=1
1
[1] True=True → Pred=1
1
[2] True=False → Pred=1
1
[3] True=True → Pred=1
1
[4] True=True → Pred=1
1
[5] True=False → Pred=1
1
[6] True=True → Pred=1
1
[7] True=True → Pred=1
1
[8] True=True → Pred=1
1
[9] True=False → Pred=1
1
[10] True=True → Pred=1
1
[11] True=False → Pred=1
1
[12] True=True → Pred=1
1
[13] True=True → Pred=1
1
[14] True=False → Pred=1
1
[15] True=True → Pred=1
1
[16] True=False → Pred=1
1
[17] True=True → Pred=1
1
[18] True=True → Pred=1
1
[19] True=True → Pred=1
1
[20] True=True → Pred=1
0
[21] True=False → Pred=0
1
[22] True=False → Pred=1
1
[23] True=False → Pred=1
1
[24] True=False → Pred=1
1
[25] True=True → Pred=1
1
[26] True=True → Pred=1
1
[27] True=False → Pred=1
1
[28] True=True → Pred=1
1
[29] True=True → Pred=1
1
[30] True=True → Pred=1
1
[31] True=True → Pred=1
1
[32] True=True → Pred=1
1
[33] True=True → Pred=1
1
[34] True=True → Pred=1
1
[35] True=True → Pred=1
1
[36] True=True → Pred=1
1
[37] True=False → Pred=1
1
[38] Tr

In [34]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# If you have any None‐predictions, filter them out
valid_idx = [i for i, p in enumerate(preds) if p is not None]
y_true = [test_labels[i] for i in valid_idx]
y_pred = [preds[i]      for i in valid_idx]

# Compute metrics
acc  = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, zero_division=0)
rec  = recall_score(y_true, y_pred, zero_division=0)
f1   = f1_score(y_true, y_pred, zero_division=0)

print(f"Accuracy : {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall   : {rec:.4f}")
print(f"F1-score : {f1:.4f}")


Accuracy : 0.5520
Precision: 0.5440
Recall   : 0.9963
F1-score : 0.7037


In [35]:
alpha = 1.0          # weight for image
beta  = 1.0 - alpha  # weight for text
k     = 10

In [36]:
rag_results = []
for i in range(len(test_img_embs)):
    # weighted fusion
    q_img = test_img_embs[i] * alpha
    q_txt = test_txt_embs[i] * beta
    q = np.concatenate([q_img, q_txt], axis=0).reshape(1, -1)
    faiss.normalize_L2(q)

    # search
    D, I = index.search(q, k)
    neighs = []
    for score, idx in zip(D[0], I[0]):
        entry = docs[idx].copy()
        neighs.append({
            "id":      entry.get("id"),
            "caption": entry.get("caption"),
            "label":   entry.get("label"),
            "score":   float(score),
            # If you stored the combined embeddings you could also
            # reconstruct them here with index.reconstruct(idx)
        })
    rag_results.append(neighs)

print(f"Retrieved top-{k} neighbors for each of {len(rag_results)} test samples")

Retrieved top-10 neighbors for each of 500 test samples


In [37]:
preds = []
for i in range(len(test_images)):
    p = predict_with_few_shot(i)
    print(p)
    preds.append(int(p) if p.isdigit() else None)
    print(f"[{i}] True={test_labels[i]} → Pred={preds[-1]}")

acc = sum(p==t for p,t in zip(preds, test_labels)) / len(test_labels)
print(f"\n10-shot RAG→LLaVA accuracy: {acc:.1%}")


0
[0] True=False → Pred=0
1
[1] True=True → Pred=1
0
[2] True=False → Pred=0
1
[3] True=True → Pred=1
0
[4] True=True → Pred=0
1
[5] True=False → Pred=1
1
[6] True=True → Pred=1
1
[7] True=True → Pred=1
1
[8] True=True → Pred=1
0
[9] True=False → Pred=0
0
[10] True=True → Pred=0
1
[11] True=False → Pred=1
1
[12] True=True → Pred=1
1
[13] True=True → Pred=1
0
[14] True=False → Pred=0
0
[15] True=True → Pred=0
1
[16] True=False → Pred=1
1
[17] True=True → Pred=1
1
[18] True=True → Pred=1
0
[19] True=True → Pred=0
1
[20] True=True → Pred=1
0
[21] True=False → Pred=0
0
[22] True=False → Pred=0
0
[23] True=False → Pred=0
0
[24] True=False → Pred=0
0
[25] True=True → Pred=0
1
[26] True=True → Pred=1
1
[27] True=False → Pred=1
1
[28] True=True → Pred=1
1
[29] True=True → Pred=1
0
[30] True=True → Pred=0
0
[31] True=True → Pred=0
1
[32] True=True → Pred=1
1
[33] True=True → Pred=1
1
[34] True=True → Pred=1
1
[35] True=True → Pred=1
0
[36] True=True → Pred=0
1
[37] True=False → Pred=1
1
[38] Tr

In [38]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# If you have any None‐predictions, filter them out
valid_idx = [i for i, p in enumerate(preds) if p is not None]
y_true = [test_labels[i] for i in valid_idx]
y_pred = [preds[i]      for i in valid_idx]

# Compute metrics
acc  = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, zero_division=0)
rec  = recall_score(y_true, y_pred, zero_division=0)
f1   = f1_score(y_true, y_pred, zero_division=0)

print(f"Accuracy : {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall   : {rec:.4f}")
print(f"F1-score : {f1:.4f}")


Accuracy : 0.7120
Precision: 0.6916
Recall   : 0.8315
F1-score : 0.7551
