In [1]:
from pipeline_utils import create_convnextv2_base, remove_pect_muscle, prep_convnextv2_for_gradcam,\
    predict, grad_cam_plusplus, get_org_cc_mlo_maps
import cv2
from UNet3Plus import ResNet101UNet3Plus
import torch
from torchvision.transforms import v2
import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
SEGMENTOR_PATH = r"D:\Study\Proposal\Breast cancer(v2)\MammoClassifier\checkpoints\last_resnet101_unet3-inbreast_mias-breast_roi-adam-no_cls_guide-no_mixup-elastic_flip-output_resized-bs8_e100.pt"
CLASSIFIER_PATH = r"D:\Study\Proposal\Breast cancer(v2)\MammoClassifier\checkpoints\convnextv2_base-AdamW-up_sample-pos_smooth-mixup-cmmd_vindr-VOILUT_Flipped_pect_imgs-bs8x8-s0_e36_seed0.pt" 

In [4]:
segmentor_model = ResNet101UNet3Plus(num_classes = 2,
                   resnet_weights = None,
                   class_guided = False,
                   is_batchnorm = True,
                   output_size = (512, 512)).to(device)
segmentor_model.load_state_dict(torch.load(SEGMENTOR_PATH, map_location=device))

  segmentor_model.load_state_dict(torch.load(SEGMENTOR_PATH, map_location=device))


<All keys matched successfully>

In [5]:
classifier_model = create_convnextv2_base(device, CLASSIFIER_PATH)
classifier_model = prep_convnextv2_for_gradcam(classifier_model)

  checkpoint = torch.load(checkpoint_path, map_location=device)


Loading D:\Study\Proposal\Breast cancer(v2)\MammoClassifier\checkpoints\convnextv2_base-AdamW-up_sample-pos_smooth-mixup-cmmd_vindr-VOILUT_Flipped_pect_imgs-bs8x8-s0_e36_seed0.pt
<All keys matched successfully>


In [6]:
segmentor_transform = v2.Compose([
    v2.ToImage(),
    v2.Resize(size = (512, 512), antialias = True),
    v2.ToDtype(torch.float32, scale = True)
    ])

resize_to_org = v2.Resize((1024, 512), antialias = True)

classifier_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.20275, 0.20275, 0.20275],
                [0.19875, 0.19875, 0.19875])])

In [24]:
for param in classifier_model.stem.parameters():
    param.requires_grad = True
    
for param in classifier_model.stages.parameters():
    param.requires_grad = True
    
for param in classifier_model.norm_pre.parameters():
    param.requires_grad = True
    
for param in classifier_model.head.parameters():
    param.requires_grad = True

In [8]:
cc_path = "D:/Study/Proposal/Breast cancer(v2)/MammoClassifier/MammoClassifier.Web/wwwroot/studies/images/8684388/I0000000.dcm"
mlo_path = "D:/Study/Proposal/Breast cancer(v2)/MammoClassifier/MammoClassifier.Web/wwwroot/studies/images/8684388/I0000002.dcm"
output_cc_path = "cc_map.png"
output_mlo_path = "mlo_map.png"

In [9]:
mlo_roi, mlo_bbox, mlo_org_shape = remove_pect_muscle(mlo_path, segmentor_transform, resize_to_org, segmentor_model, device)


In [34]:
logits, probs, cc_mlo_img_org, cc_mlo_img, cc_bbox, cc_org_shape = predict(cc_path, mlo_roi, classifier_model, device, classifier_transform)


In [15]:
from time import time

In [26]:
t1=time()
logits[:, logits.argmax()].backward()
t2=time()

print(f"backward pass with all requires grad True elapsed time: {t2-t1}s")

backward pass with all requires grad True elapsed time: 189.2937150001526s


In [35]:
t1 = time()
image_with_map, heatmap = grad_cam_plusplus(logits, classifier_model, cc_mlo_img_org, cc_mlo_img, device)
t2 = time()

print(f"GradCAM elapsed time: {t2-t1}")

GradCAM elapsed time: 485.30234694480896


In [None]:
cc_map, mlo_map = get_org_cc_mlo_maps(heatmap, cc_bbox, mlo_bbox, cc_org_shape, mlo_org_shape)

cv2.imwrite(output_cc_path, cv2.cvtColor(cc_map, cv2.COLOR_RGB2BGR))
cv2.imwrite(output_mlo_path, cv2.cvtColor(mlo_map, cv2.COLOR_RGB2BGR))
# cv2.imwrite(args.output_path, cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR))

print(f"Logits: {logits}, Probs: {probs}")
print(f"Ouput images saved to {output_cc_path}, {output_mlo_path}")
