In [1]:
from pathlib import Path

import torch

from dynamic_fusion.network_trainer.configuration import TrainerConfiguration
from dynamic_fusion.network_trainer.network_loader import NetworkLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dynamic_fusion.scripts.test_e2vid_data import get_events_from_txt
import pandas as pd
directory = Path("../data/raw/event_camera_dataset/dynamic_6dof")


events, _, _ = get_events_from_txt(directory / "events.txt", first_row_is_image_shape=False, max_t = 5)


In [3]:
import numpy as np
image_metadata = pd.read_csv(directory / "images.txt", delimiter=" ", header=None, names = ["timestamp", "path"], dtype={"timestamp": np.float64, "path": str})

events['frame_bin'] = pd.cut(events.timestamp, image_metadata.timestamp, labels=False, right = False)

In [4]:
from dynamic_fusion.data_generator.configuration import EventDiscretizerConfiguration
from dynamic_fusion.data_generator.event_discretizer import EventDiscretizer

THRESHOLD = 1
config = EventDiscretizerConfiguration(number_of_temporal_bins=1, number_of_temporal_sub_bins_per_bin=2)
discretizer = EventDiscretizer(config, max_timestamp=1.)

In [5]:
from dynamic_fusion.utils.discretized_events import DiscretizedEvents


discretized_frames = []

for frame_bin, events_in_frame in events.groupby("frame_bin"):
    timestamp_range = (image_metadata.timestamp[frame_bin], image_metadata.timestamp[frame_bin + 1])

    assert np.all((events_in_frame.timestamp < timestamp_range[1]) & (events_in_frame.timestamp >= timestamp_range[0]))
    events_in_frame.timestamp -= timestamp_range[0]
    events_in_frame.timestamp /= timestamp_range[1] - timestamp_range[0]
    discretized_frame = discretizer._discretize_events(events_in_frame, THRESHOLD, (180, 240))
    discretized_frames.append(discretized_frame)

In [7]:
MODEL = "e2vid_exp"
MODEL = "e2vid_exp_uncertainty"

if MODEL == "e2vid_exp":
    CHECKPOINT_DIR = Path("../runs/0323-new-dataset/01_st-un_st-interp_st-up/subrun_00")
elif MODEL == "e2vid_exp_uncertainty":
    CHECKPOINT_DIR = Path("../runs/0323-new-dataset/00_st-un_st-interp_st-up_uncertainty-lpips/subrun_00")

device = torch.device("cuda")

CHECKPOINT_NAME = "latest_checkpoint.pt"


config_path = CHECKPOINT_DIR / "config.json"
with config_path.open("r", encoding="utf8") as f:
    json_config = f.read()
# Parse the JSON string back into a Configuration instance
config = TrainerConfiguration.parse_raw(json_config)
# Load network
config.network_loader.decoding_checkpoint_path = CHECKPOINT_DIR / CHECKPOINT_NAME
config.network_loader.encoding_checkpoint_path = CHECKPOINT_DIR / CHECKPOINT_NAME
encoder, decoder = NetworkLoader(config.network_loader, config.shared).run()
encoder = encoder.to(device)
decoder = decoder.to(device)

shared=SharedConfiguration(sequence_length=16, resume=True, use_events=True, use_mean=True, use_std=True, use_count=True, implicit=True, spatial_unfolding=True, temporal_unfolding=True, temporal_interpolation=True, spatial_upscaling=True, predict_uncertainty=True, use_aps_for_all_frames=False, use_initial_aps_frame=False, min_allowed_max_of_mean_polarities_over_times=0.05) data_handler=DataHandlerConfiguration(augmentation=AugmentationConfiguration(network_image_size=(56, 56)), dataset=DatasetConfiguration(dataset_directory=PosixPath('/mnt/train/2subbins'), threshold=1.35, augmentation_tries=2, video_tries=5, max_upscaling=4.0), test_dataset_directory=PosixPath('/mnt/test/2subbins'), test_scale_range=(1, 4), batch_size=2, num_workers=1) network_loader=NetworkLoaderConfiguration(encoding=EncodingNetworkConfiguration(input_size=2, hidden_size=24, output_size=16, kernel_size=3), encoding_checkpoint_path=None, decoding=DecodingNetworkConfiguration(hidden_size=128, hidden_layers=4), decodin

In [9]:
from dynamic_fusion.scripts.test_e2vid_data import run_reconstruction
from dynamic_fusion.utils.discretized_events import DiscretizedEvents


discretized_events = DiscretizedEvents.stack_temporally(discretized_frames)
reconstruction = run_reconstruction(encoder, decoder, discretized_events, device, config.shared)

mins = reconstruction.min(axis=(2,3), keepdims=True)[:,0]
maxs = reconstruction.max(axis=(2,3), keepdims=True)[:,0]

reconstruction_norm = reconstruction.copy()

reconstruction_norm[:,0] = (reconstruction_norm[:,0] - mins) / (maxs - mins)


107 / 108

In [10]:
reconstruction.shape

(106, 2, 180, 240)

In [19]:
mins = reconstruction.min(axis=(2,3), keepdims=True)[:,0]
maxs = reconstruction.max(axis=(2,3), keepdims=True)[:,0]

reconstruction_norm = reconstruction.copy()

reconstruction_norm[:,0] = (reconstruction_norm[:,0] - mins) / (maxs - mins)

In [None]:
import cv2
from skimage.color import rgb2gray
import numpy as np
from skimage.metrics import structural_similarity as ssim
from dynamic_fusion.utils.loss import LPIPS
from tqdm import tqdm

images = []
for path in image_metadata.path:
    image = cv2.imread(str(directory / path))
    gray = rgb2gray(image)
    images.append(gray)

images_np = np.stack(images, axis=0)


ssim_vals = [ssim(reconstruction_norm[i,0], images[i], data_range=1) for i in range(len(reconstruction_norm))]
print(sum(ssim_vals) / len(ssim_vals))

from matplotlib import pyplot as plt

plt.plot(ssim_vals)
plt.show()


lpips_vals = []
lpips = LPIPS().to(device)

for i in tqdm(range(len(reconstruction_norm))):
    recon_tensor = torch.tensor(reconstruction_norm[i, 0:1][None]).to(device).float()
    image_tensor = torch.tensor(images[i][None, None]).to(device).float()
    lpips_vals.append(lpips(recon_tensor, image_tensor).item())

from matplotlib import pyplot as plt

plt.plot(lpips_vals)
plt.show()

In [24]:
from dynamic_fusion.utils.network import to_numpy
from dynamic_fusion.utils.plotting import discretized_events_to_cv2_frame, add_text_at_row
from dynamic_fusion.utils.visualization import create_red_blue_cmap, img_to_colormap

# Create a figure and a set of subplots
FRAMES = 40
output_dir = Path('../results/event_camera_dataset_test/')
output_dir.mkdir(parents=True,exist_ok=True)
SPEED = 0.5
size = discretized_events.event_polarity_sum.shape
out = cv2.VideoWriter(f"{str(output_dir)}/{directory.name}.mp4", cv2.VideoWriter.fourcc(*"mp4v"), int(len(discretized_events.event_polarity_sum)/events.timestamp.max()*SPEED), (size[-1]*3, size[-2]))
colored_event_polarity_sums = img_to_colormap(to_numpy(discretized_events.event_polarity_sum.sum(dim=1)), create_red_blue_cmap(501))
for I in range(FRAMES):
    event_frame = discretized_events_to_cv2_frame(colored_event_polarity_sums[I], discretized_events.event_count[I])
    recon_frame = cv2.cvtColor((reconstruction[I, 0]*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
    gt_frame = cv2.cvtColor((images[I]*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)

    add_text_at_row(recon_frame, f"LPIPS={lpips_vals[I]:.2f}", 0)
    add_text_at_row(recon_frame, f"SSIM={ssim_vals[I]:.2f}", 1)

    frame = np.concatenate([event_frame, recon_frame, gt_frame], axis=1)
    out.write(frame)

out.release()