In [None]:
import numpy as np
import pandas as pd
import json
from llama_cpp import Llama
from transformers import AutoTokenizer
import torch

In [None]:
# this notebook generates trial cohort lists using llama based on the text from clinicaltrials.gov,
# for our retrospective enrollment dataset

In [None]:
prefix = '/data/clin_notes_outcomes/pan_dfci_2024/derived_data/'
enrollments = pd.read_csv(prefix + 'useful_trial_enrollments.csv')
enrollments['trial_start_dt'] = pd.to_datetime(enrollments.trial_start_dt)
enrollments = enrollments[enrollments.trial_start_dt >= pd.to_datetime('2016-01-01')]

In [None]:
enrollments.info()

In [None]:
llm = Llama.from_pretrained(
    repo_id="lmstudio-community/Meta-Llama-3-70B-Instruct-GGUF",
    filename="*Q4*",
    verbose=False,
    local_dir = '/data/clin_notes_outcomes/meta/',
    cache_dir = '/data/clin_notes_outcomes/meta/',
    n_ctx=8192,
    n_gpu_layers=-1
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

In [None]:
def summarize_trial_multi_cohort(eligibility_text, llama_model):
    messages = [
        {'role':'system', 'content': """You are an expert clinical oncologist with an encyclopedic knowledge of cancer and its treatments.
    Your job is to review a clinical trial document and generate a concise summary of the objectives of the trial and its target cohort(s).
    A cohort is defined as a unique combination of cancer type, tumor biomarkers (such as germline or somatic gene mutations or alterations, or protein expression on tumor), which treatments a patient has received, and presence of metastatic disease.
    Some trials have only one cohort, while others have several. Generate a numbered list of such cohorts, where each cohort is described in one concise sentence. Cohorts should be separated by newlines.
    When describing prior treatments, if a drug name is mentioned in the trial criteria, add the drug class in parentheses in your cohort definition.
    Output should be formatted like this example:
    1. Metastatic non-small cell lung cancer, EGFR L858R mutant, previously treated with osimertinib (third-generation EGFR TKI), no prior immunotherapy.
    2. Metastatic solid tumors, no available standard therapies, prior immunotherapy required
    """},      
          
        {'role':'user', 'content': "Here is a clinical trial document: \n" + eligibility_text + "\n" + """Now, generate your list of the trial cohort(s), formatted as above.
        Do not provide any introductory, explanatory, concluding, or disclaimer text."""
        }
    ]


    response = llama_model.create_chat_completion(messages=messages)

    response_text = response['choices'][0]['message']['content']


    return response, response_text

In [None]:
trials = enrollments.groupby('protocol_number').first().reset_index()
trials.info()

In [None]:
trials.head()

In [None]:
%%capture
frames = []
for i in range(trials.shape[0]):
    answer = summarize_trial_multi_cohort(trials.trial_text.iloc[i], llm)
    frame = trials.iloc[[i]]
    frame['cohorts'] = answer[1]
    frames.append(frame)

    if (i % 500 == 0) or i == (trials.shape[0] - 1):
        output = pd.concat(frames, axis=0)
        output.to_csv('unique_trial_cohorts_6-27-24.csv')

In [None]:
import pandas as pd
import numpy as np
output = pd.read_csv('unique_trial_cohorts_6-27-24.csv')

In [None]:
frames = []
for i in range(output.shape[0]):
    cohorts = pd.Series(output.iloc[i].cohorts.split("\n"))
    cohorts = cohorts[~((cohorts.isnull()) | (cohorts == "\n") | (cohorts == ''))].reset_index(drop=True)
    frame = pd.DataFrame(np.repeat(output.iloc[[i]], len(cohorts), axis=0), columns=output.columns)
    frame['this_cohort'] = cohorts
    frame['cohort_number'] = frame.index
    frames.append(frame)
    

In [None]:
cohort_level_trials = pd.concat(frames, axis=0)

In [None]:
cohort_level_trials.info()

In [None]:
cohort_level_trials = cohort_level_trials[['protocol_number','study_nm','protocol_nbr','nct_id','title','brief_summary','detailed_summary','eligibility_criteria','trial_text','cohorts','this_cohort', 'cohort_number']]

In [None]:
cohort_level_trials = cohort_level_trials[cohort_level_trials.this_cohort.str[0].isin(['1','2','3','4','5','6','7','8','9'])]

In [None]:
cohort_level_trials.to_csv('trial_cohort_lineitems_6-27-24.csv')