In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm
import copy
from influence_functions import calc_influence_single_group_upweight, calc_influence_single_group_pert
from utils import set_attr
from argparse import Namespace
import math

import h5py
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from scipy.stats import rice
import pickle

seed = 111
noise = 'gaussian'

In [None]:
with open(f"data_for_llm/{noise}/seed_{seed}/train_data_seed_{seed}.pkl", "rb") as file:
    train_data = pickle.load(file)
    
labels_train = train_data['labels_train']

In [None]:
demonstration_indices_dict = {8: [1663, 1731, 2073, 2112, 1806, 1965], 
                              1686: [617, 618, 1815, 400, 1360, 217],
                             1711: [1713, 3, 1309, 1306, 151, 1925],
                             111: [96, 32, 441, 1707, 716, 1409]}

demonstration_indices = demonstration_indices_dict[seed]

In [None]:
# selected indices using clean trained
# demonstration_indices = [47, 119, 1399]

In [None]:
import base64
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

In [None]:
from PIL import Image
import requests

noisy_image_train_dir = f"/home/haovan/IF_medical_image/data_for_llm/{noise}/seed_{seed}/train_images/original/"

demonstration_images = []
demonstration_base64 = []
for idx in range(len(demonstration_indices)):
    image = Image.open(noisy_image_train_dir + f"{demonstration_indices[idx]}_{labels_train[demonstration_indices[idx]]}.jpg")
    demonstration_images.append(image)
    demonstration_base64.append(encode_image(noisy_image_train_dir+f"{demonstration_indices[idx]}_{labels_train[demonstration_indices[idx]]}.jpg"))


In [None]:
demonstration_images[0]

### Few-shot

In [None]:
from openai import OpenAI
import re
import json
client = OpenAI(api_key="<your-api-key-here>")

if len(demonstration_indices) == 3:
    actuals = []
    answers = []
    image_folder = f"/home/haovan/IF_medical_image/data_for_llm/{noise}/seed_{seed}/test_images/original/"
    all_test = os.listdir(image_folder)
    
    for idx, image_file in enumerate(all_test):
        print(f"Processing {image_file}, {(idx+1)}/{len(all_test)}")
        base64_image = encode_image(image_folder+image_file)
        actuals.append(int(image_file.split("_")[-1][0]))

        response = client.chat.completions.create(
          model="gpt-4o",
          messages=[
            {
              "role": "system",
              "content": [
                {
                  "type": "text",
              "text": "Predict the tumor type for a given brain MRI (all MRIs contain tumors). The potential tumor types are 'meningioma', 'glioma', or 'pituitary'. You are provided with labeled examples to help with understanding the task, which demonstrates in-context learning through these examples. Analyze the provided test MRI to determine the most likely diagnosis among the three tumor types.\n\n# Input Format\n\nHere are demonstration examples with MRI images and their corresponding tumor types:\n\n- Example 1:\n  <image-1>  \n  Tumor Type: meningioma\n\n- Example 2:\n  <image-2>  \n  Tumor Type: glioma \n\n- Example 3:\n  <image-3>  \n  Tumor Type: pituitary \n\nTest MRI:\n<image-test>\n\n# Steps \n\n1. Review the provided demonstration examples, analyzing the features that correspond to each tumor type. \n2. Consider the given test MRI and compare its features with examples to determine the most similar tumor type.\n3. Make a prediction for the test MRI based on observed patterns.\n\n# Output Format\n\nProvide the answer in the following JSON format:\n\n```json\n{\n  \"tumor_type\": \"<meningioma/glioma/pituitary>\"\n}\n```\n\n# Notes\n\n- Ensure your prediction is based solely on the features of the given test MRI.\n- Consider all three possible tumor types before arriving at a conclusion.\n- Use the provided examples effectively to enhance your understanding before making a prediction."
                }
              ]
            },
            {
              "role": "user",
              "content": [
                {
                  "type": "text",
              "text": "Here are demonstration examples with MRI images and their corresponding tumor types:\n\n- Example 1:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/png;base64,{demonstration_base64[0]}"
                  }
                },
                {
                  "type": "text",
                  "text": "Tumor type: meningioma\n\n- Example 2:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/png;base64,{demonstration_base64[1]}"
                  }
                },
                {
                  "type": "text",
                  "text": "Tumor type: glioma\n\n- Example 3:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/png;base64,{demonstration_base64[2]}"
                  }
                },
                {
                  "type": "text",
                  "text": "Tumor type: pituitary\n\nTest MRI:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}",
                  }
                }
              ]
            }
          ],
          temperature=1.2,
          max_tokens=64,
          top_p=1,
          frequency_penalty=0,
          presence_penalty=0,
          response_format={
            "type": "text"
          }
        )

        res = response.choices[0].message.content
        ans = re.search(r"```json(.*?)```", res.replace("\n", ""))
        if ans is not None:
            parsed_ans = json.loads(ans.group(1))
        else:
            parsed_ans = json.loads('{"tumor_type": "No Answer Received"}')
            print(f"Response: {res}")
        print(parsed_ans)
        answers.append(parsed_ans)

In [None]:
if len(demonstration_indices) == 6:
    actuals = []
    answers = []
    ori_or_noisy = 'noisy' if p_test > 0 else 'original'
    image_folder = f"/home/haovan/IF_medical_image/data_for_llm/{noise}/seed_{seed}/test_images/{ori_or_noisy}/"

    all_test = os.listdir(image_folder)
    
    for idx, image_file in enumerate(all_test):
        print(f"Processing {image_file}, {idx+1}/{len(all_test)}")
        base64_image = encode_image(image_folder+image_file)
        actuals.append(int(image_file.split("_")[-1][0]))

        response = client.chat.completions.create(
          model="gpt-4o",
          messages=[
            {
              "role": "system",
              "content": [
                {
                  "type": "text",
                  "text": "Predict the tumor type for a given brain MRI (all MRIs contain tumors). The potential tumor types are 'meningioma', 'glioma', or 'pituitary'. You are provided with labeled examples to help with understanding the task, which demonstrates in-context learning through these examples. Analyze the provided test MRI to determine the most likely diagnosis among the three tumor types.\n\n# Input Format\n\nHere are demonstration examples with MRI images and their corresponding tumor types:\n\n- Example 1:\n  <image-1>  \n  Tumor Type: meningioma\n\n- Example 2:\n  <image-2>  \n  Tumor Type: meningioma\n\n- Example 3:\n  <image-3>  \n  Tumor Type: glioma \n\n- Example 4:\n  <image-4>  \n  Tumor Type: glioma \n\n- Example 5:\n  <image-5>  \n  Tumor Type: pituitary \n\n- Example 6:\n  <image-6>  \n  Tumor Type: pituitary \n\nTest MRI:\n<image-test>\n\n# Steps \n\n1. Review the provided demonstration examples, analyzing the features that correspond to each tumor type. \n2. Consider the given test MRI and compare its features with examples to determine the most similar tumor type.\n3. Make a prediction for the test MRI based on observed patterns.\n\n# Output Format\tun\nProvide the answer in the following JSON format:\n\n```json\n{\n  \"tumor_type\": \"<meningioma/glioma/pituitary>\"\n}\n```\n\n# Notes\n\n- Ensure your prediction is based solely on the features of the given test MRI.\n- Consider all three possible tumor types before arriving at a conclusion.\n- Use the provided examples effectively to enhance your understanding before making a prediction."
                }
              ]
            },
            {
              "role": "user",
              "content": [
                {
                  "type": "text",
              "text": "Here are demonstration examples with MRI images and their corresponding tumor types:\n\n- Example 1:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/png;base64,{demonstration_base64[0]}"
                  }
                },
                {
                  "type": "text",
                  "text": "Tumor type: meningioma\n\n- Example 2:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/png;base64,{demonstration_base64[1]}"
                  }
                },
                {
                  "type": "text",
                  "text": "Tumor type: meningioma\n\n- Example 3:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/png;base64,{demonstration_base64[2]}"
                  }
                },
                {
                  "type": "text",
                  "text": "Tumor type: glioma\n\n- Example 4:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/png;base64,{demonstration_base64[3]}"
                  }
                },
                {
                  "type": "text",
                  "text": "Tumor type: glioma\n\n- Example 5:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/png;base64,{demonstration_base64[4]}"
                  }
                }, 
                {
                  "type": "text",
                  "text": "Tumor type: pituitary\n\n- Example 6:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/png;base64,{demonstration_base64[5]}"
                  }
                },
                {
                  "type": "text",
                  "text": "Tumor type: pituitary\n\nTest MRI:"
                },
                {
                  "type": "image_url",
                  "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}",
                  }
                }
              ]
            }
          ],
          temperature=1.2,
          max_tokens=64,
          top_p=1,
          frequency_penalty=0,
          presence_penalty=0,
          response_format={
            "type": "text"
          }
        )

        res = response.choices[0].message.content
        ans = re.search(r"```json(.*?)```", res.replace("\n", ""))
        if ans is not None:
            parsed_ans = json.loads(ans.group(1))
        else:
            parsed_ans = json.loads('{"tumor_type": "No Answer Received"}')
            print(f"Response: {res}")
        print(parsed_ans)
        answers.append(parsed_ans)

In [None]:
predictions = []
for idx, ans in enumerate(answers):
    pred = -1
    if "meningioma" in ans['tumor_type'].lower():
        pred = 0
    elif "glioma" in ans['tumor_type'].lower():
        pred = 1
    elif "pituitary" in ans['tumor_type'].lower():
        pred = 2
    if pred == -1:
        print(ans)
        print(str(idx)+" --", os.listdir(image_folder)[idx])
    predictions.append(pred)
    
print(sum(np.array(actuals) == np.array(predictions)))
print(sum(np.array(actuals) == np.array(predictions))/len(actuals))