In [None]:
import sys
import torch
import logging
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib import colormaps as cm

from f3 import init_event_model, load_weights_ckpt
from f3.utils import (setup_torch, ev_to_frames, plot_patched_features,
                      smooth_time_weighted_rgb_encoding, BaseExtractor)
from f3.tasks.depth.utils import init_depth_model, load_depth_weights, get_disparity_image
from f3.tasks.optical_flow.utils import init_flow_model, load_flow_weights, flow_viz_np
from f3.tasks.segmentation.utils import init_segmentation_model, load_segmentation_weights, cityscapes_palette

logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s - %(message)s")
logger = logging.getLogger(__name__)

setup_torch(cudnn_benchmark=False)

### Download Pretrained Models and a Minimal M3ED Sequence

In [None]:
# Download pretrained F3 model weights
!bash scripts/download/download_models.sh pretrained_models/f3 f3
!bash scripts/download/download_models.sh pretrained_models/seg seg
!bash scripts/download/download_models.sh pretrained_models/depth depth
!bash scripts/download/download_models.sh pretrained_models/flow flow

# Download a small M3ED sequence and generate the timestamps file
!python3 scripts/download/download_m3ed.py --vehicle car\
                                           --environment urban\
                                           --sequence car_urban_day_penno_small_loop\
                                           --output_dir data\
                                           --to_download data
!python3 scripts/generate_ts.py --data_h5 data/car_urban_day_penno_small_loop/car_urban_day_penno_small_loop_data.h5

In [None]:
################ Path where Models are downloaded ################
BASE_F3_MODEL_PATH = "pretrained_models/f3/"
BASE_SEG_MODEL_PATH = "pretrained_models/seg/"
BASE_DEPTH_MODEL_PATH = "pretrained_models/depth/"
BASE_FLOW_MODEL_PATH = "pretrained_models/flow/"

################ Model Names to Load ################
F3_MODEL_NAME = "patchff_fullcardaym3ed_small_20ms.pth"
F3_CONFIG_NAME = "confs/ff/modeloptions/1280x720x20_patchff_ds1_small.yml"
SEG_MODEL_NAME = "segformer_b3_fullm3ed_800x600x20"
DEPTH_MODEL_NAME = "dav2b_fullm3ed_pseudo_518x518x20"
FLOW_MODEL_NAME = "optflow_trainm3ed_20msff_pyr5_28k"

################ Path to a M3ED Sequence ################
H5PATH = "data/car_urban_day_penno_small_loop/car_urban_day_penno_small_loop_data.h5"
TSPATH = "data/car_urban_day_penno_small_loop/50khz_car_urban_day_penno_small_loop_data.npy"

In [None]:
extractor = BaseExtractor(H5PATH, TSPATH, w=1280, h=720, time_ctx=20000,
                          time_pred=20000, bucket=1000, max_numevents_ctx=800000,
                          randomize_ctx=False, camera="left")

## Initializing and Loading Pretrained Models

### Loading a pretrained F<sup>3</sup>

In [None]:
eventff_model = init_event_model(F3_CONFIG_NAME, return_feat=True, return_logits=True).cuda()
eventff_model = torch.compile(
    eventff_model,
    fullgraph=False,
    backend="inductor",
    options={
        "epilogue_fusion": True,
        "max_autotune": True,
    },
)
epoch, loss, acc = load_weights_ckpt(eventff_model, Path(BASE_F3_MODEL_PATH) / F3_MODEL_NAME)
eventff_model.eval()
logger.info(f"Loaded F3 model ckpt from {F3_MODEL_NAME} at epoch {epoch} with loss {loss} and acc {acc}")

### Loading a pretrained F<sup>3</sup> - SegFormer B3

In [None]:
seg_model = init_segmentation_model(Path(BASE_SEG_MODEL_PATH) / SEG_MODEL_NAME / "segmentation_config.yml").cuda()
seg_model.eventff = torch.compile(seg_model.eventff, fullgraph=False)
epoch, loss, acc, miou = load_segmentation_weights(seg_model, Path(BASE_SEG_MODEL_PATH) / SEG_MODEL_NAME / "best_miou.pth")
logger.info(f"Loaded Segmentation ckpt from {SEG_MODEL_NAME} at " +
            f"Epoch: {epoch}, Loss: {loss}, Acc: {acc}, MIoU: {miou}")
seg_model.eval()

### Loading a pretrained F<sup>3</sup> - DepthAnything V2 Base

In [None]:
depth_model = init_depth_model(Path(BASE_DEPTH_MODEL_PATH) / DEPTH_MODEL_NAME / "depth_config.yml").cuda()
depth_model.eventff = torch.compile(depth_model.eventff, fullgraph=False)
epoch, results = load_depth_weights(depth_model, Path(BASE_DEPTH_MODEL_PATH) / DEPTH_MODEL_NAME / "best.pth")
logger.info(f"Loaded Monocular Depth ckpt from {DEPTH_MODEL_NAME} at " +
            f"Epoch: {epoch}, Results: {results}")
depth_model.eval()

### Loading a pretrained F<sup>3</sup> - Flow Model

In [None]:
optflow_model = init_flow_model(Path(BASE_FLOW_MODEL_PATH) / FLOW_MODEL_NAME / "flow_config.yaml").cuda()
optflow_model.eventff = torch.compile(optflow_model.eventff, fullgraph=False)
optflow_model.flowhead = torch.compile(optflow_model.flowhead, fullgraph=False)
epoch, loss = load_flow_weights(optflow_model, Path(BASE_FLOW_MODEL_PATH) / FLOW_MODEL_NAME / "last.pth")
logger.info(f"Loaded Optical Flow ckpt from {FLOW_MODEL_NAME} at " +
            f"Epoch: {epoch}, Loss: {loss}")
optflow_model.eval()

### Run all models for inference

In [None]:
IMG_IDX = 700

t0 = extractor.hdf5_file["ovc/ts"][IMG_IDX]
img = extractor.hdf5_file["ovc/rgb/data"][IMG_IDX]

# get events in fixed time window
ctx, totcnt = extractor.get_ctx_fixedtime(t0)    
ctx, totcnt = ctx.cuda(), torch.tensor([totcnt]).cuda()

# make event frame for visualization
events_frame = ev_to_frames(ctx, totcnt, 1280, 720)[0].cpu().numpy().T

cmap = cm.get_cmap("magma")

# Run all models
with torch.no_grad():
    # F3 forward pass
    logits, feat = eventff_model(ctx, totcnt)
    
    # Segmentation forward pass
    seg_pred, _ = seg_model(ctx, totcnt)
    seg_pred = seg_pred.argmax(1).cpu().numpy()

    # Depth forward pass
    depth_pred, _ = depth_model.infer_image(ctx, totcnt)

    # Optical Flow forward pass
    flow_pred, _ = optflow_model(ctx, totcnt)
    flow_pred = flow_pred.permute(0, 2, 3, 1).cpu().numpy()

pca, _ = plot_patched_features(feat[0], plot=False)
logits_rgb = smooth_time_weighted_rgb_encoding((torch.sigmoid(logits) > 0.5).cpu().numpy())[0]
seg_img = cityscapes_palette(seg_model.num_labels)[seg_pred[0]].astype(np.uint8)
depth_img = get_disparity_image(depth_pred, torch.ones_like(depth_pred, dtype=torch.bool), cmap)
flow_img = flow_viz_np(flow_pred[0], norm=True) * (events_frame == 255)[..., None]

In [None]:
# Add optical flow to the visualization
fig, axes = plt.subplots(2, 4, figsize=(24, 8))

# Original RGB image
axes[0, 0].imshow(img[..., ::-1])
axes[0, 0].set_title('Original RGB Image', fontsize=20, fontweight='bold')
axes[0, 0].axis('off')

# Events frame
axes[0, 1].imshow(events_frame, cmap='hot')
axes[0, 1].set_title('Events Frame', fontsize=20, fontweight='bold')
axes[0, 1].axis('off')

# PCA features
axes[0, 2].imshow(pca.transpose(1, 0, 2))
axes[0, 2].set_title('PCA Features', fontsize=20, fontweight='bold')
axes[0, 2].axis('off')

# Logits RGB
axes[0, 3].imshow(logits_rgb.transpose(1, 0, 2))
axes[0, 3].set_title('Logits RGB', fontsize=20, fontweight='bold')
axes[0, 3].axis('off')

# Segmentation prediction
axes[1, 0].imshow(seg_img[..., ::-1])
axes[1, 0].set_title('Segmentation Prediction', fontsize=20, fontweight='bold')
axes[1, 0].axis('off')

# Depth image
axes[1, 1].imshow(depth_img)
axes[1, 1].set_title('Depth Prediction', fontsize=20, fontweight='bold')
axes[1, 1].axis('off')

# Optical flow
axes[1, 2].imshow(flow_img[..., ::-1])
axes[1, 2].set_title('Optical Flow', fontsize=20, fontweight='bold')
axes[1, 2].axis('off')

# Remove empty subplot
axes[1, 3].axis('off')

plt.tight_layout()
plt.show()