### Imports and definitions


In [None]:
import io
import textwrap
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import requests
import seaborn as sns
import torch
from PIL import Image
from rich import print
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig

In [None]:
def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content)).convert("RGB")


def generate(images, prompt, processor, model, device, dtype, generation_config):
    inputs = processor(images=images[:2], text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt").to(
        device=device, dtype=dtype
    )
    output = model.generate(**inputs, generation_config=generation_config)[0]
    response = processor.tokenizer.decode(output, skip_special_tokens=True)
    return response


def main():
    # step 1: Setup constant
    device = "cuda"
    dtype = torch.float16

    # step 2: Load Processor and Model
    processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXagent-8b", trust_remote_code=True)
    generation_config = GenerationConfig.from_pretrained("StanfordAIMI/CheXagent-8b")
    model = AutoModelForCausalLM.from_pretrained(
        "StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True
    ).to(device)

    # step 3: Fetch the images
    image_path = "https://upload.wikimedia.org/wikipedia/commons/3/3b/Pleural_effusion-Metastatic_breast_carcinoma_Case_166_%285477628658%29.jpg"
    images = [download_image(image_path)]

    # step 4: Generate the Findings section
    for anatomy in anatomies:
        prompt = f'Describe "{anatomy}"'
        response = generate(images, prompt, processor, model, device, dtype, generation_config)
        print(f"Generating the Findings for [{anatomy}]:")
        print(response)

#### `main()` call


In [None]:
anatomies = [
    "Airway",
    "Breathing",
    "Cardiac",
    "Diaphragm",
    "Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, and pacemakers)",
]
main()

### [NIH Chest X-ray dataset](https://www.kaggle.com/datasets/nih-chest-xrays/sample/data)


In [None]:
column_names = ["image_index", "finding_labels", "follow_up_number", "patient_id", "patient_age", "patient_gender", "view_position", "original_image_width", "original_image_height", "original_image_pixel_spacing_x", "original_image_pixel_spacing_y"]  # fmt: skip # nopep8

data = pd.read_csv(
    "/data/irf/ai/tmp_MDR/github/MDR_CheXagent/data/NIH_Chest_X-ray_Dataset/Data_Entry_2017.csv",
    names=column_names,
    header=0,
    index_col=False,
)
print(f"Rows: {data.shape[0]:,}\tColumns: {data.shape[1]}")
display(data.head())

In [None]:
sns.countplot(x=data["patient_gender"])
plt.gca().set(title="Distribution of Patient Gender", xlabel=None, ylabel="Counts")
plt.tight_layout()
plt.show()

In [None]:
sns.histplot(data=data[data["patient_age"] < 130], x="patient_age", bins=20, kde=True)
plt.gca().set(title="Distribution of Patient Age", xlabel="Age", ylabel="Counts")
plt.tight_layout()
plt.show()

In [None]:
# Create boolean columns for each pathology type
pathology_list = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion", "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule", "Pleural_Thickening", "Pneumonia", "Pneumothorax"]  # fmt: skip # nopep8

for pathology in pathology_list:
    data[pathology] = data["finding_labels"].apply(lambda x: 1 if pathology in x else 0)

data["No_Findings"] = data["finding_labels"].apply(lambda x: 1 if "No Finding" in x else 0)

In [None]:
# Sum of values across selected columns and reset index for Seaborn
sum_data = data.iloc[:, 11:].sum().reset_index()
sum_data.columns = ["Feature", "Total"]

sns.barplot(x="Total", y="Feature", data=sum_data)
plt.gca().set(title="Disease Classes", xlabel="Counts", ylabel=None)
plt.tight_layout()
plt.show()

In [None]:
sns.barplot(x="Total", y="Feature", data=sum_data.loc[sum_data["Feature"] != "No_Findings", :])
plt.gca().set(title="Disease Classes (exluding no disease class)", xlabel="Counts", ylabel=None)
plt.tight_layout()
plt.show()

In [None]:
# subset: cases with exactly one finding label, first case of each label
subset = data.loc[data.iloc[:, 11:].apply(sum, axis=1) == 1, :].groupby("finding_labels").head(1).copy()

image_dir = "/data/irf/ai/tmp_MDR/github/MDR_CheXagent/data/NIH_Chest_X-ray_Dataset"
subset["image_path"] = subset["image_index"].transform(lambda x: next(Path(image_dir).rglob(x)).as_posix())
display(subset)

### `main()` code sections


###### step 1: Setup constant


In [None]:
device = "cuda"
dtype = torch.float16

###### step 2: Load Processor and Model


In [None]:
processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXagent-8b", trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained("StanfordAIMI/CheXagent-8b")
model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True).to(
    device
)

###### step 3: Fetch the images


In [None]:
image_path = "https://upload.wikimedia.org/wikipedia/commons/3/3b/Pleural_effusion-Metastatic_breast_carcinoma_Case_166_%285477628658%29.jpg"
images = [download_image(image_path)]

###### step 4: Generate the Findings section


In [None]:
anatomies = [
    "Airway",
    "Breathing",
    "Cardiac",
    "Diaphragm",
    "Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, and pacemakers)",
]

for anatomy in anatomies:
    prompt = f'Describe "{anatomy}"'
    response = generate(images, prompt, processor, model, device, dtype, generation_config)
    print(f"Generating the Findings for [{anatomy}]:")
    print(response)

In [None]:
prompts = [
    "Identify all abnormalities in the given chest X-ray.",
    "What abnormalities are notable for this patient?",
    "Write a structured Findings section for the given images as if you are a radiologist.",
    # "Assess the chest X-ray, identify key findings in the CXR and write a structured findings section.",
    "Are there any ground glass opacities?",
    # "Put a bounding box around regions showing pleural effusion?"
]

for idx, row in subset.iterrows():
    image = Image.open(row["image_path"]).convert("RGB")
    plt.imshow(image)
    plt.axis("off")
    img_name, label = row["image_index"], row["finding_labels"]
    plt.title(f"{img_name} Label: {label}")
    plt.show()
    for prompt in prompts:
        inputs = processor(images=[image], text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt").to(
            device=device, dtype=dtype
        )
        output = model.generate(**inputs, generation_config=generation_config)[0]
        response = processor.tokenizer.decode(output, skip_special_tokens=True)
        formatted_response = "\n".join(
            textwrap.wrap(response, width=100, break_long_words=False, replace_whitespace=False)
        )
        print(f"\n\nQ: {prompt}\n\nA: {formatted_response}")
    print("-" * 100, "\n")