In [2]:
import pandas as pd
import os
import glob
import pydicom
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
from torchvision.models import densenet121
from torchvision import transforms
import zipfile

In [3]:
TRAIN_MEAN = 0.5007
TRAIN_STD  = 0.2508

In [6]:
val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([TRAIN_MEAN], [TRAIN_STD]),
])

In [4]:
zip_path   = "/Users/Kyra_1/Desktop/test_data.zip"
extract_to = "/Users/Kyra_1/Desktop/test_data"

# only extract if the dir doesn't already exist
if not os.path.isdir(extract_to):
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(extract_to)
    print(f"Unzipped into {extract_to}")
else:
    print(f"{extract_to} already exists, skipping unzip")


/Users/Kyra_1/Desktop/test_data already exists, skipping unzip


In [8]:
# ─── 1) Pick your device ─────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── 2) Load the state dict ─────────────────────────────────────────────────
state_dict = torch.load("reproduceable_densenet.pt", map_location="cpu")

# ─── 3) Instantiate DenseNet121 (no pretrained weights) ────────────────────
model = densenet121(pretrained=False)

# ─── 4) Patch the first conv to accept 1‐channel input ───────────────────────
old_conv = model.features.conv0
new_conv = nn.Conv2d(
    in_channels=1,
    out_channels=old_conv.out_channels,
    kernel_size=old_conv.kernel_size,
    stride=old_conv.stride,
    padding=old_conv.padding,
    bias=(old_conv.bias is not None)
)
with torch.no_grad():
    new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)
    if old_conv.bias is not None:
        new_conv.bias[:] = old_conv.bias
model.features.conv0 = new_conv

# ─── 5) Rebuild the classifier to match your num_classes ────────────────────
num_classes       = state_dict["classifier.weight"].shape[0]
in_feats          = model.classifier.in_features
model.classifier  = nn.Linear(in_feats, num_classes)

# ─── 6) Load weights & move to device ───────────────────────────────────────
model.load_state_dict(state_dict)
model = model.to(device).eval()

# ─── 7) Load & filter your CSV ───────────────────────────────────────────────
df = pd.read_csv("best_model_pred_final.csv")
# build your unique_labels list (must match your classifier’s ordering)
all_labels    = df["true_label"]
unique_labels = sorted(all_labels.unique().tolist())

# keep only correct predictions, and exclude certain classes
df = df[df.true_label == df.predicted]
exclude = {"ASIAN", "HISPANIC/LATINO"}
df = df[~df.true_label.isin(exclude)].reset_index(drop=True)

# ─── 8) Set up Grad-CAM hooks ────────────────────────────────────────────────
activations = {}
gradients   = {}

def forward_hook(module, inp, out):
    activations["feat"] = out

def backward_hook(module, grad_in, grad_out):
    gradients["grad"] = grad_out[0]

target_layer = model.features.norm5
fh = target_layer.register_forward_hook(forward_hook)
bh = target_layer.register_backward_hook(backward_hook)

# ─── 9) Prepare input/output paths ─────────────────────────────────────────
test_root  = "/Users/Kyra_1/Desktop/test_data"
output_dir = "/Users/Kyra_1/Desktop/local_ADS_data/gradcam_results"
os.makedirs(output_dir, exist_ok=True)

# ─── 10) Loop & generate Grad-CAM, skipping existing outputs ────────────────
for _, row in df.iterrows():
#for idx, row in df.iloc[203:].iterrows():
    fname = os.path.basename(row.dicom_path)
    matches = glob.glob(os.path.join(test_root, "**", fname), recursive=True)
    if not matches:
        print(f"⚠️  File not found: {fname}")
        continue
    dicom_path = matches[0]
    base, _   = os.path.splitext(fname)

    # If all three outputs already exist, skip:
    out_cam     = os.path.join(output_dir, f"{base}_cam.pt")
    out_overlay = os.path.join(output_dir, f"{base}_overlay.pt")
    out_png     = os.path.join(output_dir, f"{base}.png")
    if all(os.path.exists(p) for p in (out_cam, out_overlay, out_png)):
        print(f"🔹 Skipping {base}: already done")
        continue

    # ---- load & normalize DICOM ----
    ds  = pydicom.dcmread(dicom_path, force=True)
    arr = ds.pixel_array.astype("float32")
    arr_min   = arr.min()
    arr_range = np.ptp(arr)        # instead of arr.ptp()

    arr = (arr - arr_min) / (arr_range + 1e-6)  
    img = Image.fromarray((arr * 255).astype("uint8"))

    # ---- forward + backward on true class ----
    x       = val_transform(img).unsqueeze(0).to(device)
    model.zero_grad()
    out     = model(x)
    cls_idx = unique_labels.index(row.true_label)
    out[0, cls_idx].backward()

    # ---- build raw CAM ----
    feat  = activations["feat"][0]    # C×h×w
    grad  = gradients["grad"][0]      # C×h×w
    wts   = grad.mean(dim=(1,2))      # C
    cam   = (wts[:,None,None] * feat).sum(dim=0).cpu().detach().numpy()
    cam   = np.maximum(cam, 0)
    cam   = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
    cam   = cv2.resize(cam, (arr.shape[1], arr.shape[0]),
                       interpolation=cv2.INTER_LINEAR)

    # ---- build overlay ----
    heat    = cv2.applyColorMap((cam*255).astype("uint8"),
                                cv2.COLORMAP_JET)[...,::-1] / 255.0
    overlay = 0.6 * np.dstack([arr]*3) + 0.4 * heat

    # ---- save outputs ----
    torch.save(torch.from_numpy(cam).float(), out_cam)
    torch.save(torch.from_numpy(overlay)
               .permute(2,0,1).float(), out_overlay)

    plt.figure(figsize=(5,5))
    plt.imshow(overlay)
    plt.axis("off")
    plt.savefig(out_png, bbox_inches="tight", pad_inches=0)
    plt.close()

# ─── 11) Clean up hooks ─────────────────────────────────────────────────────
fh.remove()
bh.remove()


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow wit

RuntimeError: [enforce fail at inline_container.cc:659] . unexpected pos 832 vs 726

In [9]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import densenet121
import pandas as pd
import pydicom
import cv2
from PIL import Image
import matplotlib.pyplot as plt

# ─── 1) Pick your device ─────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─── 2) Load the state dict ─────────────────────────────────────────────────
state_dict = torch.load("reproduceable_densenet.pt", map_location="cpu")

# ─── 3) Instantiate DenseNet121 (no pretrained weights) ────────────────────
model = densenet121(pretrained=False)

# ─── 4) Patch the first conv to accept 1‐channel input ───────────────────────
old_conv = model.features.conv0
new_conv = nn.Conv2d(
    in_channels=1,
    out_channels=old_conv.out_channels,
    kernel_size=old_conv.kernel_size,
    stride=old_conv.stride,
    padding=old_conv.padding,
    bias=(old_conv.bias is not None)
)
with torch.no_grad():
    new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)
    if old_conv.bias is not None:
        new_conv.bias[:] = old_conv.bias
model.features.conv0 = new_conv

# ─── 5) Rebuild the classifier to match your num_classes ────────────────────
num_classes      = state_dict["classifier.weight"].shape[0]
in_feats         = model.classifier.in_features
model.classifier = nn.Linear(in_feats, num_classes)

# ─── 6) Load weights & move to device ───────────────────────────────────────
model.load_state_dict(state_dict)
model = model.to(device).eval()

# ─── 7) Load & filter your CSV ───────────────────────────────────────────────
df = pd.read_csv("best_model_pred_final.csv")
all_labels    = df["true_label"]
unique_labels = sorted(all_labels.unique().tolist())

df = df[df.true_label == df.predicted]
exclude = {"ASIAN", "HISPANIC/LATINO"}
df = df[~df.true_label.isin(exclude)].reset_index(drop=True)

# ─── 8) Set up Grad‐CAM hooks ────────────────────────────────────────────────
activations = {}
gradients   = {}

def forward_hook(module, inp, out):
    activations["feat"] = out

def backward_hook(module, grad_in, grad_out):
    gradients["grad"] = grad_out[0]

target_layer = model.features.norm5
fh = target_layer.register_forward_hook(forward_hook)
bh = target_layer.register_backward_hook(backward_hook)

# ─── 9) Prepare input/output paths on Extra Storage ─────────────────────────
DRIVE_ROOT = "/Volumes/Extra Storage"
test_root  = os.path.join(DRIVE_ROOT, "test_data")
output_dir = os.path.join(DRIVE_ROOT, "local_ADS_data", "gradcam_results")
os.makedirs(output_dir, exist_ok=True)

# ─── 10) Loop & generate Grad‐CAM, skipping existing outputs ────────────────
for idx, row in df.iterrows():
    fname = os.path.basename(row["dicom_path"])  # or row["full_path"] if that’s your column
    matches = glob.glob(os.path.join(test_root, "**", fname), recursive=True)
    if not matches:
        print(f"⚠️  File not found: {fname}")
        continue
    dicom_path = matches[0]
    base, _    = os.path.splitext(fname)

    out_cam     = os.path.join(output_dir, f"{base}_cam.pt")
    out_overlay = os.path.join(output_dir, f"{base}_overlay.pt")
    out_png     = os.path.join(output_dir, f"{base}.png")
    if all(os.path.exists(p) for p in (out_cam, out_overlay, out_png)):
        print(f"🔹 Skipping {base}: already done")
        continue

    # ---- load & normalize DICOM ----
    ds        = pydicom.dcmread(dicom_path, force=True)
    arr       = ds.pixel_array.astype("float32")
    arr_min   = arr.min()
    arr_range = np.ptp(arr)
    arr       = (arr - arr_min) / (arr_range + 1e-6)
    img       = Image.fromarray((arr * 255).astype("uint8"))

    # ---- forward + backward on true class ----
    x       = val_transform(img).unsqueeze(0).to(device)
    model.zero_grad()
    out     = model(x)
    cls_idx = unique_labels.index(row.true_label)
    out[0, cls_idx].backward()

    # ---- build raw CAM ----
    feat = activations["feat"][0]    # C×H×W
    grad = gradients["grad"][0]      # C×H×W
    wts  = grad.mean(dim=(1,2))      # C
    cam  = (wts[:,None,None] * feat).sum(dim=0).cpu().detach().numpy()
    cam  = np.maximum(cam, 0)
    cam  = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
    cam  = cv2.resize(cam, (arr.shape[1], arr.shape[0]), interpolation=cv2.INTER_LINEAR)

    # ---- build overlay ----
    heat    = cv2.applyColorMap((cam*255).astype("uint8"),
                                cv2.COLORMAP_JET)[...,::-1] / 255.0
    overlay = 0.6 * np.dstack([arr]*3) + 0.4 * heat

    # ---- save outputs ----
    torch.save(torch.from_numpy(cam).float(),    out_cam)
    torch.save(torch.from_numpy(overlay)
               .permute(2,0,1).float(),        out_overlay)

    # ---- dump a quick PNG ----
    plt.figure(figsize=(5,5))
    plt.imshow(overlay)
    plt.axis("off")
    plt.savefig(out_png, bbox_inches="tight", pad_inches=0)
    plt.close()

# ─── 11) Clean up hooks ─────────────────────────────────────────────────────
fh.remove()
bh.remove()




🔹 Skipping 8159799c-7615c0ba-9676dd65-8b0cd6ed-96872c8f: already done
🔹 Skipping b9a08a39-c53ad784-99673387-d9140a2f-cbc1dbde: already done
🔹 Skipping 163e7408-e7e88bfd-ae448fe2-484a43ec-23ebcf71: already done
🔹 Skipping fbc0acfa-ae0bbb10-37a0c81e-bff2aced-678b58b7: already done
🔹 Skipping 4c329d77-162e3abb-df1731fc-a0f2354f-4777a58e: already done
🔹 Skipping a4ed7ed0-c2305148-b7b09a2e-ec63d023-ef9fd8df: already done
🔹 Skipping 6a8f19a4-2030fcda-b0f13ba9-b050a6a1-aa07e72a: already done
🔹 Skipping 7ff22806-8d18e0c8-5d2e1bcc-638b22a0-70654bb1: already done
🔹 Skipping ed231cb9-58b5647e-672e03e3-d43be791-c485128e: already done
🔹 Skipping f24ba3b1-8a4cc77f-23ad8f8c-5c3dca7d-77e2c0da: already done
🔹 Skipping 12fd2ed8-5a501563-a86d9388-5ba1a246-2ac9104b: already done
🔹 Skipping 5d2545e0-ea3ad600-6a2fa53d-e9336b30-cf8d3179: already done
🔹 Skipping bbedc806-4228a38a-e077c922-bcb355a7-f7a6d785: already done
🔹 Skipping cb900258-b426b740-fecaaf9d-c43940af-de57c019: already done
🔹 Skipping 3cf27c2d-

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.0..1.0000000238418578].
Clipping input data to the valid range for imshow wit

🔹 Skipping 07e9a0b4-52559819-ede26703-11403ead-a9b761d5: already done
🔹 Skipping 5d93739f-badcdac7-b7414fdb-8db69418-9838f396: already done
🔹 Skipping 7776726a-5d5ceb10-22fa9917-3b7d11de-d413472a: already done
🔹 Skipping 336f2d98-81dbf659-f12c3003-665a59e4-7c040148: already done
🔹 Skipping cf8e40e4-13142780-e760eaf4-54a31990-8b748ee4: already done
🔹 Skipping 843946d3-fd44acc6-6a826d5c-c7242336-cf599b15: already done
🔹 Skipping 321bf6a2-75f7bbe3-b7515fac-a892f6cb-a3c07862: already done
🔹 Skipping 58055944-a4cda095-067c7c2e-8d6a4214-3c50f0c2: already done
🔹 Skipping 8a3efe56-73f935de-bfce35d7-d056e503-48b984ec: already done
🔹 Skipping 36b93225-8b7956ca-80a09685-70e55196-39dc2a14: already done
🔹 Skipping 54823b86-6b79c371-fa400bea-ff0dfa32-2ff9a43a: already done
🔹 Skipping d2c06203-af9fc488-12a138f1-9a6b4ece-ae383690: already done
🔹 Skipping 4e16d481-6e8f54bb-05790943-e19ccf55-cd17dd06: already done
🔹 Skipping 608e7747-d0e18df3-0293e862-f05411be-c027d53b: already done
🔹 Skipping e52db541-