In [18]:
from google.colab import drive
drive.mount('/content/drive')

project_dir = '/content/drive/My Drive/Colab Notebooks/matsuoo/dl/event_camera_repo'
%cd {project_dir}

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/Colab Notebooks/matsuoo/dl/event_camera_repo


In [2]:
!pip install hydra-core omegaconf hdf5plugin h5py numba imageio imageio-ffmpeg tqdm torchvision --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.8/41.8 MB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m63.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone


In [3]:
import torch
import hydra
from omegaconf import DictConfig
from torch.utils.data import DataLoader
import random
import numpy as np
from enum import Enum, auto
from tqdm import tqdm
from pathlib import Path
from typing import Dict, Any
import os
import time

import math
from pathlib import PurePath
from typing import Tuple
import cv2
import hdf5plugin
import h5py
from numba import jit
import imageio
imageio.plugins.freeimage.download()
import imageio.v3 as iio
from torchvision.transforms import RandomCrop
from torchvision import transforms as tf
from torch.utils.checkpoint import checkpoint
from torch.utils.data import Dataset
import torchvision.transforms.functional as F
import torch.nn.functional as nn_F

from torch import nn

Imageio: 'libfreeimage-3.16.0-linux64.so' was not found on your computer; downloading it now.
Try 1. Download from https://github.com/imageio/imageio-binaries/raw/master/freeimage/libfreeimage-3.16.0-linux64.so (4.6 MB)
Downloading: 8192/4830080 bytes (0.2%)90112/4830080 bytes (1.9%)630784/4830080 bytes (13.1%)4136960/4830080 bytes (85.6%)4830080/4830080 bytes (100.0%)
  Done
File saved as /root/.imageio/freeimage/libfreeimage-3.16.0-linux64.so.


In [4]:
!pip install einops
from einops.layers.torch import Rearrange
from einops import rearrange

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m798.9 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## `utils.py`

In [6]:
def set_seed(seed: int = 0) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class RepresentationType(Enum):
    VOXEL = auto()
    STEPAN = auto()


class EventRepresentation:
    def __init__(self):
        pass

    def convert(self, events):
        raise NotImplementedError


class VoxelGrid(EventRepresentation):
    def __init__(self, input_size: tuple, normalize: bool):
        assert len(input_size) == 3
        self.voxel_grid = torch.zeros(
            (input_size), dtype=torch.float, requires_grad=False)
        self.nb_channels = input_size[0]
        self.normalize = normalize

    def convert(self, events):
        C, H, W = self.voxel_grid.shape
        with torch.no_grad():
            self.voxel_grid = self.voxel_grid.to(events['p'].device)
            voxel_grid = self.voxel_grid.clone()

            t_norm = events['t']
            t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0])

            x0 = events['x'].int()
            y0 = events['y'].int()
            t0 = t_norm.int()

            value = 2*events['p']-1
            #start_t = time()
            for xlim in [x0, x0+1]:
                for ylim in [y0, y0+1]:
                    for tlim in [t0, t0+1]:

                        mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (
                            ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels)
                        interp_weights = value * (1 - (xlim-events['x']).abs()) * (
                            1 - (ylim-events['y']).abs()) * (1 - (tlim - t_norm).abs())
                        index = H * W * tlim.long() + \
                            W * ylim.long() + \
                            xlim.long()

                        voxel_grid.put_(
                            index[mask], interp_weights[mask], accumulate=True)

            if self.normalize:
                mask = torch.nonzero(voxel_grid, as_tuple=True)
                if mask[0].size()[0] > 0:
                    mean = voxel_grid[mask].mean()
                    std = voxel_grid[mask].std()
                    if std > 0:
                        voxel_grid[mask] = (voxel_grid[mask] - mean) / std
                    else:
                        voxel_grid[mask] = voxel_grid[mask] - mean

        return voxel_grid


class PolarityCount(EventRepresentation):
    def __init__(self, input_size: tuple):
        assert len(input_size) == 3
        self.voxel_grid = torch.zeros(
            (input_size), dtype=torch.float, requires_grad=False)
        self.nb_channels = input_size[0]

    def convert(self, events):
        C, H, W = self.voxel_grid.shape
        with torch.no_grad():
            self.voxel_grid = self.voxel_grid.to(events['p'].device)
            voxel_grid = self.voxel_grid.clone()

            x0 = events['x'].int()
            y0 = events['y'].int()

            #start_t = time()
            for xlim in [x0, x0+1]:
                for ylim in [y0, y0+1]:
                    mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (
                        ylim >= 0)
                    interp_weights = (1 - (xlim-events['x']).abs()) * (
                        1 - (ylim-events['y']).abs())
                    index = H * W * events['p'].long() + \
                        W * ylim.long() + \
                        xlim.long()

                    voxel_grid.put_(
                        index[mask], interp_weights[mask], accumulate=True)

        return voxel_grid


def flow_16bit_to_float(flow_16bit: np.ndarray):
    assert flow_16bit.dtype == np.uint16
    assert flow_16bit.ndim == 3
    h, w, c = flow_16bit.shape
    assert c == 3

    valid2D = flow_16bit[..., 2] == 1
    assert valid2D.shape == (h, w)
    assert np.all(flow_16bit[~valid2D, -1] == 0)
    valid_map = np.where(valid2D)

    # to actually compute something useful:
    flow_16bit = flow_16bit.astype('float')

    flow_map = np.zeros((h, w, 2))
    flow_map[valid_map[0], valid_map[1], 0] = (
        flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128
    flow_map[valid_map[0], valid_map[1], 1] = (
        flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128
    return flow_map, valid2D

In [7]:
def warp_images_with_flow(images, flow):
    dim3 = 0
    if images.dim() == 3:
        dim3 = 1
        images = images.unsqueeze(0)
        flow = flow.unsqueeze(0)
    height = images.shape[2]
    width = images.shape[3]
    flow_x,flow_y = flow[:,0,...],flow[:,1,...]
    coord_x, coord_y = torch.meshgrid(torch.arange(height), torch.arange(width))

    if torch.cuda.is_available():
        pos_x = coord_x.reshape(height,width).type(torch.float32).cuda() + flow_x
        pos_y = coord_y.reshape(height,width).type(torch.float32).cuda() + flow_y
    else: # Troubleshoot without cuda
        pos_x = coord_x.reshape(height,width).type(torch.float32) + flow_x
        pos_y = coord_y.reshape(height,width).type(torch.float32) + flow_y
    pos_x = (pos_x-(height-1)/2)/((height-1)/2)
    pos_y = (pos_y-(width-1)/2)/((width-1)/2)

    pos = torch.stack((pos_y,pos_x),3).type(torch.float32)
    result = torch.nn.functional.grid_sample(images, pos, mode='bilinear', padding_mode='zeros')
    if dim3 == 1:
        result = result.squeeze()

    return result

def charbonnier_loss(delta, alpha=0.45, epsilon=1e-3):
        loss = torch.mean(torch.pow((delta ** 2 + epsilon ** 2), alpha))
        return loss

def compute_smoothness_loss(flow):

    flow_ucrop = flow[..., 1:]
    flow_dcrop = flow[..., :-1]
    flow_lcrop = flow[..., 1:, :]
    flow_rcrop = flow[..., :-1, :]

    flow_ulcrop = flow[..., 1:, 1:]
    flow_drcrop = flow[..., :-1, :-1]
    flow_dlcrop = flow[..., :-1, 1:]
    flow_urcrop = flow[..., 1:, :-1]

    smoothness_loss = charbonnier_loss(flow_lcrop - flow_rcrop) +\
                      charbonnier_loss(flow_ucrop - flow_dcrop) +\
                      charbonnier_loss(flow_ulcrop - flow_drcrop) +\
                      charbonnier_loss(flow_dlcrop - flow_urcrop)
    smoothness_loss /= 4.

    return smoothness_loss

def compute_photometric_loss(prev_images, next_images, flow_dict):
    total_photometric_loss = 0.
    loss_weight_sum = 0.
    for i in range(len(flow_dict)):
        for image_num in range(prev_images.shape[0]):
            flow = flow_dict["flow{}".format(i)][image_num]
            height = flow.shape[1]
            width = flow.shape[2]

            prev_images_resize = F.to_tensor(F.resize(F.to_pil_image(prev_images[image_num].cpu()),
                                                    [height, width]))
            next_images_resize = F.to_tensor(F.resize(F.to_pil_image(next_images[image_num].cpu()),
                                                    [height, width]))

            if torch.cuda.is_available():
                prev_images_resize = prev_images_resize.cuda()
                next_images_resize = next_images_resize.cuda()

            next_images_warped = warp_images_with_flow(next_images_resize, flow)

            distance = next_images_warped - prev_images_resize
            photometric_loss = charbonnier_loss(distance)
            total_photometric_loss += photometric_loss
        loss_weight_sum += 1.
    total_photometric_loss /= loss_weight_sum

    return total_photometric_loss


class TotalLoss(torch.nn.Module):
    def __init__(self, smoothness_weight, weight_decay_weight=1e-4):
        super(TotalLoss, self).__init__()
        self._smoothness_weight = smoothness_weight
        self._weight_decay_weight = weight_decay_weight

    def forward(self, flow_dict, prev_image, next_image, EVFlowNet_model):
        # weight decay loss
        weight_decay_loss = 0
        for i in EVFlowNet_model.parameters():
            weight_decay_loss += torch.sum(i ** 2) / 2 * self._weight_decay_weight

        # smoothness loss
        smoothness_loss = 0
        for i in range(len(flow_dict)):
            smoothness_loss += compute_smoothness_loss(flow_dict["flow{}".format(i)])
        smoothness_loss *= self._smoothness_weight / 4.

        # Photometric loss.
        photometric_loss = compute_photometric_loss(prev_image,
                                                    next_image,
                                                    flow_dict)

        # Warped next image for debugging.
        #next_image_warped = warp_images_with_flow(next_image,
        #                                          flow_dict['flow3'])

        loss = weight_decay_loss + photometric_loss + smoothness_loss

        return loss

## Logging

```python
import sys
import logging
import datetime
import time

class StreamToLogger:
    def __init__(self, logger, log_level=logging.INFO):
        self.logger = logger
        self.log_level = log_level
        self.linebuf = ''

    def write(self, buf):
        for line in buf.rstrip().splitlines():
            self.logger.log(self.log_level, line.rstrip())
            sys.__stdout__.write(line + '\n')

    def flush(self):
        pass

# Create logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Create file handler which logs even debug messages
current_time = time.strftime("%Y%m%d-%H%M%S")
logfile = f'../../logs/log_{current_time}.log'
fh = logging.FileHandler(logfile)
fh.setLevel(logging.DEBUG)

# Create console handler with a higher log level
ch = logging.StreamHandler(sys.__stdout__)
ch.setLevel(logging.INFO)

# Create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)

# Add the handlers to the logger
logger.addHandler(fh)
logger.addHandler(ch)

# Ensure the log file is created
with open(logfile, 'w') as file:
    file.write('')  # Create an empty log file

# Redirect stdout and stderr to the logger
sys.stdout = StreamToLogger(logger, logging.INFO)
sys.stderr = StreamToLogger(logger, logging.ERROR)
```

## `datasets.py`

In [8]:
VISU_INDEX = 1


class EventSlicer:
    def __init__(self, h5f: h5py.File):
        self.h5f = h5f

        self.events = dict()
        for dset_str in ['p', 'x', 'y', 't']:
            self.events[dset_str] = self.h5f['events/{}'.format(dset_str)]

        # This is the mapping from milliseconds to event index:
        # It is defined such that
        # (1) t[ms_to_idx[ms]] >= ms*1000
        # (2) t[ms_to_idx[ms] - 1] < ms*1000
        # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds.
        #
        # As an example, given 't' and 'ms':
        # t:    0     500    2100    5000    5000    7100    7200    7200    8100    9000
        # ms:   0       1       2       3       4       5       6       7       8       9
        #
        # we get
        #
        # ms_to_idx:
        #       0       2       2       3       3       3       5       5       8       9
        self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64')

        self.t_offset = int(h5f['t_offset'][()])
        self.t_final = int(self.events['t'][-1]) + self.t_offset

    def get_final_time_us(self):
        return self.t_final

    def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]:
        """Get events (p, x, y, t) within the specified time window
        Parameters
        ----------
        t_start_us: start time in microseconds
        t_end_us: end time in microseconds
        Returns
        -------
        events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
        """
        assert t_start_us < t_end_us

        # We assume that the times are top-off-day, hence subtract offset:
        t_start_us -= self.t_offset
        t_end_us -= self.t_offset

        t_start_ms, t_end_ms = self.get_conservative_window_ms(
            t_start_us, t_end_us)
        t_start_ms_idx = self.ms2idx(t_start_ms)
        t_end_ms_idx = self.ms2idx(t_end_ms)
        if t_start_ms_idx is None or t_end_ms_idx is None:
            print('Error', 'start', t_start_us, 'end', t_end_us)
            # Cannot guarantee window size anymore
            return None

        events = dict()
        time_array_conservative = np.asarray(
            self.events['t'][t_start_ms_idx:t_end_ms_idx])
        idx_start_offset, idx_end_offset = self.get_time_indices_offsets(
            time_array_conservative, t_start_us, t_end_us)
        t_start_us_idx = t_start_ms_idx + idx_start_offset
        t_end_us_idx = t_start_ms_idx + idx_end_offset
        # Again add t_offset to get gps time
        events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset
        for dset_str in ['p', 'x', 'y']:
            events[dset_str] = np.asarray(
                self.events[dset_str][t_start_us_idx:t_end_us_idx])
            assert events[dset_str].size == events['t'].size
        return events

    @staticmethod
    def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]:
        """Compute a conservative time window of time with millisecond resolution.
        We have a time to index mapping for each millisecond. Hence, we need
        to compute the lower and upper millisecond to retrieve events.
        Parameters
        ----------
        ts_start_us:    start time in microseconds
        ts_end_us:      end time in microseconds
        Returns
        -------
        window_start_ms:    conservative start time in milliseconds
        window_end_ms:      conservative end time in milliseconds
        """
        assert ts_end_us > ts_start_us
        window_start_ms = math.floor(ts_start_us/1000)
        window_end_ms = math.ceil(ts_end_us/1000)
        return window_start_ms, window_end_ms

    @staticmethod
    @jit(nopython=True)
    def get_time_indices_offsets(
            time_array: np.ndarray,
            time_start_us: int,
            time_end_us: int) -> Tuple[int, int]:
        """Compute index offset of start and end timestamps in microseconds
        Parameters
        ----------
        time_array:     timestamps (in us) of the events
        time_start_us:  start timestamp (in us)
        time_end_us:    end timestamp (in us)
        Returns
        -------
        idx_start:  Index within this array corresponding to time_start_us
        idx_end:    Index within this array corresponding to time_end_us
        such that (in non-edge cases)
        time_array[idx_start] >= time_start_us
        time_array[idx_end] >= time_end_us
        time_array[idx_start - 1] < time_start_us
        time_array[idx_end - 1] < time_end_us
        this means that
        time_start_us <= time_array[idx_start:idx_end] < time_end_us
        """

        assert time_array.ndim == 1

        idx_start = -1
        if time_array[-1] < time_start_us:

            # Return same index twice: array[x:x] is empty.
            return time_array.size, time_array.size
        else:
            for idx_from_start in range(0, time_array.size, 1):
                if time_array[idx_from_start] >= time_start_us:
                    idx_start = idx_from_start
                    break
        assert idx_start >= 0

        idx_end = time_array.size
        for idx_from_end in range(time_array.size - 1, -1, -1):
            if time_array[idx_from_end] >= time_end_us:
                idx_end = idx_from_end
            else:
                break

        assert time_array[idx_start] >= time_start_us
        if idx_end < time_array.size:
            assert time_array[idx_end] >= time_end_us
        if idx_start > 0:
            assert time_array[idx_start - 1] < time_start_us
        if idx_end > 0:
            assert time_array[idx_end - 1] < time_end_us
        return idx_start, idx_end

    def ms2idx(self, time_ms: int) -> int:
        assert time_ms >= 0
        if time_ms >= self.ms_to_idx.size:
            return None
        return self.ms_to_idx[time_ms]


class Sequence(Dataset):
    def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str = 'test', delta_t_ms: int = 100,
                 num_bins: int = 4, transforms=[], name_idx=0, visualize=False, load_gt=False):
        assert num_bins >= 1
        assert delta_t_ms == 100
        assert seq_path.is_dir(), seq_path
        assert mode in {'train', 'test'}
        assert representation_type is not None
        '''
        ディレクトリ構造:

        data
        ├─test
        |  ├─seq_1
        |  |    ├─events_left
        |  |    |   ├─events.h5
        |  |    |   └─rectify_map.h5
        |  |    └─forward_timestamps.txt
        └─train
            ├─seq_1
            |    ├─events_left
            |    |       ├─ events.h5
            |    |       └─ rectify_map.h5
            |    ├─flow_forward
            |    |       ├─ 000134.png
            |    |       |.....
            |    └─forward_timestamps.txt
            ├─seq_2
            └─seq_3
        '''
        self.seq_name = PurePath(seq_path).name
        self.mode = mode
        self.name_idx = name_idx
        self.visualize_samples = visualize
        self.load_gt = load_gt
        self.transforms = transforms
        if self.mode == "test":
            assert load_gt == False
            # Get Test Timestamp File
            ev_dir_location = seq_path / 'events_left'
            timestamp_file = seq_path / 'forward_timestamps.txt'
            flow_path = seq_path / 'flow_forward'
            timestamps_flow = np.loadtxt(
                seq_path / 'forward_timestamps.txt', delimiter=',', dtype='int64')
            self.indices = np.arange(len(timestamps_flow))
            self.timestamps_flow = timestamps_flow[:, 0]

        elif self.mode == "train":
            ev_dir_location = seq_path / 'events_left'
            flow_path = seq_path / 'flow_forward'
            timestamp_file = seq_path / 'forward_timestamps.txt'
            self.flow_png = [Path(os.path.join(flow_path, img)) for img in sorted(
                os.listdir(flow_path))]
            timestamps_flow = np.loadtxt(
                seq_path / 'forward_timestamps.txt', delimiter=',', dtype='int64')
            self.indices = np.arange(len(timestamps_flow))
            self.timestamps_flow = timestamps_flow[:, 0]
        else:
            pass
        assert timestamp_file.is_file()

        file = np.genfromtxt(
            timestamp_file,
            delimiter=','
        )

        self.idx_to_visualize = file[:, 2] if file.shape[1] == 3 else []

        # Save output dimensions
        self.height = 480
        self.width = 640
        self.num_bins = num_bins


        # Set event representation
        self.voxel_grid = VoxelGrid(
                (self.num_bins, self.height, self.width), normalize=True)
        self.delta_t_us = delta_t_ms * 1000

        # Left events only
        ev_data_file = ev_dir_location / 'events.h5'
        ev_rect_file = ev_dir_location / 'rectify_map.h5'

        h5f_location = h5py.File(str(ev_data_file), 'r')
        self.h5f = h5f_location
        self.event_slicer = EventSlicer(h5f_location)

        self.h5rect = h5py.File(str(ev_rect_file), 'r')
        self.rectify_ev_map = self.h5rect['rectify_map'][()]


    def events_to_voxel_grid(self, p, t, x, y, device: str = 'cpu'):
        t = (t - t[0]).astype('float32')
        t = (t/t[-1])
        x = x.astype('float32')
        y = y.astype('float32')
        pol = p.astype('float32')
        event_data_torch = {
            'p': torch.from_numpy(pol),
            't': torch.from_numpy(t),
            'x': torch.from_numpy(x),
            'y': torch.from_numpy(y),
        }
        return self.voxel_grid.convert(event_data_torch)

    def getHeightAndWidth(self):
        return self.height, self.width

    @staticmethod
    def get_disparity_map(filepath: Path):
        assert filepath.is_file()
        disp_16bit = cv2.imread(str(filepath), cv2.IMREAD_ANYDEPTH)
        return disp_16bit.astype('float32')/256

    @staticmethod
    def load_flow(flowfile: Path):
        assert flowfile.exists()
        assert flowfile.suffix == '.png'
        flow_16bit = iio.imread(str(flowfile), plugin='PNG-FI')
        flow, valid2D = flow_16bit_to_float(flow_16bit)
        return flow, valid2D

    @staticmethod
    def close_callback(h5f):
        h5f.close()

    def get_image_width_height(self):
        return self.height, self.width

    def __len__(self):
        # Ignore the first and last images as their own
        return len(self.timestamps_flow) # - 2

    def rectify_events(self, x: np.ndarray, y: np.ndarray):
        # assert location in self.locations
        # From distorted to undistorted
        rectify_map = self.rectify_ev_map
        assert rectify_map.shape == (
            self.height, self.width, 2), rectify_map.shape
        assert x.max() < self.width
        assert y.max() < self.height
        return rectify_map[y, x]

    def get_data(self, index) -> Dict[str, any]:
        # Adjust index to skip the first element
#         index += 1

        ts_start: int = self.timestamps_flow[index] - self.delta_t_us
        ts_end: int = self.timestamps_flow[index]

        file_index = self.indices[index]

        output = {
            'file_index': file_index,
            'timestamp': self.timestamps_flow[index],
            'seq_name': self.seq_name
        }
        # Save sample for benchmark submission
        output['save_submission'] = file_index in self.idx_to_visualize
        output['visualize'] = self.visualize_samples
        event_data = self.event_slicer.get_events(
            ts_start, ts_end)
        p = event_data['p']
        t = event_data['t']
        x = event_data['x']
        y = event_data['y']

        xy_rect = self.rectify_events(x, y)
        x_rect = xy_rect[:, 0]
        y_rect = xy_rect[:, 1]

        if self.voxel_grid is None:
            raise NotImplementedError
        else:
            event_representation = self.events_to_voxel_grid(
                p, t, x_rect, y_rect)
            output['event_volume'] = event_representation
        output['name_map'] = self.name_idx

        if self.load_gt:
            output['flow_gt'] = [torch.tensor(x) for x in self.load_flow(self.flow_png[index])]
            output['flow_gt'][0] = torch.moveaxis(output['flow_gt'][0], -1, 0)
            output['flow_gt'][1] = torch.unsqueeze(output['flow_gt'][1], 0)

            flow_gt_shape = [tensor.shape for tensor in output['flow_gt']]
            zero_flow_gt = [torch.zeros_like(tensor) for tensor in output['flow_gt']]

            # Load previous image
            if index > 0:
                output['prev_flow_gt'] = [torch.tensor(x) for x in self.load_flow(self.flow_png[index - 1])]
                output['prev_flow_gt'][0] = torch.moveaxis(output['prev_flow_gt'][0], -1, 0)
                output['prev_flow_gt'][1] = torch.unsqueeze(output['prev_flow_gt'][1], 0)
            else:
                output['prev_flow_gt'] = zero_flow_gt

            # Load next image
            if index < len(self.timestamps_flow) - 1:
                output['next_flow_gt'] = [torch.tensor(x) for x in self.load_flow(self.flow_png[index + 1])]
                output['next_flow_gt'][0] = torch.moveaxis(output['next_flow_gt'][0], -1, 0)
                output['next_flow_gt'][1] = torch.unsqueeze(output['next_flow_gt'][1], 0)
            else:
                output['next_flow_gt'] = zero_flow_gt

        return output

    def __getitem__(self, idx):
        # Adjust index to skip the first element
        sample = self.get_data(idx) # idx + 1

        if self.transforms:
            sample = self.transforms(sample)

        return sample

    def get_voxel_grid(self, idx):

        if idx == 0:
            event_data = self.event_slicer.get_events(
                self.timestamps_flow[0] - self.delta_t_us, self.timestamps_flow[0])
        elif idx > 0 and idx <= self.__len__():
            event_data = self.event_slicer.get_events(
                self.timestamps_flow[idx-1], self.timestamps_flow[idx-1] + self.delta_t_us)
        else:
            raise IndexError

        p = event_data['p']
        t = event_data['t']
        x = event_data['x']
        y = event_data['y']

        xy_rect = self.rectify_events(x, y)
        x_rect = xy_rect[:, 0]
        y_rect = xy_rect[:, 1]
        return self.events_to_voxel_grid(p, t, x_rect, y_rect)

    def get_event_count_image(self, ts_start, ts_end, num_bins, normalize=True):
        assert ts_end > ts_start
        delta_t_bin = (ts_end - ts_start) / num_bins
        ts_start_bin = np.linspace(
            ts_start, ts_end, num=num_bins, endpoint=False)
        ts_end_bin = ts_start_bin + delta_t_bin
        assert abs(ts_end_bin[-1] - ts_end) < 10.
        ts_end_bin[-1] = ts_end

        event_count = torch.zeros(
            (num_bins, self.height, self.width), dtype=torch.float, requires_grad=False)

        for i in range(num_bins):
            event_data = self.event_slicer.get_events(
                ts_start_bin[i], ts_end_bin[i])
            p = event_data['p']
            t = event_data['t']
            x = event_data['x']
            y = event_data['y']

            t = (t - t[0]).astype('float32')
            t = (t/t[-1])
            x = x.astype('float32')
            y = y.astype('float32')
            pol = p.astype('float32')
            event_data_torch = {
                'p': torch.from_numpy(pol),
                't': torch.from_numpy(t),
                'x': torch.from_numpy(x),
                'y': torch.from_numpy(y),
            }
            x = event_data_torch['x']
            y = event_data_torch['y']
            xy_rect = self.rectify_events(x.int(), y.int())
            x_rect = torch.from_numpy(xy_rect[:, 0]).long()
            y_rect = torch.from_numpy(xy_rect[:, 1]).long()
            value = 2*event_data_torch['p']-1
            index = self.width*y_rect + x_rect
            mask = (x_rect < self.width) & (y_rect < self.height)
            event_count[i].put_(index[mask], value[mask], accumulate=True)

        return event_count

    @staticmethod
    def normalize_tensor(event_count):
        mask = torch.nonzero(event_count, as_tuple=True)
        if mask[0].size()[0] > 0:
            mean = event_count[mask].mean()
            std = event_count[mask].std()
            if std > 0:
                event_count[mask] = (event_count[mask] - mean) / std
            else:
                event_count[mask] = event_count[mask] - mean
        return event_count


class SequenceRecurrent(Sequence):
    def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str = 'test', delta_t_ms: int = 100,
                 num_bins: int = 15, transforms=None, sequence_length=1, name_idx=0, visualize=False, load_gt=False):
        super(SequenceRecurrent, self).__init__(seq_path, representation_type, mode, delta_t_ms, transforms=transforms,
                                                name_idx=name_idx, visualize=visualize, load_gt=load_gt)
        self.crop_size = self.transforms['randomcrop'] if 'randomcrop' in self.transforms else None
        self.sequence_length = sequence_length
        self.valid_indices = self.get_continuous_sequences()

    def get_continuous_sequences(self):
        continuous_seq_idcs = []
        if self.sequence_length > 1:
            for i in range(len(self.timestamps_flow)-self.sequence_length+1):
                diff = self.timestamps_flow[i + self.sequence_length-1] - self.timestamps_flow[i]
                if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]):
                    continuous_seq_idcs.append(i)
        else:
            for i in range(len(self.timestamps_flow)-1):
                diff = self.timestamps_flow[i+1] - self.timestamps_flow[i]
                if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]):
                    continuous_seq_idcs.append(i)
        return continuous_seq_idcs

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        assert idx >= 0
        assert idx < len(self)

        # Valid index is the actual index we want to load, which guarantees a continuous sequence length
        valid_idx = self.valid_indices[idx]

        sequence = []
        j = valid_idx

        ts_cur = self.timestamps_flow[j]
        # Add first sample
        sample = self.get_data_sample(j)
        sequence.append(sample)

        # Data augmentation according to first sample
        crop_window = None
        flip = None
        if 'crop_window' in sample.keys():
            crop_window = sample['crop_window']
        if 'flipped' in sample.keys():
            flip = sample['flipped']

        for i in range(self.sequence_length-1):
            j += 1
            ts_old = ts_cur
            ts_cur = self.timestamps_flow[j]
            assert(ts_cur-ts_old < 100000 + 1000)
            sample = self.get_data_sample(
                j, crop_window=crop_window, flip=flip)
            sequence.append(sample)

        # Check if the current sample is the first sample of a continuous sequence
        if idx == 0 or self.valid_indices[idx]-self.valid_indices[idx-1] != 1:
            sequence[0]['new_sequence'] = 1
            print("Timestamp {} is the first one of the next seq!".format(
                self.timestamps_flow[self.valid_indices[idx]]))
        else:
            sequence[0]['new_sequence'] = 0

        # random crop
        if self.crop_size is not None:
            i, j, h, w = RandomCrop.get_params(
                sample["event_volume_old"], output_size=self.crop_size)
            keys_to_crop = ["event_volume_old", "event_volume_new",
                            "flow_gt_event_volume_old", "flow_gt_event_volume_new",
                            "flow_gt_next",]

            for sample in sequence:
                for key, value in sample.items():
                    if key in keys_to_crop:
                        if isinstance(value, torch.Tensor):
                            sample[key] = tf.functional.crop(value, i, j, h, w)
                        elif isinstance(value, list) or isinstance(value, tuple):
                            sample[key] = [tf.functional.crop(v, i, j, h, w) for v in value]
        return sequence


class DatasetProvider:
    def __init__(self, dataset_path: Path, representation_type: RepresentationType, delta_t_ms: int = 100, num_bins=4, config=None, visualize=False, transforms=None):
        test_path = Path(os.path.join(dataset_path, 'test'))
        train_path = Path(os.path.join(dataset_path, 'train'))
        assert dataset_path.is_dir(), str(dataset_path)
        assert test_path.is_dir(), str(test_path)
        assert delta_t_ms == 100
        self.config = config
        self.name_mapper_test = []

        if transforms:
            self.transforms = transforms
        else:
            self.transforms = tf.Compose([
                transforms.ToTensor(),  # Convert image to PyTorch tensor
            ])

        # Assemble test sequences
        test_sequences = list()
        for child in test_path.iterdir():
            self.name_mapper_test.append(str(child).split("/")[-1])
            test_sequences.append(Sequence(child, representation_type, 'test', delta_t_ms, num_bins,
                                               name_idx=len(self.name_mapper_test) - 1,
                                               visualize=visualize,
                                               transforms=self.transforms))

        self.test_dataset = torch.utils.data.ConcatDataset(test_sequences)

        # Assemble train sequences
        available_seqs = os.listdir(train_path)

        seqs = available_seqs

        train_sequences: list[Sequence] = []
        for seq in seqs:
            extra_arg = dict()
            train_sequences.append(Sequence(Path(train_path) / seq,
                                   representation_type=representation_type, mode="train",
                                   load_gt=True, **extra_arg, transforms=self.transforms))
            self.train_dataset: torch.utils.data.ConcatDataset[Sequence] = torch.utils.data.ConcatDataset(train_sequences)

    def get_test_dataset(self):
        return self.test_dataset

    def get_train_dataset(self):
        return self.train_dataset

    def get_name_mapping_test(self):
        return self.name_mapper_test

    def summary(self, logger):
        logger.write_line(
            "================================== Dataloader Summary ====================================", True)
        logger.write_line("Loader Type:\t\t" + self.__class__.__name__, True)
        logger.write_line("Number of Voxel Bins: {}".format(
            self.test_dataset.datasets[0].num_bins), True)
        logger.write_line("Number of Train Sequences: {}".format(
            len(self.train_dataset)), True)

def train_collate(sample_list):
    batch = dict()
    for field_name in sample_list[0]:
        if field_name == 'timestamp':
            batch['timestamp'] = [sample[field_name] for sample in sample_list]
        if field_name == 'seq_name':
            batch['seq_name'] = [sample[field_name] for sample in sample_list]
        if field_name == 'new_sequence':
            batch['new_sequence'] = [sample[field_name]
                                     for sample in sample_list]
        if field_name.startswith("event_volume"):
            batch[field_name] = torch.stack(
                [sample[field_name] for sample in sample_list])
        if field_name.startswith("flow_gt") or field_name.startswith('prev_flow_gt') or field_name.startswith('next_flow_gt'):
            if all(field_name in x for x in sample_list):
                batch[field_name] = torch.stack(
                    [sample[field_name][0] for sample in sample_list])
                batch[field_name + '_valid_mask'] = torch.stack(
                    [sample[field_name][1] for sample in sample_list])

    return batch


def rec_train_collate(sample_list):
    seq_length = len(sample_list[0])
    seq_of_batch = []
    for i in range(seq_length):
        seq_of_batch.append(train_collate(
            [sample[i] for sample in sample_list]))
    return seq_of_batch

## PCLNet

## From the github

In [9]:
import torch.utils.model_zoo as model_zoo

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        endpoint = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        endpoint.append(x) # output here

        x = self.maxpool(x)
        x = self.layer1(x)
        endpoint.append(x)
        x = self.layer2(x)
        endpoint.append(x)
        x = self.layer3(x)
        endpoint.append(x)
        x = self.layer4(x)
        endpoint.append(x)

        return endpoint


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        state = model.state_dict()
        state_ckp = model_zoo.load_url(model_urls['resnet18'])
        cnt = 0
        for k, val in state_ckp.items():
            if k in state.keys():
                state[k] = val
                cnt += 1
        model.load_state_dict(state)
        print ("RestNet checkpoint loaded: %d" % cnt)
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        state = model.state_dict()
        state_ckp = model_zoo.load_url(model_urls['resnet34'])
        cnt = 0
        for k, val in state_ckp.items():
            if k in state.keys():
                state[k] = val
                cnt += 1
        model.load_state_dict(state)
        print ("RestNet checkpoint loaded: %d" % cnt)
    return model


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model

In [10]:
from torch.autograd import Variable

class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()

        assert hidden_channels % 2 == 0

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.bias = bias
        self.kernel_size = kernel_size

        self.padding = int((kernel_size - 1) / 2)
        self.Gates = nn.Conv2d(self.input_channels + self.hidden_channels , 4*self.hidden_channels,
                self.kernel_size, 1, self.padding, bias=True)

    def forward(self, x, h, c):

        stacked_inputs = torch.cat((x, h), 1)
        gates = self.Gates(stacked_inputs)

        # chunk across the channel dimension
        xi, xf, xo, xg = gates.chunk(4, 1)

        # apply sigmoid non linearity
        xi = torch.sigmoid(xi)
        xf = torch.sigmoid(xf)
        xo = torch.sigmoid(xo)
        xg = torch.tanh(xg)

        # compute current cell and hidden state
        c = (xf * c) + (xi * xg)
        h = xo * torch.tanh(c)

        return h, c

    def init_hidden(self, batch_size, hidden, shape):
        return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])),
                Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])))


class ConvLSTM(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1], bias=True):
        super(ConvLSTM, self).__init__()
        self.input_channels = [input_channels] + hidden_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = len(hidden_channels)
        self.step = step
        self.bias = bias
        self.effective_step = effective_step
        self._all_layers = []
        for i in range(self.num_layers):
            name = 'cell{}'.format(i)
            cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size, self.bias)
            setattr(self, name, cell)
            self._all_layers.append(cell)

    def forward(self, input):
        #input : (num, seq_len, channel, H,W)
        internal_state = []
        outputs = []
        for step in range(self.step):
            x = input[:, step, :,:,:]
            for i in range(self.num_layers):
                # all cells are initialized in the first step
                name = 'cell{}'.format(i)
                if step == 0:
                    bsize, _, height, width = x.size()
                    (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i],
                            shape=(height, width))
                    internal_state.append((h, c))

                # do forward
                (h, c) = internal_state[i]
                x, new_c = getattr(self, name)(x, h, c)
                internal_state[i] = (x, new_c)
            # only record effective steps
            if step in self.effective_step:
                outputs.append(x)
        return outputs, (x, new_c)

In [11]:
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
    if type(in_planes) == np.int64:
        in_planes = np.asscalar(in_planes)
    return nn.Sequential(
        nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                  padding=padding, dilation=dilation, bias=True),
        nn.LeakyReLU(0.1))


def predict_flow(in_planes):
    if type(in_planes) == np.int64:
        in_planes = np.asscalar(in_planes)
    return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=True)


def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
    if type(in_planes) == np.int64:
        in_planes = np.asscalar(in_planes)
    return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True)

def in_f(flow):
    return nn_F.interpolate(flow, size=(480, 640), mode='bilinear', align_corners=False)


class PCLNet(nn.Module):
    """
    PCLNet: Unsupervised Learning for Optical Flow Estimation Using Pyramid Convolution LSTM
    Author: Shuosen Guan
    """

    def __init__(self, args):

        super(PCLNet, self).__init__()
        self.args = args

        snippet_len = args.snippet_len
        self.feature_net = eval(args.backbone)(pretrained=True, num_classes=args.class_num)
        if args.freeze_vgg:
            for p in self.feature_net.parameters():
                p.required_grad = False
            print("[>>>> Feature head frozen.<<<<]")

        # Motion Encoding
        # in_size: 1/2
        self.clstm_encoder_1 = ConvLSTM(input_channels=64, hidden_channels=[64],
                                        kernel_size=3, step=snippet_len, effective_step=list(range(snippet_len)))
        # in_size: 1/4
        self.clstm_encoder_2 = ConvLSTM(input_channels=64, hidden_channels=[64],
                                        kernel_size=3, step=snippet_len, effective_step=list(range(snippet_len)))
        # in_size: 1/8
        self.clstm_encoder_3 = ConvLSTM(input_channels=128, hidden_channels=[128],
                                        kernel_size=3, step=snippet_len, effective_step=list(range(snippet_len)))
        # in_size: 1/16
        self.clstm_encoder_4 = ConvLSTM(input_channels=256, hidden_channels=[256],
                                        kernel_size=3, step=snippet_len, effective_step=list(range(snippet_len)))

        self.conv_B1    = conv(64, 64, stride=1, kernel_size=3, padding=1)
        self.conv_S1_1  = conv(64, 64, stride=1, kernel_size=3, padding=1)
        self.conv_S1_2  = conv(64, 64, stride=1, kernel_size=3, padding=1)
        self.conv_D1    = conv(64, 64, stride=2)
        self.Pool1      = nn.MaxPool2d(8, 8)

        self.conv_B2    = conv(64, 64, stride=1, kernel_size=3, padding=1)
        self.conv_S2_1  = conv(64 + 64, 128, stride=1, kernel_size=3, padding=1)
        self.conv_S2_2  = conv(128, 128, stride=1, kernel_size=3, padding=1)
        self.conv_D2    = conv(128, 64, stride=2)
        self.Pool2      = nn.MaxPool2d(4, 4)

        self.conv_B3    = conv(128, 128, stride=1, kernel_size=3, padding=1)
        self.conv_S3_1  = conv(128 + 64, 128, stride=1, kernel_size=3, padding=1)
        self.conv_S3_2  = conv(128, 128, stride=1, kernel_size=3, padding=1)
        self.conv_D3    = conv(128, 64, stride=2)
        self.Pool3      = nn.MaxPool2d(2, 2)

        self.conv_B4    = conv(256, 128, stride=1, kernel_size=3, padding=1)
        self.conv_S4_1  = conv(128 + 64, 128, stride=1, kernel_size=3, padding=1)
        self.conv_S4_2  = conv(128, 128, stride=1, kernel_size=3, padding=1)

        # Motion feature
        self.conv_M = conv((64 + 128 + 128 + 128), 256, stride=1, kernel_size=3, padding=1)

        # Motion reconstruction
        if self.args.couple:
            rec_in_size = [0, 64 + 64 + 2, 128 + 128 + 2, 128 + 196 + 2, 128 + 256]
        else:
            rec_in_size = [0, 64 + 2, 128 + 2, 196 + 2, 256]

        self.conv_4     = conv(rec_in_size[4], 256)
        self.pred_flow4 = predict_flow(256)
        self.up_flow4   = deconv(2, 2)
        self.up_feat4   = deconv(256, 196)

        self.conv_3     = conv(rec_in_size[3], 196)
        self.pred_flow3 = predict_flow(196)
        self.up_flow3   = deconv(2, 2)
        self.up_feat3   = deconv(196, 128)

        self.conv_2     = conv(rec_in_size[2], 96)
        self.pred_flow2 = predict_flow(96)
        self.up_flow2   = conv(2, 2)
        self.up_feat2   = conv(96, 64)

        self.conv_1     = conv(rec_in_size[1], 64)
        self.pred_flow1 = predict_flow(64)

        self.dc_conv1 = conv(64, 64, kernel_size=3, stride=1, padding=1, dilation=1)
        self.dc_conv2 = conv(64, 64, kernel_size=3, stride=1, padding=2, dilation=2)
        self.dc_conv3 = conv(64, 64, kernel_size=3, stride=1, padding=4, dilation=4)
        self.dc_conv4 = conv(64, 64, kernel_size=3, stride=1, padding=8, dilation=8)
        self.dc_conv5 = conv(64, 64, kernel_size=3, stride=1, padding=16, dilation=16)
        self.dc_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
        self.dc_conv7 = predict_flow(32)

        # for some reason, conv returns shape torch.Size([B*T, 3, 482, 642])
        self.reduce_channels = nn.Conv2d(4, 3, kernel_size=1)


    def forward(self, x):

        if x.dim() == 6:    # (batch_size, K, snippet_len, channel, H, W)
            batch_size, K, snippet_len, channel, H, W = x.size()
        elif x.dim() == 5:  # (batch_size, snippet_len, channel, H, W)
            batch_size, snippet_len, channel, H, W = x.size()
            K = 1
        elif x.dim() == 4:  # (batch_size, channel * snippet_len, H, W)
            batch_size, _channels, H, W = x.size()
            K, channel = 1, 3
            snippet_len = _channels // channel
        else:
            raise RuntimeError('Input format not suppored!')

        x = x.contiguous().view(-1, channel, H, W)
        if channel > 3:
            x = self.reduce_channels(x)

        la1, la2, la3, la4, _ = self.feature_net(x)

        la1 = la1.view((-1, snippet_len) + la1.size()[1:])
        la2 = la2.view((-1, snippet_len) + la2.size()[1:])
        la3 = la3.view((-1, snippet_len) + la3.size()[1:])
        la4 = la4.view((-1, snippet_len) + la4.size()[1:])
        # la5 = la5.view((-1, snippet_len) + la5.size()[1:])

        h1, _ = self.clstm_encoder_1(la1)
        h2, _ = self.clstm_encoder_2(la2)
        h3, _ = self.clstm_encoder_3(la3)
        h4, _ = self.clstm_encoder_4(la4)
        # list for each step (batch_size * K, channel, H, W)

        # (batch_size * K*(snippet_len -1), channel, H, W)
        h1 = torch.stack(h1[1:], 1).view((-1,) + h1[0].size()[-3:])
        h2 = torch.stack(h2[1:], 1).view((-1,) + h2[0].size()[-3:])
        h3 = torch.stack(h3[1:], 1).view((-1,) + h3[0].size()[-3:])
        h4 = torch.stack(h4[1:], 1).view((-1,) + h4[0].size()[-3:])

        x1 = self.conv_B1(h1)
        x1 = self.conv_S1_2(self.conv_S1_1(x1))

        x2 = torch.cat((self.conv_B2(h2), self.conv_D1(x1)), 1)
        x2 = self.conv_S2_2(self.conv_S2_1(x2))

        x3 = torch.cat((self.conv_B3(h3), self.conv_D2(x2)), 1)
        x3 = self.conv_S3_2(self.conv_S3_1(x3))

        x4 = torch.cat((self.conv_B4(h4), self.conv_D3(x3)), 1)
        x4 = self.conv_S4_2(self.conv_S4_1(x4))

        xm = self.conv_M(torch.cat((self.Pool1(x1), self.Pool2(x2), self.Pool3(x3), x4), 1))

        rec_x4 = torch.cat((x4, xm), 1) if self.args.couple else xm
        x = self.conv_4(rec_x4)
        flow4 = self.pred_flow4(x)
        up_flow4 = self.up_flow4(flow4)
        up_feat4 = self.up_feat4(x)

        rec_x3 = torch.cat((x3, up_feat4, up_flow4), 1) if self.args.couple else torch.cat((up_feat4, up_flow4), 1)
        x = self.conv_3(rec_x3)
        flow3 = self.pred_flow3(x)
        up_flow3 = self.up_flow3(flow3)
        up_feat3 = self.up_feat3(x)

        rec_x2 = torch.cat((x2, up_feat3, up_flow3), 1) if self.args.couple else torch.cat((up_feat3, up_flow3), 1)
        x = self.conv_2(rec_x2)
        flow2 = self.pred_flow2(x)
        up_flow2 = self.up_flow2(flow2)
        up_feat2 = self.up_feat2(x)

        rec_x1 = torch.cat((x1, up_feat2, up_flow2), 1) if self.args.couple else torch.cat((up_feat2, up_flow2), 1)
        x = self.conv_1(rec_x1)
        flow1 = self.pred_flow1(x)

        x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x))))
        flow1 += self.dc_conv7(self.dc_conv6(self.dc_conv5(x)))

        re_dict = {
            'flow0': flow4,
            'flow1': flow3,
            'flow2': flow2,
            'flow3': flow1
        }

        # output size: (batch_size, K, snippet_len -1 , C,H,W)

#         flow_pyramid = [flo.view((batch_size, K, snippet_len - 1,) + flo.size()[-3:])
#                         for flo in [flow1, flow2, flow3, flow4]]
#         re_dict = {}
#         re_dict['flow_pyramid'] = flow_pyramid

#         return re_dict

        flow1 = flow1.view((batch_size, K, snippet_len - 1,) + flow1.size()[-3:])
        flow2 = flow2.view((batch_size, K, snippet_len - 1,) + flow2.size()[-3:])
        flow3 = flow3.view((batch_size, K, snippet_len - 1,) + flow3.size()[-3:])
        flow4 = flow4.view((batch_size, K, snippet_len - 1,) + flow4.size()[-3:])

        flow1_arr = [in_f(flow1[:, :, i, :, :, :].squeeze(1).squeeze(1)) for i in range(flow1.size(2))]
        flow2_arr = [in_f(flow2[:, :, i, :, :, :].squeeze(1).squeeze(1)) for i in range(flow2.size(2))]
        flow3_arr = [in_f(flow3[:, :, i, :, :, :].squeeze(1).squeeze(1)) for i in range(flow3.size(2))]
        flow4_arr = [in_f(flow4[:, :, i, :, :, :].squeeze(1).squeeze(1)) for i in range(flow4.size(2))]

        combined_flow = [torch.mean(torch.stack([f1, f2, f3, f4], dim=0), dim=0)
                             for f1, f2, f3, f4 in zip(flow1_arr, flow2_arr, flow3_arr, flow4_arr)]
        return re_dict, combined_flow

## Arguments & Definition

In [12]:
import argparse
args = argparse.Namespace(
    name='pclnet',
    snippet_len=3,
    backbone='resnet18',
    class_num=101,
    freeze_vgg=True,
    couple=False,

    lr=0.01, # 2e-5
    num_steps=100000,
    batch_size=32, # default: 16
    image_size=[480, 640],
    mixed_precision=False,
    iters=12,
    wdecay=0.00005,
    epsilon=1e-8,
    clip=1.0,
    dropout=0.0,
    gamma=0.8,
    add_noise=False,
    seed=27,
    dataset_path='data/',
)

In [32]:
del model

In [33]:
model = torch.nn.DataParallel(PCLNet(args).cpu())

RestNet checkpoint loaded: 100
[>>>> Feature head frozen.<<<<]


In [34]:
model.load_state_dict(torch.load('models/pclnet_batch50.pth'))

<All keys matched successfully>

## Image Preprocessing

In [20]:
from sklearn.decomposition import PCA, SparsePCA, TruncatedSVD
from sklearn.preprocessing import StandardScaler, MinMaxScaler

class DimensionalityReduction:
    def __init__(self, n_components, scaler=MinMaxScaler(), red_technique='pca'):
        self.scaler = scaler # if None, then doesn't perform any scaling
        self.n_components = n_components

        match red_technique:
            case 'pca':
                self.technique = PCA(n_components=n_components)
            case 'sparsepca':
                self.technique = SparsePCA(n_components=n_components)
            case 'tsvd':
                self.technique = TruncatedSVD(n_components=n_components)
            case _:
                raise NotImplementedError

    def __call__(self, image_tensor: torch.Tensor):
        # this is for a single image, not a batch of images
        C, H, W = image_tensor.shape

        tensor_reshaped = image_tensor.cpu().numpy().reshape(C, -1).T
        if self.scaler is not None:
            tensor_reshaped = self.scaler.fit_transform(tensor_reshaped)

        reduced_tensor_reshaped = self.technique.fit_transform(tensor_reshaped)
        reduced_tensor_reshaped = reduced_tensor_reshaped.T.reshape(self.n_components, H, W)

        reduced_tensor = torch.tensor(reduced_tensor_reshaped, device=image_tensor.device)
        return reduced_tensor

```python
a = torch.rand((4, 480, 640))
b = DimensionalityReduction(n_components=3, scaler=MinMaxScaler(), red_technique='tsvd')(a)
b.shape
```

In [21]:
# Needs to have 3 channels or color passed in
# runs into the problem of loss of information in dimensionality reduction
class HistogramEqualization:
    def __call__(self, img):
        return tf.functional.equalize(img)

In [22]:
class RandomCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, clip):
        i, j, h, w = tf.RandomCrop.get_params(clip[0], output_size=self.size)
        return [F.crop(img, i, j, h, w) for img in clip]

class RandomHorizontalFlip(object):
    def __call__(self, clip):
        if random.random() > 0.5:
            return [F.hflip(img) for img in clip]
        return clip

class RandomVerticalFlip(object):
    def __call__(self, clip):
        if random.random() > 0.5:
            return [F.vflip(img) for img in clip]
        return clip

class RandomRescale(object):
    def __init__(self, scale_range):
        self.scale_range = scale_range

    def __call__(self, clip):
        scale_factor = random.uniform(*self.scale_range)
        return [F.resize(img, [int(img.shape[0] * scale_factor), int(img.shape[1] * scale_factor)]) for img in clip]

class Resize(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, clip):
        return [F.resize(img, self.size) for img in clip]

clip_transform = tf.Compose([
    RandomCrop((256, 256)),
    RandomHorizontalFlip(),
    RandomVerticalFlip(),
    RandomRescale((0.8, 1.2)),
    Resize((480, 640)),
])

In [23]:
class CombinedTransform:
    def __init__(self, transform=tf.Compose([ tf.ToTensor() ])):
        self.transform = transform

    def __call__(self, flow_dict):
        seed = np.random.randint(2147483647)

        # Just in case I decide to add more feature columns
        feaure_flow_columns = [c for c in flow_dict.keys() if 'event_volume' in c]

        for col in feaure_flow_columns:
            torch.manual_seed(seed)

            if type(flow_dict[col]) == list:
                flow_dict[col] = [ self.transform(img) for img in flow_dict[col] ]
            else:
                flow_dict[col] = self.transform(flow_dict[col])

        return flow_dict

combined_transform = CombinedTransform(
    transform=tf.Compose([
        tf.GaussianBlur(kernel_size=(5, 5)),
    ])
#     transform=clip_transform
)

## `main.py`

Instead of changing the `Sequence`, just change how the data are loaded. Or iterate over the dataloader so that you use many images at once.

## Data Formatting

In [24]:
class RepresentationType(Enum):
    VOXEL = auto()
    STEPAN = auto()

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)

def compute_epe_error(pred_flow: torch.Tensor, gt_flow: torch.Tensor):
    '''
    end-point-error (ground truthと予測値の二乗誤差)を計算
    pred_flow: torch.Tensor, Shape: torch.Size([B, 2, 480, 640]) => 予測したオプティカルフローデータ
    gt_flow: torch.Tensor, Shape: torch.Size([B, 2, 480, 640]) => 正解のオプティカルフローデータ
    '''
    epe = torch.mean(torch.mean(torch.norm(pred_flow - gt_flow, p=2, dim=1), dim=(1, 2)), dim=0)
    return epe

def save_optical_flow_to_npy(flow: torch.Tensor, file_name: str):
    '''
    optical flowをnpyファイルに保存
    flow: torch.Tensor, Shape: torch.Size([2, 480, 640]) => オプティカルフローデータ
    file_name: str => ファイル名
    '''
    np.save(f"{file_name}.npy", flow.cpu().numpy())

In [25]:
set_seed(args.seed)

'''
    ディレクトリ構造:

    data
    ├─test
    |  ├─test_city
    |  |    ├─events_left
    |  |    |   ├─events.h5
    |  |    |   └─rectify_map.h5
    |  |    └─forward_timestamps.txt
    └─train
        ├─zurich_city_11_a
        |    ├─events_left
        |    |       ├─ events.h5
        |    |       └─ rectify_map.h5
        |    ├─ flow_forward
        |    |       ├─ 000134.png
        |    |       |.....
        |    └─ forward_timestamps.txt
        ├─zurich_city_11_b
        └─zurich_city_11_c
    '''

# ------------------
#    Dataloader
# ------------------

loader = DatasetProvider(
    dataset_path=Path(args.dataset_path),
    representation_type=RepresentationType.VOXEL,
    delta_t_ms=100,
    num_bins=4,
    transforms=combined_transform # Custom class
)
train_set = loader.get_train_dataset()
test_set = loader.get_test_dataset()

# def split_train_valid(dataset):
#     train_indices = []
#     valid_indices = []
#     for idx in range(len(dataset)):
#         sample = dataset[idx]
#         if 'flow_gt_valid_mask' in sample and sample['flow_gt_valid_mask'].all():
#             valid_indices.append(idx)
#         else:
#             train_indices.append(idx)
#     train_subset = torch.utils.data.Subset(dataset, train_indices)
#     valid_subset = torch.utils.data.Subset(dataset, valid_indices)
#     return train_subset, valid_subset

# train_set_split, valid_set_split = split_train_valid(train_set)

collate_fn = train_collate
train_data = DataLoader(train_set, # train_set_split
                        batch_size=args.batch_size, #
                        shuffle=False,
                        collate_fn=collate_fn,
                        drop_last=False,
                        num_workers=os.cpu_count(),
                        pin_memory=True)
test_data = DataLoader(test_set,
                       batch_size=1,
                       shuffle=False,
                       collate_fn=collate_fn,
                       drop_last=False,
                       num_workers=os.cpu_count(),
                       pin_memory=True)

'''
train data:
    Type of batch: Dict
    Key: seq_name, Type: list
    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ
    Key: flow_gt, Type: torch.Tensor, Shape: torch.Size([Batch, 2, 480, 640]) => オプティカルフローデータのバッチ
    Key: flow_gt_valid_mask, Type: torch.Tensor, Shape: torch.Size([Batch, 1, 480, 640]) => オプティカルフローデータのvalid. ベースラインでは使わない

test data:
    Type of batch: Dict
    Key: seq_name, Type: list
    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ
'''

'\ntrain data:\n    Type of batch: Dict\n    Key: seq_name, Type: list\n    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ\n    Key: flow_gt, Type: torch.Tensor, Shape: torch.Size([Batch, 2, 480, 640]) => オプティカルフローデータのバッチ\n    Key: flow_gt_valid_mask, Type: torch.Tensor, Shape: torch.Size([Batch, 1, 480, 640]) => オプティカルフローデータのvalid. ベースラインでは使わない\n\ntest data:\n    Type of batch: Dict\n    Key: seq_name, Type: list\n    Key: event_volume, Type: torch.Tensor, Shape: torch.Size([Batch, 4, 480, 640]) => イベントデータのバッチ\n'

## Training

In [None]:
# ------------------
#   optimizer
# ------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00002, weight_decay=args.wdecay, eps=args.epsilon)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.9)

# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps + 100,
#                                                 pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps + 100,
#                                                 pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')

loss_fn = TotalLoss(smoothness_weight=0.5)

In [None]:
# num_epochs = args.train.epochs
num_epochs = 1

epe_losses = [[] for _ in range(num_epochs)]
overall_losses = [[] for _ in range(num_epochs)]

In [27]:
BATCH_CONCAT = args.snippet_len

In [None]:
def save_model(model, additional_string: str):
    current_time = time.strftime("%Y%m%d-%H%M%S")
    model_path = f"model_{current_time}_{additional_string}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")

In [None]:
# ------------------
#   Start training
# ------------------
model.train()

for epoch in range(num_epochs):

    total_loss = 0
    prev_event_volumes = [torch.zeros([args.batch_size, 4, 480, 640])] * BATCH_CONCAT # Acts as a queue

    print("on epoch: {}".format(epoch + 1))
    for i, batch in enumerate(tqdm(train_data)):

        try:
            batch: Dict[str, Any]

            event_image = batch["event_volume"].to(device) # [B, 3, 480, 640]
            ground_truth_flow = batch["flow_gt"].to(device) # [B, 2, 480, 640]

            prev_event_volumes.append(event_image)

            prev_ground_truth_flow = batch['prev_flow_gt'].to(device) # [B, 2, 480, 640]
            next_ground_truth_flow = batch['next_flow_gt'].to(device) # [B, 2, 480, 640]

            input_tensor = torch.stack(prev_event_volumes[-BATCH_CONCAT:], dim=1)
            _, flows = model(input_tensor) # [B, 3, 480, 640]

            # Overall loss requires flow0, ..., flow3 so we don't implement it here
            # What if you created flow_dict from flow0, ..., flow11 (n=12) to and then use overall loss?

            for j, flow in enumerate(flows):
                print(f'batch {i} | flow #{j + 1} | EPE LOSS:', compute_epe_error(flow, ground_truth_flow).item())

            avg_flow = torch.mean(torch.stack(flows, dim=0), dim=0)
            epe_loss: torch.Tensor = compute_epe_error(avg_flow, ground_truth_flow)
#             overall_loss: loss_fn(flow_dict,
#                                  prev_ground_truth_flow,
#                                  next_ground_truth_flow,
#                                  model)

            print(f"batch {i} average EPE LOSS: {epe_loss.item()}")
            epe_losses[epoch].append(epe_loss.item())

#             print(f'batch {i} OVERALL LOSS: {loss_fn()}')

            optimizer.zero_grad()

            epe_loss.backward() # Change this to which loss function is to be updated
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += epe_loss.item() # This too

            if len(prev_event_volumes) >= BATCH_CONCAT:
                prev_event_volumes.pop(0) # Remove first element

            if (i + 1) % 10 == 0:
                save_model(model, f'epoch{i + 1}')

        except KeyboardInterrupt:
            save_model(model)
            raise SystemExit("KeyboardInterrupt")

    scheduler.step()

    print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_data)}')

on epoch: 1


  0%|          | 0/63 [00:00<?, ?it/s]

batch 0 | flow #1 | EPE LOSS: 1.8396062090700058
batch 0 | flow #2 | EPE LOSS: 1.8395508744573315
batch 0 average EPE LOSS: 1.8395785293693376


  2%|▏         | 1/63 [02:30<2:35:34, 150.56s/it]

batch 1 | flow #1 | EPE LOSS: 1.3619467749821583
batch 1 | flow #2 | EPE LOSS: 1.3619275544724172
batch 1 average EPE LOSS: 1.3619371590125076


  3%|▎         | 2/63 [03:58<1:55:38, 113.75s/it]

batch 2 | flow #1 | EPE LOSS: 7.162377360312281
batch 2 | flow #2 | EPE LOSS: 7.162346349042931
batch 2 average EPE LOSS: 7.162361847920057


  5%|▍         | 3/63 [05:27<1:42:38, 102.64s/it]

batch 3 | flow #1 | EPE LOSS: 6.344814803053367
batch 3 | flow #2 | EPE LOSS: 6.344752112039278
batch 3 average EPE LOSS: 6.3447834534162215


  6%|▋         | 4/63 [06:56<1:35:26, 97.06s/it] 

batch 4 | flow #1 | EPE LOSS: 3.7186164226350416
batch 4 | flow #2 | EPE LOSS: 3.7185795124431653
batch 4 average EPE LOSS: 3.7185979638303523


  8%|▊         | 5/63 [08:24<1:30:39, 93.79s/it]

batch 5 | flow #1 | EPE LOSS: 2.825004593217984
batch 5 | flow #2 | EPE LOSS: 2.8249310917934443
batch 5 average EPE LOSS: 2.824967837594634


 10%|▉         | 6/63 [09:53<1:27:28, 92.09s/it]

batch 6 | flow #1 | EPE LOSS: 2.29672946530145
batch 6 | flow #2 | EPE LOSS: 2.296680076328348
batch 6 average EPE LOSS: 2.2967047661321702


 11%|█         | 7/63 [11:22<1:25:02, 91.12s/it]

batch 7 | flow #1 | EPE LOSS: 2.873589348851873
batch 7 | flow #2 | EPE LOSS: 2.8735631131565
batch 7 average EPE LOSS: 2.8735762265105746


 13%|█▎        | 8/63 [12:50<1:22:36, 90.12s/it]

batch 8 | flow #1 | EPE LOSS: 3.0382382613548553
batch 8 | flow #2 | EPE LOSS: 3.038146791772329
batch 8 average EPE LOSS: 3.0381925215022534


 14%|█▍        | 9/63 [14:19<1:20:44, 89.71s/it]

batch 9 | flow #1 | EPE LOSS: 2.7668836037184334
batch 9 | flow #2 | EPE LOSS: 2.7667894609405552
batch 9 average EPE LOSS: 2.7668365219706255


 16%|█▌        | 10/63 [15:48<1:19:16, 89.74s/it]

Model saved to model_20240716-124615_epoch10.pth
batch 10 | flow #1 | EPE LOSS: 2.4285279740776127
batch 10 | flow #2 | EPE LOSS: 2.428396687670258
batch 10 average EPE LOSS: 2.428462301353687


 17%|█▋        | 11/63 [17:17<1:17:35, 89.52s/it]

batch 11 | flow #1 | EPE LOSS: 2.4646249365988258
batch 11 | flow #2 | EPE LOSS: 2.464461195391217
batch 11 average EPE LOSS: 2.4645429721152787


 19%|█▉        | 12/63 [18:47<1:16:03, 89.49s/it]

batch 12 | flow #1 | EPE LOSS: 2.6449084530256943
batch 12 | flow #2 | EPE LOSS: 2.6447026475947277
batch 12 average EPE LOSS: 2.6448049849686837


 21%|██        | 13/63 [20:15<1:14:17, 89.15s/it]

batch 13 | flow #1 | EPE LOSS: 2.5406600661570344
batch 13 | flow #2 | EPE LOSS: 2.540498659477098
batch 13 average EPE LOSS: 2.5405734615211246


 22%|██▏       | 14/63 [21:44<1:12:49, 89.17s/it]

batch 14 | flow #1 | EPE LOSS: 2.845440700293924
batch 14 | flow #2 | EPE LOSS: 2.845453338795262
batch 14 average EPE LOSS: 2.8454354812949845


 24%|██▍       | 15/63 [23:13<1:11:11, 88.99s/it]

batch 15 | flow #1 | EPE LOSS: 2.4299786496068334
batch 15 | flow #2 | EPE LOSS: 2.4303375249572765
batch 15 average EPE LOSS: 2.4301547064369675


 25%|██▌       | 16/63 [24:42<1:09:36, 88.86s/it]

batch 16 | flow #1 | EPE LOSS: 2.246387872758973
batch 16 | flow #2 | EPE LOSS: 2.2467333557478177
batch 16 average EPE LOSS: 2.246559123212028


 27%|██▋       | 17/63 [26:10<1:08:05, 88.82s/it]

batch 17 | flow #1 | EPE LOSS: 2.5308735030063403
batch 17 | flow #2 | EPE LOSS: 2.5312205530276684
batch 17 average EPE LOSS: 2.5310459425780794


 29%|██▊       | 18/63 [27:40<1:06:54, 89.22s/it]

batch 18 | flow #1 | EPE LOSS: 2.6085754483270445
batch 18 | flow #2 | EPE LOSS: 2.608953130093144
batch 18 average EPE LOSS: 2.608763362078948


 30%|███       | 19/63 [29:11<1:05:41, 89.59s/it]

batch 19 | flow #1 | EPE LOSS: 2.7964723043875646
batch 19 | flow #2 | EPE LOSS: 2.796720692607935
batch 19 average EPE LOSS: 2.796595274667605


 32%|███▏      | 20/63 [30:41<1:04:17, 89.72s/it]

Model saved to model_20240716-130107_epoch20.pth
batch 20 | flow #1 | EPE LOSS: 2.658587000239665
batch 20 | flow #2 | EPE LOSS: 2.658736168698733
batch 20 average EPE LOSS: 2.6586607164649547


 33%|███▎      | 21/63 [32:10<1:02:40, 89.54s/it]

batch 21 | flow #1 | EPE LOSS: 2.1000249757144136
batch 21 | flow #2 | EPE LOSS: 2.1002681021195277
batch 21 average EPE LOSS: 2.1001455203687085


 35%|███▍      | 22/63 [33:39<1:01:02, 89.33s/it]

batch 22 | flow #1 | EPE LOSS: 2.487792712297309
batch 22 | flow #2 | EPE LOSS: 2.488061061618193
batch 22 average EPE LOSS: 2.4879255761767896


 37%|███▋      | 23/63 [35:07<59:14, 88.86s/it]  

batch 23 | flow #1 | EPE LOSS: 2.459311224791461
batch 23 | flow #2 | EPE LOSS: 2.4594535534925623
batch 23 average EPE LOSS: 2.4593815602586395


 38%|███▊      | 24/63 [36:37<57:58, 89.19s/it]

batch 24 | flow #1 | EPE LOSS: 2.236986798157563
batch 24 | flow #2 | EPE LOSS: 2.2371097897588332
batch 24 average EPE LOSS: 2.2370468408859074


 40%|███▉      | 25/63 [38:05<56:24, 89.06s/it]

batch 25 | flow #1 | EPE LOSS: 2.663109831764366
batch 25 | flow #2 | EPE LOSS: 2.6632135171499245
batch 25 average EPE LOSS: 2.6631576587533967


 41%|████▏     | 26/63 [39:35<55:00, 89.19s/it]

batch 26 | flow #1 | EPE LOSS: 2.1505382074575206
batch 26 | flow #2 | EPE LOSS: 2.150443526728732
batch 26 average EPE LOSS: 2.1504873017924155


 43%|████▎     | 27/63 [41:04<53:31, 89.21s/it]

batch 27 | flow #1 | EPE LOSS: 3.2348084167526894
batch 27 | flow #2 | EPE LOSS: 3.2346694565968526
batch 27 average EPE LOSS: 3.234738607624671


 44%|████▍     | 28/63 [42:34<52:06, 89.32s/it]

batch 28 | flow #1 | EPE LOSS: 6.961979225804903
batch 28 | flow #2 | EPE LOSS: 6.961799806928618
batch 28 average EPE LOSS: 6.961889392112606


 46%|████▌     | 29/63 [44:03<50:39, 89.40s/it]

batch 29 | flow #1 | EPE LOSS: 2.4598365694186803
batch 29 | flow #2 | EPE LOSS: 2.4597832064928298
batch 29 average EPE LOSS: 2.459809769515961


 48%|████▊     | 30/63 [45:34<49:19, 89.70s/it]

Model saved to model_20240716-131600_epoch30.pth
batch 30 | flow #1 | EPE LOSS: 2.2468038624075533
batch 30 | flow #2 | EPE LOSS: 2.246761215456463
batch 30 average EPE LOSS: 2.2467823458239353


 49%|████▉     | 31/63 [47:03<47:47, 89.61s/it]

batch 31 | flow #1 | EPE LOSS: 2.4344783179527894
batch 31 | flow #2 | EPE LOSS: 2.4344390852540503
batch 31 average EPE LOSS: 2.434458374238545


In [None]:
save_model(model)

In [None]:
import matplotlib.pyplot as plt

# epe_losses = list(map(lambda x: x.item(), epe_losses[0]))
# overall_losses = list(map(lambda x: x.item(), overall_losses[0]))

plt.figure(figsize=(16, 9))
len_x = min(len(epe_losses), len(overall_losses))

plt.plot(epe_losses[:len_x])
plt.plot(overall_losses[:len_x])

plt.xlabel('Batch Number')
plt.ylabel('Loss')

plt.grid()
plt.legend()

plt.show()

In [None]:
import time
# Create the directory if it doesn't exist
# if not os.path.exists('checkpoints'):
#     os.makedirs('checkpoints')

current_time = time.strftime("%Y%m%d-%H%M%S")
model_path = f"models/pclnet_model_{current_time}_epoch1.pth"
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

In [35]:
model.eval()
flow: torch.Tensor = torch.tensor([]).to(device)

prev_event_volumes = [torch.zeros([1, 4, 480, 640])] * BATCH_CONCAT # Acts as a queue

with torch.no_grad():
    print("start test")
    for batch in tqdm(test_data):
        batch: Dict[str, Any]

        event_image = batch["event_volume"].to(device)
        prev_event_volumes.append(event_image)

        input_tensor = torch.stack(prev_event_volumes[-BATCH_CONCAT:], dim=1)
        _, flows = model(input_tensor)

        batch_flow = torch.mean(torch.stack(flows, dim=0), dim=0)
        flow = torch.cat((flow, batch_flow), dim=0)  # [N, 2, 480, 640]

        if len(prev_event_volumes) >= BATCH_CONCAT:
            prev_event_volumes.pop(0)

    print("test done")

start test


100%|██████████| 97/97 [17:41<00:00, 10.95s/it]

test done





In [36]:
# ------------------
#  save submission
# ------------------
current_time = time.strftime("%Y%m%d-%H%M%S")
save_optical_flow_to_npy(flow, f'submissions/submission_pclnet_{current_time}')