In [1]:
import pickle, sys
from tqdm import tqdm
import torch, torchvision

from datasets import Fashion200k
from img_text_composition_models import TIRG

In [2]:
embed_dim = 512
dataset_path = '/Users/bo/Downloads/200k/'
model_path = '/Users/bo/Downloads/checkpoint_fashion200k.pth'
text_path = '/Users/bo/Downloads/texts.pkl'

In [3]:
trainset = Fashion200k(path=dataset_path, 
                       split='train',        
                        transform=torchvision.transforms.Compose([
                        torchvision.transforms.Resize(224),
                        torchvision.transforms.CenterCrop(224),
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])
                    ]))

testset = Fashion200k(path=dataset_path, 
                       split='test',        
                        transform=torchvision.transforms.Compose([
                        torchvision.transforms.Resize(224),
                        torchvision.transforms.CenterCrop(224),
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])
                    ]))

texts =  trainset.get_all_texts()

read pants_train_detect_all.txt
read dress_train_detect_all.txt
read jacket_train_detect_all.txt
read skirt_train_detect_all.txt
read top_train_detect_all.txt
Fashion200k: 172049 images
53099 unique cations
Modifiable images 106464
read top_test_detect_all.txt
read pants_test_detect_all.txt
read dress_test_detect_all.txt
read skirt_test_detect_all.txt
read jacket_test_detect_all.txt
Fashion200k: 29789 images


In [4]:
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import os
import sys
import pickle
from typing import List

import numpy as np
from jina.executors.devices import TorchDevice
from jina.excepts import PretrainedModelFileDoesNotExist
from jina.executors.decorators import batching_multi_input, as_ndarray
from jina.executors.encoders.multimodal import BaseMultiModalEncoder

# sys.path.append(".")
from img_text_composition_models import TIRG

class TirgMultiModalEncoder(TorchDevice, BaseMultiModalEncoder):

    def __init__(self, model_path: str,
                 texts_path: str,
                 positional_modality: List[str] = ['visual', 'textual'],
                 channel_axis: int = -1, 
                 *args, **kwargs):
        """
        :param model_path: the path where the model is stored.
        """
        super().__init__(*args, **kwargs)
        self.model_path = model_path
        self.texts_path = texts_path
        self.positional_modality = positional_modality
        self.channel_axis = channel_axis
        # axis 0 is the batch
        self._default_channel_axis = 1

    def post_init(self):
        super().post_init()
        import torch
        if self.model_path and os.path.exists(self.model_path):
            with open (self.texts_path, 'rb') as fp:
                texts = pickle.load(fp)
            self.model = TIRG(texts, 512)
            model_sd = torch.load(self.model_path, map_location=torch.device('cpu'))
            self.model.load_state_dict(model_sd['model_state_dict'])
            self.model.eval()
            self.to_device(self.model)
        else:
            raise PretrainedModelFileDoesNotExist(f'model {self.model_path} does not exist')

    def _get_features(self, data):
        visual_data = data[(self.positional_modality.index('image'))]
        if self.channel_axis != self._default_channel_axis:
            visual_data = np.moveaxis(visual_data, self.channel_axis, self._default_channel_axis)
        textual_data = data[(self.positional_modality.index('text'))]
        
        visual_data = torch.stack(visual_data).float()

        if self.on_gpu:
            visual_data = visual_data.cuda()
            textual_data = textual_data.cuda()
            
        img_features = self.model.extract_img_feature(visual_data)
        text_features = self.model.extract_text_feature(textual_data)
        
        return self.model.compose_img_text_features(img_features, text_features)

    @batching_multi_input
    @as_ndarray
    def encode(self, *data: 'np.ndarray', **kwargs) -> 'np.ndarray':
        import torch
        feature = self._get_features(*data).detach()
        if self.on_gpu:
            feature = feature.cpu()
        feature = feature.numpy()
        return feature



In [5]:
encoder = TirgMultiModalEncoder(model_path, text_path, positional_modality = ['image', 'text'], channel_axis=1)

TirgMultiModalEncoder@43638[I]:[37mpost initiating, this may take some time...[0m
TirgMultiModalEncoder@43638[I]:[37mpost initiating, this may take some time takes 1 second (1.87s)[0m


## Example, encode one single image

In [6]:
data = []
data.append([trainset.get_img(0)]) # image at position 0
data.append([texts[0]]) # text at position 1
encoded_multimodal = encoder.encode(data)

In [7]:
encoded_multimodal.shape

(1, 512)

## Encode a batch of images

In [8]:
batch_size=32
fashion_200k_loader = trainset.get_loader(batch_size=batch_size)
for batch in fashion_200k_loader:
    # use multimodal encoder
    data = []
    assert len(batch) == batch_size
    batch_of_imgs = [item['source_img_data'] for item in batch]
    batch_of_text = [item['source_caption'] for item in batch]
    data.append(batch_of_imgs)
    data.append(batch_of_text)
    assert len(data) == 2
    encoded_batch = encoder.encode(data)
    assert len(encoded_batch) == batch_size
    break

## Test

1. Ensure encoder works the same as the original model at instance level
2. Ensure encoder works the same as the original model at batch level

In [9]:
# Initialize the model
model = TIRG(texts, embed_dim)
model_sd = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(model_sd['model_state_dict'])
_ = model.eval()

In [12]:
# Ensure encoded result is correct at instance level
for i in range(10):
    print(f"testing multimodal encoder with img {i}")
    # extract feature via jina encoder
    data = []
    data.append([trainset.get_img(i)]) # visual at position 0
    data.append([texts[i]]) # textual at position 1
    encoded = encoder.encode(data)
    # extract image text feature
    text_feature = model.extract_text_feature([texts[i]])
    img = [trainset.get_img(i)]
    img = torch.stack(img).float()
    img_feature =  model.extract_img_feature(img)
    extracted = model.compose_img_text_features(img_feature, text_feature).cpu().detach().numpy()
    assert encoded.shape == extracted.shape
    assert encoded.all() == extracted.all()

testing multimodal encoder with img 0
testing multimodal encoder with img 1
testing multimodal encoder with img 2
testing multimodal encoder with img 3
testing multimodal encoder with img 4
testing multimodal encoder with img 5
testing multimodal encoder with img 6
testing multimodal encoder with img 7
testing multimodal encoder with img 8
testing multimodal encoder with img 9


In [22]:
# Ensure encoded result is correct at batch level
batch_size=64
fashion_200k_loader = trainset.get_loader(batch_size=batch_size)
for batch in tqdm(fashion_200k_loader):
    print(f"testing multimodal encoder with batch size {batch_size}")
    # use multimodal encoder
    data = []
    assert len(batch) == batch_size
    batch_of_imgs = [item['source_img_data'] for item in batch]
    batch_of_text = [item['source_caption'] for item in batch]
    data.append(batch_of_imgs)
    data.append(batch_of_text)
    assert len(data) == 2
    encoded_batch = encoder.encode(data)
    # use the original model
    batch_of_text_features = model.extract_text_feature(batch_of_text)
    batch_of_imgs = torch.stack(batch_of_imgs).float()
    batch_of_img_features =  model.extract_img_feature(batch_of_imgs)
    extracted_batch = model.compose_img_text_features(batch_of_img_features, batch_of_text_features).cpu().detach().numpy()
    assert len(extracted_batch) == batch_size
    assert extracted_batch.all() == encoded_batch.all()
    break

  0%|          | 0/2689 [00:00<?, ?it/s]

testing multimodal encoder with batch size 64


  0%|          | 0/2689 [00:11<?, ?it/s]
