In [1]:
import re
import base64
import json
import random
import requests

OUTPUT_DIR = '../exp_output/2023-11-23-14_35_26/random_1/Xray'
api_key = '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'

In [2]:

def get_gpt_response_for_covid(image_path_list, prompt, api_key=None):
    def encode_image(image_path):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')

    # Getting the base64 string

    base64_image_0 = encode_image(image_path_list[0])
    base64_image_1 = encode_image(image_path_list[1])
    base64_image_2 = encode_image(image_path_list[2])

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }

    payload = {
        "model": "gpt-4-vision-preview",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": f"{prompt}"
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image_0}"
                        }
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image_1}"
                        }
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image_2}"
                        }
                    }
                ]
            }
        ],
        "max_tokens": 500
    }

    response = requests.post(
        "https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
    return response


def restructure_data(data):
    new_structure = {}
    target_samples = []

    for key, sample in data['samples'].items():
        domain = sample['domain']
        class_name = sample['class']
        image = OUTPUT_DIR+'/'+sample['image']
        sample['image'] = image
        if domain not in new_structure:
            new_structure[domain] = {}
        if class_name not in new_structure[domain]:
            new_structure[domain][class_name] = []

        new_structure[domain][class_name].append(image)

        # Collect samples from the target domain
        if domain == 'target':
            target_samples.append(sample)

    return new_structure, target_samples

In [3]:
def search_pred_info(item_id, answer_by_gpt, class_names):
    predicted_class = None
    for class_name in class_names:
        # Pattern to match 'Answer Choice: [class_name]' or 'Answer Choice: class_name' (case-insensitive)
        pattern = re.compile(
            r"Answer Choice:\s*(?:\[)?'?\"?" +
            re.escape(class_name) + r"'?\"?(?:\])?",
            re.IGNORECASE
        )
        if pattern.search(answer_by_gpt):
            predicted_class = class_name
            break
    if not predicted_class:
        print(
            'Query failed for item {}; Check the answer or the pattern matching carefully!'.format(item_id))
    # Regular expression patterns to extract Confidence Score (0~1) and Reasoning
    confidence_score_pattern = r'Confidence Score:\s*([0-9]*\.?[0-9]+)'
    reasoning_pattern = r'Reasoning:\s*(.+)'

    # Extract Confidence Score
    confidence_score_match = re.search(
        confidence_score_pattern, answer_by_gpt, re.DOTALL)
    if confidence_score_match:
        confidence_score = confidence_score_match.group(1).strip()
    else:
        confidence_score = None

    # Extract Reasoning
    reasoning_match = re.search(reasoning_pattern, answer_by_gpt, re.DOTALL)
    if reasoning_match:
        reasoning = reasoning_match.group(1).strip()
    else:
        reasoning = None
    return predicted_class, confidence_score, reasoning

In [4]:

# Load the JSON file
with open(f'{OUTPUT_DIR}/unified_input_Xray.json', 'r') as f:
    data = json.load(f)
domain_names = data['domains']
class_names = data['class_names']
# Assuming 'data' is your original dictionary
restructured_data, target_samples = restructure_data(data)
print(restructured_data)
print(target_samples[0])

{'source': {'Normal': ['../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/NORMAL2-IM-0895-0001.jpeg', '../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/NORMAL2-IM-0501-0001.jpeg', '../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/IM-0433-0001.jpeg', '../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/NORMAL2-IM-0808-0001.jpeg', '../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/NORMAL2-IM-0885-0001.jpeg', '../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/IM-0599-0001.jpeg', '../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/IM-0363-0001.jpeg', '../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/IM-0220-0001.jpeg', '../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/IM-0588-0001.jpeg', '../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Normal/NORMAL2-IM-0576-0001.jpeg'], 'Pneumonia': ['../exp_output/2023-11-23-14_35_26/random_1/Xray/source/Pneumonia/person1307_virus_2251.jpeg

In [5]:
class_names_in_batch = list(restructured_data['source'].keys())
print(f'class_names in batch: {class_names_in_batch}')
if len(class_names_in_batch) != len(class_names):
    print('Not covering all classes in random samples')
    
data = []
for target_sample in target_samples:
    selected_source_samples = [random.choice(
        restructured_data['source'][class_name]) for class_name in class_names_in_batch]
    selected_source_samples.append(target_sample['image'])
    selected_samples = selected_source_samples
    # print('\n--------------------\n')
    # print(target_sample)
    unified_output = {}
    first_class, second_class = class_names_in_batch[0], class_names_in_batch[1]
    prompt = f"""Given the images, answer the following question, using the specified format. 
    The first image is {first_class} and the second image is {second_class}.
    Question: What is the third image? Choices: {class_names}. 

    Please respond with the following format for each image:
    ---BEGIN FORMAT TEMPLATE---
    Answer Choice: [Your Answer Choice Here]
    Confidence Score: [Your Numerical Prediction Confidence Score Here From 0 To 1]
    Reasoning: [Your Reasoning Behind This Answer Here]
    ---END FORMAT TEMPLATE---

    Do not deviate from the above format. Repeat the format template for the answer.
    """
    response = get_gpt_response_for_covid(
        image_path_list=selected_samples, prompt=prompt, api_key=api_key).json()
    if 'error' in response:
        print(f'gpt-4v returns server error')
    else:
        answer_by_gpt = response['choices'][0]['message']['content']
        print(f'answer by gpt-4v:\n{answer_by_gpt}')
        predicted_class, confidence_score, reasoning = search_pred_info(
            0, answer_by_gpt, class_names)
        
        unified_output['dataset'] = 'Xray'
        unified_output['domain'] = target_sample['domain']
        unified_output['subject'] = target_sample['subject']
        unified_output['true_class'] = target_sample['class']
        unified_output['predicted_class'] = predicted_class
        unified_output['confidence_score'] = confidence_score
        unified_output['reasoning'] = reasoning
        unified_output['image'] = target_sample['image']
        data.append(unified_output)
        # break
        # # unified_output['id'] = item_id
        # mode = 'w' if first_flag else 'a'
        # with open(f'{OUTPUT_DIR}/unified_output_gpt-4-vision-preview-incontext.jsonl', mode) as f:
        #     f.write(json.dumps(unified_output) + "\n")
        # first_flag = False

class_names in batch: ['Normal', 'Pneumonia']
Not covering all classes in random samples


answer by gpt-4v:
---BEGIN FORMAT TEMPLATE---
Answer Choice: Pneumonia
Confidence Score: 0.7
Reasoning: The third image is suggestive of pneumonia due to the presence of patchy areas of increased opacity within the lung fields, which is consistent with lung consolidation associated with pneumonia. However, it is important to note that definitive diagnosis should always be confirmed by a medical professional, and other clinical information might be necessary for accurate diagnosis.
---END FORMAT TEMPLATE---
answer by gpt-4v:
---BEGIN FORMAT TEMPLATE---
Answer Choice: Pneumonia
Confidence Score: 0.85
Reasoning: The third image displays radiographic features consistent with a pneumonia diagnosis, such as localized opacities and consolidation, which are absent in normal lung radiographs and distinct from the diffuse bilateral ground-glass opacities often associated with COVID-19 pneumonia. However, without proper medical training, this assessment may not be accurate.
---END FORMAT TEMPLATE

In [6]:
# Initialize counters and structures
total_count = 0
correct_count = 0
dataset_domain_class_accuracy = {}

# Process each entry in the data
for entry in data:
    dataset = entry['dataset']
    domain = entry['domain']
    true_class = entry['true_class']
    predicted_class = entry['predicted_class']

    # Count for overall accuracy
    total_count += 1
    correct_count += true_class == predicted_class

    # Initialize dataset, domain, and class if not exist
    if dataset not in dataset_domain_class_accuracy:
        dataset_domain_class_accuracy[dataset] = {
            'total': 0, 'correct': 0, 'domains': {}}
    if domain not in dataset_domain_class_accuracy[dataset]['domains']:
        dataset_domain_class_accuracy[dataset]['domains'][domain] = {
            'total': 0, 'correct': 0, 'classes': {}}
    if true_class not in dataset_domain_class_accuracy[dataset]['domains'][domain]['classes']:
        dataset_domain_class_accuracy[dataset]['domains'][domain]['classes'][true_class] = {
            'total': 0, 'correct': 0}

    # Count for dataset, domain, and class accuracy
    dataset_domain_class_accuracy[dataset]['total'] += 1
    dataset_domain_class_accuracy[dataset]['correct'] += true_class == predicted_class
    dataset_domain_class_accuracy[dataset]['domains'][domain]['total'] += 1
    dataset_domain_class_accuracy[dataset]['domains'][domain]['correct'] += true_class == predicted_class
    dataset_domain_class_accuracy[dataset]['domains'][domain]['classes'][true_class]['total'] += 1
    dataset_domain_class_accuracy[dataset]['domains'][domain]['classes'][
        true_class]['correct'] += true_class == predicted_class

# Calculate overall accuracy
overall_accuracy = correct_count / total_count if total_count > 0 else 0

# Calculate accuracy for each dataset, domain, and class
for dataset, info in dataset_domain_class_accuracy.items():
    info['accuracy'] = info['correct'] / \
        info['total'] if info['total'] > 0 else 0
    for domain, domain_info in info['domains'].items():
        domain_info['accuracy'] = domain_info['correct'] / \
            domain_info['total'] if domain_info['total'] > 0 else 0
        for class_name, class_info in domain_info['classes'].items():
            class_info['accuracy'] = class_info['correct'] / \
                class_info['total'] if class_info['total'] > 0 else 0

# Save results to JSON file
output_data = {
    'overall_accuracy': overall_accuracy,
    'datasets': dataset_domain_class_accuracy
}
print(output_data)

{'overall_accuracy': 0.41935483870967744, 'datasets': {'Xray': {'total': 31, 'correct': 13, 'domains': {'target': {'total': 31, 'correct': 13, 'classes': {'Pneumonia': {'total': 10, 'correct': 4, 'accuracy': 0.4}, 'Normal': {'total': 11, 'correct': 6, 'accuracy': 0.5454545454545454}, 'COVID19': {'total': 10, 'correct': 3, 'accuracy': 0.3}}, 'accuracy': 0.41935483870967744}}, 'accuracy': 0.41935483870967744}}}


In [7]:
with open(f'{OUTPUT_DIR}/results_gpt-4-vision-preview_incontext.json', 'w') as outfile:
    json.dump(output_data, outfile, indent=4)

In [8]:
def read_and_reformat_jsonl(file_path):
    reformatted_data = {}
    with open(file_path, 'r') as file:
        for line in file:
            record = json.loads(line)
            image_key = record.pop('image')
            reformatted_data[image_key] = record
    return reformatted_data


original_data = read_and_reformat_jsonl(
    f'{OUTPUT_DIR}/unified_output_gpt-4-vision-preview.jsonl')
total, num_orign, num_incontext = 0, 0, 0
for img, val in original_data.items():
    if val['domain'] != 'target':
        continue
    if val['true_class'] == val['predicted_class']:
        num_orign += 1
    for sample_incontext in data:
        if img in sample_incontext['image']:
            if sample_incontext['true_class'] == sample_incontext['predicted_class']:
                num_incontext += 1
    total += 1

In [9]:
print(total, num_orign, num_incontext)

28 11 12
