### Inputs

(**Please modify it accordingly**)

In [9]:
img_path = 'COCO_val2014_000000000772.jpg'
vocab_path = '/home/quang/datasets/coco_caption/annotations/vocab.json'
checkpoint = "/home/quang/checkpoints/ecaptioner/coco/exp34b/grit_checkpoint_4ds.pth"

### Intialize a Hydra Config

In [10]:
import sys
sys.path.append("..")

import os
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra import initialize, initialize_config_module, initialize_config_dir, compose

# initialize hydra config
GlobalHydra.instance().clear()
initialize(config_path="../configs/caption")
config = compose(config_name='coco_config.yaml', overrides=[f"exp.checkpoint={checkpoint}"])

In [11]:
import torch

# model
from models.caption.detector import build_detector
from models.caption import Transformer

# dataset
from PIL import Image
from datasets.caption.field import TextField
from datasets.caption.transforms import get_transform
from engine.utils import nested_tensor_from_tensor_list

device = torch.device(f"cuda:0")

### Build a model

In [12]:
detector = build_detector(config).to(device)
model = Transformer(detector, config=config)
model = model.to(device)

# load checkpoint
if os.path.exists(config.exp.checkpoint):
    checkpoint = torch.load(config.exp.checkpoint, map_location='cpu')
    missing, unexpected = model.load_state_dict(checkpoint['state_dict'], strict=False)
    print(f"det missing:{len(missing)} det unexpected:{len(unexpected)}")
    
model.cached_features = False

# prepare utils
transform = get_transform(config.dataset.transform_cfg)['valid']
text_field = TextField(vocab_path=vocab_path)

model missing: 0
model unexpected: 0


### Load and Transform An Image

In [13]:
rgb_image = Image.open(img_path).convert('RGB')
image = transform(rgb_image)
images = nested_tensor_from_tensor_list([image]).to(device)
# rgb_image

### Inference and Decode

In [16]:
with torch.no_grad():
    
    pred_tokens, _ = model(images,                   
                   seq=None,
                   use_beam_search=True,
                   max_len=config.model.beam_len,
                   eos_idx=config.model.eos_idx,
                   beam_size=config.model.beam_size,
                   out_size=1,
                   return_probs=False,
                  )
    caption = text_field.decode(pred_tokens, join_words=True)[0]                                                
    print(caption)

two sheep standing next to a fence in the grass
