In [12]:
import sys
import os
import torch
import torch.nn as nn
import torch.functional as f
import numpy as np
import scipy.ndimage as nd
from scipy.ndimage import distance_transform_edt

In [13]:
def compute_csf_distance_map(csf_mask: torch.Tensor, spacing=None) -> torch.Tensor:
  device = csf_mask.device
  csf_np = csf_mask.detach().cpu().numpy().astype(np.float32)

  B = csf_np.shape[0]
  dist_maps = []

  for b in range(B):
    csf_slice = csf_np[b,0]
    dt = distance_transform_edt(2.0 - csf_slice, sampling=spacing)
    dist_maps.append(dt)

  dist_maps = np.stack(dist_maps, axis=0)
  dist_maps = torch.from_numpy(dist_maps).to(device=device, dtype=torch.float32)
  dist_maps = dist_maps.unsqueeze(1)
  print()
  print("=== CSF Maps ===")
  print(csf_mask)

  return dist_maps

def clustering(loss_map):
  numpy_arr = loss_map.squeeze().numpy()
  labeled, num = nd.label(numpy_arr)
  max_values = []
  for i in range(1, num + 1):
    region = numpy_arr[labeled == i]
    # print(region)
    # print()
    max_values.append(region.max())
  return max_values


In [14]:
class CSFDistanceLoss(nn.Module):
  def __init__(self, sigma: float= 5.0, reduction: str="sum"):
    super().__init__()
    self.sigma = sigma
    # self.max_distance= max_distance
    assert reduction in ["mean", "sum", "none"]
    self.reduction = reduction

  def forward(self, cmb_logits: torch.Tensor, gt_cmb: torch.Tensor, gt_csf: torch.Tensor, spacing=None) -> torch.Tensor:
    device = cmb_logits.device
    cmb_prob = torch.sigmoid(cmb_logits)
    print("=== CMB Probability Maps ===")
    print(cmb_prob)
    cmb_prob = cmb_prob * (cmb_prob > 0.6).float()      
    print(cmb_prob)

    dist_map = compute_csf_distance_map(gt_csf, spacing).to(device) # Distance Map for CSF
    print("=== Distance Maps ===")
    print(dist_map)
    print()
       
    weight = torch.exp(-dist_map / self.sigma)  # Distance map based weight
    print("=== Weight Maps ===")
    print(weight)
    print()
    # if self.max_distance is not None:
    #   far_mask = dist_map > self.max_distance # boolean numpy 2d ndarray
    #   weight = weight * (~far_mask)

    '''
    filter for removing the true cmb regions.
    It needs to be larger to fully avoid the true cmb predicted region.
    '''
    bg_mask = (gt_cmb==0).float()
    # bg_mask = bg_mask * (gt_csf!=1.0).float()
    print("=== Background Mask (to filter out true CMB regions) ===")
    print(bg_mask)
    print()
    fp_soft = cmb_prob * bg_mask # selecting only fp cases.

    loss_map = fp_soft * weight
    print("=== Loss Map ===")
    print(loss_map)
    print()
    max_values = clustering(loss_map)

    print("=== Max values for each cluster ===")
    print(max_values)
    print()
    
    print("=== Sum of max values for all clusters: FINAL DISTANCE LOSS ===")
    print(sum(max_values))
    if self.reduction == "mean":
      denom = (bg_mask * (weight > 0).float()).sum()
      denom = torch.clamp(denom, min=1.0)
      return loss_map.sum() / denom
    elif self.reduction == "sum":
      return sum(max_values)
      # return loss_map.sum()
    else:
      return loss_map

In [15]:
cmb_logits = torch.tensor([[[
                    [0.1,0.2,0.3,0.4,0.5,0.64,0.4],
                    [0.1,0.9,0.3,0.4,0.5,0.8,0.6],
                    [0.1,0.7,0.3,0.4,0.5,0.6,0.7],
                    [0.1,0.2,0.3,0.1,0.1,0.1,0.1],
                    [0.1,0.2,0.3,0.1,0.1,0.1,0.1],
                    [0.9,0.7,0.3,0.1,0.1,0.1,0.1],
                    [0.8,0.6,0.3,0.1,0.1,0.1,0.1]]]])
cmb_logits = cmb_logits.float()
# gt_cmb = torch.zeros(1, 1, 7, 7)
gt_cmb = torch.tensor([[[                
                    [0,0,0,0,0,0,0],
                    [0,0,0,0,0,0,0],
                    [1,0,0,0,0,0,0],
                    [1,1,0,0,0,0,0],
                    [1,0,0,0,0,0,0],
                    [0,0,0,0,0,0,0],
                    [0,0,0,0,0,0,0]]]])

gt_csf = torch.tensor([[[
                    [0,0,0,0,0,0,0],
                    [0,0,0,0,0,0,0],
                    [0,0,0,0,0,2,2],
                    [0,0,0,0,0,2,2],
                    [0,0,0,0,2,2,2],
                    [0,0,0,0,0,0,2],
                    [0,0,0,0,0,2,2]]]])
# gt_csf = torch.zeros(1, 1, 7, 7)

bce_loss_fn = nn.BCEWithLogitsLoss()
distance_loss_fn = CSFDistanceLoss(sigma=5.0, reduction="mean")

bce_loss = bce_loss_fn(cmb_logits, gt_cmb.float())
distance_loss = distance_loss_fn(cmb_logits, gt_cmb, gt_csf, spacing=None)

# Combine (lambda is a hyperparameter)
lambda_dist = 0.5
total_loss = bce_loss + lambda_dist * distance_loss
# print(total_loss)

# total_loss.backward()
# optimizer.step()

=== CMB Probability Maps ===
tensor([[[[0.5250, 0.5498, 0.5744, 0.5987, 0.6225, 0.6548, 0.5987],
          [0.5250, 0.7109, 0.5744, 0.5987, 0.6225, 0.6900, 0.6457],
          [0.5250, 0.6682, 0.5744, 0.5987, 0.6225, 0.6457, 0.6682],
          [0.5250, 0.5498, 0.5744, 0.5250, 0.5250, 0.5250, 0.5250],
          [0.5250, 0.5498, 0.5744, 0.5250, 0.5250, 0.5250, 0.5250],
          [0.7109, 0.6682, 0.5744, 0.5250, 0.5250, 0.5250, 0.5250],
          [0.6900, 0.6457, 0.5744, 0.5250, 0.5250, 0.5250, 0.5250]]]])
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.6225, 0.6548, 0.0000],
          [0.0000, 0.7109, 0.0000, 0.0000, 0.6225, 0.6900, 0.6457],
          [0.0000, 0.6682, 0.0000, 0.0000, 0.6225, 0.6457, 0.6682],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.7109, 0.6682, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.6900, 0.6457, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])

=== CSF Maps

### Implementation Note V2

In [16]:
class Loss_Distance(nn.Module):
    def __init__(self, batched_inputs, root_path, sigma_param):
        super(Loss_Distance, self).__init__()
        # self.outputs = outputs['batched_masks']
        self.batched_inputs = batched_inputs
        self.root_path = "/media/Datacenter_storage/Ji/BiomedParse"
        self.csf_mask_path = os.path.join(self.root_path, batched_inputs[0]["grounding_info"][0]["mask_file"])
        self.cmb_mask_path = self.csf_mask_path.replace("cerebrospinal+fluid", "brain_microbleeds")
        self.sigma = sigma_param

    def forward(self, cmb_prob: torch.Tensor, csf_mask_path) -> torch.Tensor:
        cmb_mask_path = csf_mask_path.replace("cerebrospinal+fluid", "brain+microbleeds")
        gt_csf = Image.open(csf_mask_path).convert('L')
        gt_cmb = Image.open(cmb_mask_path).convert('L')
        gt_csf = torch.from_numpy(np.array(gt_csf)).to(cmb_prob.device).unsqueeze(0).unsqueeze(0).float()
        gt_cmb = torch.from_numpy(np.array(gt_cmb)).to(cmb_prob.device).unsqueeze(0).unsqueeze(0).float()

        # distance map from CSF (fixed wrt model => fine)
        dist_map = compute_csf_distance_map(gt_csf).to(cmb_prob.device)

        weight = torch.exp(-dist_map / self.sigma)   # Bx1xHxW

        # background mask outside CMB
        bg_mask = (gt_cmb == 0).float().to(cmb_prob.device)

        # model prediction
        cmb_prob = torch.sigmoid(cmb_prob)
        # (optional) REMOVE hard threshold for better gradients:
        # cmb_prob = cmb_prob * (cmb_prob > 0.6).float()

        # resize prediction to match mask
        cmb_prob = F.interpolate(cmb_prob, size=bg_mask.shape[2:], mode='bilinear', align_corners=False)

        # false-positive "soft" map
        fp_soft = cmb_prob * bg_mask  # only region where GT is not CMB

        # distance-weighted loss
        loss_map = fp_soft * weight

        # final scalar loss (differentiable)
        # loss_distance = loss_map.mean()
        loss_distance = loss_map.sum() / torch.clamp((weight > 0).sum(), min=1.0)

        # loss_distance = loss_map.sum()
        return loss_distance

In [17]:
cmb_logits = torch.tensor([[[
                    [0.1,0.2,0.3,0.4,0.5,0.64,0.4],
                    [0.1,0.9,0.3,0.4,0.5,0.8,0.6],
                    [0.1,0.7,0.3,0.4,0.5,0.6,0.7],
                    [0.1,0.2,0.3,0.1,0.1,0.1,0.1],
                    [0.1,0.2,0.3,0.1,0.1,0.1,0.1],
                    [0.9,0.7,0.3,0.1,0.1,0.1,0.1],
                    [0.8,0.6,0.3,0.1,0.1,0.1,0.1]]]])
cmb_logits = cmb_logits.float()
# gt_cmb = torch.zeros(1, 1, 7, 7)
gt_cmb = torch.tensor([[[                
                    [0,0,0,0,0,0,0],
                    [0,0,0,0,0,0,0],
                    [1,0,0,0,0,0,0],
                    [1,1,0,0,0,0,0],
                    [1,0,0,0,0,0,0],
                    [0,0,0,0,0,0,0],
                    [0,0,0,0,0,0,0]]]])

gt_csf = torch.tensor([[[
                    [0,0,0,0,0,0,0],
                    [0,0,0,0,0,0,0],
                    [0,0,0,0,0,1,1],
                    [0,0,0,0,0,1,1],
                    [0,0,0,0,1,1,1],
                    [0,0,0,0,0,0,1],
                    [0,0,0,0,0,1,1]]]])
# gt_csf = torch.zeros(1, 1, 7, 7)

bce_loss_fn = nn.BCEWithLogitsLoss()
distance_loss_func = Loss_Distance(batched_inputs, "/media/Datacenter_storage/Ji/BiomedParse", sigma_param=10.0)

bce_loss = bce_loss_fn(cmb_logits, gt_cmb.float())
distance_loss = distance_loss_fn(cmb_logits, gt_cmb, gt_csf, spacing=None)

# Combine (lambda is a hyperparameter)
lambda_dist = 0.5
total_loss = bce_loss + lambda_dist * distance_loss
# print(total_loss)

# total_loss.backward()
# optimizer.step()

NameError: name 'batched_inputs' is not defined

### Implementation Note V1

In [None]:
import sys
import torch
import torch.nn as nn
import torch.functional as f
import numpy as np
import scipy.ndimage as nd
from scipy.ndimage import distance_transform_edt
from PIL import Image

In [None]:
def compute_csf_distance_map(csf_mask: torch.Tensor, spacing=None) -> torch.Tensor:
  device = csf_mask.device
  csf_np = csf_mask.detach().cpu().numpy().astype(np.float32)

  B = csf_np.shape[0]
  dist_maps = []

  for b in range(B):
    csf_slice = csf_np[b,0]
    dt = distance_transform_edt(1.0 - csf_slice, sampling=spacing)
    dist_maps.append(dt)

  dist_maps = np.stack(dist_maps, axis=0)
  dist_maps = torch.from_numpy(dist_maps).to(device=device, dtype=torch.float32)
  dist_maps = dist_maps.unsqueeze(1)
  print("=== CSF Maps ===")
  print(csf_mask)
  return dist_maps

In [None]:
csf_mask_path = "/media/Datacenter_storage/Ji/BiomedParse/biomedparse_datasets/loss_dev/train_mask/sub-207-slice-108_MRI_Brain_cerebrospinal+fluid.png"
cmb_mask_path = csf_mask_path.replace("cerebrospinal+fluid", "brain+microbleeds")
gt_csf = Image.open(csf_mask_path).convert('L')
gt_cmb = Image.open(cmb_mask_path).convert('L')
gt_csf = np.array(gt_csf)
gt_cmb = np.array(gt_cmb)

gt_csf = torch.from_numpy(gt_csf)
gt_cmb = torch.from_numpy(gt_cmb)

gt_csf = gt_csf.unsqueeze(0).unsqueeze(0).float()
gt_cmb = gt_cmb.unsqueeze(0).unsqueeze(0).float()

print(gt_csf.shape)
print(gt_cmb.shape)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dist_map = compute_csf_distance_map(gt_csf).to(device) # Distance Map for CSF
print("=== Distance Maps ===")
print(dist_map)
print()

sigma = 5.0
weight = torch.exp(-dist_map / sigma)  # Distance map based weight
print("=== Weight Maps ===")
print(weight)
print()

# bg_mask = bg_mask * (gt_csf!=1.0).float()
print("=== Background Mask (to filter out true CMB regions) ===")
bg_mask = (gt_cmb==0).float()


import torch.nn.functional as F
bg_mask = F.interpolate(bg_mask, size=(256, 256), mode='nearest')
print("bg_mask")
print(bg_mask)

print()



cmb_prob = torch.rand([1,101,256, 256])
cmb_prob = torch.sigmoid(cmb_prob)
cmb_prob = cmb_prob * (cmb_prob > 0.6).float()
fp_soft = cmb_prob * bg_mask # selecting only fp cases.

loss_map = fp_soft * weight
print("=== Loss Map ===")
print(loss_map)
print()
max_values = clustering(loss_map)

print("=== Max values for each cluster ===")
print(max_values)
print()

print("=== Sum of max values for all clusters: FINAL DISTANCE LOSS ===")
print(sum(max_values))

reduction = "sum"
if reduction == "mean":
    denom = (bg_mask * (weight > 0).float()).sum()
    denom = torch.clamp(denom, min=1.0)
    # return loss_map.sum() / denom
elif reduction == "sum":
    return sum(max_values)
# return loss_map.sum()
else:
    return loss_map


torch.Size([1, 1, 512, 512])
torch.Size([1, 1, 512, 512])
=== CSF Maps ===
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])
=== Distance Maps ===
tensor([[[[222.9125, 222.1463, 221.3820,  ..., 224.0893, 224.8488, 225.6103],
          [222.2724, 221.5040, 220.7374,  ..., 223.4390, 224.2008, 224.9644],
          [221.6348, 220.8642, 220.0954,  ..., 222.7914, 223.5554, 224.3212],
          ...,
          [187.4887, 186.6574, 185.8279,  ..., 180.6239, 181.4718, 182.2581],
          [188.0452, 187.2164, 186.3867,  ..., 181.1574, 182.0027, 182.8497],
          [188.6054, 187.7711, 186.9358,  ..., 181.6948, 182.5377, 183.3821]]]],
       device='cuda:0')

=== Weight Maps ===
tensor([[[[4.3457e-20, 5.0654e-20, 5.9020e-20,  ..., 3.4344e-20,
           2.9504e-20, 2.5336e-20],
 

RuntimeError: The size of tensor a (256) must match the size of tensor b (512) at non-singleton dimension 3