### Setup, Example

In [None]:
#!pip install google-cloud-aiplatform

# The following restarts the runtime.
import IPython

app = IPython.Application.instance()
# Note that this will result in a pop-up telling you that the session has
# crashed for an unknown reason. This can be safely ignored and you can continue
# with the following cells after getting this message.
app.kernel.do_shutdown(True)

{'status': 'ok', 'restart': True}

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Imports

In [None]:
!pip install openai
!pip install bert_score
!pip install datasets

In [None]:
from google.colab import auth as google_auth
import vertexai
from vertexai.preview.language_models import TextGenerationModel

google_auth.authenticate_user()

# TODO: Replace with project ID from Cloud Console
# (https://support.google.com/googleapi/answer/7014113)
PROJECT_ID = 'PROJECT_ID'

# MedLM models are only available in us-central1.
vertexai.init(project=PROJECT_ID, location='us-central1')

parameters = {
    "candidate_count": 1,
    "max_output_tokens": 256,
    "temperature": 0.0,
    "top_k": 40,
    "top_p": 0.80,
}

model_instance = TextGenerationModel.from_pretrained("medlm-large")
response = model_instance.predict(
    "Question: What causes you to get ringworm?",
    **parameters
)

print(f"Response from Model: {response.text}")

Response from Model:  Ringworm is a common fungal infection of the skin. It can affect people of all ages, and it is spread by contact with an infected person or animal. The fungus that causes ringworm can live on the skin, hair, and nails. Ringworm is most commonly found on the scalp, feet, and groin area. In most cases, ringworm is not a serious infection and can be treated with over-the-counter antifungal medications. However, in some cases, ringworm can cause severe itching and discomfort, and it may require treatment with prescription medications. If you are experiencing symptoms of ringworm, it is important to see a doctor for diagnosis and treatment.


In [None]:
from google.colab import auth as google_auth
import vertexai
from vertexai.preview.language_models import TextGenerationModel
from sklearn.metrics import accuracy_score
from datasets import load_dataset
import pandas as pd
import bert_score
import os
import random
import openai

google_auth.authenticate_user()
PROJECT_ID = 'PROJECT_ID'
vertexai.init(project=PROJECT_ID, location='us-central1')

# Set up OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")

ds = load_dataset('arrow', data_files={'test': '/content/drive/My Drive/data-00000-of-00001.arrow'})['test']

### CUPCase Eval MedLM-Large


#### Free Text

In [None]:
import time
parameters = {
    "candidate_count": 1,
    "temperature": 0.0,
}

model_instance = TextGenerationModel.from_pretrained("medlm-large")

results = []

shuffled_ds = ds.shuffle(seed=42)

batch_size = 250
num_batches = 4

for i in range(num_batches):
    start_idx = i * batch_size
    end_idx = start_idx + batch_size
    batch_ds = shuffled_ds.select(range(start_idx, end_idx))

    for idx, example in enumerate(batch_ds):
        case_presentation = example['clean_case_presentation']
        true_diagnosis = example['correct_diagnosis']
        prompt = (f"Predict the diagnosis of this case presentation of a patient. Return the final diagnosis in one concise sentence "
                  f"without any further elaboration.\nFor example: <diagnosis name here>\nCase presentation: {case_presentation}\nDiagnosis:")

        response = model_instance.predict(prompt, **parameters)
        generated_diagnosis = response.text.strip()
        results.append({
            'Case presentation': case_presentation,
            'True diagnosis': true_diagnosis,
            'Generated diagnosis': generated_diagnosis
        })

        # Sleep every 50 samples
        if (idx + 1) % 50 == 0:
            time.sleep(3)

results_df = pd.DataFrame(results)

model_type = "microsoft/deberta-xlarge-mnli"
predictions = results_df['Generated diagnosis'].tolist()
references = results_df['True diagnosis'].tolist()

P, R, F1 = bert_score.score(predictions, references, lang="en", model_type=model_type)

results_df['BERTScore F1'] = F1.tolist()

results_df.to_csv('free_text_medlm_large.csv', index=False)

print(results_df)


#### Multi Choice

In [None]:

parameters = {
    "candidate_count": 1,
    "temperature": 0.0,
}

model_instance = TextGenerationModel.from_pretrained("medlm-large")

results = []

# Shuffle the dataset
shuffled_ds = ds.shuffle(seed=42)

batch_size = 250
num_batches = 4

# Process each batch
for i in range(num_batches):
    start_idx = i * batch_size
    end_idx = start_idx + batch_size
    batch_ds = shuffled_ds.select(range(start_idx, end_idx))

    # Track the number of processed samples
    for idx, example in tqdm(enumerate(batch_ds), total=len(batch_ds), desc="Processing batch"):
        case_presentation = example['clean_case_presentation']
        true_diagnosis = example['correct_diagnosis']
        distractor2 = example['distractor2']
        distractor3 = example['distractor3']
        distractor4 = example['distractor4']

        options = [true_diagnosis, distractor2, distractor3, distractor4]
        random.shuffle(options)
        options_text = "\n".join([f"{i+1}. {option}" for i, option in enumerate(options)])
        prompt = (f"Predict the diagnosis of this case presentation of a patient. Return only the correct index from the following list, for example: 3\n"
                  f"{options_text}\nCase presentation: {case_presentation}")

        # Retry logic
        while True:
            try:
                response = model_instance.predict(prompt, **parameters)
                generated_diagnosis = response.text.strip()

                try:
                    predicted_index = int(generated_diagnosis[0]) - 1
                except Exception as e:
                    predicted_index = -1
                    print(e)

                print(predicted_index)
                break  # Exit the retry loop if successful

            except Exception as e:
                print(f"An error occurred. Retrying in 5 seconds...{e}")
                time.sleep(5)  # Sleep for 5 seconds and retry

        try:
            results.append({
                'Case presentation': case_presentation,
                'True diagnosis': true_diagnosis,
                'Generated diagnosis': generated_diagnosis,
                'Correct index': options.index(true_diagnosis),
                'Predicted index': predicted_index,
                'Correct': options.index(true_diagnosis) == predicted_index
            })
        except Exception:
            print("An error occurred while appending results.")

        # Sleep every 50 samples
        if (idx + 1) % 50 == 0:
            results_df = pd.DataFrame(results)
            results_df.to_csv(f'checkpoint_{idx}.csv')
            time.sleep(3)

# Save results to a DataFrame and CSV
results_df = pd.DataFrame(results)
results_df.to_csv('qa_medlm_large.csv', index=False)

# Calculate and print accuracy
accuracy = accuracy_score(results_df['Correct'], [True]*len(results_df))
print(f"Accuracy: {accuracy:.2f}")
