In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from os.path import join, exists
import numpy as np
import PIL
import matplotlib.pyplot as plt
from glob import glob
import cv2
from natsort import natsorted
# from tqdm import tqdm_notebook
import tqdm
import torch
import torchvision
import pandas as pd
import decord

import shared_utils as su

from transformers import SamModel
from finetuning_sam.lightning_models.sam import SAMLightningModule

from finetuning_sam.datasets.liquid_segmentation import load_dataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
wandb_run_id = "3y2x11qo"
wandb_run_dir = f"../../audio-visual-test-of-time/{wandb_run_id}"
ckpt_name = "epoch=19-step=8920.ckpt"
ckpt_path = join(wandb_run_dir, "checkpoints", ckpt_name)
assert exists(ckpt_path)
ckpt =  torch.load(ckpt_path)
ckpt.keys()

In [None]:
# Load SAM model
sam_model = SamModel.from_pretrained("facebook/sam-vit-base")
# make sure we only compute gradients for mask decoder
for name, param in sam_model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

# Load module
module = SAMLightningModule(sam_model)

# Load checkpoint
module.load_state_dict(ckpt["state_dict"])

In [None]:
module.eval();
module = module.to(device)

### Test on validation samples

In [None]:
ds, processor = load_dataset("val", preload=True, return_processor=True)
len(ds)

In [None]:
def visualize_inference(image, prompt):
    """Returns visualization for inference on a single example."""

    # Create input visualization
    show_input = su.viz.add_bbox_on_image(
        image, prompt, color="yellow",
    )

    # prepare image + box prompt for the model
    inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)

    # forward pass
    with torch.no_grad():
        outputs = module.sam_model(**inputs, multimask_output=False)

    # apply sigmoid
    seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
    
    # convert soft mask to hard mask
    seg_prob = seg_prob.cpu().numpy().squeeze()
    seg = (seg_prob > 0.5).astype(np.uint8)

    # Create output visualization
    show_output = su.viz.add_mask_on_image(
        image,
        mask=su.viz.alpha_mask_to_pil_image(seg)
    )

    return su.viz.concat_images([show_input, show_output])

In [None]:
# Pick some random indices
indices = np.random.randint(0, len(ds), 12)

# Run inference for each index
outputs = []
for i in su.log.tqdm_iterator(indices, desc="Running inference"):

    # Load item
    item = ds.data[i]
    
    # Load inputs for inference
    image = item["image"]
    prompt = su.viz.mask_to_bounding_box(item["cup_mask"], perturbation=0)

    # Get output
    output = visualize_inference(image=image, prompt=prompt)
    outputs.append(output)

In [None]:
# Plot results
su.viz.show_grid_of_images(
    outputs, n_cols=4, figsize=(4 * 4, 2 * 3), subtitles=indices,
)

### Test on samples from `PouringIROS2019`

In [None]:
data_root = "/ssd/pbagad/datasets/"
data_dir = join(data_root, "PouringIROS2019")
video_dir = join(data_dir, "resized_data")
annot_dir = join(data_dir, "annotations")

# Load bounding box annotations
annot_path = join(annot_dir, "water_container_detections-v1.pt")
assert exists(annot_path)
annotations = torch.load(annot_path)

# Load data
csv_path = join(data_dir, "metadata/all_liquids_in_transparent_containers.csv")
assert exists(csv_path)
df = pd.read_csv(csv_path)

# Add video path
row_id = "video_id"
def _get_video_path(row):
    """Returns the path to the video file."""
    video_path = join(video_dir, row[row_id] + ".mp4")
    assert exists(video_path), "video_path does not exist."
    return video_path
df["video_path"] = df.apply(_get_video_path, axis=1)
df = df[df["video_path"].apply(os.path.exists)]
df.shape

In [None]:
def load_frames_from_a_video(i, n_frames=12, imsize=256, crop=True):
    row = df.iloc[i].to_dict()

    video_id = row["video_id"]
    video_path = row["video_path"]
    vr = decord.VideoReader(video_path)
    
    frame_indices = np.linspace(0, len(vr) - 1, n_frames, dtype=int)
    frames = vr.get_batch(frame_indices).asnumpy()
    frames = [PIL.Image.fromarray(f) for f in frames]

    # Crop frames (only if annotations are available)
    if (video_id in annotations) and crop:
        box = annotations[video_id]
        frames = [f.crop(list(box)) for f in frames]

    # Resize for SAM compatibility
    frames = [f.resize((imsize, imsize)) for f in frames]

    # Define prompts (entire size of the image since we already cropped)
    prompts = [[0, 0, imsize, imsize] for _ in range(len(frames))]

    return frames, prompts

**Samples from the same video**

In [None]:
frames, prompts = load_frames_from_a_video(130)
su.viz.show_grid_of_images(frames, n_cols=len(frames), figsize=(len(frames) * 2, 2))

In [None]:
# Run inference for each index
outputs = []
for image, prompt in zip(frames, prompts):
    outputs.append(visualize_inference(image=image, prompt=prompt))

In [None]:
# Plot results
su.viz.show_grid_of_images(
    outputs, n_cols=4, figsize=(4 * 4, 2 * 3), 
)

Can we use our idea to get better liquid segmentations? Like even if we get output right at time $t$, ideally, it should translate to excellent outputs throughout.

### Test on YouTube videos

In [None]:
data_dir = os.path.join(data_root, "Viscaural/v25")
video_dir = os.path.join(data_dir, "clips")
annot_dir = os.path.join(data_dir, "annotations")

# Load data
csv_path = os.path.join(
    data_dir,
    "splits/download_2023-05-21_11-20-59-sliding_predictions_nms_top2000-clean315-clean55.csv",
)
assert exists(csv_path)
df = pd.read_csv(csv_path)

# Load annotations
annot_path = os.path.join(
    data_dir,
    "annotations/water_glass_detections-first_frame-v1-0-311.pt",
)
assert exists(annot_path)
annotations = torch.load(annot_path)

# Add video path
row_id = "item_id"
def _get_video_path(row):
    """Returns the path to the video file."""
    video_path = join(video_dir, row[row_id] + ".mp4")
    assert exists(video_path), "video_path does not exist."
    return video_path
df["video_path"] = df.apply(_get_video_path, axis=1)
df = df[df["video_path"].apply(os.path.exists)]
df.shape

In [None]:
def load_frames_from_a_video(i, n_frames=12, imsize=256, crop=True):
    row = df.iloc[i].to_dict()

    video_id = row["item_id"]
    video_path = row["video_path"]
    vr = decord.VideoReader(video_path)
    
    frame_indices = np.linspace(0, len(vr) - 1, n_frames, dtype=int)
    frames = vr.get_batch(frame_indices).asnumpy()
    frames = [PIL.Image.fromarray(f) for f in frames]

    # Crop frames (only if annotations are available)
    if (video_id in annotations) and crop:
        box = annotations[video_id]
        box = [v / 2. for v in box]
        frames = [f.crop(list(box)) for f in frames]

    # Resize for SAM compatibility
    frames = [f.resize((imsize, imsize)) for f in frames]

    # Define prompts (entire size of the image since we already cropped)
    prompts = [[0, 0, imsize, imsize] for _ in range(len(frames))]

    return frames, prompts

**Samples from the same video**

In [None]:
frames, prompts = load_frames_from_a_video(8)
su.viz.show_grid_of_images(frames, n_cols=len(frames), figsize=(len(frames) * 2, 2))

In [None]:
# Run inference for each index
outputs = []
for image, prompt in zip(frames, prompts):
    outputs.append(visualize_inference(image=image, prompt=prompt))

In [None]:
# Plot results
su.viz.show_grid_of_images(
    outputs, n_cols=4, figsize=(4 * 4, 2 * 3), 
)

### Legacy code

In [None]:
# Select index 
i = 0
item = ds.data[i]

# Load inputs for inference
image = item["image"]
prompt = su.viz.mask_to_bounding_box(item["cup_mask"], perturbation=0)
show_input = su.viz.add_bbox_on_image(
    image, prompt, color="yellow",
)
show_input

In [None]:
# prepare image + box prompt for the model
inputs = processor(image, input_boxes=[[prompt]], return_tensors="pt").to(device)

In [None]:
# forward pass
with torch.no_grad():
    outputs = module.sam_model(**inputs, multimask_output=False)

In [None]:
# apply sigmoid
seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))

# convert soft mask to hard mask
seg_prob = seg_prob.cpu().numpy().squeeze()
seg = (seg_prob > 0.5).astype(np.uint8)

In [None]:
show_output = su.viz.add_mask_on_image(
    image,
    mask=su.viz.alpha_mask_to_pil_image(seg)
)
su.viz.concat_images([show_input, show_output])