In [None]:
# Imports
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import re
from collections import Counter
import ipywidgets as widgets
from IPython.display import display, clear_output
import math

# CONFIG
CSV_FILENAME = 'drug_data.csv'
TEXT_COLUMNS_FOR_EMBEDDING = ['Cancer Type']
OUTCOME_COLUMNS = ['Treatment_OS', 'Control_OS', 'OS_Improvement (%)', 'Treatment_PFS', 'Control_PFS', 'PFS_Improvement (%)']
RELEVANCE_SCORE_THRESHOLD = 0.5
HF_MODEL_NAME = 'dmis-lab/biobert-v1.1'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'



In [None]:
# Load model and tokenizer
print("\nLoading Hugging Face model...")
try:
    tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
    model = AutoModel.from_pretrained(HF_MODEL_NAME).to(DEVICE)
    model.eval()
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    raise



In [None]:
# Helper Functions
def clean_text(text):
    if isinstance(text, str):
        text = text.lower()
        text = re.sub(r'[^a-z0-9\s-]', '', text)
        text = re.sub(r'\s+', ' ', text).strip()
        return text
    return ''

def parse_time_to_months(time_str):
    if isinstance(time_str, (int, float)):
        return float(time_str)
    if not isinstance(time_str, str):
        return None
    time_str = time_str.strip().lower()
    if time_str in ['n/a', 'not applicable', 'not reported', 'not reached', 'nr']:
        return None
    match = re.match(r'(\d+(\.\d+)?)\s*(month|year)s?', time_str)
    if match:
        value = float(match.group(1))
        unit = match.group(3)
        return value * 12 if unit == 'year' else value
    return None

def parse_improvement_percentage(perc_str):
    if isinstance(perc_str, (int, float)):
        return float(perc_str)
    if not isinstance(perc_str, str):
        return None
    perc_str = perc_str.strip().lower()
    if perc_str in ['n/a', 'not statistically significant', 'not reported']:
        return None
    match = re.match(r'(\d+(\.\d+)?)\s*%', perc_str)
    if match:
        return float(match.group(1))
    return None

def get_mean_pooling_embedding(text, tokenizer, model, device='cpu'):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    last_hidden = outputs.last_hidden_state
    mask = inputs['attention_mask'].unsqueeze(-1).expand(last_hidden.size()).float()
    sum_embeddings = torch.sum(last_hidden * mask, dim=1)
    sum_mask = torch.clamp(mask.sum(1), min=1e-9)
    mean_pooled = (sum_embeddings / sum_mask).cpu().numpy()
    return mean_pooled[0]

def embed_texts(text_list, tokenizer, model, device='cpu'):
    return np.array([get_mean_pooling_embedding(text, tokenizer, model, device) for text in text_list])


In [None]:
# Load and prepare data
print("\nLoading and cleaning data...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
df = pd.read_csv(CSV_FILENAME)

df['combined_text_cleaned_for_embedding'] = df[TEXT_COLUMNS_FOR_EMBEDDING].fillna('').agg(' '.join, axis=1).map(clean_text)

# Generate embeddings
print("\nGenerating embeddings...")
drug_embeddings = embed_texts(df['combined_text_cleaned_for_embedding'].tolist(), tokenizer, model, device)
print("Embeddings generated.")

# User query + similarity
def get_top_matches(user_query, top_k=5):
    user_cleaned = clean_text(user_query)
    user_embedding = get_mean_pooling_embedding(user_cleaned, tokenizer, model, device)
    sims = cosine_similarity([user_embedding], drug_embeddings)[0]
    df['semantic_similarity'] = sims
    return df.sort_values(by='semantic_similarity', ascending=False).head(top_k)

# Example interactive search
def on_query_submit(query):
    clear_output(wait=True)
    print(f"\nTop matches for: \"{query}\"")
    results = get_top_matches(query)
    display(results[[*TEXT_COLUMNS_FOR_EMBEDDING, 'semantic_similarity', 'Drug Name',  *OUTCOME_COLUMNS]])

In [None]:

search_box = widgets.Text(description='Search:')
button = widgets.Button(description="Search")

def on_button_click(b):
    on_query_submit(search_box.value)

button.on_click(on_button_click)
display(widgets.HBox([search_box, button]))