# Reproducing TIRG

> Refer to《Composing Text and Image for Image Retrieval - An Empirical Odyssey》

In this work, we consider the case where queries are formulated as an input image plus a text string that describes some desired modification to the image.

The steps in the process are as follows:

1. Data Loading ;  
2. Model Building ;  
3. Model Training ;  
4. Model Evaluating ;  
5. Image Retrieving  

Collect the configuration information of environment  
Our environments are as follows:

* Python version: 3.7.10  
* PyTorch version: 1.8.1  
* CUDA used to build PyTorch: 10.2  
* pytorch-lightning: 1.2.5


In [1]:
from torch.utils import collect_env
collect_env.main()

Collecting environment information...
PyTorch version: 1.8.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 家庭中文版
GCC version: Could not collect
Clang version: 10.0.0 
CMake version: version 3.19.6

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2070 Super
Nvidia driver version: 461.72
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.2
[pip3] pytorch-lightning==1.2.5
[pip3] torch==1.8.1
[pip3] torchmetrics==0.2.0
[pip3] torchvision==0.9.1
[conda] blas                      2.108                       mkl    conda-forge
[conda] blas-devel                3.9.0                     8_mkl    conda-forge
[conda] cudatoolkit               10.2.89              hb195166_8    conda-forge
[conda] libblas                   3.9.0                 

## DATASET
### Fashion IQ
The dataset contains diverse fashion images (**dresses, shirts, and tops&tees**), side information in form of textual descriptions and product meta-data, attribute labels, and most importantly, large-scale annotations of high quality relative captions collected from human annotators.  

Refer to 《The Fashion IQ Dataset: Retrieving Images by Combining Side Information and Relative Natural Language Feedback》and [Fashion IQ 数据集介绍及处理](https://invisprints.vercel.app/fashion-iq)

### Import relvant modules

In [1]:
import os
import json
import torch
import torchvision.transforms as T
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification

>torch.utils.data.Dataset  

An abstract class representing a Dataset.  

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite **\_\_getitem\_\_()**, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite **\_\_len\_\_()**, which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.  

### Create a class for FashionIQ Dataset, read the caption content in \_\_init\_\_()  and load the image in \_\_getitem\_\_() . This is done to save memory usage, loading images when needed instead of loading them at first. The data will be organized into dict as {'c_img': c_img, 't_img': t_img, 'encoded_caption': encoded_caption}.

In [2]:
class FashionIQDataset(torch.utils.data.Dataset):
    
    def __init__(self, data_path, split, **kwargs):
        super().__init__()

        self.data_path = data_path
        self.transform = kwargs.get("transform", None)
        self.test_targets = kwargs.get("test_targets", False)
        self.split = split
        self.data_name = kwargs.get('data_name', ['dress', 'shirt', 'toptee'])    #self.data_name=['dress', 'shirt', 'toptee']
        
        tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')   #Load pretrained model
        self.tokenizer = tokenizer
        self.data = []
        
        #using split as a flag to train or test
        if split is 'train': # train dataset data = [dress + shirt + toptee]
            assert isinstance(self.data_name, list), 'data_name must be list, get {}'.format(type(self.data_name))
            for data_name in self.data_name:
                cap_file = '{}/captions/cap.{}.{}.json'.format(data_path, data_name, split)
                with open(cap_file, 'r') as f: data = json.load(f)
                for d in tqdm(data, desc='loading {} {} data'.format(data_name, split)):
                    c_name = d['candidate']
                    t_name = d['target']
                    text = [x.strip() for x in d['captions']]
                    text = ' [SEP] '.join(text)
                    #e.g. is solid black with no sleeves [SEP] is black with straps
                    self.data.append({'c_path': os.path.join(data_path, 'resized_images', data_name, c_name+'.jpg'),
                                      't_path': os.path.join(data_path, 'resized_images', data_name, t_name+'.jpg'),
                                      'encoded_caption': tokenizer(text)}) 
                    # Don't use pt here,for pt is 2-d array,but tokenizer.pad can only joint 1-d tensor
                              
        else: # test part, one of 'dress', 'shirt' and 'toptee'
            assert isinstance(self.data_name, str), 'data_name must be string, get {}'.format(type(self.data_name))
            split_file = '{}/image_splits/split.{}.{}.json'.format(data_path, self.data_name, split)
            with open(split_file, 'r') as f: self.names = json.load(f)
                              
            if self.test_targets: # test_targets
                for i, name in enumerate(tqdm(self.names, desc='loading {} {} pool data'.format(self.data_name, split))):
                    self.data.append({'t_id': i,
                                      't_path': os.path.join(data_path, 'resized_images', self.data_name, name+'.jpg')})
            else: # test_quaries 
                cap_file = '{}/captions/cap.{}.{}.json'.format(data_path, self.data_name, split)
                with open(cap_file, 'r') as f: data = json.load(f)
                for d in tqdm(data, desc='generating {} {} data'.format(self.data_name, split)):
                    c_name = d['candidate']
                    t_name = d['target']
                    text = [x.strip() for x in d['captions']]
                    text = ' [SEP] '.join(text)
                    self.data.append({'c_path': os.path.join(data_path, 'resized_images', self.data_name, c_name+'.jpg'),
                                     'c_id': self.names.index(c_name),
                                     't_path': os.path.join(data_path, 'resized_images', self.data_name, t_name+'.jpg'),
                                     't_id': self.names.index(t_name),
                                     'encoded_caption': tokenizer(text)})
                
        print('Statistics: collected in {} {}, {} datas found'.format(self.data_name, self.split, len(self.data)))
        
    def __getitem__(self, index):
        data = self.data[index]
        if self.test_targets:
            test_img_path = data['t_path']
            test_img = Image.open(test_img_path).convert("RGB")
            if self.transform is not None:
                test_img = self.transform(test_img)
            return {'t_img': test_img,
                    't_id': data['t_id']}
        else:
            c_img_path = data['c_path']
            t_img_path = data['t_path']
            encoded_caption = data['encoded_caption']

            c_img = Image.open(c_img_path).convert("RGB")
            if self.transform is not None:
                c_img = self.transform(c_img)

            t_img = Image.open(t_img_path).convert("RGB")
            if self.transform is not None:
                t_img = self.transform(t_img)

            out = {'c_img': c_img,
                   't_img': t_img,
                   'encoded_caption': encoded_caption}

            if self.split is not 'train':
                out['c_id'] = data['c_id']
                out['t_id'] = data['t_id']

        return out
    
    def __len__(self):
        return len(self.data)
    
    def get_loader(self, batch_size):
        return torch.utils.data.DataLoader(self, batch_size=batch_size, 
                                           shuffle=True if self.split is 'train' else False, 
                                           num_workers=0, pin_memory=True,
                                           drop_last=True if self.split is 'train' else False,
                                           collate_fn= None if self.test_targets else self.collate_fn)  #num_workers=8
    
    def get_img(self, idx, raw_img=False):
        img_path = self.data[idx]['c_path']
        img = Image.open(img_path).convert("RGB")
        if raw_img:
          return img
        if self.transform is not None:
          img = self.transform(img)
        return img
    
    def collate_fn(self, batch):
    
        batch.sort(key=lambda x: len(x['encoded_caption']), reverse=True)
        elem = batch[0]
        elem_type = type(elem)

        return_batch = {}
        for key in elem:     
            if key is 'encoded_caption':
                return_batch[key] = self.tokenizer.pad([d[key] for d in batch], return_tensors="pt")
            else:
                return_batch[key] = torch.utils.data._utils.collate.default_collate([d[key] for d in batch])

        return return_batch
    
    def indices_to_string(self, input_ids, skip_special_tokens=False):
        """
        Convert word indices (torch.Tensor) to sentence (string).
        """
        
        text = self.tokenizer.decode(input_ids, skip_special_tokens=skip_special_tokens)
        
        return text

>torchvision.transforms

Transforms are common image transformations. They can be chained together using *Compose*. Additionally, there is the *torchvision.transforms.functional* module. Functional transforms give fine-grained control over the transformations. This is useful if you have to build a more complex transformation pipeline.

>torchvision.transforms.RandomHorizontalFlip(p=0.5)

Horizontally flip the given image randomly with a given probability. 
>torchvision.transforms.RandomAffine

Random affine transformation of the image keeping center invariant. 

* degrees – Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). 
* translate – tuple of maximum absolute fraction for horizontal and vertical translations. For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
* scale – scaling factor interval, e.g (a, b), then scale is randomly sampled from the range a <= scale <= b. Will keep original scale by default.
>torchvision.transforms.ToTensor

Convert a PIL Image or numpy.ndarray to tensor. This transform does not support torchscript.

Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8. In the other cases, tensors are returned without scaling. 
>torchvision.transforms.Normalize(mean, std, inplace=False)

Normalize a tensor image with mean and standard deviation. This transform does not support PIL Image. Given mean: (mean[1],...,mean[n]) and std: (std[1],..,std[n]) for n channels, this transform will normalize each channel of the input torch.*Tensor i.e., output[channel] = (input[channel] - mean[channel]) / std[channel]

* mean – Sequence of means for each channel.
* std – Sequence of standard deviations for each channel.

### Use *Compose* to perform data pre-processing (data normalization).

In [5]:
data_path = '/mnt/data/fashion_iq/data'
val_data_names = ['dress', 'shirt', 'toptee']
batch_size = 32
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = T.Compose([
        T.RandomHorizontalFlip(p=0.5),
        T.RandomAffine(degrees=45, translate=(0.15, 0.15), scale=(0.9, 1.1)),
        T.ToTensor(),
        T.Normalize(mean=norm_mean, std=norm_std),
    ])

val_transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=norm_mean, std=norm_std),
    ])

### Instantiate the class.

In [6]:
train_dataset = FashionIQDataset(data_path, 'train', transform=train_transform)
#using split as a flag to train or test
#the data_name should be specified as one of 'dress', 'shirt' and 'toptee'
train_loader = train_dataset.get_loader(batch_size=batch_size)

val_datasets = {}
val_pool_datasets = {}
for data_name in val_data_names:
    val_datasets[data_name] =  FashionIQDataset(data_path, 'val', data_name=data_name, transform=val_transform)
    val_pool_datasets[data_name] = FashionIQDataset(data_path, 'val', data_name=data_name, transform=val_transform, test_targets=True)

HBox(children=(FloatProgress(value=0.0, description='loading dress train data', max=5985.0, style=ProgressStyl…




HBox(children=(FloatProgress(value=0.0, description='loading shirt train data', max=5988.0, style=ProgressStyl…




HBox(children=(FloatProgress(value=0.0, description='loading toptee train data', max=6027.0, style=ProgressSty…


Statistics: collected in ['dress', 'shirt', 'toptee'] train, 18000 datas found


HBox(children=(FloatProgress(value=0.0, description='generating dress val data', max=2017.0, style=ProgressSty…


Statistics: collected in dress val, 2017 datas found


HBox(children=(FloatProgress(value=0.0, description='loading dress val pool data', max=3817.0, style=ProgressS…


Statistics: collected in dress val, 3817 datas found


HBox(children=(FloatProgress(value=0.0, description='generating shirt val data', max=2038.0, style=ProgressSty…


Statistics: collected in shirt val, 2038 datas found


HBox(children=(FloatProgress(value=0.0, description='loading shirt val pool data', max=6346.0, style=ProgressS…


Statistics: collected in shirt val, 6346 datas found


HBox(children=(FloatProgress(value=0.0, description='generating toptee val data', max=1961.0, style=ProgressSt…


Statistics: collected in toptee val, 1961 datas found


HBox(children=(FloatProgress(value=0.0, description='loading toptee val pool data', max=5373.0, style=Progress…


Statistics: collected in toptee val, 5373 datas found


In [7]:
val_dataloaders = []
for data_name in val_data_names:
    val_dataloaders.append(val_datasets[data_name].get_loader(batch_size))
    val_dataloaders.append(val_pool_datasets[data_name].get_loader(batch_size))
print(len(val_dataloaders))

6


## MODEL
### 2.Model Building

In [8]:
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument('--img_model', default='resnet34', choices=['resnet18', 'resnet34','resnet50'])
parser.add_argument('--text_model', default='lstm', choices=['lstm', 'gru', 'transformer', 'bert-base-uncased'])
parser.add_argument('--embed_dim', type=int, default=512)
parser.add_argument('--hidden_dim', type=int, default=512)
parser.add_argument('--head_num', type=int, default=4)
parser.add_argument('--layers_num', type=int, default=6)

parser.add_argument('--epochs', type=int, default=80)

args = parser.parse_args(args=[])

args.vocab_size = train_dataset.tokenizer.vocab_size
if args.text_model == 'bert-base-cased':
    args.hidden_dim = 768

### Import relvant modules.

In [10]:
import os
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
import pytorch_lightning as pl
import torchvision

>torch.nn.Module

Base class for all neural network modules.Your models should also subclass this class.Modules can also contain other Modules, allowing to nest them in a tree structure. 

### Creat some classes and optional text models.

In [11]:
class ConCatModule(torch.nn.Module):

    def __init__(self):
        super(ConCatModule, self).__init__()

    def forward(self, x):
        x = torch.cat(x, dim=1)
        return x
    
class NormalizationLayer(torch.nn.Module):
    """Class for normalization layer."""
    def __init__(self, normalize_scale=1.0, learn_scale=True):
        super(NormalizationLayer, self).__init__()
        self.norm_s = float(normalize_scale)
        if learn_scale:
            self.norm_s = torch.nn.Parameter(torch.FloatTensor((self.norm_s,)))

    def forward(self, x):
        features = self.norm_s * x / torch.norm(x, dim=1, keepdim=True).expand_as(x)
        return features
    
class BertPooler(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(args.hidden_dim, args.embed_dim),
                                 nn.Dropout(p=0.1),)

    def forward(self, hidden_states):
        first_token_tensor = hidden_states[:, 0]
        return self.net(first_token_tensor)
    
class TextModel(torch.nn.Module):

    def __init__(self,args):

        super().__init__()
        self.text_model = args.text_model
        if args.text_model == 'lstm':
            self.embedding_layer = nn.Embedding(args.vocab_size, args.embed_dim)
            self.net = nn.LSTM(args.embed_dim, args.hidden_dim)
        elif args.text_model == 'gru':
            self.embedding_layer = nn.Embedding(args.vocab_size, args.embed_dim)
            self.net = nn.GRU(args.embed_dim, args.hidden_dim)
        elif args.text_model == 'transformer':
            self.embedding_layer = nn.Embedding(args.vocab_size, args.embed_dim)
            encoder_layer = nn.TransformerEncoderLayer(d_model=args.embed_dim, 
                                                       nhead=args.head_num, 
                                                       dim_feedforward=args.hidden_dim)
            self.net = nn.TransformerEncoder(encoder_layer, num_layers=args.layers_num)
        elif 'bert' in args.text_model:
            self.net = AutoModelForSequenceClassification.from_pretrained(args.text_model,
                                                                          num_labels=args.embed_dim)
        else:
            raise NotImplementedError
        self.fc_output = nn.Sequential(nn.Dropout(p=0.1),
                                       nn.Linear(args.hidden_dim, args.embed_dim))

    def forward(self, x):
        if self.text_model is 'lstm':
            x = self.embedding_layer(x)
            output, (h_n, c_n) = self.net(x)
            text_features = self.fc_output(h_n.squeeze(dim=0))
        elif self.text_model is 'transformer':
            x = self.embedding_layer(x)
            output = self.net(x)
            text_features = self.fc_output(output[:, -1, :])
        elif 'bert' in args.text_model:
            output = self.net(x)
            text_features = output.logits
#             text_features = self.fc_output(output.last_hidden_state[:, 0])
        return text_features


>PyTorch Lightning 

Lightning is just organized PyTorch.It forces the following structure to your code which makes it reusable and shareable:
1. Research code (the LightningModule).
2. Engineering code (you delete, and is handled by the Trainer).
3. Non-essential research code (logging, etc... this goes in Callbacks).
4. Data (use PyTorch Dataloaders or organize them into a LightningDataModule).

### First, encode the query (or reference) image x using a ResNet-17 CNN to get a 2d spatial feature vector φx. Next encode the query text t using a standard LSTM. Define φt to be the hidden state at the final time step whose size d is 512. Keep the text encoder as simple as possible. Finally, combine the two features to compute φxt.

In [12]:
class TIRG(pl.LightningModule):

    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(args)
        self.normalization_layer = NormalizationLayer(
            normalize_scale=4.0, learn_scale=True)
        # img model
        img_model = eval(f'torchvision.models.{args.img_model}(pretrained=True)')

        img_model.fc = torch.nn.Sequential(torch.nn.Linear(img_model.fc.in_features, args.embed_dim))
        self.img_model = img_model

        # text model
        self.text_model = TextModel(args)  
        
        self.a = nn.Parameter(torch.tensor([1.0, 10.0, 1.0, 1.0]))
        self.gated_feature_composer = nn.Sequential(
            ConCatModule(), nn.BatchNorm1d(2 * args.embed_dim), nn.ReLU(),
            nn.Linear(2 * args.embed_dim, args.embed_dim))
        self.res_info_composer = nn.Sequential(
            ConCatModule(), nn.BatchNorm1d(2 * args.embed_dim), nn.ReLU(),
            nn.Linear(2 * args.embed_dim, 2 * args.embed_dim), nn.ReLU(),
            nn.Linear(2 * args.embed_dim, args.embed_dim))
    
    def compose_img_text(self, imgs, texts):
        img_features = self.img_model(imgs)
        if self.hparams.text_model == 'lstm' and texts.shape[0] == imgs.shape[0]:
            texts.transpose_(1, 0) # seq, batch
        text_features = self.text_model(texts)
        
        assert img_features.shape == text_features.shape, \
        'img feat {}, text feat {}'.format(img_features.shape, text_features.shape)
        
        f1 = self.gated_feature_composer((img_features, text_features))
        f2 = self.res_info_composer((img_features, text_features))
        f = torch.sigmoid(f1) * img_features * self.a[0] + f2 * self.a[1]
        return f
    
    def compute_batch_based_classification_loss_(self, mod_img1, img2):
        x = torch.mm(mod_img1, img2.transpose(0, 1))
        labels = torch.tensor(range(x.shape[0])).long().to(self.device)
        return F.cross_entropy(x, labels)


    def training_step(self, batch, batch_idx):
    
        c_img, t_img, encoded_caption = batch['c_img'], batch['t_img'], batch['encoded_caption']
        mod_img1 = self.compose_img_text(c_img, encoded_caption["input_ids"])
        mod_img1 = self.normalization_layer(mod_img1)
        
        img2 = self.img_model(t_img)
        img2 = self.normalization_layer(img2)
        
        assert (mod_img1.shape[0] == img2.shape[0] and
                mod_img1.shape[1] == img2.shape[1])
        
        loss = self.compute_batch_based_classification_loss_(mod_img1, img2)
        self.log('train_loss', loss)
        return loss


    def validation_step(self, batch, batch_idx, dataloader_idx):
        if dataloader_idx & 1 == 0: # queries
            c_img, encoded_caption = batch['c_img'], batch['encoded_caption']
            f = self.compose_img_text(c_img, encoded_caption["input_ids"])
            return {'features': f.cpu(), 'c_ids': batch['c_id'].cpu(), 't_ids': batch['t_id'].cpu()}
        else: # targets
            t_img = batch['t_img']
            f = self.img_model(t_img)       
            return f.cpu()
        
        
    def validation_epoch_end(self, step_outputs):
        for idx in range(0, len(step_outputs), 2):
            queries, all_imgs = step_outputs[idx], step_outputs[idx+1]
            all_queries = [item['features'] for item in queries]
            all_queries = torch.cat(all_queries)
            all_imgs = torch.cat(all_imgs)

            all_queries_id = [item['c_ids'] for item in queries]
            all_queries_id = torch.cat(all_queries_id)
            all_targets_id = [item['t_ids'] for item in queries]
            all_targets_id = torch.cat(all_targets_id)


            # feature normalization
            all_queries /= torch.norm(all_queries, dim=1)[:, None]
            all_imgs /= torch.norm(all_imgs, dim=1)[:, None]

            torch.save(all_imgs, "all_imgs.pt")
        
            # match test queries to target images, get nearest neighbors
            sims = all_queries.mm(all_imgs.T)
            for i, t in enumerate(all_queries_id):
                sims[i, t] = -10e10  # remove query image
            nn_result = [torch.argsort(-sims[i, :])[:50] for i in range(sims.shape[0])]

            # compute recalls
            for k in [1, 5, 10, 50]:
                r = 0.0
                for i, nns in enumerate(nn_result):
                    if all_targets_id[i] in nns[:k]:
                        r += 1
                r /= len(nn_result)
                self.log('{} val_recall_top {}'.format(idx, k), r)

            
    def test_step(self, batch, batch_idx, dataloader_idx):
        if dataloader_idx & 1 == 0: # queries
            c_img, encoded_caption = batch['c_img'], batch['encoded_caption']
            f = self.compose_img_text(c_img, encoded_caption["input_ids"])
            return {'features': f.cpu(), 'c_ids': batch['c_id'].cpu(), 't_ids': batch['t_id'].cpu()}
        else: # targets
            t_img = batch['t_img']
            f = self.img_model(t_img)       
            return f.cpu()
        
        
    def test_epoch_end(self, step_outputs):
        for idx in range(0, len(step_outputs), 2):
            queries, all_imgs = step_outputs[idx], step_outputs[idx+1]
            all_queries = [item['features'] for item in queries]
            all_queries = torch.cat(all_queries)
            all_imgs = torch.cat(all_imgs)

            all_queries_id = [item['c_ids'] for item in queries]
            all_queries_id = torch.cat(all_queries_id)
            all_targets_id = [item['t_ids'] for item in queries]
            all_targets_id = torch.cat(all_targets_id)


            # feature normalization
            all_queries /= torch.norm(all_queries, dim=1)[:, None]
            all_imgs /= torch.norm(all_imgs, dim=1)[:, None]

            # match test queries to target images, get nearest neighbors
            sims = all_queries.mm(all_imgs.T)
            for i, t in enumerate(all_queries_id):
                sims[i, t] = -10e10  # remove query image
            nn_result = [torch.argsort(-sims[i, :])[:50] for i in range(sims.shape[0])]

            # compute recalls
            for k in [1, 5, 10, 50]:
                r = 0.0
                for i, nns in enumerate(nn_result):
                    if all_targets_id[i] in nns[:k]:
                        r += 1
                r /= len(nn_result)
                self.log('{} test_recall_top {}'.format(idx, k), r)

    def configure_optimizers(self):
#         optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-2, momentum=0.9, weight_decay=1e-6)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
        return [optimizer], [scheduler]

### 3.Model Training

In [13]:
from pytorch_lightning.callbacks import LearningRateMonitor

lr_monitor = LearningRateMonitor(logging_interval='epoch')

# init model
model = TIRG(args)

# Initialize a trainer
trainer = pl.Trainer(gpus=1, max_epochs=args.epochs, fast_dev_run=False, num_sanity_val_steps=0, callbacks=[lr_monitor], check_val_every_n_epoch=3)

# Train the model ⚡
trainer.fit(model, train_loader, val_dataloaders=val_dataloaders)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type               | Params
--------------------------------------------------------------
0 | normalization_layer    | NormalizationLayer | 1     
1 | img_model              | ResNet             | 21.5 M
2 | text_model             | TextModel          | 17.2 M
3 | gated_feature_composer | Sequential         | 526 K 
4 | res_info_composer      | Sequential         | 1.6 M 
--------------------------------------------------------------
40.9 M    Trainable params
0         Non-trainable params
40.9 M    Total params
163.442   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



1

### 4.Model Evaluating

### The metric for retrieval is **recall at rank k (R@K)**, computed as the percentage of test queries where (at least 1) target or correct labeled image is within the top K retrieved images. 

In [14]:
# model = TIRG.load_from_checkpoint('lightning_logs/version_2/checkpoints/epoch=79.ckpt', vocab_size=train_dataset.tokenizer.vocab_size, embed_dim=512)
# trainer = pl.Trainer(gpus=1)

trainer.test(model, test_dataloaders=val_dataloaders)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'0 test_recall_top 1': 0.030738720872583045,
 '0 test_recall_top 10': 0.14724838869608328,
 '0 test_recall_top 5': 0.0961824491819534,
 '0 test_recall_top 50': 0.3653941497273178,
 '2 test_recall_top 1': 0.022571148184494603,
 '2 test_recall_top 10': 0.1280667320902846,
 '2 test_recall_top 5': 0.0775269872423945,
 '2 test_recall_top 50': 0.3267909715407262,
 '4 test_recall_top 1': 0.029576746557878633,
 '4 test_recall_top 10': 0.1560428352881183,
 '4 test_recall_top 5': 0.1009688934217236,
 '4 test_recall_top 50': 0.36970933197348294}
--------------------------------------------------------------------------------
DATALOADER:1 TEST RESULTS
{'0 test_recall_top 1': 0.030738720872583045,
 '0 test_recall_top 10': 0.14724838869608328,
 '0 test_recall_top 5': 0.0961824491819534,
 '0 test_recall_top 50': 0.3653941497273178,
 '2 test_recall_top 1': 0.022571148184494603,
 '2 test_recall_t

[{'0 test_recall_top 1': 0.030738720872583045,
  '0 test_recall_top 5': 0.0961824491819534,
  '0 test_recall_top 10': 0.14724838869608328,
  '0 test_recall_top 50': 0.3653941497273178,
  '2 test_recall_top 1': 0.022571148184494603,
  '2 test_recall_top 5': 0.0775269872423945,
  '2 test_recall_top 10': 0.1280667320902846,
  '2 test_recall_top 50': 0.3267909715407262,
  '4 test_recall_top 1': 0.029576746557878633,
  '4 test_recall_top 5': 0.1009688934217236,
  '4 test_recall_top 10': 0.1560428352881183,
  '4 test_recall_top 50': 0.36970933197348294},
 {'0 test_recall_top 1': 0.030738720872583045,
  '0 test_recall_top 5': 0.0961824491819534,
  '0 test_recall_top 10': 0.14724838869608328,
  '0 test_recall_top 50': 0.3653941497273178,
  '2 test_recall_top 1': 0.022571148184494603,
  '2 test_recall_top 5': 0.0775269872423945,
  '2 test_recall_top 10': 0.1280667320902846,
  '2 test_recall_top 50': 0.3267909715407262,
  '4 test_recall_top 1': 0.029576746557878633,
  '4 test_recall_top 5': 0.10

In [15]:
# data_name = 'toptee'
# trainer.test(model, test_dataloaders=[val_datasets[data_name].get_loader(batch_size), val_pool_datasets[data_name].get_loader(batch_size)])

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'0 test_recall_top 1': 0.029576746557878633,
 '0 test_recall_top 10': 0.1560428352881183,
 '0 test_recall_top 5': 0.1009688934217236,
 '0 test_recall_top 50': 0.36970933197348294,
 '2 test_recall_top 1': 0.022571148184494603,
 '2 test_recall_top 10': 0.1280667320902846,
 '2 test_recall_top 5': 0.0775269872423945,
 '2 test_recall_top 50': 0.3267909715407262,
 '4 test_recall_top 1': 0.029576746557878633,
 '4 test_recall_top 10': 0.1560428352881183,
 '4 test_recall_top 5': 0.1009688934217236,
 '4 test_recall_top 50': 0.36970933197348294}
--------------------------------------------------------------------------------
DATALOADER:1 TEST RESULTS
{'0 test_recall_top 1': 0.029576746557878633,
 '0 test_recall_top 10': 0.1560428352881183,
 '0 test_recall_top 5': 0.1009688934217236,
 '0 test_recall_top 50': 0.36970933197348294,
 '2 test_recall_top 1': 0.022571148184494603,
 '2 test_recall_t

[{'0 test_recall_top 1': 0.029576746557878633,
  '0 test_recall_top 5': 0.1009688934217236,
  '0 test_recall_top 10': 0.1560428352881183,
  '0 test_recall_top 50': 0.36970933197348294,
  '2 test_recall_top 1': 0.022571148184494603,
  '2 test_recall_top 5': 0.0775269872423945,
  '2 test_recall_top 10': 0.1280667320902846,
  '2 test_recall_top 50': 0.3267909715407262,
  '4 test_recall_top 1': 0.029576746557878633,
  '4 test_recall_top 5': 0.1009688934217236,
  '4 test_recall_top 10': 0.1560428352881183,
  '4 test_recall_top 50': 0.36970933197348294},
 {'0 test_recall_top 1': 0.029576746557878633,
  '0 test_recall_top 5': 0.1009688934217236,
  '0 test_recall_top 10': 0.1560428352881183,
  '0 test_recall_top 50': 0.36970933197348294,
  '2 test_recall_top 1': 0.022571148184494603,
  '2 test_recall_top 5': 0.0775269872423945,
  '2 test_recall_top 10': 0.1280667320902846,
  '2 test_recall_top 50': 0.3267909715407262,
  '4 test_recall_top 1': 0.029576746557878633,
  '4 test_recall_top 5': 0.10

### 5.Image Retrieving  


In [None]:
import PIL

link = "try.jpg"
with open(link,"rb") as f:
    query_img_raw = PIL.Image.open(f).convert("RGB")
query_img = train_transform(query_img_raw)
query_img = [query_img]
query_img = torch.stack(query_img).float()
query_img = torch.autograd.Variable(query_img)
# Thay câu miêu tả: "replace A with B"
query_text_raw = ["black toptee"]
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
query_text = tokenizer(query_text_raw)["input_ids"]
query_text = torch.Tensor(query_text)


query_feature = model.compose_img_text(query_img.cuda(), query_text.long().cuda()).data.cpu()

In [None]:
import matplotlib.pyplot as plt

all_imgs= torch.load( "all_imgs.pt")
sims = query_feature.mm(all_imgs.T)
nn_result = [torch.argsort(-sims[i, :])[:50] for i in range(sims.shape[0])]

c = 5
r = 4
fig = plt.figure(figsize=(20, 20))
# Show query
fig.add_subplot(r, c, 3)
plt.imshow(query_img_raw)
plt.title(query_text_raw[0])
plt.axis("off")
# Show output
k = 15

for i in range(k):
    img = train_dataset.get_img(int(nn_result[0][i]), raw_img=True)
    fig.add_subplot(r, c, i+6)
    plt.imshow(img)
    plt.axis('off')

plt.show()
