In [12]:
# ─── 1) Imports & drive paths ───────────────────────────────────────────────
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision.models import densenet121
from pydicom import dcmread
from PIL import Image
from torchvision import transforms
import zipfile, io
import numpy as np
import pydicom
from PIL import Image
import random

# Point at your external drive:
DRIVE_ROOT   = "/Volumes/Extra Storage"
DATA_DIR     = os.path.join(DRIVE_ROOT, "local_ADS_data", "necessary_sufficient")
RESULTS_DIR  = os.path.join(DRIVE_ROOT, "local_ADS_data", "animations")
WEIGHTS_PATH = os.path.join(DRIVE_ROOT, "models", "reproduceable_densenet.pt")

os.makedirs(RESULTS_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
def load_dicom_as_pil(path):
    """
    path can be either:
      - "/full/path/to/foo.dcm"   (plain file), or
      - "/full/path/to/data.zip!inner/folder/foo.dcm"
    """
    if "!" in path:
        archive_path, inner_path = path.split("!", 1)
        inner_path = inner_path.lstrip("/")  # no leading slash in the archive
        with zipfile.ZipFile(archive_path, "r") as zf:
            data = zf.read(inner_path)
            ds   = pydicom.dcmread(io.BytesIO(data))
    else:
        ds = pydicom.dcmread(path)

    arr = ds.pixel_array.astype(np.float32)
    # normalize to [0,255] so PIL can handle it nicely
    arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
    arr = (arr * 255).clip(0,255).astype(np.uint8)
    return Image.fromarray(arr)


In [6]:
TRAIN_MEAN = 0.5007
TRAIN_STD  = 0.2508

In [7]:
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: t.float()),
    transforms.Normalize([TRAIN_MEAN], [TRAIN_STD])
])

In [8]:
def optimize_mask(model, x, alpha=1.0, lr=0.5, log_every=10, tol=1e-5, max_no_improve=5, init_mask=None, output_dir=None, prefix=None):
    """
    Optimize a mask with learning rate decay and safe return on interrupt.
    """
    mask = torch.nn.Parameter(init_mask.clone().detach() if init_mask is not None else torch.ones_like(x), requires_grad=True)
    optimizer = torch.optim.Adam([mask], lr=lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.9)
    #scheduler = torch.optim.lr_scheduler.ExponentialLRS(optimizer, gamma=0.9)

    mask_history = []
    suff_history = []
    nec_history  = []
    sparsity_history = []
    tv_history = []
    prev_loss = float('inf')
    no_improve_count = 0
    step = 0
    changed_class = False
    best_loss = float('inf')
    best_mask = None   

    try:
        while True:
            optimizer.zero_grad()
            loss, components = explanation_loss(model, x, mask, background_path= "mean_image.pt", alpha=alpha, return_components=True)
            
            if components['loss'] < best_loss:
                best_loss = components['loss']
                best_mask = mask.detach().clone()
            loss.backward()
            optimizer.step()
            scheduler.step()
            mask.data.clamp_(0, 1)

            if (step + 1) % log_every == 0 or step == 0:
                current_lr = scheduler.get_last_lr()[0]
                print(f"Step {step+1:4d}: "
                      f"Loss={components['loss']:.6f} | "
                      f"Suff={components['sufficiency']:.6f} | "
                      f"Nec={components['necessity']:.6f} | "
                      f"L1={components['l1']:.6f} | "
                      f"TV={components['tv']:.6f} | "
                      f"Changed (keep): {components['changed_with_keep']} | "
                      f"Changed (remove): {components['changed_with_remove']} | "
                      f"LR={current_lr:.6e}")
                mask_history.append(mask.detach().cpu().clone())
                suff_history.append(components['sufficiency'])
                nec_history.append( components['necessity'] )
                sparsity_history.append( mask.abs().mean().item() )
                tv_history.append(components['tv'])

                if alpha == 1.0 and components['changed_with_keep']:
                    changed_class = True
                elif alpha == 0.0 and components['changed_with_remove']:
                    changed_class = True

            loss_delta = abs(prev_loss - components['loss'])
            if loss_delta < tol:
                no_improve_count += 1
                if no_improve_count >= max_no_improve:
                    print(f"Converged at step {step+1} (loss change < {tol} for {max_no_improve} steps)")
                    break
            else:
                no_improve_count = 0

            prev_loss = components['loss']
            step += 1

    except KeyboardInterrupt:
        print(f"\n⏹️ Optimization manually interrupted at step {step+1}. Returning current mask.")
    
    final_mask = best_mask

    fig, ax = plt.subplots()
    ax.plot(suff_history,      label='Sufficiency', linewidth=2)
    ax.plot(nec_history,       label='Necessity',    linewidth=2)
    ax.plot(sparsity_history,  label='Sparsity',     linewidth=2)
    ax.plot(tv_history, label = 'Smoothness', linewidth = 2)

    ax.set_xlabel("Step")
    ax.set_ylabel("Metric Value")
    ax.set_title("Sufficiency, Necessity & Sparsity over Optimization")
    ax.legend(loc='best')
    plt.tight_layout()
    
    if output_dir and prefix:
        os.makedirs(output_dir, exist_ok=True)
        # 1) final mask
        torch.save(final_mask.cpu(),
                   os.path.join(output_dir, f"{prefix}_final_mask.pt"))
        # 2) mask history
        torch.save([m.cpu() for m in mask_history],
                   os.path.join(output_dir, f"{prefix}_mask_history.pt"))
        # 3) metric plot
        fig.savefig(os.path.join(output_dir, f"{prefix}_metrics.png"))
        plt.close(fig)
    else:
        plt.show()
    
    return mask.detach(), mask_history, changed_class


In [10]:
# ─── Load & patch a single-channel DenseNet ─────────────────────────────────
from torchvision.models import densenet121
import torch.nn as nn

# 1) instantiate the standard DenseNet121
model = densenet121(pretrained=True)

# 2) swap its first conv (3→1 channel)
old_conv = model.features.conv0
new_conv = nn.Conv2d(
    in_channels=1,
    out_channels=old_conv.out_channels,
    kernel_size=old_conv.kernel_size,
    stride=old_conv.stride,
    padding=old_conv.padding,
    bias=(old_conv.bias is not None)
)
with torch.no_grad():
    # initialize new_conv by averaging the RGB weights
    new_conv.weight[:] = old_conv.weight.mean(dim=1, keepdim=True)
model.features.conv0 = new_conv

# 3) move to device & load your fine-tuned weights
model = model.to(device)
model.load_state_dict(torch.load(
    "/Users/Kyra_1/Downloads/reproduceable_densenet.pt",
    map_location=device
))
model.eval()


RuntimeError: Error(s) in loading state_dict for DenseNet:
	size mismatch for classifier.weight: copying a param with shape torch.Size([4, 1024]) from checkpoint, the shape in current model is torch.Size([1000, 1024]).
	size mismatch for classifier.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([1000]).

In [None]:
# ─── 2) Load your model & helper funcs ───────────────────────────────────────
# Replace with whatever model class & loading you use
from your_model_module import YourMaskModel, optimize_mask, load_data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = YourMaskModel().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()


In [None]:
# ─── 3) Utility: load a DICOM into a normalized torch.Tensor ────────────────
def load_dicom_tensor(dicom_id):
    path = os.path.join(DATA_DIR, f"{dicom_id}.dcm")
    ds   = dcmread(path, force=True)
    arr  = ds.pixel_array.astype(np.float32)
    norm = (arr - arr.min()) / (arr.ptp() + 1e-6)
    # shape → (1, H, W)
    return torch.from_numpy(norm).unsqueeze(0)


In [None]:
# ─── 4) Utility: save figures & .pt ─────────────────────────────────────────
def save_overlay(orig, mask, out_path):
    plt.imshow(orig, cmap="gray")
    plt.imshow(mask, cmap="jet", alpha=0.5)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()

def save_loss_history(history, out_path):
    plt.plot(history["sufficiency"], label="sufficiency")
    plt.plot(history["necessity"],  label="necessity")
    plt.plot(history["sparsity"],   label="sparsity")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()


In [None]:
# ─── 5) Single‐ID pipeline ───────────────────────────────────────────────────
def process_dicom_id(dicom_id):
    # 1) Load image tensor
    x = load_dicom_tensor(dicom_id).unsqueeze(0).to(device)  # (1,1,H,W)
    
    # 2) Run your optimization/masking routine
    #    should return final_mask, mask_history, model_pred
    final_mask, mask_history, model_pred = optimize_mask(
        model, x,
        alpha=0.5,
        init_mask=torch.ones_like(x),
        output_dir=None,      # we’ll handle saving ourselves
        prefix=None
    )
    
    # 3) Prepare output paths
    base = os.path.join(RESULTS_DIR, dicom_id)
    os.makedirs(base, exist_ok=True)
    mask_pt_path    = os.path.join(base, f"{dicom_id}_final_mask.pt")
    overlay_png     = os.path.join(base, f"{dicom_id}_overlay.png")
    losses_png      = os.path.join(base, f"{dicom_id}_loss_history.png")
    pred_txt        = os.path.join(base, f"{dicom_id}_prediction.txt")
    
    # 4) Save final mask as .pt
    torch.save(final_mask.detach().cpu(), mask_pt_path)
    
    # 5) Save overlay figure
    orig = x.squeeze().cpu().numpy()
    m    = final_mask.squeeze().cpu().numpy()
    save_overlay(orig, m, overlay_png)
    
    # 6) Save loss history plot
    save_loss_history(mask_history, losses_png)
    
    # 7) Save model prediction
    with open(pred_txt, "w") as f:
        f.write(f"Predicted class: {model_pred}\n")
    
    print(f"[✓] Done {dicom_id}")


In [None]:
# ─── 6) Run on one (or many) IDs ─────────────────────────────────────────────
# Example: single ID
process_dicom_id("4c8be7b3-4bdfb53a-5291392a-e9387349-55fd937c")