In [None]:
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import functional as F
from sam2.modeling.backbones.image_encoder import ImageEncoder, FpnNeck
from sam2.modeling.backbones.hieradet import Hiera
from sam2.modeling.memory_attention import MemoryAttention, MemoryAttentionLayer
from sam2.modeling.memory_encoder import MemoryEncoder, MaskDownSampler, Fuser, CXBlock
from sam2.modeling.position_encoding import PositionEmbeddingSine
from sam2.modeling.sam.transformer import RoPEAttention

from src.base import SAM2Base
from src.datasets.flare import FLAREDataset3D
from src.predictor import SAM2VideoPredictor
from src.utils import get_eig_from_probs, get_coords_of_tensor_max, get_center_by_erosion, get_labels_from_coords


DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
CHECKPOINT = "./model_checkpoints/sam2_hiera_large.pt"


position_encoding = PositionEmbeddingSine(num_pos_feats=256,
                                          temperature=10000,
                                          normalize=True,
                                          scale=None)
neck = FpnNeck(position_encoding=position_encoding,
               d_model=256, backbone_channel_list=[1152, 576, 288, 144],
               fpn_top_down_levels=[2, 3], fpn_interp_model="nearest")
trunk = Hiera(embed_dim=144, num_heads=2, stages=[2, 6, 36, 4],
              global_att_blocks=[23, 33, 43], window_spec=[8, 4, 16, 8],
              window_pos_embed_bkg_spatial_size=[7, 7])
image_encoder = ImageEncoder(neck=neck, trunk=trunk, scalp=1)

self_attention = RoPEAttention(rope_theta=10000.0, feat_sizes=[32, 32], embedding_dim=256,
                               num_heads=1, downsample_rate=1, dropout=0.1)
cross_attention = RoPEAttention(rope_theta=10000.0, feat_sizes=[32, 32],
                                rope_k_repeat=True, embedding_dim=256, num_heads=1,
                                downsample_rate=1, dropout=0.1, kv_in_dim=64)
layer = MemoryAttentionLayer(activation="relu", dim_feedforward=2048, dropout=0.1,
                             pos_enc_at_attn=False, self_attention=self_attention,
                             d_model=256, pos_enc_at_cross_attn_keys=True,
                             pos_enc_at_cross_attn_queries=False,
                             cross_attention=cross_attention)
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True,
                                   layer=layer, num_layers=4)

position_encoding_memory = PositionEmbeddingSine(num_pos_feats=64, normalize=True,
                                                 scale=None, temperature=10000)
mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
layer_fuser = CXBlock(dim=256, kernel_size=7, padding=3, layer_scale_init_value=1e-6,
                      use_dwconv=True)
fuser = Fuser(layer=layer_fuser, num_layers=2)
memory_encoder = MemoryEncoder(out_dim=64, position_encoding=position_encoding_memory,
                               mask_downsampler=mask_downsampler, fuser=fuser)

predictor = SAM2VideoPredictor(image_encoder=image_encoder,
                               memory_attention=memory_attention,
                               memory_encoder=memory_encoder,
                               num_maskmem=7, image_size=1024,
                               # apply scaled sigmoid on mask logits for memory encoder,
                               # and directly feed input mask as output mask
                               sigmoid_scale_for_mem_enc=20.0,
                               sigmoid_bias_for_mem_enc=10.0,
                               use_mask_input_as_output_without_sam=True,
                               # Memory
                               directly_add_no_mem_embed=True,
                               # use high-resolution feature map in the SAM mask decoder
                               use_high_res_features_in_sam=True,
                               # 3 masks on the first click on initial conditioning frames
                               multimask_output_in_sam=True,
                               # SAM heads
                               iou_prediction_use_sigmoid=True,
                               # cross-attend to object pointers from other frames
                               # (based on SAM output tokens) in the encoder
                               use_obj_ptrs_in_encoder=True,
                               add_tpos_enc_to_obj_ptrs=False,
                               only_obj_ptrs_in_the_past_for_eval=True,
                               # object occlusion prediction
                               pred_obj_scores=True,
                               pred_obj_scores_mlp=True,
                               fixed_no_obj_ptr=True,
                               # multimask tracking settings
                               multimask_output_for_tracking=True,
                               use_multimask_token_for_obj_ptr=True,
                               multimask_min_pt_num=0, multimask_max_pt_num=1000,
                               use_mlp_for_obj_ptr_proj=True,
                               # Compilation flag
                               compile_image_encoder=False)
sd = torch.load('model_checkpoints/sam2_hiera_large.pt', map_location='cpu')['model']
missing_keys, unexpected_keys = predictor.load_state_dict(sd)
print(f"missing keys: {missing_keys}; unexpected keys: {unexpected_keys}")
predictor = predictor.to(DEVICE)


dataset = FLAREDataset3D(data_dir="data/FLARE", image_size=(125, 1024, 1024), warmstart=True)
CLASS_ID = 0
SAMPLE_ID = 0
sample = dataset[SAMPLE_ID]
print(f"image: {sample['image'].shape}; min: {sample['image'].min()}; max: {sample['image'].max()}")
print(f"masks: {sample['masks'].shape}")
print(f"coords: {sample['point_coords'].shape}")
print(f"labels: {sample['point_labels'].shape}, {sample['point_labels']}")

In [None]:

images = sample['image'].squeeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1)

# images = torch.flip(images, dims=[0])

inference_state = {
    "images": images.to(DEVICE),
    "num_frames": len(images),
    "offload_video_to_cpu": True,
    "offload_state_to_cpu": True,
    "video_height": images.shape[2],
    "video_width": images.shape[3],
    "device": DEVICE,
    "storage_device": DEVICE,
    "point_inputs_per_obj": {},
    "mask_inputs_per_obj": {},
    "cached_features": {},
    "constants": {},
    "obj_id_to_idx": OrderedDict(),
    "obj_idx_to_id": OrderedDict(),
    "obj_ids": [],
    "output_dict": {
        "cond_frame_outputs": {},
        "non_cond_frame_outputs": {},
    },
    "output_dict_per_obj": {},
    "temp_output_dict_per_obj": {},
    "consolidated_frame_inds": {
        "cond_frame_outputs": set(),
        "non_cond_frame_outputs": set()
    },
    "tracking_has_started": False,
    "frames_already_tracked": {}
}

predictor.reset_state(inference_state)


predictor._get_image_feature(inference_state, frame_idx=0, batch_size=1)
# predictor.reset_state(inference_state)

ann_obj_id = CLASS_ID
# ann_frame_idx = int(sample['point_coords'][ann_obj_id, 0, 0].numpy())
# mask = sample['masks'][CLASS_ID, ann_frame_idx].numpy()
# print(f"ann_frame_idx: {ann_frame_idx}")

# points = sample['point_coords'][ann_obj_id, :, 1:].flip(-1).numpy()
# labels = sample['point_labels'][ann_obj_id].numpy()
# print(f"points: {points}; labels: {labels}")


# img = images[ann_frame_idx].permute(1, 2, 0).numpy()
# img_min, img_max = img.min(), img.max()
# img = (img - img_min) / (img_max - img_min)
# plt.imshow(img)
# plt.contour(mask, levels=[0.1], colors='b', linewidths=.5)
# plt.scatter(points[0, 0], points[0, 1], c="green" if labels[0] else "red")
# plt.show()

indices = sample['point_coords'][ann_obj_id, :, 0].numpy().astype(int)
# print(f"indices: {indices}")

plt.figure(figsize=(15, 5))
for i, idx in enumerate(indices):
    img = images[idx].permute(1, 2, 0).numpy()
    img_min, img_max = img.min(), img.max()
    img = (img - img_min) / (img_max - img_min)
    plt.subplot(1, 3, i+1)
    plt.imshow(img)

    mask = sample['masks'][CLASS_ID, idx].numpy()
    plt.contour(mask, levels=[0.1], colors='b', linewidths=.5)

    point = sample['point_coords'][ann_obj_id, i:i+1, 1:].flip(-1).numpy()
    label = sample['point_labels'][ann_obj_id, i:i+1].numpy()
    plt.scatter(point[0, 0], point[0, 1], c="green" if label else "red", marker='+')
    
    plt.axis('off')
    plt.title(f"frame {idx}")
plt.show()


In [None]:
eigs = dict()
for i, idx in enumerate(indices[::-1]):
    points = sample['point_coords'][ann_obj_id, i:i+1, 1:].flip(-1).numpy()
    labels = sample['point_labels'][ann_obj_id, i:i+1].numpy()
    _, out_obj_ids, out_mask_logits, multimask = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=idx,
        obj_id=ann_obj_id,
        points=points,
        labels=labels,
    )
    eigs[idx] = get_eig_from_probs(torch.sigmoid(multimask).cpu())

    

video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=True):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i]).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i]).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }


for key in inference_state["output_dict"]["non_cond_frame_outputs"].keys():
    eig = get_eig_from_probs(torch.sigmoid(inference_state["output_dict"]["non_cond_frame_outputs"][key]["pred_multimasks"]))
    eigs[key] = eig.cpu()

eigs = [eigs[key] for key in sorted(list(eigs.keys()))]
eigs = torch.stack(eigs, axis=2)
eigs = eigs.squeeze(0)

In [None]:
masks = {key: val["pred_masks"].detach().cpu()
         for key, val in inference_state['output_dict_per_obj'][ann_obj_id]["cond_frame_outputs"].items()}
for key, val in inference_state["output_dict"]["non_cond_frame_outputs"].items():
    masks[key] = val["pred_masks"].detach().cpu()
masks = [torch.sigmoid(masks[k]) for k in sorted(list(masks.keys()))]
masks = torch.cat(masks, dim=1)
masks = F.interpolate(masks, size=(1024, 1024), mode='bilinear', align_corners=False).numpy()[0]
print(masks.shape)

In [30]:

# _, out_obj_ids, out_mask_logits, multimask = predictor.add_new_points_or_box(
#     inference_state=inference_state,
#     frame_idx=ann_frame_idx,
#     obj_id=ann_obj_id,
#     points=points,
#     labels=labels,
# )
# eig = get_eig_from_probs(torch.sigmoid(multimask))

# plt.subplot(1, 2, 1)
# plt.imshow(torch.sigmoid(out_mask_logits)[0][0].detach().cpu().numpy(), cmap='gray', vmin=0, vmax=1)
# plt.subplot(1, 2, 2)
# plt.imshow(eig[0][0].detach().cpu().numpy())
# plt.show()

In [None]:
plt.figure(figsize=(25, 5))
for i, mask in enumerate(masks):
    
    plt.subplot(5, 25, i + 1)
    plt.imshow(mask, cmap='gray', vmin=0, vmax=1)
    plt.axis('off')

plt.show()


In [None]:
plt.figure(figsize=(12, 16))
for i, prompt_idx in enumerate(indices):
    for idx in range(prompt_idx-2, prompt_idx+3):
        plt.subplot(6, 5, i*10 + idx - prompt_idx + 3)
        plt.imshow(masks[idx], cmap='gray', vmin=0, vmax=1)
        plt.axis('off')
        if idx == prompt_idx:
            plt.scatter(sample["point_coords"][ann_obj_id, i, 2], sample["point_coords"][ann_obj_id, i, 1], c="#FF00FF", marker='+')
            plt.title(f"Prompted Slice: {idx}")
        else:
            plt.title(f"Slice: {idx}")


        plt.subplot(6, 5, i*10 + 5 + idx - prompt_idx + 3)
        plt.imshow(eigs[0, idx].detach().cpu().numpy(), vmin=0, vmax=.4771)
        plt.axis('off')
        if idx == prompt_idx:
            plt.scatter(sample["point_coords"][ann_obj_id, i, 2], sample["point_coords"][ann_obj_id, i, 1], c="#FF00FF", marker='+')
            plt.title(f"Prompted Slice: {idx}")
        else:
            plt.title(f"Slice: {idx}")

In [None]:
max_eig_slice, coord_i, coord_j = get_coords_of_tensor_max(eigs[None]).detach().cpu().numpy()[0, 0]

plt.figure(figsize=(12, 12))
for idx in range(max_eig_slice-4, max_eig_slice+5):
    # print(f"frame {idx}")
    plt.subplot(3, 3, idx - max_eig_slice+5)
    plt.imshow(eigs[0, idx].detach().cpu().numpy(), vmin=0, vmax=.4771)
    plt.axis('off')

    mask = sample["masks"][CLASS_ID, idx].numpy()
    plt.contour(mask, levels=[0.1], colors='b', linewidths=.5)

    if idx == max_eig_slice:
        plt.scatter(coord_j, coord_i, c="#FF00FF", marker='+')
        plt.title(f"Max EIG Slice: {idx}")
    else:
        plt.title(f"Slice: {idx}")

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp

alpha = .998
beta = .999
prob = sp.stats.beta.pdf(np.linspace(0, 1, 100), alpha, beta)
plt.plot(prob)
plt.xticks(np.linspace(0, 100, 11), np.linspace(0, 1, 11).round(2))
plt.xlabel
plt.show()

res = 50
alphas = np.zeros((res, res))
betas = np.zeros((res, res))
min_, max_ = 0, 5
values = np.linspace(min_, max_, res)

for i in range(res):
    alphas[:, i] = values[i]
    betas[res-1-i, : ] = values[i]
plt.subplot(1, 2, 1)
plt.imshow(alphas)
plt.xlabel("$\\alpha$")
plt.subplot(1, 2, 2)
plt.imshow(betas)
plt.ylabel("$\\beta$")
plt.show()

entropy = sp.stats.beta.entropy(alphas, betas)

plt.imshow(entropy)
plt.xlabel("$\\alpha$")
plt.ylabel("$\\beta$")
plt.xticks(np.linspace(0, res, 11), np.linspace(int(min_), max_, 11).round(2))
plt.yticks(np.linspace(0, res, 11), np.linspace(max_, int(min_), 11).round(2))
plt.colorbar()
plt.show()

# plt.imshow(abs(alphas/(alphas+betas)-.5))
plt.imshow(alphas/(alphas+betas))
plt.colorbar()
# plt.imshow(alphas == betas, cmap='Reds', alpha=.1)
plt.plot([0, res-1], [res-1, 0], 'r-', alpha=.3, label="$\\mathbb{E}(\\theta)=0.5$")
plt.xlabel("$\\alpha$")
plt.ylabel("$\\beta$")
plt.xticks(np.linspace(0, res, 11), np.linspace(int(min_), max_, 11).round(2))
plt.yticks(np.linspace(0, res, 11), np.linspace(max_, int(min_), 11).round(2))
plt.title("Expected Value")
plt.legend()
plt.show()

In [None]:
alphas/(alphas+betas)

In [None]:
entropy