In [26]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

import numpy as np
import time
import sys
import os
import math
import tqdm

from nltk.tokenize import RegexpTokenizer
from transformers import BertTokenizer, AutoTokenizer
from PIL import Image
import argparse

from catr.models import caption
from catr.models import utils as mtils
from catr.datasets import coco, utils
from catr.cfg_damsm_bert import Config
# from catr.configuration import Config

import json, pickle
from pycocotools.coco import COCO as CC
import matplotlib.pyplot as plt


from model import CNN_ENCODER
from miscc.config import cfg, cfg_from_file
from datasets import prepare_data

In [27]:
cfg_from_file('cfg/coco_multimodal.yml')
config = Config() # initialize catr config here
tokenizer = BertTokenizer.from_pretrained(config.vocab, do_lower=True)
retokenizer = BertTokenizer.from_pretrained("catr/damsm_vocab.txt", do_lower=True)
frozen_list_image_encoder = ['Conv2d_1a_3x3','Conv2d_2a_3x3','Conv2d_2b_3x3','Conv2d_3b_1x1','Conv2d_4a_3x3']

Calling BertTokenizer.from_pretrained() with the path to a single file or url is deprecated
Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


In [3]:
# image_encoder initialization in trainer_s3
image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
image_encoder.load_state_dict(state_dict)
for p in image_encoder.parameters(): # make image encoder grad on
    p.requires_grad = True
for k,v in image_encoder.named_children(): # freeze the layer1-5 (set eval for BNlayer)
    if k in frozen_list_image_encoder:
        v.train(False)
        v.requires_grad_(False)
print('Load image encoder from:', img_encoder_path)



Load pretrained model from  https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth
Load image encoder from: ../DAMSMencoders/coco/image_encoder240.pth


In [4]:
cnnd = torch.load('../DAMSMencoders/coco/image_encoder240.pth')

In [5]:
cnnd.keys()

odict_keys(['Conv2d_1a_3x3.conv.weight', 'Conv2d_1a_3x3.bn.weight', 'Conv2d_1a_3x3.bn.bias', 'Conv2d_1a_3x3.bn.running_mean', 'Conv2d_1a_3x3.bn.running_var', 'Conv2d_1a_3x3.bn.num_batches_tracked', 'Conv2d_2a_3x3.conv.weight', 'Conv2d_2a_3x3.bn.weight', 'Conv2d_2a_3x3.bn.bias', 'Conv2d_2a_3x3.bn.running_mean', 'Conv2d_2a_3x3.bn.running_var', 'Conv2d_2a_3x3.bn.num_batches_tracked', 'Conv2d_2b_3x3.conv.weight', 'Conv2d_2b_3x3.bn.weight', 'Conv2d_2b_3x3.bn.bias', 'Conv2d_2b_3x3.bn.running_mean', 'Conv2d_2b_3x3.bn.running_var', 'Conv2d_2b_3x3.bn.num_batches_tracked', 'Conv2d_3b_1x1.conv.weight', 'Conv2d_3b_1x1.bn.weight', 'Conv2d_3b_1x1.bn.bias', 'Conv2d_3b_1x1.bn.running_mean', 'Conv2d_3b_1x1.bn.running_var', 'Conv2d_3b_1x1.bn.num_batches_tracked', 'Conv2d_4a_3x3.conv.weight', 'Conv2d_4a_3x3.bn.weight', 'Conv2d_4a_3x3.bn.bias', 'Conv2d_4a_3x3.bn.running_mean', 'Conv2d_4a_3x3.bn.running_var', 'Conv2d_4a_3x3.bn.num_batches_tracked', 'Mixed_5b.branch1x1.conv.weight', 'Mixed_5b.branch1x1.bn.w

In [18]:
print(image_encoder.Mixed_6e.branch_pool.bn.running_mean[:20]) # initialized params 
print(cnnd['Mixed_6e.branch_pool.bn.running_mean'][:20])

tensor([-0.4605, -0.5155, -0.3081, -0.5765, -0.2425, -0.5707, -0.4499, -0.4905,
        -0.4998, -0.5094, -0.5335, -0.4113, -0.4537, -0.4594, -0.3256, -0.4197,
        -0.6261, -0.4322, -0.5136, -0.4649])
tensor([-0.4605, -0.5155, -0.3081, -0.5765, -0.2425, -0.5707, -0.4499, -0.4905,
        -0.4998, -0.5094, -0.5335, -0.4113, -0.4537, -0.4594, -0.3256, -0.4197,
        -0.6261, -0.4322, -0.5136, -0.4649], device='cuda:0')


In [14]:
model, _ = caption.build_model_v2(config) # build the full caption model like CATR
# model.to(device)
print("Loading Checkpoint...")
checkpoint = torch.load('catr/checkpoints/catr_damsm256_proj_coco2014_ep02.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'])

In [8]:
checkpoint['model'].keys()

odict_keys(['backbone.0.cnn_enc.Conv2d_1a_3x3.conv.weight', 'backbone.0.cnn_enc.Conv2d_1a_3x3.bn.weight', 'backbone.0.cnn_enc.Conv2d_1a_3x3.bn.bias', 'backbone.0.cnn_enc.Conv2d_1a_3x3.bn.running_mean', 'backbone.0.cnn_enc.Conv2d_1a_3x3.bn.running_var', 'backbone.0.cnn_enc.Conv2d_1a_3x3.bn.num_batches_tracked', 'backbone.0.cnn_enc.Conv2d_2a_3x3.conv.weight', 'backbone.0.cnn_enc.Conv2d_2a_3x3.bn.weight', 'backbone.0.cnn_enc.Conv2d_2a_3x3.bn.bias', 'backbone.0.cnn_enc.Conv2d_2a_3x3.bn.running_mean', 'backbone.0.cnn_enc.Conv2d_2a_3x3.bn.running_var', 'backbone.0.cnn_enc.Conv2d_2a_3x3.bn.num_batches_tracked', 'backbone.0.cnn_enc.Conv2d_2b_3x3.conv.weight', 'backbone.0.cnn_enc.Conv2d_2b_3x3.bn.weight', 'backbone.0.cnn_enc.Conv2d_2b_3x3.bn.bias', 'backbone.0.cnn_enc.Conv2d_2b_3x3.bn.running_mean', 'backbone.0.cnn_enc.Conv2d_2b_3x3.bn.running_var', 'backbone.0.cnn_enc.Conv2d_2b_3x3.bn.num_batches_tracked', 'backbone.0.cnn_enc.Conv2d_3b_1x1.conv.weight', 'backbone.0.cnn_enc.Conv2d_3b_1x1.bn.wei

In [19]:
print(checkpoint['model']['backbone.0.cnn_enc.Mixed_6e.branch_pool.bn.running_mean'][:20])
print(model.backbone[0].cnn_enc.Mixed_6e.branch_pool.bn.running_mean[:20])

tensor([ 0.0749, -0.0109,  0.0081, -0.0231, -0.0089,  0.0211, -0.0219, -0.0052,
        -0.0149,  0.0397,  0.0272, -0.0468,  0.0276,  0.0150,  0.0170, -0.0284,
        -0.0249, -0.0120,  0.0025,  0.0248])
tensor([ 0.0749, -0.0109,  0.0081, -0.0231, -0.0089,  0.0211, -0.0219, -0.0052,
        -0.0149,  0.0397,  0.0272, -0.0468,  0.0276,  0.0150,  0.0170, -0.0284,
        -0.0249, -0.0120,  0.0025,  0.0248], grad_fn=<SelectBackward>)


### Load image encoder checkpoint in CATR model successfully $\uparrow$

In [21]:
# initialize catr model in trainer_s3
cap_model = caption.build_model_v3(config)
print("Initializing from Checkpoint...")
base_line_path = 'catr/checkpoints/catr_damsm256_proj_coco2014_ep02.pth'
print('Load C from: {0}'.format(base_line_path))
checkv3 = torch.load(base_line_path, map_location='cpu')
cap_model.load_state_dict(checkv3['model'], strict=False)
print(cap_model.state_dict().keys())

Initializing from Checkpoint...
Load C from: catr/checkpoints/catr_damsm256_proj_coco2014_ep02.pth
odict_keys(['input_proj_v2.weight', 'input_proj_v2.bias', 'transformer.encoder.layers.0.self_attn.in_proj_weight', 'transformer.encoder.layers.0.self_attn.in_proj_bias', 'transformer.encoder.layers.0.self_attn.out_proj.weight', 'transformer.encoder.layers.0.self_attn.out_proj.bias', 'transformer.encoder.layers.0.linear1.weight', 'transformer.encoder.layers.0.linear1.bias', 'transformer.encoder.layers.0.linear2.weight', 'transformer.encoder.layers.0.linear2.bias', 'transformer.encoder.layers.0.norm1.weight', 'transformer.encoder.layers.0.norm1.bias', 'transformer.encoder.layers.0.norm2.weight', 'transformer.encoder.layers.0.norm2.bias', 'transformer.encoder.layers.1.self_attn.in_proj_weight', 'transformer.encoder.layers.1.self_attn.in_proj_bias', 'transformer.encoder.layers.1.self_attn.out_proj.weight', 'transformer.encoder.layers.1.self_attn.out_proj.bias', 'transformer.encoder.layers.1.l

In [24]:
# check the transformer part of full CATR model 
print(checkpoint['model']['input_proj_v2.weight'][:20,0,0,0])
print(model.input_proj_v2.weight[:20,0,0,0])
# check the caption model in trainer_v3
print(cap_model.input_proj_v2.weight[:20,0,0,0])

print(checkpoint['model']['mlp.layers.2.weight'][:20,0])
print(model.mlp.layers[2].weight[:20,0])
# check the caption model in trainer_v3
print(cap_model.mlp.layers[2].weight[:20,0])

tensor([ 0.0749, -0.0109,  0.0081, -0.0231, -0.0089,  0.0211, -0.0219, -0.0052,
        -0.0149,  0.0397,  0.0272, -0.0468,  0.0276,  0.0150,  0.0170, -0.0284,
        -0.0249, -0.0120,  0.0025,  0.0248])
tensor([ 0.0749, -0.0109,  0.0081, -0.0231, -0.0089,  0.0211, -0.0219, -0.0052,
        -0.0149,  0.0397,  0.0272, -0.0468,  0.0276,  0.0150,  0.0170, -0.0284,
        -0.0249, -0.0120,  0.0025,  0.0248], grad_fn=<SelectBackward>)
tensor([ 0.0749, -0.0109,  0.0081, -0.0231, -0.0089,  0.0211, -0.0219, -0.0052,
        -0.0149,  0.0397,  0.0272, -0.0468,  0.0276,  0.0150,  0.0170, -0.0284,
        -0.0249, -0.0120,  0.0025,  0.0248], grad_fn=<SelectBackward>)
tensor([-0.6451, -0.1374, -0.2115, -0.1234, -0.1399, -0.1700, -0.1499, -0.1727,
        -0.1674, -0.1452, -0.1300, -0.1602, -0.1683, -0.1396, -0.1389, -0.1459,
        -0.1970, -0.1329, -0.1731, -0.1814])
tensor([-0.6451, -0.1374, -0.2115, -0.1234, -0.1399, -0.1700, -0.1499, -0.1727,
        -0.1674, -0.1452, -0.1300, -0.1602, -0.1

### Load cap_model from checkpoint in trainer_v3 successfully $\uparrow$

In [31]:
checkv3['lr_scheduler']

{'step_size': 2,
 'gamma': 0.1,
 'base_lrs': [0.0001, 0],
 'last_epoch': 3,
 '_step_count': 4,
 '_get_lr_called_within_step': False,
 '_last_lr': [1e-05, 0.0]}

In [36]:
checkv3['optimizer']['param_groups'][0]

{'lr': 1e-05,
 'betas': (0.9, 0.999),
 'eps': 1e-08,
 'weight_decay': 0.0001,
 'amsgrad': False,
 'initial_lr': 0.0001,
 'params': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,
  111,
  112,
  113,
  114,
  115,
  116,
  117,
  118,
  119,
  120,
  121,
  122,
  123,
  124,
  125,
  126,
  127,
  128,
  129,
  130,
  131,
  132,
  133,
  134,
  135,
  136,
  137,
  138,
  139,


In [28]:
## in CATR, image_encoder is joint with transformer
## when trainer_s3, separate image_encoder and cap_model is initialized
###### use catr dataloader here for testing
dataset_train = coco.build_dataset14(config, mode='training')
dataset_val = coco.build_dataset14(config, mode='validation')
print(f"Train: {len(dataset_train)}")
print(f"Valid: {len(dataset_val)}")

sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(
    sampler_train, config.batch_size, drop_last=True
)

data_loader_train = DataLoader(
    dataset_train, batch_sampler=batch_sampler_train, num_workers=0)
data_loader_val = DataLoader(dataset_val, config.batch_size,
                             sampler=sampler_val, drop_last=False, num_workers=0)

Train: 414113
Valid: 202654


In [29]:
model = model.cuda()
image_encoder = image_encoder.cuda()
cap_model = cap_model.cuda()

In [30]:
@torch.no_grad()
def evaluate(data_loader):
    model.eval()
    image_encoder.eval()
    cap_model.eval()

    for images, masks, caps, cap_masks in data_loader:
        images = images.cuda()
        masks = masks.cuda()
        samples = mtils.NestedTensor(images, masks)
        caps = caps.cuda()
        cap_masks = cap_masks.cuda()
        outputs = model(samples, caps[:, :-1], cap_masks[:, :-1])
        words_features, sent_code = image_encoder(images)
        cap_preds = cap_model(words_features, masks, caps[:, :-1], cap_masks[:, :-1])
        break   # only used for debug, inspect 1st iter
                
    return outputs, cap_preds

In [37]:
outputs, cap_preds = evaluate(data_loader_val)

  "See the documentation of nn.Upsample for details.".format(mode))


In [38]:
print(outputs.shape, cap_preds.shape)

torch.Size([32, 128, 30522]) torch.Size([32, 128, 30522])


In [39]:
print(outputs[0,1,1001:1050])
print(cap_preds[0,1,1001:1050])
### the prediction is exactly the same for the full catr model and the separate model, success ###

tensor([-43.6141, -43.2130, -69.7260, -44.7634, -38.0057, -50.1113, -47.0646,
        -52.8086, -64.2075, -35.6812, -36.8284, -38.3564, -40.8463, -44.9847,
        -42.4597, -38.7589, -37.0208, -37.8280, -41.2827, -41.9383, -42.1363,
        -40.7763, -40.1552, -44.1159, -40.8252, -72.9402, -40.7454, -43.7276,
        -46.0839, -43.5175, -42.3380, -42.1800, -47.0214, -73.6014, -70.8860,
        -46.1343, -34.3810, -38.8425, -38.2300, -42.8508, -40.5383, -43.8769,
        -39.9417, -43.6237, -39.0572, -43.5023, -41.8012, -42.1579, -46.1298],
       device='cuda:0')
tensor([-43.6141, -43.2130, -69.7260, -44.7634, -38.0057, -50.1113, -47.0646,
        -52.8086, -64.2075, -35.6812, -36.8284, -38.3564, -40.8463, -44.9847,
        -42.4597, -38.7589, -37.0208, -37.8280, -41.2827, -41.9383, -42.1363,
        -40.7763, -40.1552, -44.1159, -40.8252, -72.9402, -40.7454, -43.7276,
        -46.0839, -43.5175, -42.3380, -42.1800, -47.0214, -73.6014, -70.8860,
        -46.1343, -34.3810, -38.8425, -

In [42]:
pred_f = torch.argmax(outputs, axis=-1).cpu().detach().numpy()
pred_c = torch.argmax(cap_preds, axis=-1).cpu().detach().numpy()
cp_f = tokenizer.batch_decode(pred_f.tolist(), skip_special_tokens=True)
cp_c = tokenizer.batch_decode(pred_c.tolist(), skip_special_tokens=True)
for i in range(config.batch_size):
    print('%02d: ' % i + cp_f[i], sep='\n')
    print('%02d: ' % i + cp_c[i], sep='\n')

00: a clock with clock a clock on roman numbers of.
00: a clock with clock a clock on roman numbers of.
01: a motorcycle motorcycle motorcycle parked in a of a house.
01: a motorcycle motorcycle motorcycle parked in a of a house.
02: a bathroom with a walls and a white toilet. a.
02: a bathroom with a walls and a white toilet. a.
03: a bicycle parked is to be parked on on a bicycle. on.
03: a bicycle parked is to be parked on on a bicycle. on.
04: a large jet jet flying through the sky.
04: a large jet jet flying through the sky.
05: a is a planeur plane that off from the clear cloudy sky
05: a is a planeur plane that off from the clear cloudy sky
06: a walls white bathroom bathroom with a home home.
06: a walls white bathroom bathroom with a home home.
07: a bathroom a bathroom bathroom white bathroom with a white mirror and toilet toilet sizede. the wall.
07: a bathroom a bathroom bathroom white bathroom with a white mirror and toilet toilet sizede. the wall.
08: a bathroom bathroom 

### The output predictions from CATR and image_encoder+cap_model are the same, success $\uparrow$