In [1]:
!wget https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/embeds_chunk_36.npy 
!wget https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/pmids_chunk_36.json 
!wget https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/pubmed_chunk_36.json 

--2025-01-09 18:31:56--  https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/embeds_chunk_36.npy
Resolving ftp.ncbi.nlm.nih.gov (ftp.ncbi.nlm.nih.gov)... 130.14.250.31, 130.14.250.13, 130.14.250.12, ...
Connecting to ftp.ncbi.nlm.nih.gov (ftp.ncbi.nlm.nih.gov)|130.14.250.31|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3051946112 (2.8G)
Saving to: ‘embeds_chunk_36.npy.1’

embeds_chunk_36.npy   8%[>                   ] 242.20M  92.7MB/s               ^C
--2025-01-09 18:31:59--  https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/pmids_chunk_36.json
Resolving ftp.ncbi.nlm.nih.gov (ftp.ncbi.nlm.nih.gov)... 130.14.250.7, 130.14.250.11, 130.14.250.10, ...
Connecting to ftp.ncbi.nlm.nih.gov (ftp.ncbi.nlm.nih.gov)|130.14.250.7|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11921664 (11M) [application/json]
Saving to: ‘pmids_chunk_36.json.1’

pmids_chunk_36.json   0%[                    ]       0  --.-KB/s               ^

In [None]:
!ls


# INSTALLS

In [1]:
!pip install faiss-gpu
!pip install faiss-cpu
!pip install git+https://github.com/huggingface/transformers.git
!pip install Bio
!pip install gradio


Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2
Collecting faiss-cpu
  Downloading faiss_cpu-1.9.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.9.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.5/27.5 MB[0m [31m59.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.9.0.post1
Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://git

In [None]:
!pip freeze


# HF LOGIN

In [2]:
from huggingface_hub import HfApi, login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret0 = user_secrets.get_secret("hface_read")
login(token=secret0)  

# MAIN

# EMBEDDING VISUAL

In [3]:
import faiss
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, pipeline
import json
import gradio as gr
import matplotlib.pyplot as plt
import tempfile
import os

class MedicalRAG:
    def __init__(self, embed_path, pmids_path, content_path):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Load data
        self.embeddings = np.load(embed_path)
        self.index = self.create_faiss_index(self.embeddings)
        self.pmids, self.content = self.load_json_files(pmids_path, content_path)
        # Setup models
        self.encoder, self.tokenizer = self.setup_encoder()
        self.generator = self.setup_generator()

    def create_faiss_index(self, embeddings):
        index = faiss.IndexFlatIP(768)  # 768 is embedding dimension
        index.add(embeddings)
        return index

    def load_json_files(self, pmids_path, content_path):
        with open(pmids_path) as f:
            pmids = json.load(f)
        with open(content_path) as f:
            content = json.load(f)
        return pmids, content

    def setup_encoder(self):
        model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder").to(self.device)
        tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
        return model, tokenizer

    def setup_generator(self):
        return pipeline(
            "text-generation",
            model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
            device=self.device,
            torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32
        )

    def encode_query(self, query):
        with torch.no_grad():
            inputs = self.tokenizer([query], truncation=True, padding=True, 
                                  return_tensors='pt', max_length=64).to(self.device)
            embeddings = self.encoder(**inputs).last_hidden_state[:, 0, :]
            return embeddings.cpu().numpy()

    def search_documents(self, query_embedding, k=8):
        scores, indices = self.index.search(query_embedding, k=k)
        return [(self.pmids[idx], float(score)) for idx, score in zip(indices[0], scores[0])], indices[0]

    def get_document_content(self, pmid):
        doc = self.content.get(pmid, {})
        return {
            'title': doc.get('t', '').strip(),
            'date': doc.get('d', '').strip(),
            'abstract': doc.get('a', '').strip()
        }

    def visualize_embeddings(self, query_embed, relevant_indices, labels):
        plt.figure(figsize=(20, len(relevant_indices) + 1))
        
        # Stack and normalize embeddings to [-1, 1]
        embeddings = np.vstack([query_embed[0], self.embeddings[relevant_indices]])
        max_abs_val = np.max(np.abs(embeddings))
        normalized_embeddings = embeddings / max_abs_val
        
        # Ensure we have the same number of embeddings and labels
        assert len(normalized_embeddings) == len(labels), f"Mismatch: {len(normalized_embeddings)} embeddings vs {len(labels)} labels"
        
        # Calculate total height
        total_height = len(normalized_embeddings)
        
        # Plot each embedding
        for idx in range(total_height):
            y_pos = total_height - idx - 1
            plt.imshow(normalized_embeddings[idx].reshape(1, -1), aspect='auto', 
                      extent=[0, 768, y_pos, y_pos+0.8],
                      cmap='viridis', vmin=-1, vmax=1)
        
        # Adjust yticks for proper label placement
        plt.yticks(np.arange(total_height - 0.6, -0.6, -1), labels)
        plt.xlabel('Embedding Dimensions')
        plt.colorbar(label='Normalized Value')
        plt.title('Query and Retrieved Document Embeddings')
        
        temp_path = os.path.join(tempfile.gettempdir(), 
                                f'embeddings_{hash(str(embeddings))}.png')
        plt.savefig(temp_path, bbox_inches='tight', dpi=150)
        plt.close()
        return temp_path

    def generate_answer(self, query, contexts):
        prompt = (
            "<|im_start|>system\n"
            "You are a helpful medical assistant. Answer questions based on the provided literature."
            "<|im_end|>\n<|im_start|>user\n"
            f"Based on these medical articles, answer this question:\n\n"
            f"Question: {query}\n\n"
            f"Relevant Literature:\n{contexts}\n"
            "<|im_end|>\n<|im_start|>assistant"
        )
        
        response = self.generator(
            prompt,
            max_new_tokens=200,
            temperature=0.3,
            top_p=0.95,
            do_sample=True
        )
        return response[0]['generated_text'].split("<|im_start|>assistant")[-1].strip()

    def process_query(self, query):
        try:
            # Encode and search
            query_embed = self.encode_query(query)
            doc_matches, indices = self.search_documents(query_embed)
            
            # Prepare documents and labels
            documents = []
            sources = []
            labels = ["Query"]
            
            for pmid, score in doc_matches:
                doc = self.get_document_content(pmid)
                if doc['abstract']:
                    documents.append(f"Title: {doc['title']}\nAbstract: {doc['abstract']}")
                    sources.append(f"PMID: {pmid}, Score: {score:.3f}, Link: https://pubmed.ncbi.nlm.nih.gov/{pmid}/")
                    labels.append(f"Doc {len(labels)}: {doc['title'][:30]}...")

            
            # Generate outputs
            visualization = self.visualize_embeddings(query_embed, indices, labels)
            answer = self.generate_answer(query, "\n\n".join(documents[:3]))
            sources_text = "\n".join(sources)
            context = "\n\n".join(documents)
            
            return answer, sources_text, context, visualization
            
        except Exception as e:
            print(f"Error: {str(e)}")
            return str(e), "Error retrieving sources", "", None

In [4]:
def create_interface():
    rag = MedicalRAG(
        embed_path="embeds_chunk_36.npy",
        pmids_path="pmids_chunk_36.json",
        content_path="pubmed_chunk_36.json"
    )
    
    with gr.Blocks(title="Medical Literature QA") as interface:
        gr.Markdown("# Medical Literature Question Answering")
        with gr.Row():
            with gr.Column():
                query = gr.Textbox(lines=2, placeholder="Enter your medical question...", label="Question")
                submit = gr.Button("Submit", variant="primary")
                sources = gr.Textbox(label="Sources", lines=3)
                plot = gr.Image(label="Embedding Visualization")
            with gr.Column():
                answer = gr.Textbox(label="Answer", lines=5)
                context = gr.Textbox(label="Context", lines=6)      
        with gr.Row():
            gr.Examples(
                examples=[
                    ["What are the latest treatments for diabetes?"],
                    ["How effective are COVID-19 vaccines?"],
                    ["What are common symptoms of the flu?"],
                    ["How can I maintain good heart health?"]
                ],
                inputs=query
            )
        
        submit.click(
            fn=rag.process_query,
            inputs=query,
            outputs=[answer, sources, context, plot]
        )
    
    return interface

if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=True)

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.49k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/706k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/74.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/792 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.42G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/3.76k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/655 [00:00<?, ?B/s]

Device set to use cuda


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://cb781ee6ea1b0f2b69.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
