In [None]:
from llama_cpp import Llama
import pandas as pd
import numpy as np
import torch.nn.functional as F
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import re
import os


In [None]:
# this notebook uses Llama to generate brief summaries for each patient who enrolled on a clinical trial using our retrospective
# enrollment dataset

In [None]:
llm = Llama.from_pretrained(
    repo_id="lmstudio-community/Meta-Llama-3-70B-Instruct-GGUF",
    filename="*Q4*",
    verbose=False,
    local_dir = '/data/clin_notes_outcomes/meta/',
    cache_dir = '/data/clin_notes_outcomes/meta/',
    main_gpu=1,
    n_ctx=8192,
    n_gpu_layers=-1
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

In [None]:
# test llama query
response = llm.create_chat_completion(
     messages=[{
         "role": "user",
         "content": "what is the meaning of life?"
     }]
)

In [None]:
response

In [None]:
# function to split a patient's historical electronic health record document into smaller overlapping chunks for RAG purposes
def split_text(text, chunk_size=100, overlap=10):
    words = text.split()
    chunks = []
    current_index = 0
    
    while current_index < len(words):
        end_index = current_index + chunk_size
        chunks.append(' '.join(words[current_index:end_index]))
        current_index = end_index - overlap
        
        # Ensure the last chunk contains at least chunk_size words
        if len(words) - current_index < chunk_size:
            chunks.append(' '.join(words[-chunk_size:]))
            break

    return chunks

In [None]:
# function to transform a list of historical EHR documents for a given patient into a list of embedding vectors, one per chunk from the documents
def embed_patient(patient_dataframe, embedding_model):
    patient_sentences = []
    patient_dates = []
    notes =  patient_dataframe.text.values.tolist()
    for i, doc in enumerate(notes):
        sentences = split_text(doc)
        thisdate = patient_dataframe.date.iloc[i]
        for sentence in sentences:
            patient_sentences.append(sentence.strip())
            patient_dates.append(thisdate)
    
    patient_embeddings = embedding_model.encode(patient_sentences)
    
    return patient_sentences, patient_dates, patient_embeddings

In [None]:
# function to generate a brief patient summary with Llama after using RAG to pull relevant EHR document chunks
def summarize_patient(patient_sentences, patient_dates, patient_embeddings, embedding_model, llama_model, tokenizer, sentences_per_question=8):
    
    patient_sentences, indices = np.unique(patient_sentences, return_index = True)
    patient_embeddings = patient_embeddings[indices, :]
    patient_dates = patient_dates[indices]
    
    questions = ["cancer types",
                 "cancer stage or extent",
                 "biomarkers, mutations, protein expression",
                 "cancer treatments, such as surgery, chemotherapy, targeted therapy, immunotherapy, radiation, or transplant?",
                 "major toxicities, adverse events, or side effects"]
    
    frames = []
    for question in questions:
        query_embedding = embedding_model.encode([question], prompt_name="query")
        similarities = F.cosine_similarity(torch.tensor(query_embedding), torch.tensor(patient_embeddings))
        sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
        relevant_sentences = patient_sentences[sorted_indices[0:sentences_per_question].cpu().numpy()].tolist()
        relevant_sentences = [x.replace("search_document: ", "") for x in relevant_sentences]
        relevant_dates = patient_dates[sorted_indices[0:sentences_per_question].cpu().numpy()].tolist()
        frame = pd.DataFrame({'sentences':relevant_sentences, 'dates':relevant_dates})
        frame['dates'] = pd.to_datetime(frame.dates)
        frames.append(frame)
    frames = pd.concat(frames, axis=0).sort_values(by='dates')
    relevant_sentences = "\n".join(frames.sentences)



    messages = [{'role':'system', 'content': """You are an experienced clinical oncologist at a major cancer center.
    Your job is to construct a summary of the cancer history for a patient based on an excerpt of the patient's electronic health record. The text in the excerpt is provided in chronological order.     
    Phrase your summary as it it were the beginning of the assessment/plan section of a clinical note. Do not include the patient's name, but do include relevant dates whenever documented, including dates of diagnosis and start/stop dates of each treatment.
"""
    },
                 {'role':'user', 'content': "The excerpt is:\n" + relevant_sentences + """Now, write your summary. No preceding text before the abstraction. This will not be used for clinical care, so do not write any disclaimers or cautionary notes."""}
    ]


    trunc_messages = []
    for message in messages:
        message['content'] = tokenizer.decode(tokenizer.encode(message['content'])[1:7000])
        trunc_messages.append(message)

    response = llama_model.create_chat_completion(messages=trunc_messages, max_tokens=1000, temperature=0.2)

    return response, response['choices'][0]['message']['content']


    

In [None]:
from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("Snowflake/snowflake-arctic-embed-l", trust_remote_code=True, device='cuda:0')


In [None]:
prefix = '/data/clin_notes_outcomes/pan_dfci_2024/derived_data/'

In [None]:
# pull in our large corpus of historical electronic health records data
imaging = pd.read_parquet(prefix + 'all_imaging_reports.parquet')
medonc = pd.read_parquet(prefix + 'all_clinical_notes.parquet')
path = pd.read_parquet(prefix + 'all_path_reports.parquet')


In [None]:
all_reports = pd.concat([imaging, medonc, path], axis=0).sort_values(by=['dfci_mrn','date']).reset_index(drop=True)


In [None]:
all_reports.info()

In [None]:
all_reports.dfci_mrn.nunique()

In [None]:
# restrict documents to those from the era of our current electronic health record, for which data are more complete
recent_reports = all_reports[all_reports.date >= '2016-01-01']
recent_reports.info()



In [None]:
recent_reports.dfci_mrn.nunique()

In [None]:
train_notes = all_reports[all_reports.split=='train']

In [None]:
# pull our historical trial enrollment dataset, restrict it to the modern EHR era
enrollments = pd.read_csv(prefix + 'useful_trial_enrollments.csv')
enrollments['trial_start_dt'] = pd.to_datetime(enrollments.trial_start_dt)
enrollments = enrollments[enrollments.trial_start_dt >= pd.to_datetime('2016-01-01')]

In [None]:
enrollments.info()

In [None]:
train_enrollments = enrollments[enrollments.dfci_mrn.isin(train_notes.dfci_mrn)]

In [None]:
#sample_enrollments = train_enrollments[train_enrollments.trial_text.str.contains('NSCLC')].sample(n=100)
sample_enrollments = enrollments


In [None]:
# generate a brief summary of each patient who enrolled in a trial in our retrospective dataset

%%capture
sample_enrollments = enrollments
patient_summary_list = []
patient_sentence_list = []
patient_date_list = []
patient_mrn_list = []
patient_split_list = []
protocol_number_list = []
trial_text_list = []
enrollment_date_list = []

for i in range(0, sample_enrollments.shape[0]):
    this_enrollment = sample_enrollments.iloc[i]
  
    this_patient = all_reports[all_reports.dfci_mrn == this_enrollment.dfci_mrn]
    this_patient = this_patient[pd.to_datetime(this_patient.date) < pd.to_datetime(this_enrollment.trial_start_dt)]
    

    if this_patient.shape[0] > 0:
        patient_sentences, patient_dates, patient_embeddings = embed_patient(this_patient, embedding_model)
        patient_sentence_list.append(patient_sentences)
        patient_date_list.append(patient_dates)
        patient_summary = summarize_patient(patient_sentences, pd.to_datetime(patient_dates), patient_embeddings, embedding_model, llm, tokenizer, sentences_per_question=5)[1]
        patient_summary_list.append(patient_summary)
        patient_mrn_list.append(this_patient.groupby('dfci_mrn').first().reset_index().dfci_mrn.values.item())
        patient_split_list.append(this_patient.groupby('dfci_mrn').first().reset_index().split.values.item())
        protocol_number_list.append(this_enrollment.protocol_number)
        trial_text_list.append(this_enrollment.trial_text)
        enrollment_date_list.append(this_enrollment.trial_start_dt)
    
    if (i % 500 == 0) or i == (sample_enrollments.shape[0] - 1):
        output = pd.DataFrame({'dfci_mrn':patient_mrn_list, 
                       'split':patient_split_list,
                       'trial_start_dt': enrollment_date_list,
                       'protocol_number': protocol_number_list,
                       'trial_text': trial_text_list,
                       'patient_summary': patient_summary_list})

        output.to_csv('all_patient_summaries_6-27-24.csv')