# BigTransfer (BiT) for Video Semantic Search in Large Scale

By Han Xiao

In this colab, we will show you how to load one of the BiT models (a ResNet50 trained on ImageNet-21k) and [Jina](https://get.jina.ai) to build a video semantic search system. 

## Dataset

The data we are using is [Tumblr GIF (TGIF) dataset](http://raingo.github.io/TGIF-Release/), which contains 100K animated GIFs and 120K sentences describing visual contents. Our problem is the following: **given a video database and a query video, find the top-k semantically related videos from the database.**

|  |   | |
|:---:|:---:|:---:|
| ![](https://hanxiao.io/2019/11/22/Video-Semantic-Search-in-Large-Scale-using-GNES-and-TF-2-0/tumblr_njqj3bMKQF1unc0x7o1_250.gif)| ![](https://hanxiao.io/2019/11/22/Video-Semantic-Search-in-Large-Scale-using-GNES-and-TF-2-0/tumblr_ni35trgNe41tmk5mfo1_400.gif) | ![](https://hanxiao.io/2019/11/22/Video-Semantic-Search-in-Large-Scale-using-GNES-and-TF-2-0/tumblr_nb2mucKMeU1tkz79uo1_250.gif) |
|A well-dressed young guy with gelled red hair glides across a room and scans it with his eyes. | a woman in a car is singing. | a man wearing a suit smiles at something in the distance. |

## Problem Formulation 

“Semantic” is a casual and ambiguous word, I know. Depending on your applications and scenarios, it could mean motion-wise similar (sports video), emotional similar (e.g. memes), etc. Right now I will just consider semantically-related as as visually similar.

Text descriptions of the videos, though potentially can be very useful, are ignored at the moment. We are not building a cross-modality search solution (e.g. from text to video or vice versa), we also do not leverage textual information when building the video search solution. Nonetheless, those text descriptions can be used to evaluate/compare the effectiveness of the system in a quantitative manner.

Putting the problem into the neural search framework, this breaks down into the following steps:

> **Index time**
1. segment each video into workable semantic units (aka ["Chunk"](https://github.com/jina-ai/jina/tree/master/docs/chapters/101#document--chunk));
2. encode each chunk as a fixed-length vector;
3. store all vector representations in a vector database.

> **Query time**
1. do steps `1`,`2` in the index time for each incoming query;
2. retrieve relevant chunks from database;
3. aggregate the chunk-level score back to document-level;
4. return the top-k results to users.   


In [0]:
#@title Imports
import tensorflow as tf
import tensorflow_hub as hub

import tensorflow_datasets as tfds

import time

from PIL import Image
import requests
from io import BytesIO

import matplotlib.pyplot as plt
import numpy as np

import os
import pathlib

## Preprocessing Videos

A good neural search is only possible when document and query are comparable semantic units. The preprocessor serves exactly this purpose. It segments a document into a list of semantic units, each of which is called a "chunk" in Jina. For video, a meaningful unary chunk could a *frame* or a *shot* (i.e. a series of frames that runs for an uninterrupted period of time). In Tumblr GIF dataset, most of the animations have less than three shots. Thus, I will simply use frame as chunk to represent document. 



In [0]:
import io
from typing import Dict

import numpy as np
from PIL import Image
from gif_reader import get_frames
from jina.executors.crafters import BaseDocCrafter
from jina.executors.crafters import BaseSegmenter


class GifNameRawSplit(BaseDocCrafter):

    def craft(self, raw_bytes, *args, **kwargs) -> Dict:
        file_name, raw_bytes = raw_bytes.split(b'JINA_DELIM')
        return dict(raw_bytes=raw_bytes, meta_info=file_name)


class GifPreprocessor(BaseSegmenter):
    def __init__(self, img_shape: int = 96, every_k_frame: int = 1, max_frame: int = None, from_bytes: bool = False,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.img_shape = img_shape
        self.every_k_frame = every_k_frame
        self.max_frame = max_frame
        self.from_bytes = from_bytes

    def craft(self, raw_bytes, doc_id):
        result = []
        try:
            if self.from_bytes:
                im = Image.open(io.BytesIO(raw_bytes))
            else:
                im = Image.open(raw_bytes.decode())
            idx = 0
            for frame in get_frames(im):
                try:

                    if idx % self.every_k_frame == 0 and (
                            (self.max_frame is not None and idx < self.max_frame) or self.max_frame is None):
                        new_frame = frame.convert('RGB').resize([self.img_shape, ] * 2)
                        img = (np.array(new_frame) / 255).astype(np.float32)
                        # build chunk next, if the previous fail, then no chunk will be add
                        result.append(dict(doc_id=doc_id, offset=idx,
                                           weight=1., blob=img))
                except Exception as ex:
                    self.logger.error(ex)
                finally:
                    idx = idx + 1

            return result

        except Exception as ex:
            self.logger.error(ex)


This preprocessor loads the animation, reads its frames into RGB format, resizes each of them to 96x96 and stores in `doc.chunks.blob` as `numpy.ndarray`. At the moment we don't implement any keyframe detection in the preprocessor, so every chunk has a uniform weight, i.e. `c.weight=1`. 

![](https://hanxiao.io/2019/11/22/Video-Semantic-Search-in-Large-Scale-using-GNES-and-TF-2-0/tumblr_njqj3bMKQF1unc0x7o1_250.gif.jpg)
![](https://hanxiao.io/2019/11/22/Video-Semantic-Search-in-Large-Scale-using-GNES-and-TF-2-0/tumblr_ni35trgNe41tmk5mfo1_400.gif.jpg)
![](https://hanxiao.io/2019/11/22/Video-Semantic-Search-in-Large-Scale-using-GNES-and-TF-2-0/tumblr_nb2mucKMeU1tkz79uo1_250.gif.jpg)

One may think of more sophisticated preprocessors. For example, smart sub-sampling to reduce the number of near-duplicated frames; using [seam carving](http://en.wikipedia.org/wiki/Seam_carving) for better cropping and resizing frames; or adding image effects and enhancements. Everything is possible and I will leave these possibilities to the readers.

## Using BiT to Encode Chunks into Vectors

In the encoding step, we want to represent each chunk by a fixed-length vector. This can be easily done with the [pretrained models provided by BiT](https://github.com/google-research/big_transfer).

Jina already supports BiT, its implementation is as simple as below:

In [0]:
class BiTImageEncoder(BaseCVTFEncoder):
    def __init__(self, model_path: str, channel_axis: int = -1, *args, **kwargs):
        """
        :param model_path: the path of the model in the `SavedModel` format. `model_path` should be a directory path,
            which has the following structure. The pretrained model can be downloaded at
            wget https://storage.googleapis.com/bit_models/Imagenet21k/[model_name]/feature_vectors/saved_model.pb
            wget https://storage.googleapis.com/bit_models/Imagenet21k/[model_name]/feature_vectors/variables/variables.data-00000-of-00001
            wget https://storage.googleapis.com/bit_models/Imagenet21k/[model_name]/feature_vectors/variables/variables.index

            ``[model_name]`` includes `R50x1`, `R101x1`, `R50x3`, `R101x3`, `R152x4`

            .. highlight:: bash
            .. code-block:: bash

                .
                ├── saved_model.pb
                └── variables
                    ├── variables.data-00000-of-00001
                    └── variables.index

        :param channel_axis: the axis id of the channel, -1 indicate the color channel info at the last axis.
                If given other, then ``np.moveaxis(data, channel_axis, -1)`` is performed before :meth:`encode`.
        """
        super().__init__(*args, **kwargs)
        self.channel_axis = channel_axis
        self.model_path = model_path

    def post_init(self):
        self.to_device()
        import tensorflow as tf
        _model = tf.saved_model.load(self.model_path)
        self.model = _model.signatures['serving_default']
        self._get_input = tf.convert_to_tensor

    @batching
    @as_ndarray
    def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
        if self.channel_axis != -1:
            data = np.moveaxis(data, self.channel_axis, -1)
        _output = self.model(self._get_input(data.astype(np.float32)))
        return _output['output_1'].numpy()


The key function `encode()` is simply calling the model to extract features. The `batching` decorator is a very handy helper to control the size of the data flowing into the encoder. After all, OOM error is the last thing you want to see. 

### Download Pre-trained BiT model

In this tutorial, we will use a ResNet50x1 model trained on ImageNet-21k. And we are using the feature extractor version of it.

```bash
#!/usr/bin/env bash

MODEL_NAME="R50x1"
MODEL_DIR="pretrained"
MODEL_VAR_DIR=$MODEL_DIR/variables
mkdir -p ${MODEL_DIR}
mkdir -p ${MODEL_VAR_DIR}

curl https://storage.googleapis.com/bit_models/Imagenet21k/${MODEL_NAME}/feature_vectors/saved_model.pb --output ${MODEL_DIR}/saved_model.pb

curl https://storage.googleapis.com/bit_models/Imagenet21k/${MODEL_NAME}/feature_vectors/variables/variables.data-00000-of-00001 --output ${MODEL_VAR_DIR}/variables.data-00000-of-00001

curl https://storage.googleapis.com/bit_models/Imagenet21k/${MODEL_NAME}/feature_vectors/variables/variables.index --output ${MODEL_VAR_DIR}/variables.index
```

Note that now all necessary piecese are in `pretrained`. Using BiT as encoder is extremely easy, you only need to write `encode.yml` file as follows:
```yaml
!BiTImageEncoder
with:
  model_path: pretrained
  pool_strategy: avg
```

## Indexing Chunks and Documents

For indexing, I will use the built-in chunk indexers and document indexers of Jina. Chunk indexing is essentially vector indexing, we need to store a map of chunk ids and their corresponding vector representations. Simply write a YAML config `chunk.yml` as follows:
```yaml
!ChunkIndexer
components:
  - !NumpyIndexer
    with:
      index_filename: vec.gz
    metas:
      name: vecidx  # a customized name
      workspace: $TEST_WORKDIR
  - !BasePbIndexer
    with:
      index_filename: chunk.gz
    metas:
      name: chunkidx
      workspace: $TEST_WORKDIR
metas:
  name: chunk_compound_indexer
  workspace: $TEST_WORKDIR
```

As eventually in the query time, we are interested in documents not chunks, hence the map of doc id and chunk ids should be also stored. This is essentially a key-value database, and a simple Python `Dict` structure will do the job. Again, only a YAML config `doc.yml` is required:

```yaml
!DocPbIndexer
with:
  index_filename: doc.gzip
metas:
  name: doc_indexer  # a customized name
  workspace: $TEST_WORKDIR
```

Note that the doc indexer does not require the encoding step, thus it can be done in parallel with the chunk indexer.

## Putting Everything Together

### Index Flow

```yaml
!Flow
with:
  logserver: true
pods:
  chunk_seg:
    yaml_path: craft/index-craft.yml
    replicas: $REPLICAS
    read_only: true
  doc_idx:
    yaml_path: index/doc.yml
  tf_encode:
    yaml_path: encode/encode.yml
    needs: chunk_seg
    replicas: $REPLICAS
    read_only: true
  chunk_idx:
    yaml_path: index/chunk.yml
    replicas: $SHARDS
    separated_workspace: true
  join_all:
    yaml_path: _merge
    needs: [doc_idx, chunk_idx]
    read_only: true
```

### Query Flow

```yaml
!Flow
with:
  logserver: true
  read_only: true  # better add this in the query time
pods:
  chunk_seg:
    yaml_path: craft/index-craft.yml
    replicas: $REPLICAS
  tf_encode:
    yaml_path: encode/encode.yml
    replicas: $REPLICAS
  chunk_idx:
    yaml_path: index/chunk.yml
    replicas: $SHARDS
    separated_workspace: true
    polling: all
    reducing_yaml_path: _merge_topk_chunks
    timeout_ready: 100000 # larger timeout as in query time will read all the data
  ranker:
    yaml_path: BiMatchRanker
  doc_idx:
    yaml_path: index/doc.yml
```

## Full Example

The full example and results can be [found in here](https://github.com/jina-ai/examples/tree/master/tumblr-gif-search).