# segment wise attribution

In [1]:

def _gain_density(mask1, attr, mask2=None):
  # Compute the attr density over mask1. If mask2 is specified, compute density
  # for mask1 \ mask2
  if mask2 is None:
    added_mask = mask1
  else:
    added_mask = _get_diff_mask(mask1, mask2)
  if not np.any(added_mask):
    return -np.inf
  else:
    return attr[added_mask].mean()


def _get_diff_mask(add_mask, base_mask):
  return np.logical_and(add_mask, np.logical_not(base_mask))


def _get_diff_cnt(add_mask, base_mask):
  return np.sum(_get_diff_mask(add_mask, base_mask))

def _unpack_segs_to_masks(segs):
  masks = []
  for seg in segs:
    for l in range(seg.min(), seg.max() + 1):
      masks.append(seg == l)
  return masks

def _normalize_image(im, value_range, resize_shape=None):
  """Normalize an image by resizing it and rescaling its values.

  Args:
      im: Input image.
      value_range: [min_value, max_value]
      resize_shape: New image shape. Defaults to None.

  Returns:
      Resized and rescaled image.
  """
  im_max = np.max(im)
  im_min = np.min(im)
  im = (im - im_min) / (im_max - im_min)
  im = im * (value_range[1] - value_range[0]) + value_range[0]
  if resize_shape is not None:
    im = resize(im,
                resize_shape,
                order=3,
                mode='constant',
                preserve_range=True,
                anti_aliasing=True)
  return im

In [2]:
def calculate_XRAI_attr(attr,segs):
    output_attr = -np.inf * np.ones(shape=attr.shape, dtype=float)
    list_of_masks_for_prediction=[]
    list_of_segment_avg_for_prediction=[]
    n_masks = len(segs)
    current_area_perc = 0.0
    area_perc_th=1.0
    min_pixel_diff=5
    current_mask = np.zeros(attr.shape, dtype=bool)
    
    masks_trace = []
    remaining_masks = {ind: mask for ind, mask in enumerate(segs)}
    
    added_masks_cnt = 1
    # While the mask area is less than area_th and remaining_masks is not empty
    while current_area_perc <= area_perc_th:
      best_gain = -np.inf
      best_key = None
      remove_key_queue = []
      for mask_key in remaining_masks:
        mask = remaining_masks[mask_key]
        # If mask does not add more than min_pixel_diff to current mask, remove
        mask_pixel_diff = _get_diff_cnt(mask, current_mask)
        if mask_pixel_diff < min_pixel_diff:
          remove_key_queue.append(mask_key)
          # if _logger.isEnabledFor(logging.DEBUG):
          #   _logger.debug('Skipping mask with pixel difference: {:.3g},'.format(
          #       mask_pixel_diff))
          continue
        gain = _gain_density(mask, attr, mask2=current_mask)
        if gain > best_gain:
          best_gain = gain
          best_key = mask_key
      for key in remove_key_queue:
        del remaining_masks[key]
      if not remaining_masks:
        break
      added_mask = remaining_masks[best_key]
      mask_diff = _get_diff_mask(added_mask, current_mask)
      masks_trace.append((mask_diff, best_gain))

      list_of_masks_for_prediction.append(best_key)
      list_of_segment_avg_for_prediction.append(best_gain)
      current_mask = np.logical_or(current_mask, added_mask)
      current_area_perc = np.mean(current_mask)
      
      output_attr[mask_diff] = best_gain
      del remaining_masks[best_key]  # delete used key
      # plt.imshow(output_attr)
      # plt.show()
      # if _logger.isEnabledFor(logging.DEBUG):
      #   current_attr_sum = np.sum(attr[current_mask])
      #   _logger.debug(
      #       '{} of {} masks added,'
      #       'attr_sum: {}, area: {:.3g}/{:.3g}, {} remaining masks'.format(
      #           added_masks_cnt, n_masks, current_attr_sum, current_area_perc,
      #           area_perc_th, len(remaining_masks)))
      added_masks_cnt += 1
    
    uncomputed_mask = output_attr == -np.inf
    # Assign the uncomputed areas a value such that sum is same as ig
    output_attr[uncomputed_mask] = _gain_density(uncomputed_mask, attr)
    # masks_trace = [v[0] for v in sorted(masks_trace, key=lambda x: -x[1])]
    # if np.any(uncomputed_mask):
    #   masks_trace.append(uncomputed_mask)
    # if integer_segments:
    #   attr_ranks = np.zeros(shape=attr.shape, dtype=int)
    #   for i, mask in enumerate(masks_trace):
    #     attr_ranks[mask] = i + 1
    #   return output_attr, attr_ranks
    # else:
    #   return output_attr, masks_trace
    return output_attr,list_of_masks_for_prediction,list_of_segment_avg_for_prediction

In [1]:
def check_segmentwise_attributions_and_load(pixelwise_attributions,topk,top10list,pth,segs,redo=False):
    list_of_masks={}
    list_of_segment_avg={}
    if os.path.exists(pth+'_swa'+'.pkl') and redo==False:
        with open(pth+'_swa'+'.pkl', 'rb') as file:
            (segmentwise_attributions,list_of_masks,list_of_segment_avg)=pickle.load(file)
        print('loaded segment wise attributions from file')
        return (segmentwise_attributions,list_of_masks,list_of_segment_avg)

    else:
        print('calculating segment wise attributions')
        segmentwise_attributions={}
        for i in range(topk):
            print('for '+str(top10list[i]))
            target_class_idx_list=[top10list[i]]
            segmentwise_attributions[top10list[i]],list_of_masks_for_prediction,list_of_segment_avg_for_prediction=calculate_XRAI_attr(pixelwise_attributions[top10list[i]],segs)
            # print(list_of_masks_for_prediction)
            list_of_masks[top10list[i]]=list_of_masks_for_prediction
            list_of_segment_avg[top10list[i]]=list_of_segment_avg_for_prediction

        with open(pth+'_swa'+'.pkl', 'wb') as file:
            pickle.dump((segmentwise_attributions,list_of_masks,list_of_segment_avg), file)
            file.close()
        return (segmentwise_attributions,list_of_masks,list_of_segment_avg)

        # dont save them into pkl file we can save along wirth xrai attribution along with segmentation