In [1]:
# Main drive script to run train/test

# clone https://github.com/google/tirg in ./..
# copy *.npy from ./..

In [2]:
import sys
import json
sys.path.append('./../')
sys.path.append('./../tirg/')
from main import *

In [3]:
opt = parse_opt() 
opt.batch_size = 32
opt.coco_path = '../../../datasets/coco'
opt.sic112_path = '../../../datasets/SIC112/'

logger = SummaryWriter(comment = opt.comment)

In [4]:
# datasets

In [5]:
trainset, _, sic112 = load_datasets(opt)

# add subject, verb and location annotations to SIC112
for img in sic112.imgs:
    img['subjects'] = [img['captions'][0].split()[0]]
    if img['captions'][0].split()[1].endswith("ing"):
        img['verbs'] = [img['captions'][0].split()[1]]
        img['locations'] = [' '.join(img['captions'][0].split()[2:])]
    else:
        img['verbs'] = []
        img['locations'] = [' '.join(img['captions'][0].split()[1:])]

# add subject, verb and location annotations to coco train 2014
# (need 'coco_splitted_captions_train2014.json' preprocess_coco first)
id2img = {}
for img in trainset.imgs:
    id2img[img['id']] = img
    img['subjects'] = []
    img['verbs'] = []
    img['locations'] = []
for caption in tqdm(json.load(open('coco_splitted_captions_train2014.json', 'rt'))['annotations']):
    img = id2img[caption['image_id']]
    if caption['subject_phrase'] is not None:
        img['subjects'] += [caption['subject_phrase']]
    if caption['verb_phrase'] is not None:
        img['verbs'] += [caption['verb_phrase']]
    if caption['location_phrase'] is not None:
        img['locations'] += [caption['location_phrase']]

17919 745
sucessfully loaded features
sucessfully loaded features
sucessfully loaded features


100%|██████████| 414113/414113 [00:01<00:00, 240719.56it/s]


In [6]:
# model

In [7]:
class One2OneTransformation(torch.nn.Module):
    def __init__(self):
        super(One2OneTransformation, self).__init__()
        embed_dim = opt.embed_dim
        self.m = torch.nn.Sequential(
            torch.nn.Linear(embed_dim * 1, embed_dim * 2),
            torch.nn.ReLU(),
            torch.nn.Linear(embed_dim * 2, embed_dim * 2),
            torch.nn.BatchNorm1d(embed_dim * 2),
            torch.nn.ReLU(),
            torch.nn.Linear(embed_dim * 2, embed_dim)
        )
        self.norm = torch_functions.NormalizationLayer(learn_scale=False)

    def forward(self, x):
        f = self.norm(x)
        f = self.m(f)
        return f
    
class Three2OneTransformation(torch.nn.Module):
    def __init__(self):
        super(Three2OneTransformation, self).__init__()
        embed_dim = opt.embed_dim
        self.m = torch.nn.Sequential(
            torch.nn.Linear(embed_dim * 3, embed_dim * 5),
            torch.nn.ReLU(),
            torch.nn.Linear(embed_dim * 5, embed_dim * 5),
            torch.nn.BatchNorm1d(embed_dim * 5),
            torch.nn.ReLU(),
            torch.nn.Linear(embed_dim * 5, embed_dim)
        )
        self.norm = torch_functions.NormalizationLayer(learn_scale=False)

    def forward(self, x):
        f = torch.cat([self.norm(i) for i in x], dim=1)
        f = self.m(f)
        return f

model, optimizer = create_model(opt, trainset)
model.subject_extractor = One2OneTransformation()
model.verb_extractor = One2OneTransformation()
model.location_extractor = One2OneTransformation()
model.svl_combine = Three2OneTransformation() 
model = model.cuda()

In [8]:
# create optimizer
params = []
params.append({'params': [p for p in model.img_encoder.fc.parameters()]})
params.append({'params': [p for p in model.img_encoder.parameters()], 'lr': 0.1 * opt.learning_rate})
params.append({'params': [p for p in model.text_encoder.parameters()], 'lr': opt.learning_rate})
params.append({'params': [p for p in model.transformer.parameters()], 'weight_decay': opt.weight_decay * 0.1})
params.append({'params': [p for p in model.parameters()]})

# remove dup params (keep the first one)
for i1, p1 in enumerate(params):
  for i2, p2 in enumerate(params):
    if p1 is not p2:
      for p11 in p1['params']:
        for j, p22 in enumerate(p2['params']):
          if p11 is p22:
            p2['params'][j] = torch.tensor(0.0, requires_grad=True)

optimizer = torch.optim.SGD(
    params,
    lr=opt.learning_rate,
    momentum=opt.momentum,
    weight_decay=opt.weight_decay
)
if opt.optim == 'adam':
    optimizer = torch.optim.Adam(
        params,
        lr=opt.learning_rate,
        weight_decay=opt.weight_decay
    )

In [9]:
# test function

In [10]:
def test_svl(model, testset, opt):
    model = model.eval()

    # all img features
    img_features = []
    for data in testset.get_loader(batch_size = opt.batch_size, shuffle = False, drop_last= False):
        # extract image features
        imgs = np.stack([d['image'] for d in data])
        imgs = torch.from_numpy(imgs).float()
        if len(imgs.shape) == 2:
            imgs = model.img_encoder.fc(imgs.cuda())
        else:
            imgs = model.img_encoder(imgs.cuda())
        imgs = model.snorm(imgs).cpu().detach().numpy()
        img_features += [imgs]

    img_features = np.concatenate(img_features, axis=0)
    img_labels = [img['captions'][0] for img in testset.imgs]
    
    # construct random queries
    queries = []
    np.random.seed(123)
    for _ in range(5):
      for img in testset.imgs:
        if len(img['verbs']) == 0:
            continue
        while True:
            i = np.random.randint(0, len(testset.imgs))
            if img['subjects'][0] == testset.imgs[i]['subjects'][0] and img is not testset.imgs[i]:
                break
        while True:
            j = np.random.randint(0, len(testset.imgs))
            if len(testset.imgs[j]['verbs']) == 0:
                continue
            if img['verbs'][0] == testset.imgs[j]['verbs'][0] and img is not testset.imgs[j]:
                break
        while True:
            k = np.random.randint(0, len(testset.imgs))
            if img['locations'][0] == testset.imgs[k]['locations'][0] and img is not testset.imgs[k]:
                break
            
        
        queries += [{
            'subject_img_id': i,
            'verb_img_id': j,
            'location_img_id': k,
            'subject': img['subjects'][0],
            'verb': testset.imgs[j]['verbs'][0],
            'location': img['locations'][0],
            'label': img['captions'][0]
        }]
        
    #----
    r = []
    query_setting_combinations = []
    for s in ['t', 'i']:
        for v in ['t', 'i']:
            for l in ['t', 'i']:
                query_setting_combinations += [(s, v, l)]
    for s, v, l in query_setting_combinations:
        # compute query features
        query_features = []
        query_labels = []
        for i in range(0, len(queries), opt.batch_size):
            if s == 'i':
                subjects = model.subject_extractor(torch.from_numpy(
                    img_features[[q['subject_img_id'] for q in queries[i:(i+opt.batch_size)]],:]
                ).cuda())
            else:
                subjects = model.text_encoder([q['subject'] for q in queries[i:(i+opt.batch_size)]])
            if v == 'i':
                verbs = model.verb_extractor(torch.from_numpy(
                    img_features[[q['verb_img_id'] for q in queries[i:(i+opt.batch_size)]],:]
                ).cuda())
            else:
                verbs = model.text_encoder([q['verb'] for q in queries[i:(i+opt.batch_size)]])
            if l == 'i':
                locations = model.location_extractor(torch.from_numpy(
                    img_features[[q['location_img_id'] for q in queries[i:(i+opt.batch_size)]],:]
                ).cuda())
            else:
                locations = model.text_encoder([q['location'] for q in queries[i:(i+opt.batch_size)]])
            svl = model.svl_combine([subjects, verbs, locations])
            svl = svl.cpu().detach().numpy()
            query_features += [svl]
            query_labels += [q['label'] for q in queries[i:(i+opt.batch_size)]]

        query_features = np.concatenate(query_features, axis=0)

        # compute recall
        def measure_retrieval_performance(query_features, name = 'X'):
            sims = query_features.dot(img_features.T)
            sims = sims
            for k in [1, 5, 10]:
                r1 = 0.0
                r1_novel = 0.0
                count_novel = 0.0
                r1_nonnovel = 0.0
                count_nonnovel = 0.0
                for i in range(sims.shape[0]):
                    novel_query = False
                    if queries[i]['label'].split()[0] in ['trex', 'stormtrooper', 'darthvader', 'chewbacca']:
                        novel_query = True
                    if novel_query:
                        count_novel += 1
                    else:
                        count_nonnovel += 1
                        
                    s = -sims[i,:]
                    s = np.argsort(s)
                    if query_labels[i] in [img_labels[s[j]] for j in range(k)]:
                        r1 += 1
                        if novel_query:
                            r1_novel += 1
                        else:
                            r1_nonnovel += 1
                        
                r1 /= sims.shape[0]
                r.append(('svl_' + name + '_recall_top' + str(k), r1))
            return r
        measure_retrieval_performance(query_features, name = s + v + l)
    return r

def test(model, testset, opt):
    r = test_text_to_image_retrieval(model, testset, opt)
    if '112' in testset.name():
        r += test_svl(model, testset, opt)
    return r

In [11]:
# TRAIN

In [12]:
def compute_losses(model, data, losses_tracking, add_extract_compose_losses = True):
    losses = []

    # joint embedding loss
    imgs = np.stack([d['image'] for d in data])
    imgs = torch.from_numpy(imgs).float()
    if len(imgs.shape) == 2:
        imgs = model.img_encoder.fc(imgs.cuda())
    else:
         imgs = model.img_encoder(imgs.cuda())
    texts = [random.choice(d['captions']) for d in data]
    texts = model.text_encoder(texts)
    loss_name = 'joint_embedding'
    loss_weight = 1
    loss_value = model.pair_loss(texts, imgs).cuda()
    losses += [(loss_name, loss_weight, loss_value)]
    
    def do_add_extract_compose_losses():
        try:
            subjects = [random.choice(trainset.imgs[d['index']]['subjects']) for d in data]
            verbs = [random.choice(trainset.imgs[d['index']]['verbs']) for d in data]
            locations = [random.choice(trainset.imgs[d['index']]['locations']) for d in data]
        except:
            return
        encoded_subjects = model.text_encoder(subjects).detach()
        encoded_verbs = model.text_encoder(verbs).detach()
        encoded_locations = model.text_encoder(locations).detach()
        extracted_subjects = model.subject_extractor(random.choice([texts, imgs]).detach())
        extracted_verbs = model.verb_extractor(random.choice([texts, imgs]).detach())
        extracted_location = model.location_extractor(random.choice([texts, imgs]).detach())
            
        # extract
        loss_value = 0
        loss_value += model.pair_loss(
            torch.cat([extracted_subjects, extracted_verbs, extracted_location]),
            torch.cat([encoded_subjects, encoded_verbs, encoded_locations])
        ).cuda()
        loss_name = 'extract'
        loss_weight = 1
        losses.append((loss_name, loss_weight, loss_value))
        
        # compose with encoded
        loss_value = model.pair_loss(
            model.svl_combine([encoded_subjects, encoded_verbs, encoded_locations]),
            random.choice([imgs, model.text_encoder([s + ' ' + v + ' ' + l for s, v, l in zip(subjects, verbs, locations)])]).detach()
        ).cuda()
        loss_name = 'compose1'
        loss_weight = 0.5
        losses.append((loss_name, loss_weight, loss_value))

        # shuffle
        shuffled_subjects_indices = range(len(data))
        shuffled_verbs_indices = range(len(data))
        shuffled_locations_indices = range(len(data))
        random.shuffle(shuffled_subjects_indices)
        random.shuffle(shuffled_verbs_indices)
        random.shuffle(shuffled_locations_indices)
        encoded_subjects = encoded_subjects[shuffled_subjects_indices,:]
        encoded_verbs = encoded_verbs[shuffled_verbs_indices,:]
        encoded_locations = encoded_locations[shuffled_locations_indices,:]
        extracted_subjects = extracted_subjects[shuffled_subjects_indices,:]
        extracted_verbs = extracted_verbs[shuffled_verbs_indices,:]
        extracted_location = extracted_location[shuffled_locations_indices,:]
        subjects = np.array(subjects)[shuffled_subjects_indices]
        verbs = np.array(verbs)[shuffled_verbs_indices]
        locations = np.array(locations)[shuffled_locations_indices]

        # compose with extracted
        loss_value = model.pair_loss(
            model.svl_combine([extracted_subjects, extracted_verbs, extracted_location]),
            model.text_encoder([s + ' ' + v + ' ' + l for s, v, l in zip(subjects, verbs, locations)]).detach()
        ).cuda()
        loss_name = 'compose2'
        loss_weight = 0.5
        losses.append((loss_name, loss_weight, loss_value))
    if add_extract_compose_losses:
        do_add_extract_compose_losses()

    # total loss
    total_loss = sum([loss_weight * loss_value for loss_name, loss_weight, loss_value in losses])
    assert(not torch.isnan(total_loss))
    losses += [('total training loss', None, total_loss)]

    # save losses
    for loss_name, loss_weight, loss_value in losses:
        if not losses_tracking.has_key(loss_name):
            losses_tracking[loss_name] = []
        losses_tracking[loss_name].append(float(loss_value.data.item()))
    return total_loss

def train_1_epoch(model, optimizer, trainset, opt, losses_tracking, add_extract_compose_losses = True):
    model.train()
    loader = trainset.get_loader(
        batch_size=opt.batch_size, shuffle=True,
        drop_last=True, num_workers=opt.loader_num_workers)
    for data in tqdm(loader, desc = 'training 1 epoch'):
        total_loss = compute_losses(model, data, losses_tracking, add_extract_compose_losses)
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

In [13]:
# train loop
losses_tracking = {}
epoch = 0
tic = time.time()
while True:

    # show stat, training losses
    print 'Epoch', epoch, 'Elapsed time', round(time.time() - tic, 4)
    tic = time.time()
    for loss_name in losses_tracking:
        avg_loss = np.mean(losses_tracking[loss_name][-250:])
        print '   ', loss_name, round(avg_loss, 4)
        logger.add_scalar(loss_name, avg_loss, epoch)

    # test
    tests = []
    for dataset in [trainset, sic112]:
        t = test(model, dataset, opt)
        tests += [(dataset.name() + ' ' + metric_name, metric_value) for metric_name, metric_value in t]
    for metric_name, metric_value in tests:
        print ' ', metric_name, round(metric_value, 4)
        logger.add_scalar(metric_name, metric_value, epoch)

    # train
    if epoch >= opt.num_epochs:
        break
    train_1_epoch(model, optimizer, trainset, opt, losses_tracking,
                  add_extract_compose_losses = epoch>=1)
    epoch += 1

    # learing rate scheduling
    if epoch % opt.learning_rate_decay_frequency == 0:
        for g in optimizer.param_groups:
            g['lr'] *= 0.1

Epoch 0 Elapsed time 0.0008


training 1 epoch:   0%|          | 0/2586 [00:00<?, ?it/s]

  CocoCapTrain text2image_recall_top1 0.0005
  SimpleImageCaptions112 text2image_recall_top1 0.0
  SimpleImageCaptions112 svl_ttt_recall_top1 0.0123
  SimpleImageCaptions112 svl_ttt_recall_top5 0.0467
  SimpleImageCaptions112 svl_ttt_recall_top10 0.0712
  SimpleImageCaptions112 svl_tti_recall_top1 0.0123
  SimpleImageCaptions112 svl_tti_recall_top5 0.0459
  SimpleImageCaptions112 svl_tti_recall_top10 0.0816
  SimpleImageCaptions112 svl_tit_recall_top1 0.0123
  SimpleImageCaptions112 svl_tit_recall_top5 0.0341
  SimpleImageCaptions112 svl_tit_recall_top10 0.0712
  SimpleImageCaptions112 svl_tii_recall_top1 0.0123
  SimpleImageCaptions112 svl_tii_recall_top5 0.037
  SimpleImageCaptions112 svl_tii_recall_top10 0.0713
  SimpleImageCaptions112 svl_itt_recall_top1 0.0123
  SimpleImageCaptions112 svl_itt_recall_top5 0.0344
  SimpleImageCaptions112 svl_itt_recall_top10 0.0816
  SimpleImageCaptions112 svl_iti_recall_top1 0.0123
  SimpleImageCaptions112 svl_iti_recall_top5 0.0467
  SimpleImageCa

training 1 epoch: 100%|██████████| 2586/2586 [01:16<00:00, 35.70it/s]


Epoch 1 Elapsed time 113.1016
    total training loss 0.7649
    joint_embedding 0.7649


training 1 epoch:   0%|          | 0/2586 [00:00<?, ?it/s]

  CocoCapTrain text2image_recall_top1 0.1326
  SimpleImageCaptions112 text2image_recall_top1 0.2045
  SimpleImageCaptions112 svl_ttt_recall_top1 0.0
  SimpleImageCaptions112 svl_ttt_recall_top5 0.0123
  SimpleImageCaptions112 svl_ttt_recall_top10 0.0123
  SimpleImageCaptions112 svl_tti_recall_top1 0.0
  SimpleImageCaptions112 svl_tti_recall_top5 0.0021
  SimpleImageCaptions112 svl_tti_recall_top10 0.0246
  SimpleImageCaptions112 svl_tit_recall_top1 0.0
  SimpleImageCaptions112 svl_tit_recall_top5 0.0
  SimpleImageCaptions112 svl_tit_recall_top10 0.0246
  SimpleImageCaptions112 svl_tii_recall_top1 0.0
  SimpleImageCaptions112 svl_tii_recall_top5 0.0123
  SimpleImageCaptions112 svl_tii_recall_top10 0.0237
  SimpleImageCaptions112 svl_itt_recall_top1 0.0
  SimpleImageCaptions112 svl_itt_recall_top5 0.0123
  SimpleImageCaptions112 svl_itt_recall_top10 0.0269
  SimpleImageCaptions112 svl_iti_recall_top1 0.0
  SimpleImageCaptions112 svl_iti_recall_top5 0.0246
  SimpleImageCaptions112 svl_iti

training 1 epoch: 100%|██████████| 2586/2586 [02:55<00:00, 14.99it/s]


Epoch 2 Elapsed time 211.9801
    total training loss 3.2935
    extract 2.7807
    joint_embedding 0.6419
    compose2 1.4407
    compose1 0.4638


training 1 epoch:   0%|          | 0/2586 [00:00<?, ?it/s]

  CocoCapTrain text2image_recall_top1 0.1551
  SimpleImageCaptions112 text2image_recall_top1 0.1893
  SimpleImageCaptions112 svl_ttt_recall_top1 0.2192
  SimpleImageCaptions112 svl_ttt_recall_top5 0.5224
  SimpleImageCaptions112 svl_ttt_recall_top10 0.7772
  SimpleImageCaptions112 svl_tti_recall_top1 0.1684
  SimpleImageCaptions112 svl_tti_recall_top5 0.5192
  SimpleImageCaptions112 svl_tti_recall_top10 0.6916
  SimpleImageCaptions112 svl_tit_recall_top1 0.1817
  SimpleImageCaptions112 svl_tit_recall_top5 0.4927
  SimpleImageCaptions112 svl_tit_recall_top10 0.6748
  SimpleImageCaptions112 svl_tii_recall_top1 0.1324
  SimpleImageCaptions112 svl_tii_recall_top5 0.3876
  SimpleImageCaptions112 svl_tii_recall_top10 0.5439
  SimpleImageCaptions112 svl_itt_recall_top1 0.1714
  SimpleImageCaptions112 svl_itt_recall_top5 0.5176
  SimpleImageCaptions112 svl_itt_recall_top10 0.7014
  SimpleImageCaptions112 svl_iti_recall_top1 0.1338
  SimpleImageCaptions112 svl_iti_recall_top5 0.4277
  SimpleIma

training 1 epoch: 100%|██████████| 2586/2586 [02:56<00:00, 14.69it/s]


Epoch 3 Elapsed time 211.3612
    total training loss 2.9445
    extract 2.6656
    joint_embedding 0.5478
    compose2 1.3172
    compose1 0.4592


training 1 epoch:   0%|          | 0/2586 [00:00<?, ?it/s]

  CocoCapTrain text2image_recall_top1 0.1959
  SimpleImageCaptions112 text2image_recall_top1 0.208
  SimpleImageCaptions112 svl_ttt_recall_top1 0.2044
  SimpleImageCaptions112 svl_ttt_recall_top5 0.5727
  SimpleImageCaptions112 svl_ttt_recall_top10 0.763
  SimpleImageCaptions112 svl_tti_recall_top1 0.1686
  SimpleImageCaptions112 svl_tti_recall_top5 0.5193
  SimpleImageCaptions112 svl_tti_recall_top10 0.675
  SimpleImageCaptions112 svl_tit_recall_top1 0.191
  SimpleImageCaptions112 svl_tit_recall_top5 0.4956
  SimpleImageCaptions112 svl_tit_recall_top10 0.662
  SimpleImageCaptions112 svl_tii_recall_top1 0.1389
  SimpleImageCaptions112 svl_tii_recall_top5 0.4076
  SimpleImageCaptions112 svl_tii_recall_top10 0.5616
  SimpleImageCaptions112 svl_itt_recall_top1 0.1899
  SimpleImageCaptions112 svl_itt_recall_top5 0.5429
  SimpleImageCaptions112 svl_itt_recall_top10 0.7131
  SimpleImageCaptions112 svl_iti_recall_top1 0.1585
  SimpleImageCaptions112 svl_iti_recall_top5 0.4419
  SimpleImageCap

training 1 epoch: 100%|██████████| 2586/2586 [02:56<00:00, 14.65it/s]


Epoch 4 Elapsed time 211.9835
    total training loss 2.9333
    extract 2.6095
    joint_embedding 0.5068
    compose2 1.2512
    compose1 0.4309


training 1 epoch:   0%|          | 0/2586 [00:00<?, ?it/s]

  CocoCapTrain text2image_recall_top1 0.2099
  SimpleImageCaptions112 text2image_recall_top1 0.217
  SimpleImageCaptions112 svl_ttt_recall_top1 0.143
  SimpleImageCaptions112 svl_ttt_recall_top5 0.5077
  SimpleImageCaptions112 svl_ttt_recall_top10 0.6765
  SimpleImageCaptions112 svl_tti_recall_top1 0.1812
  SimpleImageCaptions112 svl_tti_recall_top5 0.5223
  SimpleImageCaptions112 svl_tti_recall_top10 0.6786
  SimpleImageCaptions112 svl_tit_recall_top1 0.1786
  SimpleImageCaptions112 svl_tit_recall_top5 0.5077
  SimpleImageCaptions112 svl_tit_recall_top10 0.6953
  SimpleImageCaptions112 svl_tii_recall_top1 0.1655
  SimpleImageCaptions112 svl_tii_recall_top5 0.4485
  SimpleImageCaptions112 svl_tii_recall_top10 0.6021
  SimpleImageCaptions112 svl_itt_recall_top1 0.1824
  SimpleImageCaptions112 svl_itt_recall_top5 0.5281
  SimpleImageCaptions112 svl_itt_recall_top10 0.7041
  SimpleImageCaptions112 svl_iti_recall_top1 0.1416
  SimpleImageCaptions112 svl_iti_recall_top5 0.4291
  SimpleImage

training 1 epoch: 100%|██████████| 2586/2586 [02:56<00:00, 15.41it/s]


Epoch 5 Elapsed time 211.8101
    total training loss 2.6842
    extract 2.5427
    joint_embedding 0.4746
    compose2 1.1771
    compose1 0.381


training 1 epoch:   0%|          | 0/2586 [00:00<?, ?it/s]

  CocoCapTrain text2image_recall_top1 0.2277
  SimpleImageCaptions112 text2image_recall_top1 0.2259
  SimpleImageCaptions112 svl_ttt_recall_top1 0.2056
  SimpleImageCaptions112 svl_ttt_recall_top5 0.655
  SimpleImageCaptions112 svl_ttt_recall_top10 0.7993
  SimpleImageCaptions112 svl_tti_recall_top1 0.1921
  SimpleImageCaptions112 svl_tti_recall_top5 0.5434
  SimpleImageCaptions112 svl_tti_recall_top10 0.7087
  SimpleImageCaptions112 svl_tit_recall_top1 0.2101
  SimpleImageCaptions112 svl_tit_recall_top5 0.524
  SimpleImageCaptions112 svl_tit_recall_top10 0.6728
  SimpleImageCaptions112 svl_tii_recall_top1 0.1613
  SimpleImageCaptions112 svl_tii_recall_top5 0.4481
  SimpleImageCaptions112 svl_tii_recall_top10 0.5998
  SimpleImageCaptions112 svl_itt_recall_top1 0.1699
  SimpleImageCaptions112 svl_itt_recall_top5 0.5087
  SimpleImageCaptions112 svl_itt_recall_top10 0.6586
  SimpleImageCaptions112 svl_iti_recall_top1 0.1595
  SimpleImageCaptions112 svl_iti_recall_top5 0.4403
  SimpleImage

training 1 epoch: 100%|██████████| 2586/2586 [02:56<00:00, 14.40it/s]


Epoch 6 Elapsed time 211.9347
    total training loss 2.6923
    extract 2.5058
    joint_embedding 0.4551
    compose2 1.1243
    compose1 0.3732


training 1 epoch:   0%|          | 0/2586 [00:00<?, ?it/s]

  CocoCapTrain text2image_recall_top1 0.2326
  SimpleImageCaptions112 text2image_recall_top1 0.2634
  SimpleImageCaptions112 svl_ttt_recall_top1 0.2597
  SimpleImageCaptions112 svl_ttt_recall_top5 0.6986
  SimpleImageCaptions112 svl_ttt_recall_top10 0.8564
  SimpleImageCaptions112 svl_tti_recall_top1 0.1964
  SimpleImageCaptions112 svl_tti_recall_top5 0.561
  SimpleImageCaptions112 svl_tti_recall_top10 0.7408
  SimpleImageCaptions112 svl_tit_recall_top1 0.1834
  SimpleImageCaptions112 svl_tit_recall_top5 0.5115
  SimpleImageCaptions112 svl_tit_recall_top10 0.6855
  SimpleImageCaptions112 svl_tii_recall_top1 0.1546
  SimpleImageCaptions112 svl_tii_recall_top5 0.4421
  SimpleImageCaptions112 svl_tii_recall_top10 0.6043
  SimpleImageCaptions112 svl_itt_recall_top1 0.2212
  SimpleImageCaptions112 svl_itt_recall_top5 0.5743
  SimpleImageCaptions112 svl_itt_recall_top10 0.732
  SimpleImageCaptions112 svl_iti_recall_top1 0.1811
  SimpleImageCaptions112 svl_iti_recall_top5 0.5133
  SimpleImage

training 1 epoch: 100%|██████████| 2586/2586 [02:56<00:00, 14.62it/s]


Epoch 7 Elapsed time 212.4345
    total training loss 2.6842
    extract 2.4965
    joint_embedding 0.447
    compose2 1.0986
    compose1 0.3732


training 1 epoch:   0%|          | 0/2586 [00:00<?, ?it/s]

  CocoCapTrain text2image_recall_top1 0.2457
  SimpleImageCaptions112 text2image_recall_top1 0.2589
  SimpleImageCaptions112 svl_ttt_recall_top1 0.2676
  SimpleImageCaptions112 svl_ttt_recall_top5 0.6869
  SimpleImageCaptions112 svl_ttt_recall_top10 0.8042
  SimpleImageCaptions112 svl_tti_recall_top1 0.1876
  SimpleImageCaptions112 svl_tti_recall_top5 0.5535
  SimpleImageCaptions112 svl_tti_recall_top10 0.7187
  SimpleImageCaptions112 svl_tit_recall_top1 0.216
  SimpleImageCaptions112 svl_tit_recall_top5 0.5535
  SimpleImageCaptions112 svl_tit_recall_top10 0.715
  SimpleImageCaptions112 svl_tii_recall_top1 0.1406
  SimpleImageCaptions112 svl_tii_recall_top5 0.4194
  SimpleImageCaptions112 svl_tii_recall_top10 0.5915
  SimpleImageCaptions112 svl_itt_recall_top1 0.2371
  SimpleImageCaptions112 svl_itt_recall_top5 0.5918
  SimpleImageCaptions112 svl_itt_recall_top10 0.7651
  SimpleImageCaptions112 svl_iti_recall_top1 0.16
  SimpleImageCaptions112 svl_iti_recall_top5 0.4646
  SimpleImageCa

training 1 epoch:  25%|██▍       | 636/2586 [00:43<02:03, 15.80it/s]

KeyboardInterrupt: 