In [5]:
import torch
import torchvision as tv
import torch.nn as nn
import os
import numpy as np
import pickle

In [6]:
class LayoutModel(torch.nn.Module):
    def __init__(self, tree_path):
        super(LayoutModel, self).__init__()

        with open(tree_path, 'rb') as f:
            # load some decision tree
            self.tree = pickle.load(f)
            z = 1

    def extract_features(self, image_tensor):
        """
        X.shape = (time, number_of_channels=3, height, width)
        """
        n_frames, _, height, width = image_tensor.shape

        # Analyze Region 1
        # get part of image between (0.25 - 0.41) width
        start_width = 3 * width // 12
        end_width = 5 * width // 12
        frame_cut = image_tensor[:, :, :, start_width:end_width].float()

        region_1_col_dist = self.get_mean_column_wise_diff_for_columns(frame_cut)

        # Analyze Region 2
        # get part of image between (0.33 - 0.66) width
        start_width = 4 * width // 12
        end_width = 8 * width // 12
        frame_cut = image_tensor[:, :, :, start_width:end_width].float()

        region_2_col_dist = self.get_mean_column_wise_diff_for_columns(frame_cut)

        # Analyze Region 3
        # width: (0 - 0.33)
        # height: (0.33 - 0.66)
        start_height = 4 * height // 12
        end_height = 8 * height // 12
        end_width = width // 3
        frame_cut = image_tensor[:, :, start_height:end_height, :end_width].float()

        region_3_row_dist = self.get_mean_row_wise_diff_for_rows(frame_cut)

        # Returns
        region_1_max_col_dist, region_1_max_col_index = region_1_col_dist.max(1, keepdim=True)
        # some normalization effort (?)
        region_1_max_col_dist /= 255
        # reverse frame cutting
        region_1_max_col_position = region_1_max_col_index.float() / region_1_col_dist.shape[1] / 6 + 1 / 4

        region_2_max_col_dist, region_2_max_col_index = region_2_col_dist.max(1, keepdim=True)
        region_2_max_col_dist /= 255
        region_2_max_col_position = region_2_max_col_index.float() / region_2_col_dist.shape[1] / 3 + 1 / 3

        region_3_max_row_dist, region_3_max_row_index = region_3_row_dist.max(1, keepdim=True)
        region_3_max_row_dist /= 255
        region_3_max_row_position = region_3_max_row_index.float() / region_3_row_dist.shape[1] / 3 + 1 / 3

        max_dists = torch.cat([region_1_max_col_dist, region_2_max_col_dist, region_3_max_row_dist], dim=1)
        positions = torch.cat([region_1_max_col_position, region_2_max_col_position, region_3_max_row_position], dim=1)
        return max_dists, positions


    def get_mean_column_wise_diff_for_columns(self, frame):
        """
        The values in returned tensor are between 0 and 442.
        Lowest distance:
        - all pixels in the column N-1 and N+1 have the same value on 3 channels
        - the column wise dist for each pixel is 0
        - the mean column wise dist for each column is 0

        Highest distance:
        - all pixels in the column N-1 and N+1 have opposite value on all 3 channels, i.e. one has 255 and the other 0
        - the column wise dist for each pixel is sqrt(255^2 + 255^2 + 255^2) = ~441
        - the mean column wise dist for each column is ~441

        :param frame: shape = (time, number_of_channels=3, height, width)
        :return: tensor (time, width) with mean column wise distance for all pixels in this column in frame
        """
        frame_column_d = self.calculate_column_wise_dist(frame)
        frame_row_d = self.calculate_row_wise_dist(frame)

        frame_width = frame.shape[3]
        # vector that gives higher weights to pixel columns that are near the middle of frame
        # e.g. [0, 0.2, 0.4, 0.5, 0.4, 0.2, 0.1]
        prioritized_dist = -torch.linspace(start=-1, end=1, steps=frame_width - 2) ** 2 + 1

        # pixels with higher column dist will have value 1, others 0
        pixels_column_dist_heatmap = frame_column_d * (frame_column_d > 10 * frame_row_d)
        # each pixel will have the mean value for whole column
        # new shape = (time, width)
        pixels_mean_column_dist = pixels_column_dist_heatmap.mean(1)

        return pixels_mean_column_dist * prioritized_dist

    def get_mean_row_wise_diff_for_rows(self, frame):
        """

        :param frame: shape = (time, number_of_channels=3, height, width)
        :return: tensor (time, height) with mean row wise distance for all pixels in this row in frame
        """
        frame_column_d = self.calculate_column_wise_dist(frame)
        frame_row_d = self.calculate_row_wise_dist(frame)

        frame_height = frame.shape[2]
        # vector that gives higher weights to pixel columns that are near the middle of frame
        # e.g. [0, 0.2, 0.4, 0.5, 0.4, 0.2, 0.1]
        prioritized_dist = -torch.linspace(start=-1, end=1, steps=frame_height - 2) ** 2 + 1

        # pixels with higher row dist will have value 1, others 0
        pixels_column_dist_heatmap = frame_row_d * (frame_row_d > 10 * frame_column_d)
        # each pixel will have the mean value for whole row
        # new shape = (time, height)
        pixels_mean_column_dist = pixels_column_dist_heatmap.mean(2)

        return pixels_mean_column_dist * prioritized_dist

    @staticmethod
    def calculate_column_wise_dist(frame_tensor):
        """
        - for each pixel in column N we want to calculate the distance between adjacent pixels N-1 and N+1
        - each pixel has 3 "coordinates" because of 3 channels: RGB
        - we will use Euclidean distance

        a - pixel N-1
        b - pixel N+1
        d(a,b) = sqrt( (a_r - b_r)^2 + (a_g - b_g)^2 + (a_b - b_b)^2 )

        NOTE: we skip 1st and last pixels column - WHY (?)
        :param frame_tensor: shape = (time, number_of_channels=3, height, width)
        :return: vertical difference tensor, shape = (time, height, width)
        """
        vertical_difference_between_pixels = frame_tensor[:, :, 1:-1, :-2] - frame_tensor[:, :, 1:-1, 2:]
        vertical_distance_between_pixels = vertical_difference_between_pixels.pow(2).sum(1).sqrt()

        return vertical_distance_between_pixels

    @staticmethod
    def calculate_row_wise_dist(frame_tensor):
        """
        :param frame_tensor: shape = (time, number_of_channels=3, height, width)
        :return: horizontal difference tensor, shape = (time, height, width)
        """
        horizontal_difference_between_pixels = frame_tensor[:, :, :-2, 1:-1] - frame_tensor[:, :, 2:, 1:-1]
        horizontal_distance_between_pixels = horizontal_difference_between_pixels.pow(2).sum(1).sqrt()

        return horizontal_distance_between_pixels

    def __call__(self, X):
        V, A = self.extract_features(X)

        return torch.Tensor(self.tree.predict((V + 0.00001).log().numpy())), A