###### Imports and definitions


In [1]:
import csv
import json
import warnings
from datetime import datetime
from pathlib import Path

import pandas as pd
import torch
from PIL import Image
from rich import print
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig


class CheXpertDataset(Dataset):
    def __init__(
        self, images, finding_labels, prompt_keys, prompts_dict, processor_pretrained_model="StanfordAIMI/CheXagent-8b"
    ):
        self.images = images
        self.finding_labels = finding_labels
        self.prompt_keys = prompt_keys
        self.prompts_dict = prompts_dict
        self.processor = AutoProcessor.from_pretrained(processor_pretrained_model, trust_remote_code=True)

    def __len__(self):
        return len(self.finding_labels)

    def __getitem__(self, index):
        image_path = self.images[index]
        image = Image.open(image_path).convert("RGB")
        prompt_key = self.prompt_keys[index]
        prompt = self.prompts_dict[prompt_key]
        finding_label = self.finding_labels[index]
        inputs = self.processor(images=[image], text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt")
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}  # Adjust as necessary

        return inputs, image_path, prompt_key, finding_label


def disp_df(dataframe):
    display(dataframe.head())
    print(f"nRows: {dataframe.shape[0]:,}\tnColumns: {dataframe.shape[1]}")

### Prepare input data for analysis


###### NIH Chest X-ray dataset


In [2]:
column_names = ["image_index", "finding_labels", "follow_up_number", "patient_id", "patient_age", "patient_gender", "view_position", "original_image_width", "original_image_height", "original_image_pixel_spacing_x", "original_image_pixel_spacing_y"]  # fmt: skip # nopep8

data = pd.read_csv(
    "./data/NIH_Chest_X-ray_Dataset/Data_Entry_2017.csv",
    names=column_names,
    header=0,
    index_col=False,
)

disp_df(data)

Unnamed: 0,image_index,finding_labels,follow_up_number,patient_id,patient_age,patient_gender,view_position,original_image_width,original_image_height,original_image_pixel_spacing_x,original_image_pixel_spacing_y
0,00000001_000.png,Cardiomegaly,0,1,58,M,PA,2682,2749,0.143,0.143
1,00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.143,0.143
2,00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168
3,00000002_000.png,No Finding,0,2,81,M,PA,2500,2048,0.171,0.171
4,00000003_000.png,Hernia,0,3,81,F,PA,2582,2991,0.143,0.143


#### Single finding cases


###### Prompt dictionary


In [3]:
# # import prompts dictionary
# with open("data/prompts.json", "r") as json_file:
#     prompts = json.load(json_file)

###### NIH Chest X-ray dataset: single finding cases with prompt_key column


In [4]:
# data_single_finding = data.loc[
#     (data["finding_labels"].str.count("\\|") + 1) == 1, ["image_index", "finding_labels"]
# ].copy()
# data_single_finding_prompts = pd.DataFrame()
# for prompt_key in prompts.keys():
#     df_i = data_single_finding.copy()
#     df_i["prompt_key"] = prompt_key
#     data_single_finding_prompts = pd.concat([data_single_finding_prompts, df_i])

# disp_df(data_single_finding_prompts)

###### Load analyzed image/prompt combinations


In [5]:
# df_results_prev = pd.read_csv("output/NIH_Chest_Xray_n1_findings/disease_classification_QA.csv.gz", dtype=str)

# disp_df(df_results_prev)

###### Input image/prompt combinations that have not been analyzed


In [6]:
# df_not_analyzed = pd.merge(left=data_single_finding_prompts, right=df_results_prev, how="left").pipe(
#     lambda x: x.loc[x["response"].isna(), ["image_index", "finding_labels", "prompt_key"]]
# )

# disp_df(df_not_analyzed)

###### sort the input dataframe


In [7]:
# # show least analyzed finding_label first
# count_map = df_results_prev["finding_labels"].value_counts().to_dict()
# df_not_analyzed["label_count"] = df_not_analyzed["finding_labels"].map(count_map)

# # create int dtype of prompt key column
# df_not_analyzed["prompt_key_int"] = df_not_analyzed["prompt_key"].astype(int)

# df_not_analyzed = df_not_analyzed.sort_values(["label_count", "image_index", "prompt_key_int"], ignore_index=True).drop(
#     columns=["label_count", "prompt_key_int"]
# )
# disp_df(df_not_analyzed)

###### add image filepaths


In [8]:
# # import filepaths
# df_full_paths = (
#     pd.read_csv("./output/single_finding_datalist.csv", usecols=["image_index", "image_path"])
#     .set_index("image_index")["image_path"]
#     .to_dict()
# )

# df_not_analyzed["image_path"] = df_not_analyzed["image_index"].map(df_full_paths)

# disp_df(df_not_analyzed)

###### delete unused objects


In [9]:
# del (
#     column_names,
#     count_map,
#     data,
#     data_single_finding,
#     data_single_finding_prompts,
#     df_full_paths,
#     df_i,
#     df_results_prev,
#     filelist,
#     json_file,
#     prompt_key,
# )

###### define filepath for results


In [10]:
# csv_file_parent = Path("output/NIH_Chest_Xray_n1_findings/")

#### Two findings cases


###### Prompt dictionary


In [11]:
# # import prompts dictionary
# with open("data/prompts_two_findings.json", "r") as json_file:
#     prompts = json.load(json_file)

###### NIH Chest X-ray dataset: two findings cases with prompt_key column


In [12]:
# data_two_finding = data.loc[
#     (data["finding_labels"].str.count("\\|") + 1) == 2, ["image_index", "finding_labels"]
# ].copy()
# data_two_finding_prompts = pd.DataFrame()
# for prompt_key in prompts.keys():
#     df_i = data_two_finding.copy()
#     df_i["prompt_key"] = prompt_key
#     data_two_finding_prompts = pd.concat([data_two_finding_prompts, df_i])

# disp_df(data_two_finding_prompts)

###### Load analyzed image/prompt combinations


In [13]:
# filelist = [x.as_posix() for x in Path("output/NIH_Chest_Xray_n2_findings").glob("disease_classification_QA*.csv")]

# if filelist == []:
#     df_results_prev = pd.DataFrame(columns=["image_index", "finding_labels", "prompt_key", "response"])
# else:
#     df_results_prev = pd.DataFrame()
#     for f in filelist:
#         df_i = pd.read_csv(f, usecols=["image_index", "finding_labels", "prompt_key", "response"], dtype=str)
#         df_results_prev = pd.concat([df_results_prev, df_i])

#     df_results_prev = df_results_prev.drop_duplicates(ignore_index=True)

# disp_df(df_results_prev)

###### Input image/prompt combinations that have not been analyzed


In [14]:
# df_not_analyzed = pd.merge(left=data_two_finding_prompts, right=df_results_prev, how="left").pipe(
#     lambda x: x.loc[x["response"].isna(), ["image_index", "finding_labels", "prompt_key"]]
# )

# disp_df(df_not_analyzed)

###### sort the input dataframe


In [15]:
# # show least analyzed finding_label first
# count_map = df_results_prev["finding_labels"].value_counts().to_dict()
# df_not_analyzed["label_count"] = df_not_analyzed["finding_labels"].map(count_map)

# # create int dtype of prompt key column
# df_not_analyzed["prompt_key_int"] = df_not_analyzed["prompt_key"].astype(int)

# df_not_analyzed = df_not_analyzed.sort_values(["label_count", "image_index", "prompt_key_int"], ignore_index=True).drop(
#     columns=["label_count", "prompt_key_int"]
# )
# disp_df(df_not_analyzed)

###### add image filepaths


In [16]:
# # import filepaths
# df_full_paths = pd.read_csv("data/NIH_Chest_X-ray_image_filepaths.csv").set_index("image_index")["image_path"].to_dict()

# df_not_analyzed["image_path"] = df_not_analyzed["image_index"].map(df_full_paths)

# disp_df(df_not_analyzed)

###### delete unused objects


In [17]:
# del (
#     column_names,
#     count_map,
#     data,
#     data_two_finding,
#     data_two_finding_prompts,
#     df_full_paths,
#     df_i,
#     df_results_prev,
#     filelist,
#     json_file,
#     prompt_key,
# )

###### define filepath for results


In [18]:
# csv_file_parent = Path("output/NIH_Chest_Xray_n2_findings/")

#### Three or more findings cases


###### Prompt dictionary


In [19]:
# import prompts dictionary
with open("data/prompts_three_or_more_findings.json", "r") as json_file:
    prompts = json.load(json_file)

###### NIH Chest X-ray dataset: two findings cases with prompt_key column


In [20]:
data_many_finding = data.loc[
    (data["finding_labels"].str.count("\\|") + 1) > 2, ["image_index", "finding_labels"]
].copy()
data_many_finding_prompts = pd.DataFrame()
for prompt_key in prompts.keys():
    df_i = data_many_finding.copy()
    df_i["prompt_key"] = prompt_key
    data_many_finding_prompts = pd.concat([data_many_finding_prompts, df_i])

disp_df(data_many_finding_prompts)

Unnamed: 0,image_index,finding_labels,prompt_key
42,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,21
43,00000013_005.png,Emphysema|Infiltration|Pleural_Thickening|Pneu...,21
48,00000013_010.png,Effusion|Pneumonia|Pneumothorax,21
56,00000013_018.png,Effusion|Infiltration|Pneumothorax,21
61,00000013_023.png,Infiltration|Mass|Pneumothorax,21


###### Load analyzed image/prompt combinations


In [21]:
filelist = [x.as_posix() for x in Path("output/NIH_Chest_Xray_many_findings").glob("disease_classification_QA*.csv")]

if filelist == []:
    df_results_prev = pd.DataFrame(columns=["image_index", "finding_labels", "prompt_key", "response"])
else:
    df_results_prev = pd.DataFrame()
    for f in filelist:
        df_i = pd.read_csv(f, usecols=["image_index", "finding_labels", "prompt_key", "response"], dtype=str)
        df_results_prev = pd.concat([df_results_prev, df_i])

    df_results_prev = df_results_prev.drop_duplicates(ignore_index=True)

disp_df(df_results_prev)

Unnamed: 0,image_index,finding_labels,prompt_key,response


###### Input image/prompt combinations that have not been analyzed


In [22]:
df_not_analyzed = pd.merge(left=data_many_finding_prompts, right=df_results_prev, how="left").pipe(
    lambda x: x.loc[x["response"].isna(), ["image_index", "finding_labels", "prompt_key"]]
)

disp_df(df_not_analyzed)

Unnamed: 0,image_index,finding_labels,prompt_key
0,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,21
1,00000013_005.png,Emphysema|Infiltration|Pleural_Thickening|Pneu...,21
2,00000013_010.png,Effusion|Pneumonia|Pneumothorax,21
3,00000013_018.png,Effusion|Infiltration|Pneumothorax,21
4,00000013_023.png,Infiltration|Mass|Pneumothorax,21


###### sort the input dataframe


In [23]:
# show least analyzed finding_label first
count_map = df_results_prev["finding_labels"].value_counts().to_dict()
df_not_analyzed["label_count"] = df_not_analyzed["finding_labels"].map(count_map)

# create int dtype of prompt key column
df_not_analyzed["prompt_key_int"] = df_not_analyzed["prompt_key"].astype(int)

df_not_analyzed = df_not_analyzed.sort_values(["label_count", "image_index", "prompt_key_int"], ignore_index=True).drop(
    columns=["label_count", "prompt_key_int"]
)
disp_df(df_not_analyzed)

Unnamed: 0,image_index,finding_labels,prompt_key
0,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,21
1,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,22
2,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,23
3,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,25
4,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,26


###### add image filepaths


In [24]:
# import filepaths
df_full_paths = pd.read_csv("data/NIH_Chest_X-ray_image_filepaths.csv").set_index("image_index")["image_path"].to_dict()

df_not_analyzed["image_path"] = df_not_analyzed["image_index"].map(df_full_paths)

disp_df(df_not_analyzed)

Unnamed: 0,image_index,finding_labels,prompt_key,image_path
0,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,21,/data/irf/ai/rustadmd/CheXagent/data/NIH_Chest...
1,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,22,/data/irf/ai/rustadmd/CheXagent/data/NIH_Chest...
2,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,23,/data/irf/ai/rustadmd/CheXagent/data/NIH_Chest...
3,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,25,/data/irf/ai/rustadmd/CheXagent/data/NIH_Chest...
4,00000013_004.png,Effusion|Emphysema|Infiltration|Pneumothorax,26,/data/irf/ai/rustadmd/CheXagent/data/NIH_Chest...


###### delete unused objects


In [25]:
del (
    column_names,
    count_map,
    data,
    data_many_finding,
    data_many_finding_prompts,
    df_full_paths,
    df_i,
    df_results_prev,
    filelist,
    json_file,
    prompt_key,
)

###### define filepath for results


In [26]:
csv_file_parent = Path("output/NIH_Chest_Xray_many_findings/")

### Using CheXagent


###### initialize `CheXpertDataset`


In [27]:
dataset = CheXpertDataset(
    df_not_analyzed["image_path"].values,
    df_not_analyzed["finding_labels"].values,
    df_not_analyzed["prompt_key"].values,
    prompts,
)
# keys=[key for key in prompts.keys()]
# dataset = CheXpertDataset(
#     ["data/DRR.jpg"] * len(keys),
#     ["No Finding"] * len(keys),
#     [key for key in prompts.keys()],
#     prompts,
# )
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)  # Batch size set to 1 for simplicity
processor = dataset.processor

# Load the model and set it to evaluation mode
device = "cuda"
dtype = torch.float16

model_name = "StanfordAIMI/CheXagent-8b"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, trust_remote_code=True).to(device)
model.eval()

# Load generation config if needed
generation_config = GenerationConfig.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

###### run the analysis


In [28]:
now = datetime.now()
datetime_str = now.strftime("%Y%m%d_%H%M%S")
csv_file_path = csv_file_parent.joinpath(f"disease_classification_QA_{datetime_str}.csv").as_posix()

warnings.filterwarnings("ignore", category=UserWarning)

with open(csv_file_path, mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["image_index", "finding_labels", "prompt_key", "response"])

    for batch in tqdm(data_loader, total=len(data_loader), desc="Processing images"):
        inputs = batch[0]
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        image_path, prompt_key, finding_label = [x[0] for x in batch[1:]]
        image_index = Path(image_path).name

        # Generate text; adjust depending on your model's API
        output = model.generate(
            **inputs, generation_config=generation_config, pad_token_id=processor.tokenizer.eos_token_id
        )[0]

        # Decode and print the generated text
        generated_text = processor.tokenizer.decode(output, skip_special_tokens=True)

        row_data = [image_index, finding_label, prompt_key, generated_text]
        writer.writerow(row_data)

warnings.resetwarnings()

Processing images:   0%|          | 0/58410 [00:00<?, ?it/s]