In [81]:
import numpy as np
import tensorflow as tf
from itertools import chain
from time import time
rng = np.random.default_rng()

<h3>Kinetics-i3d video action classifier</h3>

See https://github.com/deepmind/kinetics-i3d

In [2]:
# after i3d.py

pad_same = "SAME"
pad_valid = "VALID"

class Unit3D(tf.Module):
    #Basic unit containing Conv3D + BatchNorm + non-linearity.

    def __init__(self, output_channels,
                   kernel_shape = [1,1,1],
                   stride=[1, 1, 1],
                   activation_fn=tf.nn.relu,
                   use_batch_norm=True,
                   use_bias=False,
                   name='unit_3d'):
        super(Unit3D, self).__init__()
        self._output_channels = output_channels
        self._kernel_shape = kernel_shape
        self._padding = pad_same
        self._stride = [1] + stride + [1]
        self._use_batch_norm = use_batch_norm
        self._activation_fn = activation_fn
        self._use_bias = use_bias
        self._name = name

        # vardict refers to a global dictionary of tf.Variables loaded from the file containing the weights
        if self._use_batch_norm:
            self.bn_beta = vardict[self._name + "/batch_norm/beta"]
            self.bn_moving_mean = vardict[self._name + "/batch_norm/moving_mean"]
            self.bn_moving_variance = vardict[self._name + "/batch_norm/moving_variance"]
        self.conv_w = vardict[self._name + "/conv_3d/w"]
        if self._use_bias:
            self.conv_b = vardict[self._name+"/conv_3d/b"]
        
    def __call__(self, inputs, is_training):
         # input shape is [batch, depth, height, width, channels]
        net = tf.nn.conv3d(inputs, filters=self.conv_w, strides=self._stride, padding=self._padding)
        if self._use_bias:
            net = tf.nn.bias_add(net, self.conv_b)
        if self._use_batch_norm:
            net = tf.nn.batch_normalization(net, 
                                            self.bn_moving_mean, 
                                            self.bn_moving_variance, 
                                            self.bn_beta, 
                                            scale=1, 
                                            variance_epsilon=0.01)
        if self._activation_fn is not None:
            net = self._activation_fn(net)
        return net

class InceptionI3d(tf.Module):
#  """Inception-v1 I3D architecture.

#  The model is introduced in:

#    Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
#    Joao Carreira, Andrew Zisserman
#    https://arxiv.org/pdf/1705.07750v1.pdf.

#  See also the Inception architecture, introduced in:

#    Going deeper with convolutions
#    Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
#    Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
#    http://arxiv.org/pdf/1409.4842v1.pdf.
#  """

  # Endpoints of the model in order. During construction, all the endpoints up
  # to a designated `final_endpoint` are returned in a dictionary as the
  # second return value.
    VALID_ENDPOINTS = (
      'Conv3d_1a_7x7',
      'MaxPool3d_2a_3x3',
      'Conv3d_2b_1x1',
      'Conv3d_2c_3x3',
      'MaxPool3d_3a_3x3',
      'Mixed_3b',
      'Mixed_3c',
      'MaxPool3d_4a_3x3',
      'Mixed_4b',
      'Mixed_4c',
      'Mixed_4d',
      'Mixed_4e',
      'Mixed_4f',
      'MaxPool3d_5a_2x2',
      'Mixed_5b',
      'Mixed_5c',
      'Logits',
      'Predictions',
  )
    
    # In the paper referenced above, notations are made of the receptive field after each pooling layer, i.e. the
    # size of the input data that each of its outputs depends on. They are (time is the first dim listed):
    #   MaxPool3d_2a_3x3   7x11x11
    #   MaxPool3d_3a_3x3   11x27x27
    #   MaxPool3d_4a_3x3   23x75x75
    #   MaxPool3d_5a_2x2   59x219x219
    #   AvgPool3d_2x7x7    99x539x539
    # The AvgPool layer is immediately prior to the logits (which are linear combinations of its outputs).
    # Since the net uses only convolutional layers the dimensions of its input are not fixed (they can even vary
    # between calls, since the net calls tf.nn.conv3d directly rather than keras.layers.Conv3D).
    

    def __init__(self, var_prefix='RGB', num_classes=400, spatial_squeeze=True,
               final_endpoint='Logits', name='inception_i3d'):
#    """Initializes I3D model instance.

#    Args:
#      num_classes: The number of outputs in the logit layer (default 400, which
#          matches the Kinetics dataset).
#      spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
#          before returning (default True).
#      final_endpoint: The model contains many possible endpoints.
#          `final_endpoint` specifies the last endpoint for the model to be built
#          up to. In addition to the output at `final_endpoint`, all the outputs
#          at endpoints up to `final_endpoint` will also be returned, in a
#          dictionary. `final_endpoint` must be one of
#          InceptionI3d.VALID_ENDPOINTS (default 'Logits').
#      name: A string (optional). The name of this module.
#    Raises:
#      ValueError: if `final_endpoint` is not recognized.
#    """

        if final_endpoint not in self.VALID_ENDPOINTS:
            raise ValueError('Unknown final endpoint %s' % final_endpoint)

        super(InceptionI3d, self).__init__(name=name)
        self._num_classes = num_classes
        self._spatial_squeeze = spatial_squeeze
        self._final_endpoint = final_endpoint
        self._var_prefix = var_prefix

        # except for the first and last entries here all this (output channels and kernel shape) is already implicit 
        # in the weights passed to the modules. the important part is the correspondence between modules and the 
        # names of variables contained in the checkpoint data
        arg_dict = {'Conv3d_1a_7x7' : {'output_channels': 64, 'kernel_shape' : [7,7,7], 'stride' : [2,2,2]}, 
                     'Conv3d_2b_1x1' : {'output_channels' : 64, 'kernel_shape' : [1, 1, 1]},
                     'Conv3d_2c_3x3' : {'output_channels' : 192, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_3b/Branch_0/Conv3d_0a_1x1' : {'output_channels' : 64, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_3b/Branch_1/Conv3d_0a_1x1' : {'output_channels' : 96, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_3b/Branch_1/Conv3d_0b_3x3' : {'output_channels' : 128, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_3b/Branch_2/Conv3d_0a_1x1' : {'output_channels' : 16, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_3b/Branch_2/Conv3d_0b_3x3' : {'output_channels' : 32, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_3b/Branch_3/Conv3d_0b_1x1' : {'output_channels' : 32, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_3c/Branch_0/Conv3d_0a_1x1' : {'output_channels' : 128, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_3c/Branch_1/Conv3d_0a_1x1' : {'output_channels' : 128, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_3c/Branch_1/Conv3d_0b_3x3' : {'output_channels' : 192, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_3c/Branch_2/Conv3d_0a_1x1' : {'output_channels' : 32, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_3c/Branch_2/Conv3d_0b_3x3' : {'output_channels' : 96, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_3c/Branch_3/Conv3d_0b_1x1' : {'output_channels' : 64, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4b/Branch_0/Conv3d_0a_1x1' : {'output_channels' : 192, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4b/Branch_1/Conv3d_0a_1x1' : {'output_channels' : 96, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4b/Branch_1/Conv3d_0b_3x3' : {'output_channels' : 208, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4b/Branch_2/Conv3d_0a_1x1' : {'output_channels' : 16, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4b/Branch_2/Conv3d_0b_3x3' : {'output_channels' : 48, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4b/Branch_3/Conv3d_0b_1x1' : {'output_channels' : 64, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4c/Branch_0/Conv3d_0a_1x1' : {'output_channels' : 160, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4c/Branch_1/Conv3d_0a_1x1' : {'output_channels' : 112, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4c/Branch_1/Conv3d_0b_3x3' : {'output_channels' : 224, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4c/Branch_2/Conv3d_0a_1x1' : {'output_channels' : 24, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4c/Branch_2/Conv3d_0b_3x3' : {'output_channels' : 64, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4c/Branch_3/Conv3d_0b_1x1' : {'output_channels' : 64, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4d/Branch_0/Conv3d_0a_1x1' : {'output_channels' : 128, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4d/Branch_1/Conv3d_0a_1x1' : {'output_channels' : 128, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4d/Branch_1/Conv3d_0b_3x3' : {'output_channels' : 256, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4d/Branch_2/Conv3d_0a_1x1' : {'output_channels' : 24, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4d/Branch_2/Conv3d_0b_3x3' : {'output_channels' : 64, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4d/Branch_3/Conv3d_0b_1x1' : {'output_channels' : 64, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4e/Branch_0/Conv3d_0a_1x1' : {'output_channels' : 112, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4e/Branch_1/Conv3d_0a_1x1' : {'output_channels' : 144, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4e/Branch_1/Conv3d_0b_3x3' : {'output_channels' : 288, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4e/Branch_2/Conv3d_0a_1x1' : {'output_channels' : 32, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4e/Branch_2/Conv3d_0b_3x3' : {'output_channels' : 64, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4e/Branch_3/Conv3d_0b_1x1' : {'output_channels' : 64, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4f/Branch_0/Conv3d_0a_1x1' : {'output_channels' : 256, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4f/Branch_1/Conv3d_0a_1x1' : {'output_channels' : 160, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4f/Branch_1/Conv3d_0b_3x3' : {'output_channels' : 320, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4f/Branch_2/Conv3d_0a_1x1' : {'output_channels' : 32, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_4f/Branch_2/Conv3d_0b_3x3' : {'output_channels' : 128, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_4f/Branch_3/Conv3d_0b_1x1' : {'output_channels' : 128, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_5b/Branch_0/Conv3d_0a_1x1' : {'output_channels' : 256, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_5b/Branch_1/Conv3d_0a_1x1' : {'output_channels' : 160, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_5b/Branch_1/Conv3d_0b_3x3' : {'output_channels' : 320, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_5b/Branch_2/Conv3d_0a_1x1' : {'output_channels' : 32, 'kernel_shape' : [1, 1, 1]},
                    # typo here: in other modules the name is Branch2/Conv3d_0b_3x3 !
                     'Mixed_5b/Branch_2/Conv3d_0a_3x3' : {'output_channels' : 128, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_5b/Branch_3/Conv3d_0b_1x1' : {'output_channels' : 128, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_5c/Branch_0/Conv3d_0a_1x1' : {'output_channels' : 384, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_5c/Branch_1/Conv3d_0a_1x1' : {'output_channels' : 192, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_5c/Branch_1/Conv3d_0b_3x3' : {'output_channels' : 384, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_5c/Branch_2/Conv3d_0a_1x1' : {'output_channels' : 48, 'kernel_shape' : [1, 1, 1]},
                     'Mixed_5c/Branch_2/Conv3d_0b_3x3' : {'output_channels' : 128, 'kernel_shape' : [3, 3, 3]},
                     'Mixed_5c/Branch_3/Conv3d_0b_1x1' : {'output_channels' : 128, 'kernel_shape' : [1, 1, 1]},
                     'Logits/Conv3d_0c_1x1' : {'output_channels' : self._num_classes, 'kernel_shape' : [1, 1, 1],
                                              'activation_fn' : None, 'use_batch_norm' : False, 'use_bias' : True}
                    }
        
        self.module_dict = {}
        model_prefix = self._var_prefix + "/" + self.name + "/" 
        for module_name in list(arg_dict.keys()):
            self.module_dict[module_name] = Unit3D(**arg_dict[module_name],
                                              name = model_prefix+module_name)

    def __call__(self, inputs, is_training = False, dropout_prob=0.0):
#    """Connects the model to inputs.

#    Args:
#      inputs: Inputs to the model, which should have dimensions
#          `batch_size` x `num_frames` x 224 x 224 x `num_channels`.
#      is_training: whether to use training mode for snt.BatchNorm (boolean).
#      dropout_prob: Probability for the tf.nn.dropout layer (float in
#          [0, 1)).

#    Returns:
#      A tuple consisting of:
#        1. Network output at location `self._final_endpoint`.
#        2. Dictionary containing all endpoints up to `self._final_endpoint`,
#           indexed by endpoint name.

#    Raises:
#      ValueError: if `self._final_endpoint` is not recognized.
#    """
        if self._final_endpoint not in self.VALID_ENDPOINTS:
            raise ValueError('Unknown final endpoint %s' % self._final_endpoint)

        net = inputs
        end_points = {}
        end_point = 'Conv3d_1a_7x7'        
        net = self.module_dict[end_point](net, is_training=is_training)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points
        
        end_point = 'MaxPool3d_2a_3x3'
        net = tf.nn.max_pool3d(net, ksize=[1, 1, 3, 3, 1], strides=[1, 1, 2, 2, 1],
                               padding=pad_same, name=end_point)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points
        end_point = 'Conv3d_2b_1x1'
        net = self.module_dict[end_point](net, is_training=is_training)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points
        end_point = 'Conv3d_2c_3x3'
        net = self.module_dict[end_point](net, is_training=is_training)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points
        end_point = 'MaxPool3d_3a_3x3'
        net = tf.nn.max_pool3d(net, ksize=[1, 1, 3, 3, 1], strides=[1, 1, 2, 2, 1],
                               padding=pad_same, name=end_point)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Mixed_3b'
        branch_0 = self.module_dict[end_point+'/Branch_0/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0b_3x3'](branch_1, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0b_3x3'](branch_2, is_training=is_training)
        branch_3 = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1],
                                            strides=[1, 1, 1, 1, 1], padding=pad_same)
        branch_3 = self.module_dict[end_point+'/Branch_3/Conv3d_0b_1x1'](branch_3, is_training=is_training)

        net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Mixed_3c'
        branch_0 = self.module_dict[end_point+'/Branch_0/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0b_3x3'](branch_1, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0b_3x3'](branch_2, is_training=is_training)
        branch_3 = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1],
                                            strides=[1, 1, 1, 1, 1], padding=pad_same)
        branch_3 = self.module_dict[end_point+'/Branch_3/Conv3d_0b_1x1'](branch_3, is_training=is_training)
        net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'MaxPool3d_4a_3x3'
        # modified: was time stride = 2
        net = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1], strides=[1, 1, 2, 2, 1],
                               padding=pad_same, name=end_point)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Mixed_4b'
        branch_0 = self.module_dict[end_point+'/Branch_0/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0b_3x3'](branch_1, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0b_3x3'](branch_2, is_training=is_training)
        branch_3 = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1],
                                            strides=[1, 1, 1, 1, 1], padding=pad_same)
        branch_3 = self.module_dict[end_point+'/Branch_3/Conv3d_0b_1x1'](branch_3, is_training=is_training)
        net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Mixed_4c'
        branch_0 = self.module_dict[end_point+'/Branch_0/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0b_3x3'](branch_1, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0b_3x3'](branch_2, is_training=is_training)
        branch_3 = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1],
                                            strides=[1, 1, 1, 1, 1], padding=pad_same)
        branch_3 = self.module_dict[end_point+'/Branch_3/Conv3d_0b_1x1'](branch_3, is_training=is_training)
        net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Mixed_4d'
        branch_0 = self.module_dict[end_point+'/Branch_0/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0b_3x3'](branch_1, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0b_3x3'](branch_2, is_training=is_training)
        branch_3 = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1],
                                            strides=[1, 1, 1, 1, 1], padding=pad_same)
        branch_3 = self.module_dict[end_point+'/Branch_3/Conv3d_0b_1x1'](branch_3, is_training=is_training)
        net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Mixed_4e'
        branch_0 = self.module_dict[end_point+'/Branch_0/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0b_3x3'](branch_1, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0b_3x3'](branch_2, is_training=is_training)
        branch_3 = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1],
                                            strides=[1, 1, 1, 1, 1], padding=pad_same)
        branch_3 = self.module_dict[end_point+'/Branch_3/Conv3d_0b_1x1'](branch_3, is_training=is_training)
        net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Mixed_4f'
        branch_0 = self.module_dict[end_point+'/Branch_0/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0b_3x3'](branch_1, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0b_3x3'](branch_2, is_training=is_training)
        branch_3 = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1],
                                            strides=[1, 1, 1, 1, 1], padding=pad_same)
        branch_3 = self.module_dict[end_point+'/Branch_3/Conv3d_0b_1x1'](branch_3, is_training=is_training)
        net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'MaxPool3d_5a_2x2'
        net = tf.nn.max_pool3d(net, ksize=[1, 2, 2, 2, 1], strides=[1, 1, 2, 2, 1],
                               padding=pad_same, name=end_point)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Mixed_5b'
        branch_0 = self.module_dict[end_point+'/Branch_0/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0b_3x3'](branch_1, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_1x1'](net, is_training=is_training)
        # typo here: in other modules the name is Branch2/Conv3d_0b_3x3 !
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_3x3'](branch_2, is_training=is_training)
        branch_3 = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1],
                                            strides=[1, 1, 1, 1, 1], padding=pad_same)
        branch_3 = self.module_dict[end_point+'/Branch_3/Conv3d_0b_1x1'](branch_3, is_training=is_training)
        net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Mixed_5c'
        branch_0 = self.module_dict[end_point+'/Branch_0/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_1 = self.module_dict[end_point+'/Branch_1/Conv3d_0b_3x3'](branch_1, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0a_1x1'](net, is_training=is_training)
        branch_2 = self.module_dict[end_point+'/Branch_2/Conv3d_0b_3x3'](branch_2, is_training=is_training)
        branch_3 = tf.nn.max_pool3d(net, ksize=[1, 3, 3, 3, 1],
                                            strides=[1, 1, 1, 1, 1], padding=pad_same)
        branch_3 = self.module_dict[end_point+'/Branch_3/Conv3d_0b_1x1'](branch_3, is_training=is_training)
        net = tf.concat([branch_0, branch_1, branch_2, branch_3], 4)
        end_points[end_point] = net
        if self._final_endpoint == end_point: return net, end_points

        end_point = 'Logits'
        net = tf.nn.avg_pool3d(net, ksize=[1, 2, 7, 7, 1],
                                 strides=[1, 1, 1, 1, 1], padding=pad_valid)
        net = tf.nn.dropout(net, dropout_prob)
        logits = self.module_dict[end_point+'/Conv3d_0c_1x1'](net, is_training=is_training)
        if self._spatial_squeeze:
            logits = tf.squeeze(logits, [2, 3], name='SpatialSqueeze')
        averaged_logits = tf.reduce_mean(logits, axis=1)
        end_points[end_point] = averaged_logits
        if self._final_endpoint == end_point: return averaged_logits, end_points

        end_point = 'Predictions'
        predictions = tf.nn.softmax(averaged_logits)
        end_points[end_point] = predictions
        return predictions, end_points

In [3]:
fine_tuning=False

# paths to pre-trained i3d models provided by the paper's authors
_CHECKPOINT_PATHS = {
    'rgb': 'i3d/data/checkpoints/rgb_scratch/model.ckpt',
    'rgb600': 'i3d/data/checkpoints/rgb_scratch_kin600/model.ckpt',
    'flow': 'i3d/data/checkpoints/flow_scratch/model.ckpt',
    'rgb_imagenet': 'i3d/data/checkpoints/rgb_imagenet/model.ckpt',
    'flow_imagenet': 'i3d/data/checkpoints/flow_imagenet/model.ckpt',
}

# this list has the names of all the weights in the Flow net and their shapes
flow_varlist = tf.train.list_variables(_CHECKPOINT_PATHS['flow_imagenet'])
flow_vardict = {}
# make variables to load the saved weights into
for variable in flow_varlist:
    flow_vardict[variable[0]] = tf.Variable(initial_value = np.zeros(variable[1], dtype=np.float32),
                                            shape=tf.TensorShape(variable[1]),
                                            trainable=fine_tuning,
                                            name=variable[0])
    
flow_saver = tf.compat.v1.train.Saver(var_list=flow_vardict)
flow_saver.restore(sess=None, save_path=_CHECKPOINT_PATHS['flow_imagenet'])

rgb_varlist = tf.train.list_variables(_CHECKPOINT_PATHS['rgb_imagenet'])
rgb_vardict = {}
for variable in rgb_varlist:
    rgb_vardict[variable[0]] = tf.Variable(initial_value = np.zeros(variable[1], dtype=np.float32),
                                           shape=tf.TensorShape(variable[1]), 
                                           trainable=fine_tuning,
                                           name=variable[0])
rgb_saver = tf.compat.v1.train.Saver(var_list=rgb_vardict, reshape=True)
rgb_saver.restore(sess=None, save_path=_CHECKPOINT_PATHS['rgb_imagenet'])

# now vardict will contain all the weights
vardict = {}
vardict.update(rgb_vardict)
vardict.update(flow_vardict)

# some warning messages about deprecated functions appear here, they're irrelevant

INFO:tensorflow:Restoring parameters from i3d/data/checkpoints/flow_imagenet/model.ckpt
INFO:tensorflow:Restoring parameters from i3d/data/checkpoints/rgb_imagenet/model.ckpt


<h3>Video FS recog</h3>

The Encoder module runs the i3d network on image (RGB) frames of fingerspelling sequences together with optical flow frames computed from them, up to a specified endpoint within the net architecture. All the i3d layers use padding to keep shape, so the reductions in time depth come only at the layers with time stride larger than 1. These are the first layer (conv 7x7x7 stride 2x2x2), and the pooling layers after the second and seventh inception modules,

    MaxPool3d_4a_3x3, MaxPool3d_5a_3x3
    
Consequently after layer 4a the time depth has been reduced by a factor of 4, and after layer 5a the time depth has been reduced by a factor of 8. This means that the network can output potentially one character per 4 or 8 frames, respectively. In our data set the average number of frames per character is about 5.7, but a quarter of sequences have fewer than 4 frames per character. Since we are using the i3d net for its feature abstraction capabilities (rather than the classification task for which it was trained) it appears reasonable to decrease the time strides of the pooling layers, in which case the network can emit up to one character per 2 frames.

The output of the Encoder module is fed into the Cogitator module, a sequence of LSTM cells. The time depth is unchanged here. Finally, the Decoder module assigns a probability distribution (over the alphabet) to each output frame. (More precisely, these modules are run on two streams of hand-detected inputs from the frame sequence. The output depths of the two streams are the same.) The alphabet includes the special 'blank' symbol which is meant to indicate a break in the input sequence or just a lack of result. Most often in this dataset only one hand in the frame sequence is actually signing, and the output of the other stream should tend to be just a sequence of 'blanks'. From the decoder output a CTC loss is computed, and the network weights are updated. (At least for now, the Encoder module doesn't contain any trainable weights.) 

Unfortunately for our purposes, in the basic CTC algorithm the 'blank' character performs two functions. Generally the network's output characters are emitted faster than letters are formed. The intended behaviour in this case is that the network should continue emitting the character it sees in the frame sequence until something changes; a later step in the calculation of the CTC loss collapses these repeats, but repeated characters separated by a blank are not merged. This means that interleaving the outputs of the two input streams is not a neutral operation: the stream watching a hand that does not sign may emit only blanks, but putting these into the output sequence of the stream watching a signing hand causes the repetitions it emits not to be merged as intended. Elaborations of the CTC set-up have long allowed for multiple 'blank' symbols and other special punctuation, which would remove this problem, but TensorFlow does not implement this. (The change is not trivial: CTC computes the probability of a given label from the probabilities of all the output sequences which reduce to that label, of which there are many, by a dynamic programming 'forward-backward' method. Adding another 'blank' symbol requires updating this calculation to recognise the new outputs which collapse to a given label. Changing the processing to remove the 'separator' functionality of blanks, and merging repeated characters even across them, presents the same difficulty.) 

In [88]:
l2_coef = 1/2048

def random_flip(seq):
    if tf.random.categorical(tf.math.log([[0.5, 0.5]]), 1).numpy()[0][0]:
        #return tf.raw_ops.Reverse(tensor=seq, dims=[False,False,True,False])
        return tf.raw_ops.Reverse(tensor=seq, dims=[False,False,False,True,False])
    return seq

def random_augment(seq):
    seq = random_flip(seq)
    seq = tf.image.random_brightness(seq, 0.15)
    seq = tf.image.random_saturation(seq, 0.85, 1.15)
    seq = tf.image.random_contrast(seq, 0.85, 1.15)
    #seq = tf.image.random_hue(seq, 0.01)
    return tf.raw_ops.ClipByValue(t=seq, clip_value_min=0, clip_value_max=1)

In [97]:
def softsign_shift(x):
    return tf.nn.softsign(x) + 0.5

class testSpeller3dEncoder(tf.Module):
    def __init__(self, i3d_endpoint='MaxPool3d_4a_3x3', fine_tuning=False):
        super(testSpeller3dEncoder, self).__init__()
        self.fine_tuning = fine_tuning
        self.rgb = InceptionI3d(var_prefix='RGB', final_endpoint=i3d_endpoint)
        self.flow = InceptionI3d(var_prefix='Flow',final_endpoint=i3d_endpoint)
        
    def __call__(self, inputs, is_training=False, return_dict=False):
        # input shape [batch, time, height, width, channels]
        rgb_scaled = random_augment(tf.image.convert_image_dtype(inputs[0].to_tensor(), dtype=tf.float32))
        rgb_results = self.rgb(rgb_scaled, self.fine_tuning)
        # i3d expects flow in range [-1,1]
        flow_results = self.flow(2*random_flip(inputs[1].to_tensor())-1, self.fine_tuning)
        if return_dict:
            return rgb_results[1], flow_results[1]
        return rgb_results[0], flow_results[0]
    
class testSpellerAttender(tf.Module):
    def __init__(self, units=256, encoder_size = 480):
        self.dense1 = tf.keras.layers.Dense(units, activation='relu')
        self.dense2 = tf.keras.layers.Dense(units, activation='relu')
        self.dense_rgb = tf.keras.layers.Dense(encoder_size, activation='softmax')
        self.dense_flow = tf.keras.layers.Dense(encoder_size, activation='softmax')
        
    def __call__(self, geom, is_training=True):
        # pool w/ stride 2 in time dim to agree with depth reduction in Encoder
        out = self.dense2(self.dense1(geom.to_tensor()))
        return tf.nn.avg_pool(self.dense_rgb(out), ksize=[1,2,1], strides=[1,2,1], padding='SAME'), tf.nn.avg_pool(self.dense_flow(out), ksize=[1,2,1], strides=[1,2,1], padding='SAME')
    
class testSpellerCogitator(tf.Module):
    def __init__(self, units=256):
        super(testSpellerCogitator, self).__init__()
        # i/o shape [batch, timesteps, channels] with return_sequences on
        # out shape [batch, units] with return_sequences off
        self.lstm1 = tf.keras.layers.LSTM(units, return_sequences=True, recurrent_activation=softsign_shift)
        self.lstm2 = tf.keras.layers.LSTM(units, return_sequences=True, recurrent_activation=softsign_shift)
    
    def __call__(self, inputs, is_training=True):
        # an option here is to propagate the LSTM state as well; this amounts to making it a recurrent layer
        return self.lstm2(self.lstm1(inputs))
    
class testSpellerDecoder(tf.Module):
    def __init__(self, units=256, labels=32):
        super(testSpellerDecoder,self).__init__()
        self.class1 = tf.keras.layers.Dense(units, activation = 'relu')
        self.class2 = tf.keras.layers.Dense(units, activation = 'relu')
        self.class3 = tf.keras.layers.Dense(labels)
        
    def __call__(self, input_, is_training=True):
        return self.class3(self.class2(self.class1(input_)))
    
learning_rate=0.0001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                    learning_rate,
                    decay_steps=500,
                    decay_rate=0.9,
                    staircase=True)

#opt=tf.keras.optimizers.Adam(learning_rate=lr_schedule)
opt=tf.keras.optimizers.Nadam(learning_rate=0.0001)
# Nadam, adam with nesterov momentum!

In [78]:
batch_size = 8 # or 32, 64, 128, 256; we should increase this depending on how many processors/threads will be running
buffer_size = 32 # should be at least batch_size; only used for shuffling the dataset, so not a critical parameter

alphabet_size = 32 # 26 letters, 5 punctuation symbols, and the 'blank' symbol

attender_units = 64
cogitator_units = 128
decoder_units = 128

i3d_endpoint='MaxPool3d_4a_3x3'
encoder_output = 480
# later endpoint options include:
#i3d_endpoint = 'MaxPool3d_5a_2x2'
#i3d_endpoint = 'Mixed_5c'
timings = np.zeros(6)

#testSpell2dEncode = testSpeller2dEncoder((64,64,128,128))
testSpell3dEncode = testSpeller3dEncoder(i3d_endpoint = i3d_endpoint)
testSpellAttend = testSpellerAttender(attender_units, encoder_output)
testSpellCogitate = testSpellerCogitator(units = cogitator_units)
testSpellDecode = testSpellerDecoder(units = decoder_units, labels=alphabet_size)

In [14]:
base_cfsw = "h:/project/cfsw/"
base_cfswp = "h:/project/cfswp/"

# this makes a small difference to the data shuffling; it should be at least batch_size

#buffer_size = batch_size

def load_fs_batches(start_batch, endpoint, 
                   image_in, geom_in, label_in, image_tag,
                   batch_size = batch_size, image_size=128, file_size=256, buffer_size=32):
    for batch_number in range(start_batch, endpoint):
        suffix = str(batch_number) + ".npy"
        image_suffix = str(batch_number) + image_tag + ".npy"
        
        rgb_a = np.load(image_in + "train_img_a_" + image_suffix)
        flow_a = np.load(image_in + "train_flow_a_" + image_suffix)
        geom_a = np.concatenate([np.load(geom_in + "train_geom_img_a_" + suffix), # 21 * 3 floats per frame per file
                                 np.load(geom_in + "train_geom_wrl_a_" + suffix)], axis=1)
        rgb_b = np.load(image_in + "train_img_b_" + image_suffix)
        flow_b = np.load(image_in + "train_flow_b_" + image_suffix)
        geom_b = np.concatenate([np.load(geom_in + "train_geom_img_b_" + suffix),
                                 np.load(geom_in + "train_geom_wrl_b_" + suffix)], axis=1)
        label_seqs = np.load(label_in + "train_label_seqs_" + suffix)
        label_lengths = np.load(label_in + "train_seq_lengths_" + suffix)
        frame_counts = np.load(label_in + "train_frame_counts_" + suffix)

        seq_count = frame_counts.shape[0]
        seq_heights=[image_size]*file_size
        seq_widths = [image_size]*file_size
        # for input that hasn't been resized to a fixed section, these lines compute dimensions
        #seq_heights = [crnrs[1][1] - crnrs[0][1] for crnrs in corners[0]]
        #seq_widths = [crnrs[1][0] - crnrs[0][0] for crnrs in corners[0]]

        total_frames = frame_counts.sum()
        frame_heights = list(chain(*[[seq_heights[j]]*frame_counts[j] for j in range(seq_count)]))
        frame_widths = list(chain(*[[seq_widths[j]]*frame_counts[j] for j in range(seq_count)]))
        by_widths = list(chain(*[[frame_widths[j]]*frame_heights[j] for j in range(total_frames)]))
        geom_lengths = list(chain(*[[21]*frame_counts[j] for j in range(seq_count)]))

        rgb_a_tensor = tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(rgb_a, 
                                                                                                                      by_widths),
                                                                                        frame_heights),
                                                        frame_counts)
        del rgb_a 
        flow_a_tensor = tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(flow_a, 
                                                                                                                        by_widths),
                                                                                          frame_heights),
                                                          frame_counts)
        del flow_a
        flow_a_tensor = tf.cast(flow_a_tensor, tf.float32)

        geom_a_tensor = tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(geom_a, geom_lengths),
                                                          frame_counts).merge_dims(-2,-1)
        del geom_a


        rgb_b_tensor = tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(rgb_b, 
                                                                                                                        by_widths),
                                                                                          frame_heights),
                                                          frame_counts)
        del rgb_b

        flow_b_tensor = tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(flow_b, 
                                                                                                                        by_widths),
                                                                                          frame_heights),
                                                          frame_counts)
        del flow_b
        flow_b_tensor = tf.cast(flow_b_tensor, tf.float32)

        geom_b_tensor = tf.RaggedTensor.from_row_lengths(tf.RaggedTensor.from_row_lengths(geom_b, geom_lengths),
                                                          frame_counts).merge_dims(-2,-1)
        del geom_b

        # bundle the data into a tf Dataset with set batch size
        if batch_number==start_batch:
            dataset_ = tf.data.Dataset.from_tensor_slices((((rgb_a_tensor, flow_a_tensor, geom_a_tensor), 
                                                    (rgb_b_tensor, flow_b_tensor, geom_b_tensor)), 
                                                   (label_seqs, label_lengths)))
        else:
            dataset_ = dataset_.concatenate(tf.data.Dataset.from_tensor_slices((((rgb_a_tensor, flow_a_tensor, geom_a_tensor), 
                                                    (rgb_b_tensor, flow_b_tensor, geom_b_tensor)), 
                                                   (label_seqs, label_lengths))))
        del rgb_a_tensor, flow_a_tensor, geom_a_tensor
        del rgb_b_tensor, flow_b_tensor, geom_b_tensor
    return dataset_.shuffle(buffer_size=buffer_size, reshuffle_each_iteration=True).batch(batch_size)

def load_csfw_batches(start_batch, endpoint, batch_size = batch_size):
    return load_fs_batches(start_batch, endpoint,
                           image_in = base_cfsw + 'cfsw_128/',
                           geom_in = base_cfsw + 'cfsw_geom/',
                           label_in = base_cfsw + 'cfsw_labels/',
                           batch_size = batch_size, image_tag = "_128", 
                           file_size=256, buffer_size = buffer_size)

def load_csfwp_batches(start_batch, endpoint, batch_size = batch_size):
    return load_fs_batches(start_batch, endpoint,
                          image_in = base_cfswp + 'cfswp_128/',
                          geom_in = base_cfswp + 'cfswp_geom/',
                          label_in = base_cfswp + 'cfswp_labels/',
                          batch_size = batch_size, image_tag = '',
                          file_size=512, buffer_size=buffer_size)

In [93]:
def process_fs_batch(input_batch, label_batch, return_lists=False, batch_size=batch_size, is_training=True):
    # input_batch has frame sequences [batch_size, frame_count, height, width, channels]
    #   organised as ((rgb_a, flow_a), (rgb_b, flow_b))
    # label_batch has seq labels plus their lengths
    out_lengths = np.zeros(batch_size, dtype=np.int32) # records time depth of encoder output
    timings_ = np.zeros(5) # informational statistics on how long various computations take
    
    timings_[0]=time()    
    decoder_results_list_a = []
    decoder_results_list_b = []
    for k in range(batch_size):
        # since this length depends only on the depth of the input and the i3d net endpoint,
        # both A and B have the same output length, even if the frame sizes are different
        out_lengths[k] = np.int32(np.ceil(input_batch[0][0][k].get_shape()[0]/2))
    encoder_results_a = testSpell3dEncode(input_batch[0], is_training=is_training)
    encoder_results_b = testSpell3dEncode(input_batch[1], is_training=is_training)
    results_depth = encoder_results_a[0].shape[1]
    timings_[1] = time()
    
    # attender gives weights to the filters of the encoder; spatial axes of size 1x1 are added
    # to get tf to broadcast multiplication correctly
    attender_results_a = tf.expand_dims(tf.expand_dims(testSpellAttend(input_batch[0][2], is_training=is_training), -2), -2)
    attender_results_b = tf.expand_dims(tf.expand_dims(testSpellAttend(input_batch[1][2], is_training=is_training), -2), -2)
    
    encoder_reshaped_a = tf.concat([tf.reshape(tf.math.multiply(encoder_results_a[0],
                                                                attender_results_a[0]), 
                                               (batch_size, results_depth, -1)),
                                    tf.reshape(tf.math.multiply(encoder_results_a[1],
                                                                attender_results_a[1]),
                                               (batch_size, results_depth, -1))],
                                   axis=2)
    encoder_reshaped_b = tf.concat([tf.reshape(tf.math.multiply(encoder_results_b[0],
                                                                attender_results_b[0]), 
                                               (batch_size, results_depth, -1)),
                                    tf.reshape(tf.math.multiply(encoder_results_b[1],
                                                                attender_results_b[1]),
                                               (batch_size, results_depth, -1))],
                                   axis=2)
    timings_[2] = time()

    cogitator_results_a = testSpellCogitate(encoder_reshaped_a, is_training=is_training)
    cogitator_results_b = testSpellCogitate(encoder_reshaped_b, is_training=is_training)
    timings_[3] = time()

    for j in range(cogitator_results_a.get_shape()[1]):
        decoder_results_list_a.append(testSpellDecode(cogitator_results_a[:,j], is_training=is_training))
        decoder_results_list_b.append(testSpellDecode(cogitator_results_b[:,j], is_training=is_training))
    decoder_results_list_a = tf.convert_to_tensor(decoder_results_list_a)
    decoder_results_list_b = tf.convert_to_tensor(decoder_results_list_b)

    if return_lists:
        timings_[4] = time()
        timings[0:4] = np.diff(timings_)
        return decoder_results_list_a, decoder_results_list_b, out_lengths
    
    loss_a = tf.nn.ctc_loss(label_batch[0], # tensor of shape [batch_size, max_label_seq_length] or SparseTensor
                         decoder_results_list_a, # tensor of shape [frames, batch_size, num_labels] : prob. of each character at each time step 
                         label_batch[1],  # tensor of shape [batch_size], or None if labels=SparseTensor
                         out_lengths,  # tensor of shape [batch_size] "Length of input sequence in logits."
                         logits_time_major=True) # flip to swap first two axes of logits tensor
    loss_b = tf.nn.ctc_loss(label_batch[0],
                         decoder_results_list_b, 
                         label_batch[1],  
                         out_lengths,  
                         logits_time_major=True)
    timings_[4] = time()
    timings[0:4] = np.diff(timings_)
    # tf.nn.l2_loss(tensor) # regulariser
    
    return tf.math.minimum(loss_a, loss_b)

In [98]:
testSpell3dEncode2 = testSpeller3dEncoder(i3d_endpoint = i3d_endpoint)
test_batch = test_dataset.take(1)
for elt in test_batch:
    input_batch, label_batch = elt

In [100]:
encoder_results_a = testSpell3dEncode2(input_batch[0], return_dict=True)
encoder_results_a = encoder_results_a

In [105]:
pool_args = {'ksize' : [1,1,2,2,1], 'strides' : [1,1,2,2,1], 'padding' : pad_same}

def pool_to_4x4(filter_tensor):
    filter_space = filter_tensor.get_shape()[-2]
    parity = True
    while filter_space > 4:
        if parity:
            filter_tensor = tf.nn.avg_pool3d(filter_tensor, **pool_args)
            filter_space = filter_space // 2
            parity = not parity
        else:
            filter_tensor = tf.nn.max_pool3d(filter_tensor, **pool_args)
            filter_space == filter_space // 2
            parity = not parity
    return filter_tensor

def pool_and_concatenate(tensor_dict, keys):
    return tf.concat([pool_to_4x4(t[key])] for key in keys], axis=-1)
  
i3d_endpoints = ['MaxPool3d_5a_2x2', 'MaxPool3d_4a_3x3', 'MaxPool3d_3a_3x3', 'MaxPool3d_2a_3x3']
encoder_results_a[0]['conv_2dim'] = testSpell2dEncode(input_batch[0][0])
encoder_results_b[0]['conv_2dim'] = testSpell2dEncode(input_batch[1][0])
encoder_out_a = tf.concat([pool_and_concatenate(encoder_results_a[0], i3d_endpoints + ['conv_2dim']),
                           pool_and_concatenate(encoder_results_a[1], i3d_endpoints)]).reshape(batch_size, 
                                                                                               results_depth, -1)
encoder_out_b = encoder_out_a = tf.concat([pool_and_concatenate(encoder_results_a[0], i3d_endpoints + ['conv_2dim']),
                           pool_and_concatenate(encoder_results_a[1], i3d_endpoints)]).reshape(batch_size, 
                                                                                               results_depth, -1)
        

print(encoder_results_a(['MaxPool3d_2a_3x3'].shape) # 32x32x64
print(encoder_results_a(['MaxPool3d_3a_3x3'].shape) # 16x16x192
print(encoder_results_a(['MaxPool3d_4a_3x3'].shape) # 8x8x480
#encoder_results_a[0]['MaxPool3d_5a_2x2'].shape # 4x4x832

encoder_results_a[0]['MaxPool3d_4a_3x3'],
tf.nn.avg_pool3d(encoder_results_a[0]['MaxPool3d_4a_3x3'], **pool_args)
tf.nn.avg_pool3d(encoder_results_a[1]['MaxPool3d_4a_3x3'], **pool_args)
tf.nn.avg_pool3d(encoder_results_b[0]['MaxPool3d_4a_3x3'], **pool_args)
tf.nn.max_pool3d

(8, 31, 32, 32, 64)
(8, 31, 16, 16, 192)
(8, 31, 8, 8, 480)


True

In [None]:
# this loop version loads the entire dataset into memory
epochs = 1
timings_temp = time()
cfsw_train = load_cfsw_batches(0, 21, batch_size = batch_size)
timings[5] = time()-timings_temp

for epoch in range(epochs):
    print("\nStart of epoch " + str(epoch))
    epoch_loss = 0       
    for input_batch, label_batch in cfsw_train:
        with tf.GradientTape(persistent=True) as tape:
            loss_ = process_fs_batch(input_batch, label_batch)
            print(loss_.numpy())
            loss_sum = tf.math.reduce_sum(loss_).numpy()
            epoch_loss += loss_sum
            
        timings_temp = time()
        grad_decode = tape.gradient(loss_, testSpellDecode.trainable_variables)
        grad_cogitate = tape.gradient(loss_, testSpellCogitate.trainable_variables)
        grad_attend = tape.gradient(loss_, testSpellAttend.trainable_variables)
        timings[4] = time() - timings_temp

        opt.apply_gradients(zip(grad_decode, testSpellDecode.trainable_variables))
        opt.apply_gradients(zip(grad_cogitate, testSpellCogitate.trainable_variables))
        opt.apply_gradients(zip(grad_attend, testSpellAttend.trainable_variables))
    print('\nAverage loss this epoch: ', epoch_loss/(21*256)) 
print('Timings: ', timings)

In [118]:
def get_var_list_stats(var_list):
    norm_acc, number_acc = 0, 0
    for j in range(len(var_list)):
        number_acc += tf.math.reduce_prod(var_list[j].shape)
        norm_acc += tf.math.reduce_euclidean_norm(var_list[j])
    return norm_acc.numpy(), number_acc.numpy(), norm_acc.numpy()/number_acc.numpy()

def var_list_decay(var_list, coef):
    for j in range(len(var_list)):
        var_list[j].assign_sub(var_list[j]*coef)

In [106]:
# this loop version loads parts of the dataset into memory at a time, cycling through them
batches_at_once = 1 # or 3 or 7 or 21
epochs = 1
cfsw_train_subset = None

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    epoch_loss = 0
    # really want to take these in random order
    for k in rng.permutation(21//batches_at_once):
        batch_loss = 0
        del cfsw_train_subset
        timings_temp = time()
        cfsw_train_subset = load_cfsw_batches(k*batches_at_once, (k+1)*batches_at_once, batch_size = batch_size)
        timings[5] = time()-timings_temp
        
        for input_batch, label_batch in cfsw_train_subset:
            with tf.GradientTape(persistent=True) as tape:
                loss_ = process_fs_batch(input_batch, label_batch)
                print(loss_.numpy())
                loss_sum = tf.math.reduce_sum(loss_).numpy()
                batch_loss += loss_sum
                epoch_loss += loss_sum
                
            timings_temp = time()
            grad_decode = tape.gradient(loss_, testSpellDecode.trainable_variables)
            grad_cogitate = tape.gradient(loss_, testSpellCogitate.trainable_variables)
            grad_attend = tape.gradient(loss_, testSpellAttend.trainable_variables)
            timings[4] = time() - timings_temp
            print(timings)
            print(get_var_list_stats(grad_decode))
            print(get_var_list_stats(grad_cogitate))
            print(get_var_list_stats(grad_attend))
            
            var_list_decay(testSpellDecode.trainable_variables, l2_coef)
            var_list_decay(testSpellCogitate.trainable_variables, l2_coef)
            var_list_decay(testSpellAttend.trainable_variables, l2_coef)
            opt.apply_gradients(zip(grad_decode, testSpellDecode.trainable_variables))
            opt.apply_gradients(zip(grad_cogitate, testSpellCogitate.trainable_variables))
            opt.apply_gradients(zip(grad_attend, testSpellAttend.trainable_variables))
                     
        print('\nAverage loss for batch ' + str(epoch)+"/"+str(k) +': ', batch_loss/(256*batches_at_once))
    print('\nAverage loss this epoch: ', epoch_loss/(21*256))
print('Timings: ', timings)


Start of epoch 0
[ 9.796211   8.473819  15.416117  43.322746   6.1899147 17.462254
 17.666729  27.439209 ]
[19.32058263  0.0626924   0.67225838  4.71258354  9.34029102 15.78667021]
(191.85931, 37152, 0.0051641718875119445)
(38.530037, 31654912, 1.2171898290625333e-06)
(11.735307, 74688, 0.00015712439400984265)
[ 6.6294107  9.096701   7.926929   6.747944   6.7425213 12.820393
 17.470823  26.094831 ]
[19.06116343  0.05864835  0.68686438  0.42482233  9.44190145 15.78667021]
(137.9225, 37152, 0.0037123842756877573)
(25.09863, 31654912, 7.928826624174904e-07)
(15.16831, 74688, 0.00020308898571932938)
[ 9.199769 16.891727 10.904272 13.731448 14.006483 25.982197 20.295616
 19.662216]
[12.8787353   0.05043149  0.47526312  0.33152246  6.14788961 15.78667021]
(174.96635, 37152, 0.00470947336267542)
(37.756226, 31654912, 1.192744607406822e-06)
(30.541744, 74688, 0.0004089243818575639)
[ 6.310892  26.204674   6.0927315 12.650826  12.319907   9.847191
  8.310272  24.978683 ]
[18.45156169  0.049292

[11.39529967  0.04029703  0.41394711  0.29171443  5.55115962 15.78667021]
(89.71384, 37152, 0.0024147781188071133)
(15.513202, 31654912, 4.900724953385438e-07)
(2.044028, 74688, 2.736755628410117e-05)
[ 4.7671175 30.332767   5.2156305  7.1049824 21.338531   8.341793
 21.964289  14.425127 ]
[24.48677611  0.07066035  0.87712955  0.50647902 12.65562797 15.78667021]
(162.56107, 37152, 0.004375567013184435)
(37.12052, 31654912, 1.1726622884123058e-06)
(19.24214, 74688, 0.00025763362007664124)
[  5.2342625   4.495678  709.584       4.5542274  12.60906    10.731551
  13.322425   13.420117 ]
[16.20144868  0.05041909  0.60287571  0.34342933  8.11578465 15.78667021]
(187.64091, 37152, 0.005050627554828601)
(41.651386, 31654912, 1.3157953577942763e-06)
(23.366919, 74688, 0.00031286041350475006)
[10.154652  25.88444    7.9416957 39.37274    9.152397  10.118042
  5.744207  11.838406 ]
[20.39018512  0.05131483  0.74883938  0.43888712 10.35318661 15.78667021]
(265.33127, 37152, 0.007141776171149517)


[ 5.903958  24.67627   35.188717   6.5146546  9.743471  21.686518
 33.19203   14.773273 ]
[18.10789704  0.05943871  0.77101588  0.50407767  9.36475563 19.14966393]
(213.20892, 37152, 0.005738827609276587)
(29.526693, 31654912, 9.327681386104062e-07)
(3.561154, 74688, 4.7680402323029036e-05)
[13.29659    9.643696  10.831479  64.7794    28.45706    9.3889675
  8.39043   25.280788 ]
[43.37036848  0.1218853   1.68970585  1.04150891 22.95608497 19.14966393]
(141.41093, 37152, 0.003806280535320903)
(22.973694, 31654912, 7.257544689322229e-07)
(4.8948026, 74688, 6.553666680514966e-05)
[23.253944  11.187419  23.090137  26.295534  50.399536  31.284601
 10.816544   7.0026345]
[28.59828401  0.06939864  1.01843262  0.55608058 14.48482251 19.14966393]
(143.83136, 37152, 0.0038714297982149347)
(42.96406, 31654912, 1.3572636605990422e-06)
(18.904274, 74688, 0.00025310992377378436)
[20.094215  38.589325  20.361782   8.948833  10.643777   6.6752253
 36.619293  25.041903 ]
[24.61879468  0.07500005  0.93

[23.05817127  0.06492019  0.81700253  0.50412607 11.46232677 24.69034958]
(225.89236, 37152, 0.00608022083607755)
(45.615593, 31654912, 1.4410273184946137e-06)
(5.6449313, 74688, 7.55801643687839e-05)
[13.952521  21.338547  30.810425  15.9802475 44.626152  48.11917
  5.40077   18.883888 ]
[60.88966751  0.13372111  2.02498102  1.31369638 32.25278115 24.69034958]
(543.1516, 37152, 0.01461971391387072)
(353.33896, 31654912, 1.1162215795773951e-05)
(16.018112, 74688, 0.00021446701187094564)
[37.26317   26.735891   4.1092644 12.928933   8.823449   6.603446
 26.393204   5.7706757]
[41.12660503  0.09060001  1.23466969  0.76433778 20.28187895 24.69034958]
(358.73828, 37152, 0.009655961489287252)
(57.795948, 31654912, 1.825812942666353e-06)
(58.65199, 74688, 0.0007852933400700822)
[ 4.784301  40.841698   7.3750267 10.952774   5.1428213  5.31944
 18.31645    5.2132215]
[26.27144885  0.0707171   0.92642641  0.60258365 12.83506751 24.69034958]
(268.41345, 37152, 0.007224737622427797)
(39.738895, 3

[ 9.90473   18.43777   15.439573  22.506157  16.7518    15.0858135
 22.269163  23.189175 ]
[16.93858147  0.05049586  0.5674181   0.30081296  8.05937719 18.43240762]
(149.8244, 37152, 0.004032741221346597)
(39.85401, 31654912, 1.2590150791019268e-06)
(8.904934, 74688, 0.00011922844271426949)
[25.422583 15.186783 11.860341 27.048359 31.924894  7.862547 17.425945
 50.922287]
[28.00999904  0.07071924  0.93607187  0.581141   13.76459479 18.43240762]
(154.45316, 37152, 0.004157330844034726)
(52.183884, 31654912, 1.6485240479263435e-06)
(10.222727, 74688, 0.0001368724135322865)
[25.921692  28.812794  13.164911  25.124147   9.48151    8.490883
  7.9666986  8.016422 ]
[27.52539539  0.07057095  0.9484148   0.57236362 13.13229489 18.43240762]
(181.62727, 37152, 0.0048887616698850755)
(30.295643, 31654912, 9.570597717277876e-07)
(1.3752387, 74688, 1.8413113980795852e-05)
[31.47601   11.454204  38.72441   22.896767  20.227074   6.4823713
 22.321598   8.806608 ]
[26.13938498  0.07064533  0.84381557 

[ 19.30627    17.226875  710.47577     6.8518558   8.494934   48.951912
  32.21505    11.95833  ]
[30.02944136  0.06492376  0.93377519  0.56044364 14.67345119 18.22829747]
(186.57555, 37152, 0.005021951611343897)
(31.055801, 31654912, 9.810736921840618e-07)
(13.07141, 74688, 0.0001750135253205091)
[10.140709  15.833483  12.702584  11.262122  15.72148   11.210411
  7.7037835 18.566324 ]
[13.39638925  0.04032803  0.50333285  0.32061267  6.60397267 18.22829747]
(110.374306, 37152, 0.0029708846286901824)
(23.191113, 31654912, 7.326228712406639e-07)
(21.283619, 74688, 0.00028496704861560027)
[ 9.294263 27.810467 19.887009 23.610058  8.999241 66.81937  25.725727
  6.767634]
[29.93950772  0.08094478  1.09622216  0.55767131 14.82489014 18.22829747]
(191.44333, 37152, 0.0051529750446119155)
(24.924133, 31654912, 7.873701655143205e-07)
(6.6531777, 74688, 8.907960767713283e-05)
[ 6.9150705 24.557585   8.620681  13.143844  30.46113    8.729795
  6.986642   9.540995 ]
[39.68931651  0.08505177  1.41

[137.76787567   0.40756392   4.63434315   2.74955797  61.27723503
  15.78768158]
(237.06874, 37152, 0.006381049226010082)
(62.65947, 31654912, 1.9794548664198525e-06)
(96.59891, 74688, 0.001293365834815541)
[ 32.50599   10.709186   7.96778   35.452736  39.242973  11.437173
  11.976629 709.00977 ]
[36.82563996  0.08357239  5.51655078  0.74968219 17.2286644  15.78768158]
(227.19733, 37152, 0.0061153457865029135)
(42.680534, 31654912, 1.3483068397976581e-06)
(7.9777884, 74688, 0.00010681486247233479)
[ 10.013348  10.582167  50.949593  29.783215   7.587714 709.69604
  11.528336  24.84115 ]
[31.2634275   0.07506704  1.00143719  0.57652497 15.25194216 15.78768158]
(132.07034, 37152, 0.003554864960636793)
(37.832897, 31654912, 1.195166714924979e-06)
(26.6622, 74688, 0.00035698105355257034)
[41.182377  9.834907  9.966293 10.240526  9.48793  22.397526 23.30185
 11.964304]
[23.70387173  0.06406283  0.81518626  0.51194096 11.12368822 15.78768158]
(115.44814, 37152, 0.003107454322926655)
(137.3555

[19.49155855  0.05246377  0.67255998  0.40091705  8.92896581 18.46889377]
(128.0685, 37152, 0.0034471494590897277)
(18.730366, 31654912, 5.91704875160412e-07)
(4.4579124, 74688, 5.968713106614663e-05)
[ 7.347961 12.288617  9.419021  6.041459 21.384033 11.755468 11.672304
 12.222061]
[16.63149643  0.04238439  0.56363654  0.35571742  8.03808355 18.46889377]
(146.7171, 37152, 0.003949103737370297)
(39.906105, 31654912, 1.2606607480540115e-06)
(6.1796274, 74688, 8.273922743302895e-05)
[32.549812   9.687704   9.997201  13.198854   9.7294855 58.893806
 17.490993  16.614304 ]
[48.50975418  0.15657353  1.7082634   1.94102383 24.049227   18.46889377]
(246.52979, 37152, 0.006635706964800011)
(36.504627, 31654912, 1.1532057719125299e-06)
(1.2568104, 74688, 1.6827474650707424e-05)
[15.174288  15.126748   7.8383484 16.496807  20.184029   8.504979
 38.281044  16.35395  ]
[20.75635433  0.06065774  0.82936454  0.49075007 10.2320087  18.46889377]
(152.80327, 37152, 0.004112921738604037)
(21.510712, 316

[10.43528    7.627506   5.8433647 55.036297  21.755272   8.0596695
  9.608751  30.725142 ]
[48.705585    0.1116097   1.50124025  0.86652088 23.0729301  19.95447159]
(267.22018, 37152, 0.007192619087160096)
(43.31566, 31654912, 1.3683708414459007e-06)
(2.9535325, 74688, 3.9544939713898946e-05)
[20.161863  23.062447  10.150178   6.935674  13.387985  60.357365
  7.0347514 23.251982 ]
[30.84256816  0.07750058  1.99832582  0.58873725 14.13129735 19.95447159]
(298.17136, 37152, 0.008025714798696487)
(36.047394, 31654912, 1.138761459795817e-06)
(4.8550863, 74688, 6.50049047584501e-05)
[ 6.3508654 20.575188  37.873962  11.897501  20.72468   10.245124
  6.6351347 13.4135275]
[25.29577637  0.08808064  0.73260093  0.41552711 11.58589721 19.95447159]
(134.1307, 37152, 0.003610322641771893)
(24.686895, 31654912, 7.798756594389964e-07)
(2.375534, 74688, 3.180610081428325e-05)
[ 7.466657   7.771383  12.761274  24.611755   6.91741   11.1484585
 19.111572   8.5270605]
[16.12767172  0.05438137  0.551768

KeyboardInterrupt: 

In [109]:
print(get_var_list_stats(testSpellDecode.trainable_variables))
print(get_var_list_stats(testSpellCogitate.trainable_variables))
print(get_var_list_stats(testSpellAttend.trainable_variables))

(32.718094, 37152, 0.0008806549814833741)
(187.1872, 31654912, 5.913369616197914e-06)
(46.870586, 74688, 0.0006275517672887702)


In [123]:
testSpellDecode.trainable_variables[0]

<tf.Variable 'dense_16/kernel:0' shape=(128, 128) dtype=float32, numpy=
array([[ 0.08459155,  0.07163593, -0.08155863, ..., -0.09446944,
         0.13766198,  0.12962008],
       [-0.05576618,  0.06116143, -0.01730708, ..., -0.06724232,
         0.07287413,  0.00297883],
       [ 0.06309021,  0.11109361, -0.11438391, ...,  0.04126996,
         0.15530567, -0.14080565],
       ...,
       [ 0.11576623,  0.02597345,  0.08837017, ..., -0.17141892,
         0.13526745, -0.04232301],
       [ 0.09379844, -0.04267445, -0.19085574, ...,  0.13805334,
         0.04832689, -0.12917371],
       [-0.03968348, -0.06075937,  0.04498382, ..., -0.15817237,
        -0.09669452, -0.02687849]], dtype=float32)>

In [124]:
var_list_decay(testSpellDecode.trainable_variables, 1/2048)

In [125]:
testSpellDecode.trainable_variables[0]

<tf.Variable 'dense_16/kernel:0' shape=(128, 128) dtype=float32, numpy=
array([[ 0.08455025,  0.07160095, -0.08151881, ..., -0.09442332,
         0.13759476,  0.12955679],
       [-0.05573895,  0.06113156, -0.01729863, ..., -0.06720948,
         0.07283854,  0.00297738],
       [ 0.0630594 ,  0.11103936, -0.11432806, ...,  0.04124981,
         0.15522984, -0.1407369 ],
       ...,
       [ 0.1157097 ,  0.02596077,  0.08832703, ..., -0.17133522,
         0.13520141, -0.04230235],
       [ 0.09375264, -0.04265362, -0.19076255, ...,  0.13798593,
         0.04830329, -0.12911063],
       [-0.03966411, -0.0607297 ,  0.04496186, ..., -0.15809514,
        -0.09664731, -0.02686536]], dtype=float32)>

In [None]:
# the rest is scratchwork, ignore this

In [90]:
# alphabet_encoder associates an integer to each character, encoding text strings as numerical sequences
alphabet_encoder = {'blank': 0, ' ': 1, '&': 2, "'": 3, '.': 4, '@': 5, 'a': 6, 'b': 7, 'c': 8, 'd': 9, 'e': 10,
                    'f': 11, 'g': 12, 'h': 13, 'i': 14, 'j': 15, 'k': 16, 'l': 17, 'm': 18, 'n': 19, 'o': 20,
                    'p': 21, 'q': 22, 'r': 23, 's': 24, 't': 25, 'u': 26, 'v': 27, 'w': 28, 'x': 29, 'y': 30, 'z': 31}

# alphabet_decoder conversely replaces numerical sequences with character strings
alphabet_decoder = {alphabet_encoder[char] : char for char in alphabet_encoder.keys()}
alphabet_decoder[0] = '_' # representation for the blank character

def decode_label_seq(seq, tensor_chars = False, decoder=alphabet_decoder):
    if tensor_chars:
        return str().join([decoder.get(index.numpy(),'*') for index in seq])
    return str().join([decoder.get(index,'*') for index in seq])

# beam search returns for 
def logits_to_string_beam(logits, label_length, **kwargs):
    # logits is a tensor of shape [max_time, batch_size, alphabet_size]
    # label_lengths records the number of characters given for each logit set in the batch
    
    # top_paths = 1  controls how many search results are returned
    # beam_width = 100  controls how many probabilities are maintained at each step of computation
    
    results, probs = tf.nn.ctc_beam_search_decoder(tf.expand_dims(logits, axis=1), 
                                                   [label_length], **kwargs)
    
    # An individual return is a SparseTensor, with bounding shape [batch_size, max_output_length], with:
    # .indices pairs (batch_number, character_position) and .values the output characters.
    ret = []
    #print(results[0].values)
    for result in results:
        #ret.append(str().join([alphabet_decoder.get(char.numpy(), '*')for char in result.values]))
        ret.append(decode_label_seq(result.values, True))
    return ret, tf.squeeze(probs).numpy()

def logits_to_string_greedy(logits, label_length, **kwargs):
    results, probs = tf.nn.ctc_greedy_decoder(tf.expand_dims(logits, axis=1), 
                                                   [label_length],
                                                   blank_index=0, **kwargs)
    ret = []
    #print(results[0].values)
    for result in results:
        #ret.append(str().join([alphabet_decoder.get(char.numpy(), '*')for char in result.values]))
        ret.append(decode_label_seq(result.values, True))
    return ret, tf.squeeze(probs).numpy()

In [None]:
test_dataset = load_cfsw_batches(0,1)

In [96]:
for input_batch, label_batch test_dataset:
    decoder_results_list_a, decoder_results_list_b, out_lengths = process_fs_batch(input_batch, label_batch, 
                                                                      return_lists=True, is_training=False)
    for k in range(batch_size):
        #print(decoder_results_list)
        print('label:', decode_label_seq(label_batch[0][k][0:label_batch[1][k].numpy()].numpy()),
              '\nbeam A:', logits_to_string_beam(decoder_results_list_a[:,k,:], 
                                                 np.int32(out_lengths[k]), top_paths=3),
              '\nbeam B:', logits_to_string_beam(decoder_results_list_b[:,k,:], 
                                                 np.int32(out_lengths[k]), top_paths=3),)
              #', greedy A:', logits_to_string_greedy(decoder_results_list_a[:,k,:], np.int32(out_lengths[k])),
              #', greedy B:', logits_to_string_greedy(decoder_results_list_b[:,k,:], np.int32(out_lengths[k])))
          
    #loss = tf.nn.ctc_loss(label_batch[0],
    #                 decoder_results_list,
    #                 label_batch[1], 
    #                 2*out_lengths,  
    #                 logits_time_major=True)  
    #total_loss += tf.keras.metrics.Sum()(loss)
    
    break
#print('total loss = ', total_loss)

label: it 
beam A: (['_', 'i_', 's_'], array([-1.9809374, -3.7772012, -3.81764  ], dtype=float32)) 
beam B: (['_', 'i_', 's_'], array([-4.8712807, -5.3748426, -5.5574226], dtype=float32))
label: asl literature 
beam A: (['_', '_s_', '_u_'], array([-3.7691383, -4.4382334, -4.7812605], dtype=float32)) 
beam B: (['_e_e_', '_e_e_e_', '_e_e_l_'], array([-18.223228, -18.256298, -18.472479], dtype=float32))
label: all 
beam A: (['_', '_s_', '_e_'], array([-2.691038 , -4.0412016, -4.2800803], dtype=float32)) 
beam B: (['_', '_s_', '_e_'], array([-2.4200158, -3.9401357, -4.296234 ], dtype=float32))
label: fotgs 
beam A: (['_e_', '_l_', '_i_'], array([-10.457163, -10.60331 , -10.723205], dtype=float32)) 
beam B: (['_', '_s_', '_e_'], array([-3.2630477, -4.228913 , -4.577375 ], dtype=float32))
label: ok 
beam A: (['_', 'i_', '_e_'], array([-5.32949  , -5.9381557, -5.940513 ], dtype=float32)) 
beam B: (['_', '_e_', '_l_'], array([-3.7440312, -4.594789 , -4.6847715], dtype=float32))
label: asl lit 

In [None]:
# following the idea of presenting the LSTM middle of the network with information at variable time-scales,
# we consider also a 2d convolutional network that operates only on individual frames

## check this makes sense for 5d input
class testSpeller2dEncoder(tf.Module):
    def __init__(self, conv_filters, reg_coef=0, spatial_dropout_prob=0.01):
        super(testSpeller2dEncoder, self).__init__()
        filters_1, filters_2, filters_3, filters_4 = conv_filters
        self.reg = tf.keras.regularizers.L2(reg_coef)
        self.spatial_dropout_prob = spatial_dropout_prob 
        
        self.conv_1a_rgb = tf.keras.layers.Convolution2D(filters_1, 5, padding='same', use_bias=True, activation='relu')
        self.conv_1b_rgb = tf.keras.layers.Convolution2D(filters_1, 3, padding='same', use_bias=True, activation='relu')
        self.conv_2a_rgb = tf.keras.layers.Convolution2D(filters_2, 3, padding='same', use_bias=True, activation='relu')
        self.conv_2b_rgb = tf.keras.layers.Convolution2D(filters_2, 3, padding='same', use_bias=True, activation='relu')
        self.conv_3a_rgb = tf.keras.layers.Convolution2D(filters_3, 3, padding='same', use_bias=True, activation='relu')
        self.conv_3b_rgb = tf.keras.layers.Convolution2D(filters_3, 3, padding='same', use_bias=True, activation='relu') 
        self.conv_4a_rgb = tf.keras.layers.Convolution2D(filters_4, 3, padding='same', use_bias=True, activation='relu')
        self.conv_4b_rgb = tf.keras.layers.Convolution2D(filters_4, 3, padding='same', use_bias=True, activation='relu')
        
        self.conv_1a_flow = tf.keras.layers.Convolution2D(filters_1, 5, padding='same', use_bias=True, activation='relu')
        self.conv_1b_flow = tf.keras.layers.Convolution2D(filters_1, 3, padding='same', use_bias=True, activation='relu')
        self.conv_2a_flow = tf.keras.layers.Convolution2D(filters_2, 3, padding='same', use_bias=True, activation='relu')
        self.conv_2b_flow = tf.keras.layers.Convolution2D(filters_2, 3, padding='same', use_bias=True, activation='relu')
        self.conv_3a_flow = tf.keras.layers.Convolution2D(filters_3, 3, padding='same', use_bias=True, activation='relu')
        self.conv_3b_flow = tf.keras.layers.Convolution2D(filters_3, 3, padding='same', use_bias=True, activation='relu') 
        self.conv_4a_flow = tf.keras.layers.Convolution2D(filters_4, 3, padding='same', use_bias=True, activation='relu')
        self.conv_4b_flow = tf.keras.layers.Convolution2D(filters_4, 3, padding='same', use_bias=True, activation='relu')

    def __call__(self, inputs, is_training=True):
        # 128x128xch input
        rgb_out = tf.keras.layers.GaussianNoise(0.03)(tf.image.convert_image_dtype(inputs[0].to_tensor(), 
                                                                                   dtype=tf.float32),
                                                      training=is_training)
        rgb_out = tf.keras.layers.SpatialDropout2D(self.spatial_dropout_prob)(self.conv_1a(rgb_out),
                                                                            training=is_training)
        rgb_out = tf.keras.layers.MaxPool2D(strides=(2,2))(self.conv_1b(rgb_out)) # 64x64xch        
        rgb_out = tf.keras.layers.SpatialDropout2D(self.spatial_dropout_prob)(self.conv_2a(rgb_out),
                                                                            training=is_training)
        rgb_out = tf.keras.layers.MaxPool2D(strides=(2,2))(self.conv_2b(rgb_out)) # 32x32xch        
        rgb_out = tf.keras.layers.SpatialDropout2D(self.spatial_dropout_prob) (self.conv_3a(rgb_out),
                                                                            training=is_training)
        rgb_out = tf.keras.layers.MaxPool2D(strides=(2,2))(self.conv_3b(rgb_out)) # 16x16xch        
        rgb_out = self.conv_4a(rgb_out) 
        rgb_out = tf.keras.layers.MaxPool2D(strides=(2,2))(self.conv_4b(rgb_out)) # out: 8x8xch
        
        # 128x128xch input
        flow_out = tf.keras.layers.GaussianNoise(0.03)(tf.image.convert_image_dtype(inputs[1].to_tensor(), 
                                                                                    dtype=tf.float32),
                                                       training=is_training)
        flow_out = tf.keras.layers.SpatialDropout2D(self.spatial_dropout_prob)(self.conv_1a(flow_out),
                                                                            training=is_training)
        flow_out = tf.keras.layers.MaxPool2D(strides=(2,2))(self.conv_1b(flow_out)) # 64x64xch        
        flow_out = tf.keras.layers.SpatialDropout2D(self.spatial_dropout_prob)(self.conv_2a(flow_out),
                                                                            training=is_training)
        flow_out = tf.keras.layers.MaxPool2D(strides=(2,2))(self.conv_2b(flow_out)) # 32x32xch
        flow_out = tf.keras.layers.SpatialDropout2D(self.spatial_dropout_prob) (self.conv_3a(flow_out),
                                                                            training=is_training)
        flow_out = tf.keras.layers.MaxPool2D(strides=(2,2))(self.conv_3b(flow_out)) # 16x16xch
        flow_out = self.conv_4a(flow_out)
        flow_out = tf.keras.layers.MaxPool2D(strides=(2,2))(self.conv_4b(flow_out)) # out: 8x8xch
        
        return rgb_out, flow_out

In [50]:
results[1].keys()

dict_keys(['Conv3d_1a_7x7', 'MaxPool3d_2a_3x3', 'Conv3d_2b_1x1', 'Conv3d_2c_3x3', 'MaxPool3d_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'MaxPool3d_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool3d_5a_2x2'])

In [52]:
for key in results[1].keys():
    print(key, results[1][key].shape)

Conv3d_1a_7x7 (8, 24, 64, 64, 64)
MaxPool3d_2a_3x3 (8, 24, 32, 32, 64)
Conv3d_2b_1x1 (8, 24, 32, 32, 64)
Conv3d_2c_3x3 (8, 24, 32, 32, 192)
MaxPool3d_3a_3x3 (8, 24, 16, 16, 192)
Mixed_3b (8, 24, 16, 16, 256)
Mixed_3c (8, 24, 16, 16, 480)
MaxPool3d_4a_3x3 (8, 24, 8, 8, 480)
Mixed_4b (8, 24, 8, 8, 512)
Mixed_4c (8, 24, 8, 8, 512)
Mixed_4d (8, 24, 8, 8, 512)
Mixed_4e (8, 24, 8, 8, 528)
Mixed_4f (8, 24, 8, 8, 832)
MaxPool3d_5a_2x2 (8, 12, 4, 4, 832)
