In [1]:
import os
import ast
import torch
import numpy as np
import pandas as pd

from PIL import Image
from timm.models import create_model
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = create_model(
    "deit_base_patch16_224",
    pretrained=False,
    num_classes=2,
    drop_rate=0.1,
    drop_path_rate=0.2,
    drop_block_rate=None,
    img_size=224
)
model = model.to(device)

checkpoint_path = r"F:\code\ssl_mammo_599_base\best_acc_checkpoint.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)

model.load_state_dict(checkpoint['model'])
model.eval()

def enable_dropout(m):
    if isinstance(m, torch.nn.Dropout):
        m.train()

def enable_droppath(m):
    if m.__class__.__name__ == "DropPath":
        m.train()

model.apply(enable_dropout)
model.apply(enable_droppath)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [3]:
class InferenceDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform=transform

        self.label_map = {"grade 3": 0, "grade 4": 1}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        patient_id = row["patient_id"]
        laterality = row["laterality"]
        grade = row["grade"]
        image_paths = ast.literal_eval(row["files"])

        patch_dict = {}
        for img_path in image_paths:
            filename = os.path.basename(img_path)
            patch_num = int(filename.split("_patch_")[1].split(".")[0])

            if patch_num not in patch_dict:
                patch_dict[patch_num] = []
            patch_dict[patch_num].append(img_path)

        grouped_images = []
        for patch_num in sorted(patch_dict.keys()):
            images = []
            for img_path in sorted(patch_dict[patch_num]):
                image = Image.open(img_path).convert('RGB')

                if self.transform:
                    image = self.transform(image)

                images.append(image)

            grouped_images.append(torch.stack(images))

        grouped_images = torch.stack(grouped_images)

        label = self.label_map[grade]

        return grouped_images, label

In [None]:
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.2492, 0.2492, 0.2492], std=[0.1920, 0.1920, 0.1920]) # normalize with pretrain dataset (ssl)
])

csv_file = "/hdchoi00/data/CNUH_data/test_inference_example.csv"
test_dataset = InferenceDataset(csv_file, transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=0)

In [5]:
optimal_T = 0.626 # Optimal temperature scaling factor (ImageNet : 0.460, SSL : 0.626)
optimal_threshold = 0.81 # Uncertainty theshold
N = 100 # Number of MC-Dropout forward passes

correct = 0
total = 0
uncertain_count = 0

all_preds = []
all_true = []

label_map = {
    0: "grade 3",
    1: "grade 4"
}

for i, (images, targets) in enumerate(test_loader):
    print("\n## sample " + str(i) + " ##")
    batch_size, num_patches, num_views, C, H, W = images.shape

    for b_idx in range(batch_size):
        patch_probs_list = []
        for p_idx in range(num_patches):
            print("\npatch " + str(p_idx))
            view_probs = []
            for v_idx in range(num_views):
                img_tensor = images[b_idx, p_idx, v_idx].unsqueeze(0).to(device)
                with torch.no_grad(), torch.cuda.amp.autocast():
                    prob_list = []
                    for _ in range(N):
                        logits = model(img_tensor)
                        # Calibrate logits using the optimal temperature
                        calibrated_logits = logits / optimal_T
                        prob = torch.softmax(calibrated_logits, dim=1)[:, 1].item()
                        prob_list.append(prob)

                mean_prob = np.mean(prob_list)
                variance = np.var(prob_list)

                if mean_prob >= 0.5:
                    confidence = mean_prob
                else:
                    confidence = 1 - mean_prob

                if confidence >= optimal_threshold:
                    view_probs.append(mean_prob)
                else:
                    view_probs.append(None)
            
            print("view_probs:", view_probs)
            valid_view_probs = [vp for vp in view_probs if vp is not None]
            if valid_view_probs:
                patch_avg_prob = np.mean(valid_view_probs)
            else:
                patch_avg_prob = None
            
            patch_probs_list.append(patch_avg_prob)

        print("\n-> patch_probs:", patch_probs_list)

        valid_patch_probs = [vp for vp in patch_probs_list if vp is not None]

        if any(vp is None for vp in patch_probs_list):
            if any(vp is not None and vp >= 0.5 for vp in patch_probs_list):
                result = "grade 4"
            else:
                result = "uncertain"
        else:
            if any(vp >= 0.5 for vp in patch_probs_list):
                result = "grade 4"
            else:
                result = "grade 3"

        true_label = label_map.get(targets[b_idx].item(), "unknown")

        all_preds.append(result)
        all_true.append(true_label)

        if result == true_label:
            correct += 1
        
        if result == "uncertain":
            uncertain_count += 1
            
        total += 1

        print("-------------------------------------")
        print("prediction:", result)
        print("label:", true_label)


## sample 0 ##

patch 0
view_probs: [0.824573603272438, 0.8727135056257248]

patch 1
view_probs: [0.8763566786050796, None]

patch 2
view_probs: [0.8805757057666779, None]

-> patch_probs: [0.8486435544490813, 0.8763566786050796, 0.8805757057666779]
-------------------------------------
prediction: grade 4
label: grade 4
