In [63]:
import json
from pathlib import Path

import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig


class CheXpertDataset(Dataset):
    def __init__(self, images, finding_labels, prompt_keys, prompts_dict,
                 processor_pretrained_model="StanfordAIMI/CheXagent-8b"):
        self.images = images
        self.finding_labels = finding_labels
        self.prompt_keys = prompt_keys
        self.prompts_dict = prompts_dict
        self.processor = AutoProcessor.from_pretrained(processor_pretrained_model, trust_remote_code=True)

    def __len__(self):
        return len(self.finding_labels)

    def __getitem__(self, index):
        image_path = self.images[index]
        image = Image.open(image_path).convert("RGB")
        prompt_key = self.prompt_keys[index]
        prompt = self.prompts_dict[prompt_key]
        finding_label = self.finding_labels[index]
        inputs = self.processor(images=[image], text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt")
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}  # Adjust as necessary

        return inputs, image_path, prompt_key, finding_label

In [64]:
column_names = ["image_index", "finding_labels", "follow_up_number", "patient_id", "patient_age", "patient_gender", "view_position", "original_image_width", "original_image_height", "original_image_pixel_spacing_x", "original_image_pixel_spacing_y"]  # fmt: skip # nopep8

data = pd.read_csv(
    "./data/NIH_Chest_X-ray_Dataset/Data_Entry_2017.csv",
    names=column_names,
    header=0,
    index_col=False,
)

# import prompts dictionary
with open("output/prompts.json", "r") as json_file:
    prompts = json.load(json_file)


results_prev = pd.read_csv(
    "output/disease_classification_QA.csv",
    usecols=["image_index", "finding_labels", "prompt_key", "response"],
    dtype=str,
)


# images with missing prompt cases
images_incomplete = (
    results_prev.groupby(["image_index"], as_index=True, sort=False)["prompt_key"]
    .apply(lambda x: x.nunique())
    .pipe(lambda x: x[x != len(prompts)])
    .index.to_list()
)

# images that have been analyed
images_analyzed = set(results_prev["image_index"].values) - set(images_incomplete)

# images to input into CheXagent
images_not_analyzed = set(data["image_index"].values) - images_analyzed
subset = (
    data[
        data["image_index"].isin(images_not_analyzed)
        & ((data["finding_labels"].str.count("\\|") + 1) == 1)
        & (data["finding_labels"] != "No Finding")
    ]
    .sample(5)
    .copy()
)

subset["image_path"] = [
    next(Path("data/NIH_Chest_X-ray_Dataset").rglob(x)).as_posix() for x in subset["image_index"].values
]

display(subset.head())
print(f"nRows: {subset.shape[0]:,}\tnColumns: {subset.shape[1]}")

Unnamed: 0,image_index,finding_labels,follow_up_number,patient_id,patient_age,patient_gender,view_position,original_image_width,original_image_height,original_image_pixel_spacing_x,original_image_pixel_spacing_y,image_path
9559,00002484_000.png,Atelectasis,0,2484,45,F,AP,2500,2048,0.171,0.171,data/NIH_Chest_X-ray_Dataset/images_002/images...
76049,00018657_010.png,Atelectasis,10,18657,76,F,AP,3056,2544,0.139,0.139,data/NIH_Chest_X-ray_Dataset/images_009/images...
41354,00010698_012.png,Fibrosis,12,10698,74,F,AP,2500,2048,0.168,0.168,data/NIH_Chest_X-ray_Dataset/images_005/images...
74207,00018237_007.png,Consolidation,7,18237,34,F,AP,3056,2544,0.139,0.139,data/NIH_Chest_X-ray_Dataset/images_008/images...
22121,00005855_000.png,Pneumothorax,0,5855,20,M,PA,2048,2500,0.171,0.171,data/NIH_Chest_X-ray_Dataset/images_003/images...


nRows: 5	nColumns: 12


In [65]:
df_img = subset[["image_index", "finding_labels", "image_path"]].copy()
df_img_prompts = pd.DataFrame()
for key in prompts.keys():
    tmp = df_img
    tmp["prompt_key"] = key
    df_img_prompts = pd.concat([df_img_prompts, tmp], ignore_index=True)
display(df_img_prompts.head())

Unnamed: 0,image_index,finding_labels,image_path,prompt_key
0,00002484_000.png,Atelectasis,data/NIH_Chest_X-ray_Dataset/images_002/images...,1
1,00018657_010.png,Atelectasis,data/NIH_Chest_X-ray_Dataset/images_009/images...,1
2,00010698_012.png,Fibrosis,data/NIH_Chest_X-ray_Dataset/images_005/images...,1
3,00018237_007.png,Consolidation,data/NIH_Chest_X-ray_Dataset/images_008/images...,1
4,00005855_000.png,Pneumothorax,data/NIH_Chest_X-ray_Dataset/images_003/images...,1


In [70]:
dataset = CheXpertDataset(df_img_prompts.image_path.values, df_img_prompts.finding_labels,
                          df_img_prompts.prompt_key, prompts)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)  # Batch size set to 1 for simplicity
processor = dataset.processor

# Load the model and set it to evaluation mode
device = "cuda"
dtype = torch.float16

model_name = "StanfordAIMI/CheXagent-8b"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, trust_remote_code=True).to(device)
model.eval()

# Load generation config if needed
generation_config = GenerationConfig.from_pretrained(model_name)

# Perform text generation
for batch in data_loader:
    inputs, image_path, prompt_key, finding_label = batch
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    
    # Generate text; adjust depending on your model's API
    output = model.generate(**inputs, generation_config=generation_config,
                            pad_token_id=processor.tokenizer.eos_token_id)[0]

    # Decode and print the generated text
    generated_text = processor.tokenizer.decode(output, skip_special_tokens=True)
    print(f"{image_path[0]=}, {prompt_key[0]=}, {finding_label[0]=}")
    print(generated_text)

  [torch.tensor(pixel_values) for pixel_values in encoding_image_processor["pixel_values"]]


image_path[0]='data/NIH_Chest_X-ray_Dataset/images_002/images/00002484_000.png', prompt_key[0]='1', finding_label[0]='Atelectasis'
No Finding


  [torch.tensor(pixel_values) for pixel_values in encoding_image_processor["pixel_values"]]
