In [10]:
import os
import pandas as pd
import transformers 
import torch
import yaml
from shutil import copyfile
from tqdm import tqdm 
from openai import AzureOpenAI
import random
import json
from datasets import load_dataset, Dataset, DatasetDict

In [21]:
metadata_path = "data/keno_1000/Metadata_1000_only_new.csv"
output_dir = "data/keno_1000/vqa"
model_name = "jb-turbo-2024-04-09"  


client = AzureOpenAI(
    api_key="e849b8c4c4a04d3d817aa67d66189251",
    api_version="2024-02-01",
    azure_endpoint="https://jb-turbo-2024-04-09.openai.azure.com/",
)

# Replace the iteration section
total_images = 1000  # Number of images to process

In [4]:
question_types = [
    "Does the chest X-ray show cardiomegaly (yes or no)?",
    "Is there a pulmonary congestion (yes or no)?",
    "Is there a right-sided pleural effusion (yes or no)?",
    "Is there a left-sided pleural effusion (yes or no)?",
    "Are there opacities in the right lung (yes or no)?",
    "Are there opacities in the left lung (yes or no)?",
    "Is there a right-sided atelectasis (yes or no)?",
    "Is there a left-sided atelectasis (yes or no)?",
    "Is there a right-sided pneumothorax (yes or no)?",
    "Is there a left-sided pneumothorax (yes or no)?",
    "Is there a central venous catheter present in the image (yes or no)?",
    "Is there a gastric tube present in the image (yes or no)?",

    "What is the size of the heart (normal, borderline, enlarged, massively enlarged)?",
    "What is the severity of pulmonary congestion (none, questionable, mild, moderate, severe)?",
    "What is the severity of the pleural effusion on the right (none, questionable, mild, moderate, severe)?",
    "What is the severity of the pleural effusion on the left (none, questionable, mild, moderate, severe)?",
    "What is the severity of right-sided pulmonary opacities (none, questionable, mild, moderate, severe)?",
    "What is the severity of left-sided pulmonary opacities (none, questionable, mild, moderate, severe)?",
    "What is the severity of right-sided atelectasis (none, questionable, mild, moderate, severe)?",
    "What is the severity of left-sided atelectasis (none, questionable, mild, moderate, severe)?",

    "Which side has worse pleural effusion? (right, left, same severity, absent)?",
    "Which side has worse pulmonary opacities? (right, left, same severity, absent)?",
    "Which side has worse atelectasis? (right, left, same severity, absent)?",
]
print("Num question types:", len(question_types))

Num question types: 23


In [5]:
metadata_df = pd.read_csv(metadata_path)
metadata_df.set_index("UID", inplace=True)


In [22]:
processed = 100
cardio_map = {-1: "not assessable", 0: "normal", 1: "borderline", 2: "enlarged", 4: "massively enlarged"}
other_map = {0: "none", 1: "questionable", 2: "mild", 3: "moderate", 4: "severe"}
pneumo_map = {0: "no", 1: "there is a"}
def describe_row(row):
    parts = [f"Patient age {int(row['Age'])//365} years"]
    parts.append(f"{cardio_map.get(row['cardiomegaly2'], 'unknown')} heart size")
    parts.append(f"{other_map.get(row['congestion2'], 'unknown')} congestion")
    parts.append(f"{other_map.get(row['pleural_effusion_right2'], 'unknown')} right pulmonary opacities")
    parts.append(f"{other_map.get(row['pleural_effusion_left2'], 'unknown')} left pulmonary opacities")
    parts.append(f"{other_map.get(row['pneumonic_infiltrates_right2'], 'unknown')} right pneumonic infiltrates")
    parts.append(f"{other_map.get(row['pneumonic_infiltrates_left2'], 'unknown')} left pneumonic infiltrates")
    parts.append(f"{other_map.get(row['atelectasis_right2'], 'unknown')} right atelectasis")
    parts.append(f"{other_map.get(row['atelectasis_left2'], 'unknown')} left atelectasis")
    parts.append(f"{pneumo_map.get(row['pneumothorax_right'], 'unknown')} right pneumothorax")
    parts.append(f"{pneumo_map.get(row['pneumothorax_left'], 'unknown')} left pneumothorax")
    parts.append(f"{(row['Sonstiges'], 'unknown')}")
    return "Clinical data: " + ", ".join(parts) + "."

def generate_answer(question, row):
    clinical_info = describe_row(row)
    if not clinical_info:
        return "No clinical information available for this image."
    prompt = f"Given the following clinical information, answer the question. Assume support devices not mentioned as absent: {clinical_info}\nQuestion: {question}\nAnswer:"
    response = client.chat.completions.create(
        model=model_name,
        messages=[{"role": "user", "content": prompt}],
        max_tokens=50,
        temperature=0.0,
    )
    answer = response.choices[0].message.content.strip()
    return answer
# Sample annotation generation script
def generate_vqa_pairs(metadata_df, question_types, num_questions_per_image=5):
    # check if output file exists and if yes, load it
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    if os.path.exists(os.path.join(output_dir, "vqa_eval_set.jsonl")):
        with open(os.path.join(output_dir, "vqa_eval_set.jsonl"), "r") as f:
            vqa_data = [json.loads(line) for line in f]
        print(f"Loaded {len(vqa_data)} existing VQA samples from {output_dir}")
    else:
        vqa_data = []
    for uid, row in metadata_df.iterrows():
        # check if uid already exists in vqa_data
        if any(sample['image_id'] == uid for sample in vqa_data):
            continue
        selected_questions = random.sample(question_types, num_questions_per_image)
        for question in selected_questions:
            answer = generate_answer(question, row)
            question_type = question.split(" ")[0]  # Extract the type of question (e.g., "Does", "Is", "What")
            if question_type.lower() in ["does", "is", "are"]:
                question_type = "binary"
                if answer.lower() in ["yes", "no"]:
                    answer = "yes" if "yes" in answer.lower() else "no"
                # else discard the answer if it doesn't match expected binary response and try another question
                else:
                    continue
            elif question_type.lower() in ["what"]:
                question_type = "ordinal"
                if answer.lower() in ["none", "questionable", "mild", "moderate", "severe", "normal", "borderline", "enlarged", "massively enlarged"]:
                    answer = answer.lower()
                else:
                    continue
            else:
                question_type = "comparison"
                if answer.lower() in ["right", "left", "same severity", "absent"]:
                    answer = answer.lower()
                else:
                    continue
            vqa_sample = {
                "image_id": uid,
                "Question": question,
                "Answer": answer,  
                "Type": question_type,
                "PatientID": row['PatientID'],
                "Age": row['Age'],
                "Cardiomegaly": row['cardiomegaly2'],
                "PulmonaryCongestion": row['congestion2'],
                "PleuralEffusion_Right": row['pleural_effusion_right2'],
                "PleuralEffusion_Left": row['pleural_effusion_left2'],
                "PulmonaryOpacities_Right": row['pneumonic_infiltrates_right2'],
                "PulmonaryOpacities_Left": row['pneumonic_infiltrates_left2'],
                "Atelectasis_Right": row['atelectasis_right2'],
                "Atelectasis_Left": row['atelectasis_left2'],
                "Pneumothorax_Right": row['pneumothorax_right'],
                "Pneumothorax_Left": row['pneumothorax_left'],
                "Comments": row['Sonstiges'],
            }
            vqa_data.append(vqa_sample)
        global processed
        processed += 1
        if processed % 10 == 0:
            print(f"Processed {processed}/{total_images} images")
        if processed >= total_images:
            break
    print(f"Generated {len(vqa_data)} VQA samples")

    return vqa_data

# Generate samples
vqa_samples = generate_vqa_pairs(metadata_df, question_types, num_questions_per_image=8)

# Save to JSONL
output_file = os.path.join(output_dir, "vqa_eval_set.jsonl")
with open(output_file, "w") as f:
    for entry in vqa_samples:
        f.write(json.dumps(entry) + "\n")

# Display sample
vqa_samples[:1]


Loaded 637 existing VQA samples from data/keno_1000/vqa
Processed 110/1000 images
Processed 120/1000 images
Processed 130/1000 images
Processed 140/1000 images
Processed 150/1000 images
Processed 160/1000 images
Processed 170/1000 images
Processed 180/1000 images
Processed 190/1000 images
Processed 200/1000 images
Processed 210/1000 images
Processed 220/1000 images
Processed 230/1000 images
Processed 240/1000 images
Processed 250/1000 images
Processed 260/1000 images
Processed 270/1000 images
Processed 280/1000 images
Processed 290/1000 images
Processed 300/1000 images
Processed 310/1000 images
Processed 320/1000 images
Processed 330/1000 images
Processed 340/1000 images
Processed 350/1000 images
Processed 360/1000 images
Processed 370/1000 images
Processed 380/1000 images
Processed 390/1000 images
Processed 400/1000 images
Processed 410/1000 images
Processed 420/1000 images
Processed 430/1000 images
Processed 440/1000 images
Processed 450/1000 images
Processed 460/1000 images
Processe

[{'image_id': 'd4507d8df405e051cdbbcc9d07ef5303438b26938fd35e8960ed96a0271b8bf6',
  'Question': 'Is there a pulmonary congestion (yes or no)?',
  'Answer': 'yes',
  'Type': 'binary',
  'PatientID': 'b9ad4fa9-2312-4706-8c15-ba962861eca7',
  'Age': 24124,
  'Cardiomegaly': 1,
  'PulmonaryCongestion': 2,
  'PleuralEffusion_Right': 3,
  'PleuralEffusion_Left': 2,
  'PulmonaryOpacities_Right': 3,
  'PulmonaryOpacities_Left': 0,
  'Atelectasis_Right': 2,
  'Atelectasis_Left': 1,
  'Pneumothorax_Right': 0,
  'Pneumothorax_Left': 0,
  'Comments': 'Tracheostomy tubus in proper position. Gastric tube. '}]

In [23]:
print(f"Generated {len(vqa_samples)} VQA samples and saved to {output_file}")
vqa_dict = {
        'UID': [], 'Question': [], 'Answer': [], 'Type': [],
        'PatientID': [], 'Age': [],
        'HeartSize': [], 'PulmonaryCongestion': [],
        'PleuralEffusion_Right': [], 'PleuralEffusion_Left': [],
        'PulmonaryOpacities_Right': [], 'PulmonaryOpacities_Left': [],
        'Atelectasis_Right': [], 'Atelectasis_Left': []
    }
for sample in vqa_samples:
    vqa_dict['UID'].append(sample['image_id'])
    vqa_dict['Question'].append(sample['Question'])
    vqa_dict['Answer'].append(sample['Answer'])
    vqa_dict['Type'].append(sample['Type'])
    vqa_dict['PatientID'].append(sample['PatientID'])
    vqa_dict['Age'].append(sample['Age'])
    vqa_dict['HeartSize'].append(sample['Cardiomegaly'])
    vqa_dict['PulmonaryCongestion'].append(sample['PulmonaryCongestion'])
    vqa_dict['PleuralEffusion_Right'].append(sample['PleuralEffusion_Right'])
    vqa_dict['PleuralEffusion_Left'].append(sample['PleuralEffusion_Left'])
    vqa_dict['PulmonaryOpacities_Right'].append(sample['PulmonaryOpacities_Right'])
    vqa_dict['PulmonaryOpacities_Left'].append(sample['PulmonaryOpacities_Left'])
    vqa_dict['Atelectasis_Right'].append(sample['Atelectasis_Right'])
    vqa_dict['Atelectasis_Left'].append(sample['Atelectasis_Left'])
vqa_dataset = Dataset.from_dict(vqa_dict)
print(vqa_dataset)

Generated 6364 VQA samples and saved to data/keno_1000/vqa/vqa_eval_set.jsonl
Dataset({
    features: ['UID', 'Question', 'Answer', 'Type', 'PatientID', 'Age', 'HeartSize', 'PulmonaryCongestion', 'PleuralEffusion_Right', 'PleuralEffusion_Left', 'PulmonaryOpacities_Right', 'PulmonaryOpacities_Left', 'Atelectasis_Right', 'Atelectasis_Left'],
    num_rows: 6364
})


In [24]:
image_dataset = load_dataset("jomoll/TAIX-reasoning-v3.0-expert", name="default")

# match the uids with the images
def create_merged_dataset(split: str = "train") -> Dataset:
    image_dataset_split = image_dataset[split]
    # keep all vqa samples that have a matching uid in the image dataset
    vqa_dataset_split = vqa_dataset.filter(lambda x: x['UID'] in image_dataset_split['UID'])
    # enrich vqa dataset with images and metadata
    uid2image = {row['UID']: row for row in image_dataset_split}

    merged_dataset = vqa_dataset_split.map(
        lambda x: {
            "Split": uid2image[x['UID']]['Split'],
            "PhysicianID": uid2image[x['UID']]['PhysicianID'],
            "StudyDate": uid2image[x['UID']]['StudyDate'],
            "Sex": uid2image[x['UID']]["Sex"],
            "Image": uid2image[x['UID']]['Image'],
        }
    )
    return merged_dataset

vqa_dataset_train = create_merged_dataset("train")
print(vqa_dataset_train)
vqa_dataset_val = create_merged_dataset("val")
vqa_dataset_test = create_merged_dataset("test")

Filter: 100%|██████████| 6364/6364 [00:02<00:00, 2417.25 examples/s]
Map: 100%|██████████| 4071/4071 [01:04<00:00, 62.90 examples/s] 


Dataset({
    features: ['UID', 'Question', 'Answer', 'Type', 'PatientID', 'Age', 'HeartSize', 'PulmonaryCongestion', 'PleuralEffusion_Right', 'PleuralEffusion_Left', 'PulmonaryOpacities_Right', 'PulmonaryOpacities_Left', 'Atelectasis_Right', 'Atelectasis_Left', 'Split', 'PhysicianID', 'StudyDate', 'Sex', 'Image'],
    num_rows: 4071
})


Filter: 100%|██████████| 6364/6364 [00:00<00:00, 8069.06 examples/s]
Map: 100%|██████████| 996/996 [00:08<00:00, 120.85 examples/s]
Filter: 100%|██████████| 6364/6364 [00:00<00:00, 6754.52 examples/s]
Map: 100%|██████████| 1297/1297 [00:13<00:00, 97.69 examples/s] 


In [25]:
dataset_dict = DatasetDict({
        'train': vqa_dataset_train,
        'val': vqa_dataset_val,
        'test': vqa_dataset_test
    })

dataset_dict.push_to_hub("jomoll/TAIX-VQA", private=True)

Map: 100%|██████████| 1357/1357 [00:00<00:00, 3454.35 examples/s]s]
Creating parquet from Arrow format: 100%|██████████| 14/14 [00:00<00:00, 80.34ba/s]
Map: 100%|██████████| 1357/1357 [00:00<00:00, 2515.56 examples/s] 3.83s/it]
Creating parquet from Arrow format: 100%|██████████| 14/14 [00:00<00:00, 73.95ba/s]
Map: 100%|██████████| 1357/1357 [00:00<00:00, 3423.50 examples/s] 4.27s/it]
Creating parquet from Arrow format: 100%|██████████| 14/14 [00:00<00:00, 58.13ba/s]
Uploading the dataset shards: 100%|██████████| 3/3 [00:12<00:00,  4.16s/it]
Map: 100%|██████████| 996/996 [00:00<00:00, 4548.94 examples/s]t/s]
Creating parquet from Arrow format: 100%|██████████| 10/10 [00:00<00:00, 84.79ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.72s/it]
Map: 100%|██████████| 1297/1297 [00:00<00:00, 1787.79 examples/s]s]
Creating parquet from Arrow format: 100%|██████████| 13/13 [00:00<00:00, 67.09ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:03<00:00,  3.85s/

CommitInfo(commit_url='https://huggingface.co/datasets/jomoll/TAIX-VQA/commit/be63dd210b54ea1c99ce9ddb2c648fe0f6672307', commit_message='Upload dataset', commit_description='', oid='be63dd210b54ea1c99ce9ddb2c648fe0f6672307', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/jomoll/TAIX-VQA', endpoint='https://huggingface.co', repo_type='dataset', repo_id='jomoll/TAIX-VQA'), pr_revision=None, pr_num=None)