<a href="https://colab.research.google.com/github/manastahir/Video-Classification/blob/trunk/video_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install av
!pip -q install einops

In [2]:
!nvidia-smi -L

GPU 0: Tesla V100-SXM2-16GB (UUID: GPU-d1af7f21-6e62-8135-6f1e-57d4d97a5dad)


#### Moving data

In [3]:
# !tar -xzf 20bn-something-something-v2.tar.gz

#### Importing libraries

In [1]:
base_dir = '/content/drive/My Drive/Projects/Temporal Conv Network'
experiment_dir = f'{base_dir}/experiments/smth-smth/experiment 0'

In [2]:
import cv2
import os
import av
import json
import random
import numbers
import collections
import numpy as np

from matplotlib import pylab as plt
from collections import namedtuple

import tensorflow as tf
import tensorflow.keras.metrics as Metrics
from einops.layers.tensorflow import Rearrange

#### Utils

In [3]:
def save_images_for_debug(dir_img, imgs):
    """
    2x3x12x224x224 --> [BS, C, seq_len, H, W]
    """
    print("Saving images to {}".format(dir_img))
    imgs = imgs.permute(0, 2, 3, 4, 1)  # [BS, seq_len, H, W, C]
    imgs = imgs.mul(255).numpy()
    if not os.path.exists(dir_img):
        os.makedirs(dir_img)
    print(imgs.shape)
    for batch_id, batch in enumerate(imgs):
        batch_dir = os.path.join(dir_img, "batch{}".format(batch_id + 1))
        if not os.path.exists(batch_dir):
            os.makedirs(batch_dir)
        for j, img in enumerate(batch):
            plt.imsave(os.path.join(batch_dir, "frame{%04d}.png" % (j + 1)),
                       img.astype("uint8"))

#### Augmenter

In [4]:
class ComposeMix(object):
    r"""Composes several transforms together. It takes a list of
    transformations, where each element odf transform is a list with 2
    elements. First being the transform function itself, second being a string
    indicating whether it's an "img" or "vid" transform
    Args:
        transforms (List[Transform, "<type>"]): list of transforms to compose.
                                                <type> = "img" | "vid"
    Example:
        >>> transforms.ComposeMix([
        [RandomCropVideo(84), "vid"],
        [torchvision.transforms.ToTensor(), "img"],
        [torchvision.transforms.Normalize(
                   mean=[0.485, 0.456, 0.406],  # default values for imagenet
                   std=[0.229, 0.224, 0.225]), "img"]
    ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, imgs):
        for t in self.transforms:
            if t[1] == "img":
                for idx, img in enumerate(imgs):
                    imgs[idx] = t[0](img)
            elif t[1] == "vid":
                imgs = t[0](imgs)
            else:
                print("Please specify the transform type")
                raise ValueError
        return imgs


class RandomCropVideo(object):
    r"""Crop the given video frames at a random location. Crop location is the
    same for all the frames.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (w, h), a square crop (size, size) is
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
        pad_method (cv2 constant): Method to be used for padding.
    """

    def __init__(self, size, padding=0, pad_method=cv2.BORDER_CONSTANT):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.padding = padding
        self.pad_method = pad_method

    def __call__(self, imgs):
        """
        Args:
            img (numpy.array): Video to be cropped.
        Returns:
            numpy.array: Cropped video.
        """
        th, tw = self.size
        h, w = imgs[0].shape[:2]
        x1 = np.random.randint(0, w - tw)
        y1 = np.random.randint(0, h - th)
        for idx, img in enumerate(imgs):
            if self.padding > 0:
                img = cv2.copyMakeBorder(img, self.padding, self.padding,
                                         self.padding, self.padding,
                                         self.pad_method)
            # sample crop locations if not given
            # it is necessary to keep cropping same in a video
            img_crop = img[y1:y1 + th, x1:x1 + tw]
            imgs[idx] = img_crop
        return imgs


class RandomHorizontalFlipVideo(object):
    """Horizontally flip the given video frames randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, imgs):
        """
        Args:
            imgs (numpy.array): Video to be flipped.
        Returns:
            numpy.array: Randomly flipped video.
        """
        if random.random() < self.p:
            for idx, img in enumerate(imgs):
                imgs[idx] = cv2.flip(img, 1)
        return imgs

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomReverseTimeVideo(object):
    """Reverse the given video frames in time randomly with a given probability.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, imgs):
        """
        Args:
            imgs (numpy.array): Video to be flipped.
        Returns:
            numpy.array: Randomly flipped video.
        """
        if random.random() < self.p:
            imgs = imgs[::-1]
        return imgs

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomRotationVideo(object):
    """Rotate the given video frames randomly with a given degree.
    Args:
        degree (float): degrees used to rotate the video
    """

    def __init__(self, degree=10):
        self.degree = degree

    def __call__(self, imgs):
        """
        Args:
            imgs (numpy.array): Video to be rotated.
        Returns:
            numpy.array: Randomly rotated video.
        """
        h, w = imgs[0].shape[:2]
        degree_sampled = np.random.choice(
                            np.arange(-self.degree, self.degree, 0.5))
        M = cv2.getRotationMatrix2D((w / 2, h / 2), degree_sampled, 1)

        for idx, img in enumerate(imgs):
            imgs[idx] = cv2.warpAffine(img, M, (w, h))

        return imgs

    def __repr__(self):
        return self.__class__.__name__ + '(degree={})'.format(self.degree_sampled)


class IdentityTransform(object):
    """
    Returns same video back
    """
    def __init__(self,):
        pass

    def __call__(self, imgs):
        return imgs


class Scale(object):
    r"""Rescale the input image to the given size.
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``cv2.INTER_LINEAR``
    """

    def __init__(self, size, interpolation=cv2.INTER_LINEAR):
        assert isinstance(size, int) or (isinstance(
            size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (numpy.array): Image to be scaled.
        Returns:
            numpy.array: Rescaled image.
        """
        if isinstance(self.size, int):
            h, w = img.shape[:2]
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                if ow < w:
                    return cv2.resize(img, (ow, oh), cv2.INTER_AREA)
                else:
                    return cv2.resize(img, (ow, oh))
            else:
                oh = self.size
                ow = int(self.size * w / h)
                if oh < h:
                    return cv2.resize(img, (ow, oh), cv2.INTER_AREA)
                else:
                    return cv2.resize(img, (ow, oh))
        else:
            return cv2.resize(img, tuple(self.size))


class Normalize(object):
    """normalize an tensor image with mean and standard deviation.
    Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel x std) + mean
    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
    """

    def __init__(self, mean, std):
        self.mean = np.array(mean).astype('float32')
        self.std = np.array(std).astype('float32')

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        tensor = (tensor-self.mean)/ self.std
        return tensor

In [5]:
class Augmentor(object):
    def __init__(self, augmentation_mappings_json=None,
                 augmentation_types_todo=None,
                 fps_jitter_factors=[1, 0.75, 0.5]):
        self.augmentation_mappings_json = augmentation_mappings_json
        self.augmentation_types_todo = augmentation_types_todo
        self.fps_jitter_factors = fps_jitter_factors

        # read json to get the mapping dict
        self.augmentation_mapping = self.read_augmentation_mapping(
                                        self.augmentation_mappings_json)
        self.augmentation_transforms = self.define_augmentation_transforms()

    def __call__(self, imgs, label):
        if not self.augmentation_mapping:
            return imgs, label
        else:
            candidate_augmentations = {"same": label}
            for candidate in self.augmentation_types_todo:
                if candidate == "jitter_fps":
                    continue
                if label in self.augmentation_mapping[candidate]:
                    if isinstance(self.augmentation_mapping[candidate], list):
                        candidate_augmentations[candidate] = label
                    elif isinstance(self.augmentation_mapping[candidate], dict):
                        candidate_augmentations[candidate] = self.augmentation_mapping[candidate][label]
                    else:
                        print("Something wrong with data type specified in "
                              "augmentation file. Please check!")
            augmentation_chosen = np.random.choice(list(candidate_augmentations.keys()))
            imgs = self.augmentation_transforms[augmentation_chosen](imgs)
            label = candidate_augmentations[augmentation_chosen]

            return imgs, label

    def read_augmentation_mapping(self, path):
        if path:
            with open(path, "rb") as fp:
                mapping = json.load(fp)
        else:
            mapping = None
        return mapping

    def define_augmentation_transforms(self, ):
        augmentation_transforms = {}
        augmentation_transforms["same"] = IdentityTransform()
        augmentation_transforms["left/right"] = RandomHorizontalFlipVideo(1)
        augmentation_transforms["left/right agnostic"] = RandomHorizontalFlipVideo(1)
        augmentation_transforms["reverse time"] = RandomReverseTimeVideo(1)
        augmentation_transforms["reverse time agnostic"] = RandomReverseTimeVideo(0.5)

        return augmentation_transforms

    def jitter_fps(self, framerate):
        if self.augmentation_types_todo and "jitter_fps" in self.augmentation_types_todo:
            jitter_factor = np.random.choice(self.fps_jitter_factors)
            return int(jitter_factor * framerate)
        else:
            return framerate

#### Data Generator

In [6]:
ListData = namedtuple('ListData', ['id', 'label', 'path'])


class DatasetBase(object):
    """
    To read json data and construct a list containing video sample `ids`,
    `label` and `path`
    """
    def __init__(self, json_path_input, json_path_labels, data_root,
                 extension, is_test=False):
        self.json_path_input = json_path_input
        self.json_path_labels = json_path_labels
        self.data_root = data_root
        self.extension = extension
        self.is_test = is_test

        # preparing data and class dictionary
        self.classes = self.read_json_labels()
        self.classes_dict = self.get_two_way_dict(self.classes)
        self.json_data = self.read_json_input()

    def read_json_input(self):
        json_data = []
        if not self.is_test:
            with open(self.json_path_input, 'rb') as jsonfile:
                json_reader = json.load(jsonfile)
                for elem in json_reader:
                    label = self.clean_template(elem['template'])
                    if label not in self.classes:
                        raise ValueError("Label mismatch! Please correct")
                    item = ListData(elem['id'],
                                    label,
                                    os.path.join(self.data_root,
                                                 elem['id'] + self.extension)
                                    )
                    json_data.append(item)
        else:
            with open(self.json_path_input, 'rb') as jsonfile:
                json_reader = json.load(jsonfile)
                for elem in json_reader:
                    # add a dummy label for all test samples
                    item = ListData(elem['id'],
                                    "Holding something",
                                    os.path.join(self.data_root,
                                                 elem['id'] + self.extension)
                                    )
                    json_data.append(item)
        return json_data

    def read_json_labels(self):
        classes = []
        with open(self.json_path_labels, 'rb') as jsonfile:
            json_reader = json.load(jsonfile)
            for elem in json_reader:
                classes.append(elem)
        return sorted(classes)

    def get_two_way_dict(self, classes):
        classes_dict = {}
        for i, item in enumerate(classes):
            classes_dict[item] = i
            classes_dict[i] = item
        return classes_dict

    def clean_template(self, template):
        """ Replaces instances of `[something]` --> `something`"""
        template = template.replace("[", "")
        template = template.replace("]", "")
        return template


class WebmDataset(DatasetBase):
    def __init__(self, json_path_input, json_path_labels, data_root,
                 is_test=False):
        EXTENSION = ".webm"
        super().__init__(json_path_input, json_path_labels, data_root,
                         EXTENSION, is_test)


class I3DFeatures(DatasetBase):
    def __init__(self, json_path_input, json_path_labels, data_root,
                 is_test=False):
        EXTENSION = ".npy"
        super().__init__(json_path_input, json_path_labels, data_root,
                         EXTENSION, is_test)


class ImageNetFeatures(DatasetBase):
    def __init__(self, json_path_input, json_path_labels, data_root,
                 is_test=False):
        EXTENSION = ".npy"
        super().__init__(json_path_input, json_path_labels, data_root,
                         EXTENSION, is_test)

In [7]:
FRAMERATE = 12  # default value
class VideoFolder(tf.keras.utils.Sequence):

    def __init__(self, root, json_file_input, json_file_labels, clip_size,
                 nclips, step_size, is_val, batch_size=32, transform_pre=None, 
                 transform_post=None, augmentation_mappings_json=None, 
                 augmentation_types_todo=None, get_item_id=False, is_test=False):
      
        self.dataset_object = WebmDataset(json_file_input, json_file_labels,
                                          root, is_test=is_test)
        self.json_data = self.dataset_object.json_data
        random.shuffle(self.json_data)
        self.classes = self.dataset_object.classes
        self.classes_dict = self.dataset_object.classes_dict
        self.root = root
        self.transform_pre = transform_pre
        self.transform_post = transform_post
        self.batch_size = batch_size

        self.augmentor = Augmentor(augmentation_mappings_json,
                                   augmentation_types_todo)

        self.clip_size = clip_size
        self.nclips = nclips
        self.step_size = step_size
        self.is_val = is_val
        self.get_item_id = get_item_id

    def __getitem__(self, index):
        """
        [!] FPS jittering doesn't work with AV dataloader as of now
        """
        batch = self.json_data[index*self.batch_size: (index+1)*self.batch_size] 
        X, Y, item_id = [], [], []

        for item in batch:
            # Open video file
            reader = av.open(item.path)

            try:
                imgs = []
                imgs = [f.to_rgb().to_ndarray() for f in reader.decode(video=0)]
            except (RuntimeError, ZeroDivisionError) as exception:
                print('{}: WEBM reader cannot open {}. Empty '
                      'list returned.'.format(type(exception).__name__, item.path))

            if(self.transform_pre is not None):
                imgs = self.transform_pre(imgs)
            
            imgs, label = self.augmentor(imgs, item.label)
            
            if(self.transform_post is not None):
                imgs = self.transform_post(imgs)

            num_frames = len(imgs)
            target_idx = self.classes_dict[label]

            if self.nclips > -1:
                num_frames_necessary = self.clip_size * self.nclips * self.step_size
            else:
                num_frames_necessary = num_frames
            offset = 0
            if num_frames_necessary < num_frames:
                # If there are more frames, then sample starting offset.
                diff = (num_frames - num_frames_necessary)
                # temporal augmentation
                if not self.is_val:
                    offset = np.random.randint(0, diff)

            imgs = imgs[offset: num_frames_necessary + offset: self.step_size]

            if len(imgs) < (self.clip_size * self.nclips):
                imgs.extend([imgs[-1]] *
                            ((self.clip_size * self.nclips) - len(imgs)))

            X.append(imgs)
            Y.append(target_idx)
            item_id.append(item.id)

        if self.get_item_id:
            return (np.array(X), np.array(Y), np.array(item_id))
        else:
            return (np.array(X), np.array(Y))

    def __len__(self):
        return int(len(self.json_data)/self.batch_size)

#### Model

In [10]:
def scaled_dot_product_attention(q, k, v):
    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output


class Attention(tf.keras.layers.Layer):
    def __init__(self, d_model):
        super(Attention, self).__init__()
        self.d_model = d_model

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)


    def call(self, k, v, q, training=True):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        # scaled_attention.shape == (batch_size, seq_len_q, d_model)
        scaled_attention = scaled_dot_product_attention(q, k, v)

        output = self.dense(scaled_attention)  # (batch_size, seq_len_q, d_model)

        return output

In [12]:
class CONVAttentionCell(tf.keras.layers.Layer):
    def __init__(self, d_model, filters,):
        super(CONVAttentionCell, self).__init__()
        
        self.f_conv = tf.keras.layers.Conv2D(filters, 11, padding='same', activation='tanh')
        self.c_conv = tf.keras.layers.Conv2D(filters, 11, padding='same', activation='tanh')
        self.attention = Attention(d_model)

    @tf.function
    def call(self, input, hidden, training=True):
        #input -> (b, h, w, c)
        shapes = tf.shape(input)
        h, c = hidden

        vec = input + h

        f = self.f_conv(vec)
        c = f+c
        c = self.c_conv(c)
        
        q = tf.reshape(input, (shapes[0], -1, shapes[-1]))
        kv = tf.reshape(c, (shapes[0], -1, shapes[-1]))

        o = self.attention(kv, kv, q)
        o = tf.reshape(o, shapes)

        return o, [o, c]

class CONVAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, filters, return_state=False, return_sequences=False, **kwargs):
        super(CONVAttention, self).__init__(**kwargs)
        
        self.filters = filters
        self.return_sequences=return_sequences
        self.return_state=return_state

        self.cell = CONVAttentionCell(d_model, filters)
        
    @tf.function
    def call(self, input, training=True):
        # input -> (b,s,h,w,c)

        shapes = tf.shape(input)
        h = tf.zeros(shape=(shapes[0], shapes[2], shapes[3], self.filters))
        c = tf.zeros(shape=(shapes[0], shapes[2], shapes[3], self.filters))
        hidden = [h,c]

        input = tf.transpose(input, (1, 0, 2, 3, 4))
        # input -> (s,b,h,w,c)

        y = tf.TensorArray(dtype=tf.float32, size=shapes[1]) #(s)
        for i in range(shapes[1]):
            out, hidden = self.cell(input[i], hidden, training)
            y = y.write(i, out)

        y = tf.transpose(y.stack(), (1, 0, 2, 3, 4)) #(b,s,h,w,c)
        if(not self.return_sequences):
            y = y[:, -1] 

        if(self.return_state):
            return hidden, y
        else:
            return y

class IdentityBlock(tf.keras.layers.Layer):
    def __init__(self, kernel_size, filters, stage, block, **kwargs):
        super(IdentityBlock, self).__init__(**kwargs)
        
        filters1, filters2, filters3 = filters
        bn_axis = 3
        conv_name_base = 'res' + str(stage) + block + '_branch'
        bn_name_base = 'bn' + str(stage) + block + '_branch'

        self.conv1 = tf.keras.layers.Conv2D(filters1, (1, 1), 
                                  kernel_initializer='he_normal',
                                  name=conv_name_base + '2a')
        self.bn1 = tf.keras.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')
        self.activ1 = tf.keras.layers.Activation('sigmoid')

        self.conv2 = tf.keras.layers.Conv2D(filters2, kernel_size,
                                   padding='same',
                                   kernel_initializer='he_normal',
                                   name=conv_name_base + '2b')
        self.bn2 = tf.keras.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')
        self.activ2 = tf.keras.layers.Activation('sigmoid')

        self.conv3 = tf.keras.layers.Conv2D(filters3, (1, 1),
                                  kernel_initializer='he_normal',
                                  name=conv_name_base + '2c')
        self.bn3 = tf.keras.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')

        self.activ3 = tf.keras.layers.Activation('sigmoid')

    @tf.function
    def call(self, x):
        shortcut = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activ1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activ2(x)

        x = self.conv3(x)
        x = self.bn3(x)

        x = tf.keras.layers.add([shortcut, x])
        x = self.activ3(x)

        return x

class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self,  kernel_size, filters, stage, block, strides=(2, 2), **kwargs):
        super(ResidualBlock, self).__init__(**kwargs)
        
        filters1, filters2, filters3 = filters
        bn_axis = 3
        conv_name_base = 'res' + str(stage) + block + '_branch'
        bn_name_base = 'bn' + str(stage) + block + '_branch'

        self.conv1 = tf.keras.layers.Conv2D(filters1, (1, 1), 
                                  strides=strides,
                                  kernel_initializer='he_normal',
                                  name=conv_name_base + '2a')
        self.bn1 = tf.keras.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')
        self.activ1 = tf.keras.layers.Activation('sigmoid')

        self.conv2 = tf.keras.layers.Conv2D(filters2, kernel_size,
                                   padding='same',
                                   kernel_initializer='he_normal',
                                  name=conv_name_base + '2b')
        self.bn2 = tf.keras.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')
        self.activ2 = tf.keras.layers.Activation('sigmoid')

        self.conv3 = tf.keras.layers.Conv2D(filters3, (1, 1),
                                  kernel_initializer='he_normal',
                                  name=conv_name_base + '2c')
        self.bn3 = tf.keras.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')

        self.short_conv = tf.keras.layers.Conv2D(filters3, (1, 1), 
                                        strides=strides,
                                        kernel_initializer='he_normal',
                                        name=conv_name_base + '1')
        self.short_bn = tf.keras.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1')
        self.activ3 = tf.keras.layers.Activation('sigmoid')

    @tf.function
    def call(self, x):
        shortcut = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activ1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activ2(x)

        x = self.conv3(x)
        x = self.bn3(x)

        shortcut = self.short_conv(shortcut)
        shortcut = self.short_bn(shortcut)
        
        x = tf.keras.layers.add([shortcut, x])
        x = self.activ3(x)

        return x

In [17]:
class TimeCNN(tf.keras.Model):
    def __init__(self, input_size, seq_len, **kwargs):
        super(TimeCNN, self).__init__(**kwargs)
        
        self.seq_len = seq_len

        self.conv_lstm_1 = CONVAttention(64*64, 64,  return_sequences=True, return_state=True)
        self.conv_lstm_2 = CONVAttention(32*32, 128, return_sequences=True, return_state=True)
        self.conv_lstm_3 = CONVAttention(16*16, 256, return_sequences=True, return_state=True)
        self.conv_lstm_4 = CONVAttention(8*8,   512, return_state=True)

        self.block1 = tf.keras.Sequential([
                                        tf.keras.layers.Conv2D(64, (7, 7),
                                                      strides=(2, 2),
                                                      padding='same',
                                                      kernel_initializer='he_normal',
                                                      name='conv1'),
                                        tf.keras.layers.BatchNormalization(axis=3, name='bn_conv1'),
                                        tf.keras.layers.Activation('relu'),
        ])

        self.block2 = tf.keras.Sequential([
                                          ResidualBlock(3, [64, 64, 128], stage=2, block='a', strides=(1, 1)),
                                          IdentityBlock(3, [64, 64, 128], stage=2, block='b'),
                                          IdentityBlock(3, [64, 64, 128], stage=2, block='c')
        ])

        self.block3 = tf.keras.Sequential([
                                          ResidualBlock(3, [128, 128, 256], stage=3, block='a'),
                                          IdentityBlock(3, [128, 128, 256], stage=3, block='b'),
                                          IdentityBlock(3, [128, 128, 256], stage=3, block='c'),
                                          IdentityBlock(3, [128, 128, 256], stage=3, block='d')
        ])

        self.block4 = tf.keras.Sequential([
                                          ResidualBlock(3, [256, 256, 512], stage=4, block='a'),
                                          IdentityBlock(3, [256, 256, 512], stage=4, block='b'),
                                          IdentityBlock(3, [256, 256, 512], stage=4, block='c'),
                                          IdentityBlock(3, [256, 256, 512], stage=4, block='d'),
                                          IdentityBlock(3, [256, 256, 512], stage=4, block='e'),
                                          IdentityBlock(3, [256, 256, 512], stage=4, block='f')
        ])

        # self.block5 = tf.keras.Sequential([
        #                                   ResidualBlock(3, [512, 512, 1024], stage=5, block='a'),
        #                                   IdentityBlock(3, [512, 512, 1024], stage=5, block='b'),
        #                                   IdentityBlock(3, [512, 512, 1024], stage=5, block='c')
        # ])

                                           
        self.flatten = tf.keras.layers.Flatten()
        self.avg_pool = tf.keras.layers.GlobalAveragePooling2D()
        self.linear1 = tf.keras.layers.Dense(2048)
        self.out = tf.keras.layers.Dense(174, activation='softmax')

        self.batch_to_time = Rearrange('(b t) h w c -> b t h w c', t=self.seq_len)
        self.time_to_batch = Rearrange('b t h w c -> (b t) h w c', t=self.seq_len)

    @tf.function
    def call(self, x, training=True):
        
        # Block1
        x = self.time_to_batch(x)
        x = self.block1(x)
        x = self.batch_to_time(x)
        print(x.shape)

        h1, x = self.conv_lstm_1(x)
        h1 = [self.flatten(self.avg_pool(h1[1]))]

        # Block2
        x = self.time_to_batch(x)
        x = self.block2(x)
        x = self.batch_to_time(x)
        print(x.shape)

        h2, x = self.conv_lstm_2(x)
        h2 = [self.flatten(self.avg_pool(h2[1]))]

        # Block3
        x = self.time_to_batch(x)
        x = self.block3(x)
        x = self.batch_to_time(x)
        
        h3, x = self.conv_lstm_3(x)
        h3 = [self.flatten(self.avg_pool(h3[1]))]

        # Block4
        x = self.time_to_batch(x)
        x = self.block4(x)
        x = self.batch_to_time(x)

        h4, x = self.conv_lstm_4(x)
        h4 = [self.flatten(self.avg_pool(h4[1]))]
        
        # x = self.block5(x)
        # x = [self.flatten(self.avg_pool(x))]

        x = tf.concat(h1 + h2 + h3 + h4, axis=-1)

        x = self.linear1(x)
        x = self.out(x)

        return x

In [18]:
model = TimeCNN(128, 36)
yhat = model(np.random.randn(2, 36, 128, 128, 3))
model.summary()

(2, 36, 64, 64, 64)


ValueError: ignored

In [14]:
yhat.shape

TensorShape([2, 174])

#### Callbacks

In [10]:
callbacks = [tf.keras.callbacks.ModelCheckpoint(f'{experiment_dir}/best_model.h5', 
                                                save_best_only=True, save_weights_only=True, verbose=1, 
                                                monitor='loss', mode='min'),
             tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', patience=2),
             tf.keras.callbacks.EarlyStopping(monitor='top_5', patience=10, mode='max'),
             tf.keras.callbacks.TensorBoard(log_dir=f'{experiment_dir}/logs')]

#### Init Data Generators

In [22]:
upscale_size = int(128* 1.1)
transform_pre = ComposeMix([
        [Scale(upscale_size), "img"],
        [RandomCropVideo(128), "vid"],
        [tf.keras.applications.resnet.preprocess_input,"img"],
        [Normalize((0,0,0), (255,255,255)), "img"]
          ])

# identity transform

train_loader = VideoFolder(root="/content/20bn-something-something-v2",
                      json_file_input="/content/meta/something-something-v2-train.json",
                      json_file_labels="/content/meta/something-something-v2-labels.json",
                      clip_size=36,
                      nclips=1,
                      step_size=2,
                      is_val=False,
                      transform_pre=transform_pre,
                      batch_size=4,
                      )

val_loader = VideoFolder(root="/content/20bn-something-something-v2",
                      json_file_input="/content/meta/something-something-v2-validation.json",
                      json_file_labels="/content/meta/something-something-v2-labels.json",
                      clip_size=36,
                      nclips=1,
                      step_size=2,
                      is_val=True,
                      transform_pre=transform_pre,
                      batch_size=2,
                      )

In [23]:
x, y = train_loader[0]

In [24]:
x.shape, y

((4, 128, 4608, 3), array([119, 167, 129,   0]))

In [25]:
x.min(), x.max()

(-0.4850196, 0.5923961)

#### Training

In [20]:
model.compile(optimizer=tf.keras.optimizers.SGD(lr=1e-3), loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
              metrics=[Metrics.SparseTopKCategoricalAccuracy(k=1, name='top_1'),
                       Metrics.SparseTopKCategoricalAccuracy(k=5, name='top_5'),
                       ])

In [None]:
history = model.fit(train_loader, 
                    epochs = 100, 
                    callbacks = callbacks,
                    workers = 6
                    ) 