In [1]:
import os
os.chdir('TirgEncoder')

import pickle, sys
import torch, torchvision
from img_text_composition_models import TIRG
from datasets import Fashion200k

os.chdir('..')

In [2]:
embed_dim = 512
dataset_path = 'data-all/fashion-200k/'

In [3]:
trainset = Fashion200k(path=dataset_path, 
                       split='train',        
                        transform=torchvision.transforms.Compose([
                        torchvision.transforms.Resize(224),
                        torchvision.transforms.CenterCrop(224),
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])
                    ]))

read pants_train_detect_all.txt
read dress_train_detect_all.txt
read jacket_train_detect_all.txt
read skirt_train_detect_all.txt
read top_train_detect_all.txt
Fashion200k: 172049 images
53099 unique cations
Modifiable images 106464


In [4]:
# texts = [t for t in trainset.get_all_texts()]

In [5]:
# with open('texts.pkl', 'wb') as fp:
#     pickle.dump(texts, fp, protocol=pickle.HIGHEST_PROTOCOL)
    
# len(texts)

In [6]:
with open ('texts.pkl', 'rb') as fp:
    texts = pickle.load(fp)

In [7]:
model = TIRG(texts, embed_dim)

In [8]:
model_sd = torch.load('checkpoint_fashion200k.pth', map_location=torch.device('cpu'))

In [9]:
model.load_state_dict(model_sd['model_state_dict'])
_ = model.eval()

In [10]:
# import json

# with open('vocab.json', 'w') as fp:
#     json.dump(model.text_model.vocab.wordcount, fp)

In [11]:
def get_img(self, idx, raw_img=False):
    img_path = self.img_path + self.imgs[idx]['file_path']
    with open(img_path, 'rb') as f:
        img = PIL.Image.open(f)
        img = img.convert('RGB')
    if raw_img:
        return img
    if self.transform:
        img = self.transform(img)
    return img

In [12]:
imgs = [trainset.get_img(0)]
imgs = torch.stack(imgs).float()
print(imgs.shape)
img_feature = model.extract_img_feature(imgs)
print(img_feature.shape)

torch.Size([1, 3, 224, 224])
torch.Size([1, 512])


In [13]:
text_feature = model.extract_text_feature(['hello'])
print(text_feature.shape)

torch.Size([1, 512])


In [14]:
img_text_features = model.compose_img_text_features(img_feature, text_feature)
img_text_features.shape

torch.Size([1, 512])

## Image encoder

In [15]:
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import os
import numpy as np
import pickle

from jina.executors.decorators import batching, as_ndarray
from jina.executors.encoders.frameworks import BaseTorchEncoder
from jina.excepts import PretrainedModelFileDoesNotExist

from img_text_composition_models import TIRG


class TirgEncoder(BaseTorchEncoder):

    def __init__(self, model_path: str,
                 texts_path: str,
                 channel_axis: int = -1, 
                 *args, **kwargs):
        """
        :param model_path: the path where the model is stored.
        """
        super().__init__(*args, **kwargs)
        self.model_path = model_path
        self.texts_path = texts_path
        self.channel_axis = channel_axis
        # axis 0 is the batch
        self._default_channel_axis = 1

    def post_init(self):
        super().post_init()
        import torch
        if self.model_path and os.path.exists(self.model_path):
            with open (self.texts_path, 'rb') as fp:
                texts = pickle.load(fp)
            self.model = TIRG(texts, 512)
            model_sd = torch.load(self.model_path, map_location=torch.device('cpu'))
            self.model.load_state_dict(model_sd['model_state_dict'])
            self.model.eval()
            self.to_device(self.model)
        else:
            raise PretrainedModelFileDoesNotExist(f'model {self.model_path} does not exist')

    def _get_features(self, data):
        return self.model.extract_img_feature(data)

    @batching
    @as_ndarray
    def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
        if self.channel_axis != self._default_channel_axis:
            data = np.moveaxis(data, self.channel_axis, self._default_channel_axis)
        import torch
        _input = torch.from_numpy(data.astype('float32'))
        if self.on_gpu:
            _input = _input.cuda()
        _feature = self._get_features(_input).detach()
        if self.on_gpu:
            _feature = _feature.cpu()
        _feature = _feature.numpy()
        return _feature

In [18]:
encoder = TirgEncoder('checkpoint_fashion200k.pth', 'texts.pkl', channel_axis=1)

    TirgEncoder@9822[I]:[37mpost initiating, this may take some time...[0m
    TirgEncoder@9822[I]:[37mpost initiating, this may take some time takes 1 second (1.10s)[0m


In [19]:
encoder.encode(imgs.numpy()).shape

(1, 512)

## Multimodal

In [23]:
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import os
import sys
import numpy as np
import pickle
from typing import List

from jina.executors.decorators import batching, as_ndarray
from jina.executors.encoders.multimodal import BaseMultiModalEncoder
from jina.excepts import PretrainedModelFileDoesNotExist
from jina.executors.devices import TorchDevice

# sys.path.append(".")
from img_text_composition_models import TIRG

class TirgMultiModalEncoder(TorchDevice, BaseMultiModalEncoder):

    def __init__(self, model_path: str,
                 texts_path: str,
                 positional_modality: List[str] = ['visual', 'textual'],
                 channel_axis: int = -1, 
                 *args, **kwargs):
        """
        :param model_path: the path where the model is stored.
        """
        super().__init__(*args, **kwargs)
        self.model_path = model_path
        self.texts_path = texts_path
        self.positional_modality = positional_modality
        self.channel_axis = channel_axis
        # axis 0 is the batch
        self._default_channel_axis = 1

    def post_init(self):
        super().post_init()
        import torch
        if self.model_path and os.path.exists(self.model_path):
            with open (self.texts_path, 'rb') as fp:
                texts = pickle.load(fp)
            self.model = TIRG(texts, 512)
            model_sd = torch.load(self.model_path, map_location=torch.device('cpu'))
            self.model.load_state_dict(model_sd['model_state_dict'])
            self.model.eval()
            self.to_device(self.model)
        else:
            raise PretrainedModelFileDoesNotExist(f'model {self.model_path} does not exist')

    def _get_features(self, data):
        visual_data = data[(self.positional_modality.index('visual'))]
        if self.channel_axis != self._default_channel_axis:
            visual_data = np.moveaxis(visual_data, self.channel_axis, self._default_channel_axis)
        textual_data = data[(self.positional_modality.index('textual'))]
        
        visual_data = torch.stack(visual_data).float()

        if self.on_gpu:
            visual_data = visual_data.cuda()
            textual_data = textual_data.cuda()
            
        img_features = self.model.extract_img_feature(visual_data)
        text_features = self.model.extract_text_feature(textual_data)
        
        return self.model.compose_img_text_features(img_features, text_features)

    @batching
    @as_ndarray
    def encode(self, *data: 'np.ndarray', **kwargs) -> 'np.ndarray':
        import torch
        feature = self._get_features(*data).detach()
        if self.on_gpu:
            feature = feature.cpu()
        feature = feature.numpy()
        return feature

In [None]:
mutimodal_encoder = TirgMultiModalEncoder(
    'checkpoint_fashion200k.pth',
    'texts.pkl',
    positional_modality = ['visual', 'textual'],
    channel_axis=1
)

In [None]:
data = []
data.append([trainset.get_img(0)]) # visual at position 0
data.append([texts[0]]) # textual at position 1

In [None]:
encoded_multimodal = mutimodal_encoder.encode(data)
print(encoded_multimodal.shape)

### Test TIRG encoder and TIRG Multimodal encoder

In [None]:
encoder = TirgEncoder('checkpoint_fashion200k.pth', 'texts.pkl', channel_axis=1)
for i in range(10):
    print(f"testing image encoder with img {i}")
    img = [trainset.get_img(i)]
    img = torch.stack(img).float()
    # extract feature via jina encoder
    encoded = encoder.encode(img.numpy())
    # extract feature from model directly
    extracted =  model.extract_img_feature(img).cpu().detach().numpy()
    assert encoded.shape == extracted.shape
    assert encoded.all() == extracted.all()

mutimodal_encoder = TirgMultiModalEncoder('checkpoint_fashion200k.pth','texts.pkl',positional_modality = ['visual', 'textual'], channel_axis=1)
for i in range(10):
    print(f"testing multimodal encoder with img {i}")
    # extract feature via jina encoder
    data = []
    data.append([trainset.get_img(i)]) # visual at position 0
    data.append([texts[i]]) # textual at position 1
    encoded = mutimodal_encoder.encode(data)
    # extract image text feature
    text_feature = model.extract_text_feature([texts[i]])
    img = [trainset.get_img(i)]
    img = torch.stack(img).float()
    img_feature =  model.extract_img_feature(img)
    extracted = model.compose_img_text_features(img_feature, text_feature).cpu().detach().numpy()
    assert encoded.shape == extracted.shape
    assert encoded.all() == extracted.all()