In [1]:
!pip install pycocotools > /dev/null
!pip install neptune-client > /dev/null



In [2]:
!wc -l ../input/flickr/*.txt

  1000 ../input/flickr/test.txt
 29783 ../input/flickr/train.txt
  1000 ../input/flickr/val.txt
 31783 total


In [3]:
import torch
torch_device = 'cpu'
if hasattr(torch._C, '_cuda_getDeviceCount'):
    torch_device = 'cuda'

In [4]:
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import pytorch_lightning as pl
import nltk
#nltk.download('stopwords')
from nltk.corpus import stopwords
from scipy.io import loadmat
import os
import pickle
import numpy as np
from pycocotools.coco import COCO
from torch.nn import functional as F
import math

from pytorch_lightning.loggers.neptune import NeptuneLogger

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
neptune_key = user_secrets.get_secret("neptune")

def load_word_embeddings(word_embedding_filename, embedding_length):
    with open(word_embedding_filename, 'r') as f:
        word_embeddings = {}
        for i, line in enumerate(f):
            if i % 10000 == 0:
                print('Reading word embedding vector {:d}'.format(i))
                    
            line = line.strip()
            if not line:
                continue

            vec = line.split()
            if len(vec) != embedding_length + 1:
                continue
            
            label = vec[0].lower()
            vec = np.array([float(x) for x in vec[1:]], np.float32)            
            assert len(vec) == embedding_length
            word_embeddings[label] = vec

    return word_embeddings

def load_flickr_captions(args, split):
    stop_words = set(stopwords.words('english'))
    split_fn = os.path.join(args.feat_path, args.dataset, split + '.txt')
    images = [im.strip() for im in open(split_fn, 'r')]
    im2idx = dict(zip(images, range(len(images))))
    images = set(images)
    caption_fn = os.path.join(args.feat_path, args.dataset, 'results_20130124.token')
    im2captions = {}
    with open(caption_fn, 'r') as f:
        for line in f:
            line = line.strip().lower().split()
            im = line[0].split('.')[0]
            if im in images:
                if im not in im2captions:
                    im2captions[im] = []

                im2captions[im].append([token for token in line[1:] if token not in stop_words])

    assert(len(im2idx) == len(im2captions))
    captions = []
    cap2im = []
    for im, idx in im2idx.items():
        im_captions = im2captions[im]
        captions += im_captions
        cap2im.append(np.ones(len(im_captions), np.int32) * idx)

    cap2im = np.hstack(cap2im)
    return captions, cap2im

class DatasetLoader:
    """ Dataset loader class that loads feature matrices from given paths and
        create shuffled batch for training, unshuffled batch for evaluation.
    """
    def __init__(self, args, split='train', val_limit=1000):
        feat_path = os.path.join(args.feat_path, args.dataset, split + '_features.npy')
        print('Loading features from', feat_path)
        self.im_feats = np.load(feat_path)
        if args.dataset == 'flickr':
            self.captions, self.cap2im = load_flickr_captions(args, split)
        else:
            self.captions, self.cap2im = load_coco_captions(args, split)

            if split == 'val':
                # let's only take the first 1K images for MSCOCO images
                num_images = val_limit
                self.im_feats = self.im_feats[:num_images]
                subset_ims = self.cap2im < num_images
                self.captions = [caption for caption, is_val in zip(self.captions, subset_ims) if is_val]
                self.cap2im = [im for im, is_val in zip(self.cap2im, subset_ims) if is_val]

        assert len(self.cap2im) == len(self.captions)
        if split != 'train':
            self.labels = np.zeros((len(self.cap2im), len(self.im_feats)), np.float)
            self.labels[(range(len(self.cap2im)), self.cap2im)] = 1
        else:
            self.im2cap = {}
            for cap, im in enumerate(self.cap2im):
                if im not in self.im2cap:
                    self.im2cap[im] = []

                self.im2cap[im].append(cap)

        print('Loading complete')
        self.split = split
        self.sample_size = args.sample_size
        self.im_feats = torch.from_numpy(self.im_feats).to(torch_device)

    def build_vocab(self, cache_filename, word_embeddings_filename=None, embedding_length=300):
        if os.path.exists(cache_filename):
            vocab_data = pickle.load(open(cache_filename, 'rb'))
            self.max_length = vocab_data['max_length']
            self.tok2idx = vocab_data['tok2idx']
            vecs = vocab_data['vecs']
        else:
            assert word_embeddings_filename is not None
            word_embeddings = load_word_embeddings(word_embeddings_filename, embedding_length)
            self.max_length = 0
            vocab = set()
            for caption in self.captions:
                tokens = [token for token in caption if token in word_embeddings]
                vocab.update(tokens)
                self.max_length = max(self.max_length, len(tokens))

            vocab = list(vocab)
            # +1 for a padding vector which *must* be the 0th index
            self.tok2idx = dict(zip(vocab, range(1, len(vocab) + 1)))
            vecs = np.zeros((len(vocab) + 1, embedding_length), np.float32)
            for i, token in enumerate(vocab):
                vecs[i + 1] = word_embeddings[token]
            
            vocab_data = {'max_length' : self.max_length,
                          'tok2idx' : self.tok2idx,
                          'vecs' : vecs}

            pickle.dump(vocab_data, open(cache_filename, 'wb'))

        self.sent_feats = np.zeros((len(self.captions), self.max_length), np.int64)
        for i, caption in enumerate(self.captions):
            tokens = [self.tok2idx[token] for token in caption if token in self.tok2idx]
            self.sent_feats[i, :len(tokens)] = tokens
        
        self.sent_feats = torch.from_numpy(self.sent_feats).to(torch_device)
        return vecs

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        im_feat = self.im_feats[self.cap2im[index]]
        sent_feat = self.sent_feats[index]
        return im_feat, sent_feat

In [5]:
import pickle
import numbers
import torch
import torch.nn as nn

def make_fc_1d(f_in, f_out):
    return nn.Sequential(nn.Linear(f_in, f_out),
                         nn.BatchNorm1d(f_out),
#                          nn.Tanh(),
                         nn.Dropout(p=0.3),
                        )

def make_fc_relu_1d(f_in, f_out):
    return nn.Sequential(nn.Linear(f_in, f_out),
                         nn.BatchNorm1d(f_out),
                         nn.ReLU(inplace=True),
                         nn.Dropout(p=0.5))

class EmbedBranch(nn.Module):
    def __init__(self, args, feat_dim, embedding_dim, metric_dim):
        super(EmbedBranch, self).__init__()
        self.args = args
        self.fc1 = make_fc_relu_1d(feat_dim, embedding_dim)
        self.fc2 = nn.Linear(embedding_dim, metric_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)

        # L2 normalize each feature vector
        if self.args.l2_norm:
            x = nn.functional.normalize(x)
        return {'embedding': x}

class SparseEmbedBranch(nn.Module):
    def __init__(self, args, feat_dim, embedding_dim, metric_dim, use_relu=False, norm_weight=None):
        super(SparseEmbedBranch, self).__init__()
        self.args = args
        self.shared = make_fc_1d(feat_dim, embedding_dim)
        self.embed_tower = nn.Sequential(
            nn.ReLU(),
            nn.Linear(embedding_dim, metric_dim),
        )
        self.reg_tower = nn.Sequential(
            nn.Sigmoid(),
            nn.Linear(embedding_dim, metric_dim),
        )
        if use_relu:
            self.embed_tower.add_module('relu', nn.ReLU())
        self.reg_gate = False
        self.beta = 1.0
        self.high = args.high
        self.low = args.low
        self.reg_weight = 0.0
        self.reg_max_weight = 0.0
        self.q = 0.99
        self.norm_weight = 2.0
        if norm_weight is not None:
            self.norm = nn.BatchNorm1d(metric_dim, affine=False)
        else:
            self.norm = None
        
    def forward(self, x):
        ret = {}
        shared = self.shared(x)
        embedding = self.embed_tower(shared)
        if self.reg_gate is True:
            alpha = self.reg_tower(shared)
            if self.norm:
                alpha = self.norm(alpha)
                alpha = alpha + self.norm_weight
            weight = self.sample_attention(alpha)
            ret['alpha'] = torch.mean(alpha, [0, 1])
            ret['gate'] = torch.mean(weight, [0, 1])
            pp = self.get_prob(alpha)
#             pp = weight
            ret['prob'] = pp
            pp = torch.clamp(pp, min=0.0001, max=0.9999)
            loss1 = torch.max(torch.mean(pp, [1])) * self.reg_weight
            loss2 = torch.max(torch.mean(pp, [0])) * self.reg_max_weight
            loss = loss1 + loss2
#             gp = torch.mean(pp, [0])
#             loss_raw = torch.log(gp / self.q) * gp + (1.0 - gp) * torch.log((1.0 - gp) / (1.0 - self.q))
#             loss = torch.mean(loss_raw) * self.reg_weight
            ret['loss'] = loss
#             if self.args.l2_norm:
#                 embedding = nn.functional.normalize(embedding)
#             embedding = embedding * weight
            embedding = nn.functional.normalize(weight)
        else:
            if self.args.l2_norm:
                embedding = nn.functional.normalize(embedding)
        ret['embedding'] = embedding
        return ret
    
    def get_prob(self, weights):
        return torch.sigmoid(weights - self.beta * math.log(- self.low / self.high))
        
    def sample_attention(self, weights):
        if self.training:
            eps = torch.rand_like(weights)
            s = torch.sigmoid((torch.log(eps) - torch.log(1.0 - eps) + weights)/self.beta)
        else:
            s = torch.sigmoid(weights/0.001)
        s = s * (self.high - self.low) + self.low
        return F.hardtanh(s, min_val=0, max_val=1)

class TextEncoder(nn.Module):
    def __init__(self, args, vecs):
        super(TextEncoder, self).__init__()
        self.args = args
        n_tokens, token_dim = vecs.shape
        self.words = nn.Embedding(n_tokens, token_dim)
        self.words.weight = nn.Parameter(torch.from_numpy(vecs))
        self.vecs = torch.from_numpy(vecs)
        self.word_reg = nn.MSELoss()
        if args.language_model == 'attend':
            self.word_attention = nn.Sequential(nn.Linear(vecs.shape[1] * 2, 1),
                                                nn.ReLU(inplace=True),
                                                nn.Softmax(dim=1))
        if args.worker == 'cuda':
            self.cuda()
            self.vecs = self.vecs.cuda()
        
    
    def forward(self, tokens):
        words = self.words(tokens)
        n_words = torch.sum(tokens > 0, 1).float() + 1e-10
        sum_words = words.sum(1).squeeze()
        sentences = sum_words / n_words.unsqueeze(1)
        
        if self.args.language_model == 'attend':
            max_length = tokens.size(-1)
            context_vector = sentences.unsqueeze(1).repeat(1, max_length, 1)
            attention_inputs = torch.cat((context_vector, words), 2)
            attention_weights = self.word_attention(attention_inputs)
            sentences = nn.functional.normalize(torch.sum(words * attention_weights, 1))
        
        return sentences

        
class ImageSentenceEmbeddingNetwork(pl.LightningModule):
    def __init__(self, args, vecs, image_feature_dim):
        super(ImageSentenceEmbeddingNetwork, self).__init__()
        embedding_dim = args.dim_embed
        metric_dim = args.metric_dim
        self.text_encoder = TextEncoder(args, vecs)
        _, token_dim = vecs.shape
        if args.use_sparse:
            self.text_branch = SparseEmbedBranch(args, token_dim, embedding_dim, metric_dim, use_relu=args.use_relu, norm_weight=args.text_norm_weight)
            self.image_branch = SparseEmbedBranch(args, image_feature_dim, embedding_dim, metric_dim, use_relu=args.use_relu, norm_weight=args.image_norm_weight)
        else:
            self.text_branch = EmbedBranch(args, token_dim, embedding_dim, metric_dim)
            self.image_branch = EmbedBranch(args, image_feature_dim, embedding_dim, metric_dim)
        self.args = args
        self.output = nn.Sequential(
            nn.Linear(1, 1),
            nn.Sigmoid(),
        )
        self.reg_factor = args.reg_weight_alpha
        self.beta_factor = args.reg_beta_alpha
        self.dot_loss_weight = 0.0
        self.dot_true_loss_weight = 0.0

    def forward(self, images, tokens):
        sentences = self.text_encoder(tokens)
        ret_text = self.text_branch(sentences)
        ret_image = self.image_branch(images)
        
        sentences = ret_text['embedding']
        sentences_loss = ret_text.get('loss', None)
        images = ret_image['embedding']
        image_loss = ret_image.get('loss', None)
        if 'alpha' in ret_text:
            self.log('alpha_text', ret_text['alpha'])
            self.log('gate_text', ret_text['gate'])
        if 'alpha' in ret_image:
            self.log('alpha_image', ret_image['alpha'])
            self.log('gate_image', ret_image['gate'])
        text_prob = None
        image_prob = None
        if sentences_loss is None:
            total_loss = image_loss
        elif image_loss is None:
            total_loss = sentences_loss
        else:
            text_prob = ret_text['prob']
            image_prob = ret_image['prob']
            #prob_dot = text_prob * image_prob
            #dot_loss = torch.mean(torch.max(prob_dot, 1)[0]) * self.dot_loss_weight
            total_loss = sentences_loss + image_loss
        return images, sentences, total_loss, image_prob, text_prob        
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.args.lr)
    
    def loss(self, batch, export=False):
        losses = {}
        images, sentences = batch
        im_embeds, sent_embeds, embed_loss, image_prob, text_prob = self(images, sentences)
        embed_dot = im_embeds * sent_embeds
        embed_dot = torch.sum(embed_dot, dim=(1,), keepdim=True)
        losses['dot_true_avg'] = torch.mean(embed_dot)
        ret = self.output(embed_dot)
        ones = torch.ones_like(ret)
        zeros = torch.zeros_like(ret)
        ret = torch.clamp(ret, min=0.01, max=0.99)
        loss = F.binary_cross_entropy(ret, ones)
        losses['true_entropy_loss'] = loss
        dot_loss = 0.0
        if image_prob is not None:
            prob_dot = text_prob * image_prob
            dot_loss += torch.mean(prob_dot) * self.dot_true_loss_weight
                
        for i in range(self.args.sample_size):
            sent_embeds2 = torch.roll(sent_embeds, 1 + i, 0)
            embed_dot2 = im_embeds * sent_embeds2
            embed_dot2 = torch.sum(embed_dot2, dim=(1,), keepdim=True)
            losses['dot_false_avg'] = torch.mean(embed_dot2)
            ret = self.output(embed_dot2)
            neg_loss = F.binary_cross_entropy(ret, zeros)
            losses[f'false_entropy_loss'] = neg_loss
            loss = loss + neg_loss * self.args.neg_sample_weight

            if image_prob is not None:
                image_prob2 = torch.roll(image_prob, 1 + i, 0)
                prob_dot = text_prob * image_prob2
                dot_loss += torch.mean(torch.max(prob_dot, 1)[0]) * self.dot_loss_weight
#                 dot_loss += torch.mean(prob_dot) * self.dot_loss_weight
                  
        loss = loss / (1.0 +self.args.sample_size)
        dot_loss = dot_loss / self.args.sample_size
        losses['entropy_loss'] = loss
        losses['dot_loss'] = dot_loss
        if export:
            losses['embed_im'] = im_embeds.detach().cpu().numpy()
            losses['embed_sent'] = sent_embeds.detach().cpu().numpy()
        if embed_loss is not None:
            losses['embed_loss'] = embed_loss
            loss = loss + embed_loss
        loss = loss + dot_loss
        with torch.no_grad():
            losses['image_used_indices_avg'] = torch.mean(torch.count_nonzero(im_embeds, dim=1).double())
            losses['index_used_images_avg'] = torch.mean(torch.count_nonzero(im_embeds, dim=0).double())
            losses['index_used_images_max'] = torch.max(torch.count_nonzero(im_embeds, dim=0).double())
            losses['text_used_indices_avg'] = torch.mean(torch.count_nonzero(sent_embeds, dim=1).double())
            losses['index_used_text_avg'] = torch.mean(torch.count_nonzero(sent_embeds, dim=0).double())
            losses['index_used_text_max'] = torch.max(torch.count_nonzero(sent_embeds, dim=0).double())
        losses['loss'] = loss
        return losses
    
    def logall(self, prefix, losses):
        for k, v in losses.items():
            if k in ['embed_im', 'embed_sent']:
                continue
            self.log(f'{prefix}_{k}', v)
            
    def validation_step(self, batch, batch_idx):
        losses = self.loss(batch, export=True)
        self.logall('val', losses)
        return losses
        
    def training_step(self, batch, batch_idx):
        losses = self.loss(batch)
        self.logall('train', losses)
        return losses
            
    def validation_epoch_end(self, validation_step_outputs):
        if self.current_epoch > self.args.reg_start_epoch:
            self.reg_factor = min(1.0, self.reg_factor + self.args.reg_weight_beta)
            self.text_branch.q = max(self.text_branch.q - self.args.q_beta, self.args.q_text)
            self.image_branch.q = max(self.image_branch.q - self.args.q_beta, self.args.q_image)
            if self.args.text_norm_weight is not None:
                self.text_branch.norm_weight = max(self.text_branch.norm_weight - self.args.text_norm_step, self.args.text_norm_weight)
                self.log('norm_text', self.text_branch.norm_weight)
            self.image_branch.norm_weight = max(self.image_branch.norm_weight - self.args.image_norm_step, self.args.image_norm_weight)
            self.log('q_text', self.text_branch.q)
            self.log('q_image', self.image_branch.q)
            self.log('norm_image', self.image_branch.norm_weight)
            self.dot_loss_weight = self.args.dot_loss_weight * self.reg_factor
            self.dot_true_loss_weight = self.args.dot_true_loss_weight * self.reg_factor
            self.text_branch.reg_weight = self.args.reg_weight_text * self.reg_factor
            self.image_branch.reg_weight = self.args.reg_weight_image * self.reg_factor
            self.text_branch.reg_max_weight = self.args.reg_max_weight_text * self.reg_factor
            self.image_branch.reg_max_weight = self.args.reg_max_weight_image * self.reg_factor
            self.text_branch.reg_gate = True
            self.image_branch.reg_gate = True
            self.beta_factor = max(self.args.reg_beta_min, self.beta_factor * self.args.reg_beta_beta)
            self.text_branch.beta = self.beta_factor
            self.image_branch.beta = self.beta_factor
            self.log('beta_factor', self.beta_factor)
            self.log('reg_factor', self.reg_factor)
            
        embed_im = []
        embed_sent = []
        for out in validation_step_outputs:
            embed_im.append(out['embed_im'])
            embed_sent.append(out['embed_sent'])
        embed_im = np.vstack(embed_im)
        embed_sent = np.vstack(embed_sent)        
        prod = np.matmul(embed_sent, embed_im.T)
        N = prod.shape[0]
#         if self.current_epoch % 10 == 9:
#             embed_im_ones = (embed_im != 0).astype(int)
#             embed_sent_ones = (embed_sent != 0).astype(int)
#             dot_prod_count = 1.0 * np.sum(np.matmul(embed_sent_ones, embed_im_ones.T)) / N / N
#             self.log(f'val_dot_count', dot_prod_count)
        val_im_retrieved = np.mean(np.count_nonzero(prod, axis=1))
        self.log('val_im_retrieved', val_im_retrieved)
        self.log('val_im_retrieved_ratio', val_im_retrieved/N)
        topall = np.argsort(prod)
        keys = topall[0,-100::5]
        keys = topall[0,:100:5]
        i = np.arange(N) 
        i = np.tile(i,(N,1))
        i = i.T
        topall = topall == i
        for topk in [5, 20, 100]:
            top1 = topall[:,-topk:]
            top1 = np.sum(top1, axis=1)
            nz = np.count_nonzero(top1)
            recall1 = nz / N
            top1 = topall[:,:topk]
            top1 = np.sum(top1, axis=1)
            nz = np.count_nonzero(top1)
            recall2 = nz / N
            recall = max(recall1, recall2)
            self.log(f'val_recall@{topk}', recall)
        

In [6]:
class FLAGS:
    vocab_filename = '/kaggle/working/vocab.pkl'
    word_embedding_filename = '/kaggle/input/grovle/grovle.txt'
    encoder_filename = '/kaggle/working/encoder.pt'
    text_branch_filename = '/kaggle/working/text_branch.pt'
    ranker_filename = '/kaggle/working/ranker.pkl'
    index_filename = '/kaggle/working/index.pkl'
    batch_size = 600
    feat_path = '/kaggle/input/'
    dataset = 'flickr'
    sample_size = 10
    worker = torch_device
    use_relu = False
    dim_embed = 2048
    metric_dim = 2048
    language_model = 'avg'
    lr = 3e-2
    reg_start_epoch = 0
    dot_true_loss_weight = 0.0
    dot_loss_weight = 0.0
    q_image = 0.01
    q_text = 0.01
    q_beta = 0.1
    reg_weight_image = 0.0
    reg_weight_text = 0.0
    reg_max_weight_image = 0.0
    reg_max_weight_text = 0.0
    reg_weight_alpha = 1.0
    reg_weight_beta = 0.002
    reg_beta_alpha = 0.8
    reg_beta_beta = 1.0
    reg_beta_min = 0.8
    image_norm_weight = -3
    text_norm_weight = -5
    image_norm_step = 0.1
    text_norm_step = 0.1
    use_sparse = True
    l2_norm = True
    max_epochs = 200
    high = 1.1
    low = -0.1
    neg_sample_weight = 2

In [7]:
train_loader = DatasetLoader(FLAGS, 'train')
embedding_length = 300
print('Loading vocab')
vecs = train_loader.build_vocab(FLAGS.vocab_filename, FLAGS.word_embedding_filename, embedding_length)
val_loader = DatasetLoader(FLAGS, 'val')
val_loader.build_vocab(FLAGS.vocab_filename)
print('Loading complete')


Loading features from /kaggle/input/flickr/train_features.npy
Loading complete
Loading vocab
Reading word embedding vector 0
Reading word embedding vector 10000
Reading word embedding vector 20000
Loading features from /kaggle/input/flickr/val_features.npy
Loading complete
Loading complete


In [8]:
def to_param():
    d = {}
    for k, v in FLAGS.__dict__.items():
        if k.startswith('_'):
            continue
        d[k] = v
    return d

In [9]:
torch_train_loader = torch.utils.data.DataLoader(train_loader,batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)
torch_val_loader = torch.utils.data.DataLoader(val_loader,batch_size=FLAGS.batch_size, shuffle=True)
image_feature_dim = torch_train_loader.dataset.im_feats.shape[-1]
text_feature_dim = torch_train_loader.dataset.sent_feats.shape[-1]
print(f'Image Feature dim {image_feature_dim} text feature dim {text_feature_dim}')
model = ImageSentenceEmbeddingNetwork(FLAGS, vecs, image_feature_dim)
n_parameters = sum([p.data.nelement() for p in model.parameters()])
print('  + Number of params: {}'.format(n_parameters))
FLAGS.n_parameters = n_parameters
neptune_logger = NeptuneLogger(api_key=neptune_key, project_name="pickedmelon/flickr", params=to_param())
trainer = pl.Trainer(
    max_epochs=FLAGS.max_epochs,
    gpus=1 if FLAGS.worker == 'cuda' else 0, 
    tpu_cores=8 if FLAGS.worker == 'tpu' else None, 
    logger=neptune_logger)#, callbacks=[EarlyStopping(monitor='val_loss')]
trainer.fit(model, torch_train_loader, torch_val_loader)

Image Feature dim 2048 text feature dim 45
  + Number of params: 26003802
https://ui.neptune.ai/pickedmelon/flickr/e/FLIC-240


Validation sanity check: 0it [00:00, ?it/s]

  f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [10]:
class ImageDS:
    """ Dataset loader class that loads feature matrices from given paths and
        create shuffled batch for training, unshuffled batch for evaluation.
    """
    def _load(self, args, split):
        feat_path = os.path.join(args.feat_path, args.dataset, split + '_features.npy')
        print('Loading features from', feat_path)
        im_feats = np.load(feat_path)
        split_fn = os.path.join(args.feat_path, args.dataset, split + '.txt')
        images = [im.strip() for im in open(split_fn, 'r')]
        return im_feats, images
        
    def __init__(self, args, split='val'):
        feats_train, images_train = self._load(args, 'train')
        feats_val, images_val = self._load(args, 'val')
        self.im_feats = np.vstack([feats_train, feats_val])
        self.images = images_train + images_val

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        im_feat = self.im_feats[index]
        image_name = self.images[index]
        return im_feat, image_name

In [11]:
imgDS = ImageDS(FLAGS)
torch_predict_loader = torch.utils.data.DataLoader(imgDS,batch_size=FLAGS.batch_size, num_workers=1, drop_last=False)

Loading features from /kaggle/input/flickr/train_features.npy
Loading features from /kaggle/input/flickr/val_features.npy


In [12]:
from scipy.sparse import csc_matrix, csr_matrix
img_list = []
embeddings = []
model.image_branch.eval()
for (idx, batch) in enumerate(torch_predict_loader):
    embedding = model.image_branch(batch[0])['embedding'].detach().cpu().numpy()
    img_list.extend(batch[1])
    embeddings.append(embedding)
img_index = csc_matrix(np.vstack(embeddings))
print(f'Index memory used {img_index.data.nbytes/1024/1024} MB')

Index memory used 3.8870849609375 MB


In [13]:
print(img_index.shape)
with open('test.npy', 'wb') as f:
    np.save(f, img_index)

(30783, 2048)


In [14]:
torch.save(model.text_encoder, FLAGS.encoder_filename)
torch.save(model.text_branch, FLAGS.text_branch_filename)

In [15]:
class Ranker:
    def __init__(self, args, img_list, img_index, should_flip):
        super(Ranker, self).__init__()
        device = torch.device('cpu')
        self.text_encoder = torch.load(args.encoder_filename, map_location=device)
        self.text_encoder.eval()
        self.text_branch = torch.load(args.text_branch_filename, map_location=device)
        self.text_branch.eval()
        self.stop_words = set(stopwords.words('english'))
        vocab_data = pickle.load(open(args.vocab_filename, 'rb'))
        self.max_length = vocab_data['max_length']
        self.tok2idx = vocab_data['tok2idx']
        self.vecs = vocab_data['vecs']
        self.img_list = img_list
        self.img_index = img_index
        self.should_flip = should_flip


    def _tokenize_caption(self, caption):
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        tokens = [self.tok2idx[token] for token in tokens if token not in self.stop_words and token in self.tok2idx]
        if len(tokens) == 0:
            return None, None
        text_vec = np.zeros((1, self.max_length), np.int64)
        text_vec[0, :len(tokens)] = tokens
        with torch.no_grad():
            text_vec = torch.from_numpy(text_vec)
            text_vec = self.text_encoder(text_vec)
            text_vec = self.text_branch(text_vec)['embedding']
            text_vec = text_vec.detach().numpy()
        _, pos = np.nonzero(text_vec)
        val = text_vec[0, pos]
        return pos, val

    def search(self, caption, ignored_index=None):
        if ignored_index is None:
            ignored_index = []
        pos, val = self._tokenize_caption(caption)
        if pos is None:
            return [], []
        for key in ignored_index:
            val[np.where(pos == key)] = 0
        prod = self.img_index[:, pos].dot(val)
        if self.should_flip:
            prod = -prod
        topall = np.argsort(prod)
        keys = topall[:-20:-1]
        ret = []
        for k in keys:
            ret.append(self.img_list[k])
        return ret, pos.tolist()

In [16]:
import pickle
should_flip = bool(model.output[0]._parameters['weight'].data[0][0] < 0)
ranker = Ranker(FLAGS, img_list, img_index, should_flip)
pickle.dump(ranker, open(FLAGS.ranker_filename, 'wb'))

In [17]:
ranker._tokenize_caption("a large red stop sign that is by some flowers")

(array([ 121,  184,  251,  440,  464,  476,  503,  609,  643,  684,  838,
         963,  965, 1248, 1431, 1458, 1474, 1482, 1497, 1508, 1583, 1724,
        1796, 1810, 1870, 1919, 1958, 1977, 2027]),
 array([0.18569534, 0.18569534, 0.18569534, 0.18569534, 0.18569534,
        0.18569534, 0.18569534, 0.18569534, 0.18569534, 0.18569534,
        0.18569534, 0.18569534, 0.18569534, 0.18569534, 0.18569534,
        0.18569534, 0.18569534, 0.18569534, 0.18569534, 0.18569534,
        0.18569534, 0.18569534, 0.18569534, 0.18569534, 0.18569534,
        0.18569534, 0.18569534, 0.18569534, 0.18569534], dtype=float32))

In [18]:
ranker.search("people with a cat and dog")

(['4346765529',
  '4966265765',
  '3018847610',
  '4633639059',
  '397982550',
  '4329197704',
  '2851198725',
  '2578161080',
  '583090328',
  '4594171071',
  '6030269328',
  '8171835256',
  '4803941821',
  '3458274053',
  '2858793859',
  '2164018393',
  '3417788829',
  '1420060020',
  '1362851262'],
 [29,
  43,
  47,
  64,
  84,
  97,
  127,
  143,
  145,
  180,
  187,
  234,
  248,
  295,
  356,
  367,
  384,
  385,
  397,
  430,
  494,
  533,
  558,
  589,
  596,
  607,
  632,
  699,
  714,
  741,
  749,
  768,
  783,
  794,
  799,
  813,
  818,
  845,
  875,
  889,
  892,
  903,
  907,
  941,
  976,
  991,
  1005,
  1011,
  1015,
  1040,
  1062,
  1076,
  1077,
  1093,
  1135,
  1149,
  1153,
  1168,
  1172,
  1175,
  1200,
  1201,
  1268,
  1275,
  1276,
  1284,
  1286,
  1294,
  1303,
  1340,
  1342,
  1353,
  1373,
  1384,
  1395,
  1412,
  1430,
  1434,
  1437,
  1489,
  1492,
  1518,
  1599,
  1664,
  1665,
  1684,
  1739,
  1741,
  1760,
  1806,
  1823,
  1833,
  1843,
  188

In [19]:
!rm -rf Untitled