# Dataset Exploration

## 0. Set-Up Environment

### 0.1. Import Necessary Libraries

In [1]:
import base64
from collections import Counter
from io import BytesIO
from pathlib import Path
from typing import Optional

import pandas as pd
import plotly.express as px
from IPython.display import display, HTML
from PIL import Image
from datasets import Dataset, load_dataset, disable_progress_bars

In [2]:
%load_ext ipyform
%form_config --auto-detect 1

### 0.2. Define Helper Functions (extract to Python classes later on)

In [3]:
def resize_image(
    image: Image.Image,
    width: Optional[int],
    height: Optional[int]
) -> Image.Image:
    if width or height:
        original_width, original_height = image.size

        if width and not height:
            height = int((width / original_width) * original_height)
        elif height and not width:
            width = int((height / original_height) * original_width)

        image = image.resize((width, height), Image.Resampling.LANCZOS)

    return image

In [4]:
def display_base64_image(
    base64_image: str,
    width: Optional[int] = None,
    height: Optional[int] = None
) -> None:
    image_data = base64.b64decode(base64_image)
    image = Image.open(BytesIO(image_data))
    resized_image = resize_image(image, width, height)
    display(resized_image)

In [5]:
def display_formatted_section(
    section_name: str,
    section_style: str,
    section_content: str | int
) -> None:
    section_text = f"""
    <div style='{section_style}'>
        <b>{section_name}:</b> {section_content}
    </div>
    """
    display(HTML(section_text))

In [6]:
def visualize_qa_pair_by_id(
    dataset: Dataset,
    id: int,
    image_width: Optional[int] = None,
    image_height: Optional[int] = None
) -> None:
    # Obtain row by index
    filtered_dataset = dataset.filter(lambda row: row['index'] == id)
    if len(filtered_dataset) == 0:
        raise ValueError(f"No row found with index {id}")
    else:
        row = filtered_dataset[0]

    # Display row id
    display_formatted_section(
        section_name="ID",
        section_style="margin: 20px 0;",
        section_content=row['index']
    )

    # Display question
    display_formatted_section(
        section_name="Question",
        section_style="margin-bottom: 20px;",
        section_content=row['question']
    )

    # Display context image
    display_formatted_section(
        section_name="Context Image",
        section_style="margin-bottom: 20px;",
        section_content=""
    )
    display_base64_image(
        base64_image=row['image'],
        width=image_width,
        height=image_height
    )

    # Display possible answers
    formatted_options = []
    possible_options = ['A', 'B', 'C', 'D']
    for option in possible_options:
        if option == row['correct_option']:
            formatted_options.append(f"<p style='color: rgb(0, 255, 0);'><b>{option}) {row[option]}</b>")
        else:
            formatted_options.append(f"<p>{option}) {row[option]}")
    answer = "<br><br>" + "<br>".join(formatted_options)

    display_formatted_section(
        section_name="Possible Answers",
        section_style="margin-top: 30px;",
        section_content=answer
    )

## 1. WorldMedQA-V Dataset

### 1.1. Define Constants

In [9]:
DATASET_DIR = Path("data/WorldMedQA-V")
COUNTRY = "spain"
FILE_TYPE = "english"

### 1.2. Load Dataset

In [10]:
# Set dataset file path
dataset_filename = f"{COUNTRY}_{FILE_TYPE}_processed.tsv"
data_filepath = str(DATASET_DIR / dataset_filename)

# Load dataset
world_med_qa_v_dataset = load_dataset(
    "csv",
    data_files=[data_filepath],
    sep="\t",
)["train"]
world_med_qa_v_dataset

Dataset({
    features: ['index', 'image', 'question', 'A', 'B', 'C', 'D', 'answer', 'correct_option', 'split'],
    num_rows: 125
})

### 1.3. Visualize QA Pair by ID

In [11]:
# avoid progress bar when applying filter to dataset
disable_progress_bars()

In [12]:
# @title QA Pair Visualizer Form
question_id = 33 # @param {"type":"integer"}
image_width = 600 # @param {"type":"integer"}

visualize_qa_pair_by_id(
    dataset=world_med_qa_v_dataset,
    id=question_id,
    image_width=image_width
)

FormWidget(children=(VBox(children=(HTML(value=''), HTML(value='<h2>QA Pair Visualizer Form</h2>'), Box(childr…

### 1.4. Correct Answer Distribution

In [15]:
correct_answer_distribution = Counter(world_med_qa_v_dataset['correct_option'])
correct_answer_distribution_df = pd.DataFrame({
    "correct_option": correct_answer_distribution.keys(),
    "count": correct_answer_distribution.values()
})
correct_answer_distribution_df = correct_answer_distribution_df.sort_values('correct_option')

correct_answer_distribution_pie_chart = px.pie(
    data_frame=correct_answer_distribution_df,
    names='correct_option',
    values="count",
    title="Correct Answer Distribution",
    hole=0.45,
    category_orders={
        "correct_option": sorted(correct_answer_distribution.keys())
    },
    color='correct_option',
    color_discrete_sequence=px.colors.qualitative.Pastel
)

correct_answer_distribution_pie_chart.update_traces(
    textposition='inside',
    textinfo='percent+label+value',
    pull=[0.05] * len(correct_answer_distribution_df),
    textfont=dict(size=18, color='black', weight='bold')
)

correct_answer_distribution_pie_chart.update_layout(
    legend=dict(
        title='Possible Answers',
        orientation='h',
        yanchor='bottom',
        y=-0.2,
        xanchor='center',
        x=0.5,
        font=dict(size=14)
    ),
    width=850,
    height=650,
    title=dict(
        x=0.5,
        font=dict(size=24, color="black")
    ),
)

correct_answer_distribution_pie_chart.show()

## 2. WikiMed Dataset