In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm
import copy
from influence_functions import calc_influence_single_group_upweight, calc_influence_single_group_pert
from utils import set_attr
from argparse import Namespace
import math

import h5py
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from scipy.stats import rice
import pickle

seed = 8

noise = 'rician'


In [None]:
import pathlib
import textwrap

import numpy as np
import pickle

from IPython.display import display
from IPython.display import Markdown
import torch
import torch.nn.functional as F
import time

from sklearn.metrics import roc_auc_score
import re

def to_markdown(text):
    text = text.replace('•', '  *')
    return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

import google.generativeai as genai

In [None]:
with open(f"data_for_llm/{noise}/seed_{seed}/train_data_seed_{seed}.pkl", "rb") as file:
    train_data = pickle.load(file)
    
labels_train = train_data['labels_train']

In [None]:
demonstration_indices_dict = {8: [1663, 1731, 2073, 2112, 1806, 1965], 
                              1686: [617, 618, 1815, 400, 1360, 217],
                             1711: [1713, 3, 1309, 1306, 151, 1925],
                             111: [96, 32, 441, 1707, 716, 1409]}

demonstration_indices = demonstration_indices_dict[seed]

In [None]:
from PIL import Image
import requests

noisy_image_train_dir = f"/home/haovan/IF_medical_image/data_for_llm/{noise}/seed_{seed}/train_images/original/"

demonstration_images = []
demonstration_base64 = []
for idx in range(len(demonstration_indices)):
    image = Image.open(noisy_image_train_dir + f"{demonstration_indices[idx]}_{labels_train[demonstration_indices[idx]]}.jpg")
    demonstration_images.append(image)


In [None]:
if len(demonstration_images) > 0:
    demonstration_images[0]

In [None]:
genai.configure(api_key='<your-api-key-here>')

for m in genai.list_models():
    if 'generateContent' in m.supported_generation_methods:
        print(m.name)

In [None]:
safety_settings = [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
                   {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, 
                   {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
                   {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}]

In [None]:
model = genai.GenerativeModel('models/gemini-1.5-flash-latest', safety_settings = safety_settings)

In [None]:
actuals = []
answers = []
answers_dict = {}

### Zero-shot

In [None]:
import re
import json


if len(demonstration_indices) == 0:
    image_folder = f"/home/haovan/IF_medical_image/data_for_llm/{noise}/seed_{seed}/test_images/original/"

    prompt = [
        "Predict the tumor type for a given brain MRI. The potential tumor types are 'meningioma', 'glioma', or 'pituitary'. Analyze the provided test MRI to determine the most likely diagnosis among the three tumor types.\n\nProvide the answer in the following JSON format:\n\n```json\n{\n  \"tumor_type\": \"<meningioma/glioma/pituitary>\"\n}\n```\n\nTest image:"]
    
    all_images = os.listdir(image_folder)
    for idx, image_file in enumerate(all_images[len(answers):]):
        print(f"Processing {image_file}, {idx+1}/{len(all_images)}")
        test_image = Image.open(image_folder+image_file)
        actuals.append(int(image_file.split("_")[-1][0]))
        final_prompt = prompt + [test_image]
       
        response = model.generate_content(final_prompt, request_options={"timeout": 1000})
        res = response.text

        ans = re.search(r"```json(.*?)```", res.replace("\n", ""))
        if ans is not None:
            parsed_ans = json.loads(ans.group(1))
        else:
            parsed_ans = json.loads('{"tumor_type": "No Answer Received"}')
            print(f"Response: {res}")
        print(parsed_ans)
        answers.append(parsed_ans)
        time.sleep(4)

### Few-shot

In [None]:
import re
import json


if len(demonstration_indices) == 3:
    image_folder = f"/home/haovan/IF_medical_image/data_for_llm/{noise}/seed_8/test_images/original/"

    prompt = [
        "Predict the tumor type for a given brain MRI. The potential tumor types are 'meningioma', 'glioma', or 'pituitary'. Analyze the provided test MRI to determine the most likely diagnosis among the three tumor types.\n\nProvide the answer in the following JSON format:\n\n```json\n{\n  \"tumor_type\": \"<meningioma/glioma/pituitary>\"\n}\n```",
               "Here are demonstration examples with MRI images and their corresponding tumor types:\n\n- Image 1:", 
                  demonstration_images[0],
                 "Tumor type: meningioma\n\n- Image 2:", 
                  demonstration_images[1],
                 "Tumor type: glioma\n\n- Image 3:",
                 demonstration_images[2],
                 "Tumor type: pituitary\n\nTest image:"]
    
    all_images = os.listdir(image_folder)
    for idx, image_file in enumerate(all_images[len(answers):]):
        print(f"Processing {image_file}, {idx+1}/{len(all_images)}")
        test_image = Image.open(image_folder+image_file)
        actuals.append(int(image_file.split("_")[-1][0]))
        final_prompt = prompt + [test_image]
       
        response = model.generate_content(final_prompt, request_options={"timeout": 1000})
        res = response.text

        ans = re.search(r"```json(.*?)```", res.replace("\n", ""))
        if ans is not None:
            parsed_ans = json.loads(ans.group(1))
        else:
            parsed_ans = json.loads('{"tumor_type": "No Answer Received"}')
            print(f"Response: {res}")
        print(parsed_ans)
        answers.append(parsed_ans)
        time.sleep(4)

In [None]:

if len(demonstration_indices) == 6:
    prompt = [
        "Predict the tumor type for a given brain MRI. The potential tumor types are 'meningioma', 'glioma', or 'pituitary'. You are provided with labeled examples to help with understanding the task, which demonstrates in-context learning through these examples. Analyze the provided test MRI to determine the most likely diagnosis among the three tumor types.\n\nProvide the answer in the following JSON format:\n\n```json\n{\n  \"tumor_type\": \"<meningioma/glioma/pituitary>\"\n}\n```",
                "Here are demonstration examples with MRI images and their corresponding tumor types:\n\n- Image 1:", 
                  demonstration_images[0],
                 "Label: meningioma\n\n- Image 2:", 
                  demonstration_images[2],
                 "Label: glioma\n\n- Image 3:", 
                  demonstration_images[4],
                 "Label: pituitary\n\n- Image 4:",
                 demonstration_images[1],
                 "Label: meningioma\n\n- Image 5:",
                 demonstration_images[3],
                 "Label: glioma\n\n- Image 6:",
                 demonstration_images[5],
                 "Label: pituitary\n\nTest image:"]
    
    image_folder = f"/home/haovan/IF_medical_image/data_for_llm/{noise}/seed_8/test_images/original/"
    all_images = os.listdir(image_folder)
    for idx, image_file in enumerate(all_images[len(answers):]):
        print(f"Processing {image_file}, {idx+1}/{len(all_images)}")
        test_image = Image.open(image_folder+image_file)
        actuals.append(int(image_file.split("_")[-1][0]))
        final_prompt = prompt + [test_image]
       
        response = model.generate_content(final_prompt, request_options={"timeout": 1000})
        res = response.text

        ans = re.search(r"```json(.*?)```", res.replace("\n", ""))
        if ans is not None:
            parsed_ans = json.loads(ans.group(1))
        else:
            parsed_ans = json.loads('{"tumor_type": "No Answer Received"}')
            print(f"Response: {res}")
        print(parsed_ans)
        answers.append(parsed_ans)
        answers_dict[image_file] = parsed_ans
        time.sleep(5)

In [None]:
predictions = []
for idx, ans in enumerate(answers):
    pred = -1
    if "meningioma" in ans['tumor_type'].lower():
        pred = 0
    elif "glioma" in ans['tumor_type'].lower():
        pred = 1
    elif "pituitary" in ans['tumor_type'].lower():
        pred = 2
    if pred == -1:
        print(ans)
        print(str(idx)+" --", os.listdir(image_folder)[idx])
    predictions.append(pred)
    
print(sum(np.array(actuals) == np.array(predictions)))
print(sum(np.array(actuals) == np.array(predictions))/len(actuals))