In [1]:
from pathlib import Path
import requests
import json
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Change path as needed inside google drive
project_dir = Path("/content/drive/MyDrive/medgemma_playground")

# Create the project directory if missing
project_dir.mkdir(parents=True, exist_ok=True)

# Image and data dir
data_dir = project_dir / "data"
images_dir = data_dir / "images"

# Make sure folders exist
images_dir.mkdir(parents=True, exist_ok=True)

print("Project directory:", project_dir)
print("Data directory:", data_dir)
print("Images directory:", images_dir)

# Toy dataset metadata - list of dicts
SAMPLES = [
    {
        "filename": "cxr_normal_pa.png",
        "url": "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
        "view": "PA",
        "report": (
            "Normal posteroanterior chest radiograph with clear lungs and a "
            "normal cardiomediastinal silhouette."
        ),
    },
    {
        "filename": "cxr_pneumonia_pa.jpg",
        "url": "https://upload.wikimedia.org/wikipedia/commons/5/53/X-ray_lung_consolidation.jpg",
        "view": "PA",
        "report": (
            "Chest X-ray showing consolidation in the right mid and both lower "
            "lobes, consistent with pneumonia."
        ),
    },
    {
        "filename": "cxr_pneumonia_variant_pa.jpg",
        "url": "https://upload.wikimedia.org/wikipedia/commons/5/53/X-ray_lung_consolidation.jpg",
        "view": "PA",
        "report": (
            "Posteroanterior chest radiograph demonstrating bilateral lower "
            "lobe consolidation with airspace opacities."
        ),
    },
]

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Project directory: /content/drive/MyDrive/medgemma_playground
Data directory: /content/drive/MyDrive/medgemma_playground/data
Images directory: /content/drive/MyDrive/medgemma_playground/data/images


In [2]:
def download_image(sample):
    """
    Download one image from the web into images_dir if it isn't already there.
    We add a browser-like User-Agent header so the server doesn't block us (403).
    """

    # Build output path based on images_dir from previous cell
    out_path = images_dir / sample["filename"]

    # If we've already downloaded it in a previous run, just reuse it
    if out_path.exists():
        print(f"Already exists: {out_path}")
        return out_path

    print(f"Downloading {sample['url']} -> {out_path}")

    # Some sites (like Wikimedia) block default Python clients.
    # Pretend to be a normal Chrome browser via the User-Agent header.
    headers = {
        "User-Agent": (
            "Mozilla/5.0 (X11; Linux x86_64) "
            "AppleWebKit/537.36 (KHTML, like Gecko) "
            "Chrome/120.0 Safari/537.36"
        )
    }

    # Stream the response so we don't load the whole image into memory at once
    resp = requests.get(sample["url"], headers=headers, stream=True)
    # If the server returns 4xx/5xx, this will raise an HTTPError
    resp.raise_for_status()

    # Save the bytes we get from the network into a file on disk-writing image to disk
    with open(out_path, "wb") as f:
        for chunk in resp.iter_content(chunk_size=8192):
            if chunk:  # filter out "keep-alive" chunks
                f.write(chunk)

    # return path to downloaded image
    return out_path


# download each image in SAMPLES
local_samples = []
for sample in SAMPLES:
    # Download the image (or reuse if already exists) and get its local path
    local_path = download_image(sample)

    # Make a copy of the sample dict so we can add the local path to dict
    sample_with_path = dict(sample)
    sample_with_path["local_image_path"] = str(local_path)

    # Store this enriched sample in a new list
    local_samples.append(sample_with_path)

print("\nDownloaded image files:")
for s in local_samples:
    print("-", s["local_image_path"])

Already exists: /content/drive/MyDrive/medgemma_playground/data/images/cxr_normal_pa.png
Already exists: /content/drive/MyDrive/medgemma_playground/data/images/cxr_pneumonia_pa.jpg
Already exists: /content/drive/MyDrive/medgemma_playground/data/images/cxr_pneumonia_variant_pa.jpg

Downloaded image files:
- /content/drive/MyDrive/medgemma_playground/data/images/cxr_normal_pa.png
- /content/drive/MyDrive/medgemma_playground/data/images/cxr_pneumonia_pa.jpg
- /content/drive/MyDrive/medgemma_playground/data/images/cxr_pneumonia_variant_pa.jpg


In [3]:
# sets output path of JSONL
jsonl_path = data_dir / "cxr_generation_dataset.jsonl"
print("JSONL will be saved to:", jsonl_path)


def build_prompt(view: str, report_text: str) -> str:
    """
    Convert a simple radiology-style report into a text prompt for generation.

    This function is deliberately heuristic:
    it looks for certain keywords in the report and constructs a
    'generate a <view> chest x-ray with ...' style prompt.
    """
    # Start with the projection/view
    base = f"generate a {view.lower()} chest x-ray"
    lower = report_text.lower()
    findings = []

    # Naive keyword-based parsing of the report text
    if "cardiomegaly" in lower or "enlarged heart" in lower:
        findings.append("with cardiomegaly")
    if "consolidation" in lower or "pneumonia" in lower or "airspace" in lower:
        findings.append("with parenchymal consolidation")
    if "effusion" in lower:
        findings.append("with pleural effusion")
    if "atelectasis" in lower:
        findings.append("with basal atelectasis")

    # If nothing matched, assume essentially normal
    if not findings:
        findings.append("with no acute cardiopulmonary abnormality")

    # Join all parts into a single prompt string
    prompt = base + " " + " and ".join(findings)

    return prompt


num_written = 0 # counts number of samples

# Open the JSONL file for writing (this will overwrite if it already exists)
with open(jsonl_path, "w") as f_out:
    for sample in local_samples:
        # This path was added in Cell 2
        image_path = sample["local_image_path"]

        # View and report come from our toy metadata
        view = sample["view"]
        report_text = sample["report"].strip()

        # Build a generation-style prompt from the view + report
        prompt = build_prompt(view, report_text)

        # This is the final training record we'll store on disk
        example = {
            "image": image_path,       # local path to the CXR file
            "prompt": prompt,          # text input we feed into MedGemma
            "cxr_report": report_text  # original report-like description
        }

        # Convert the dict to a JSON string and write it as one line
        f_out.write(json.dumps(example) + "\n")
        num_written += 1

print(f"Wrote {num_written} examples to {jsonl_path}")

JSONL will be saved to: /content/drive/MyDrive/medgemma_playground/data/cxr_generation_dataset.jsonl
Wrote 3 examples to /content/drive/MyDrive/medgemma_playground/data/cxr_generation_dataset.jsonl


In [4]:
from huggingface_hub import login

login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
# making PyTorch Dataset
from pathlib import Path
import json
from PIL import Image
from torch.utils.data import Dataset

class CXRGenerationDataset(Dataset):
  def __init__(self, jsonl_path: Path, transform = None):
    self.jsonl_path = Path(jsonl_path)
    self.transform = transform # if we want to add torchvision transforms later
    self.samples = []

    with open(self.jsonl_path, "r") as f:
      for line in f:
        line = line.strip()
        if not line:
          continue
        record = json.loads(line)
        self.samples.append(record)

  def __len__(self):
    return len(self.samples)

  def __getitem__(self, idx: int):
    sample = self.samples[idx]

    # get fields
    image_path = sample["image"]
    prompt = sample["prompt"]
    report = sample["cxr_report"]

    # convert to RBG as model's processor expects 3 channels
    image = Image.open(image_path).convert("RGB")

    # apply transform if user provided one
    if self.transform is not None:
      image = self.transform(image)

    # return final dict
    return {
        "image": image,
        "prompt": prompt,
        "report": report
    }


In [8]:
from transformers import AutoProcessor
from torch.utils.data import DataLoader

# instantiate processor
processor = AutoProcessor.from_pretrained("google/medgemma-4b-it")

def collate_fn(batch):
    """
    Turn a list of dataset items into a single batch for MedGemma.

    We explicitly process images and text separately to avoid
    the '<image> token' requirement in Gemma3Processor.__call__.
    """

    # split fields from each sample, make a list for each field
    images  = [item["image"]  for item in batch]
    prompts = [item["prompt"] for item in batch]
    reports = [item["report"] for item in batch]

    print("len(batch):", len(batch), "len(images):", len(images), "len(prompts):", len(prompts))

    # convert image to pixel tensors (B, C, H, W)
    image_inputs = processor.image_processor(
        images=images,
        return_tensors="pt",
    )

    # convert text to token ids and attention masks
    text_inputs = processor.tokenizer(
        text=prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )

    # merge the dicts
    encoded = {
        "pixel_values": image_inputs["pixel_values"],
        "input_ids": text_inputs["input_ids"],
        "attention_mask": text_inputs["attention_mask"],
        "prompt": prompts,
        "report": reports,
    }

    # debug prints
    print("input_ids shape:", encoded["input_ids"].shape)
    print("pixel_values shape:", encoded["pixel_values"].shape)

    return encoded

# instantiate dataset from JSONL we made earlier
dataset = CXRGenerationDataset(jsonl_path)

# wrap dataset in a DataLoader for batching and shuffling
loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=collate_fn,
)


In [9]:
batch = next(iter(loader))

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


len(batch): 2 len(images): 2 len(prompts): 2
input_ids shape: torch.Size([2, 14])
pixel_values shape: torch.Size([2, 3, 896, 896])
