### Inputs

(**Please modify it accordingly**)

In [1]:
img_path = 'COCO_val2014_000000000772.jpg'
vocab_path = 'path_to_annotations/vocab.json'
checkpoint = "path_to_checkpoint/grit_checkpoint_4ds.pth"

### Intialize a Hydra Config

In [2]:
import sys
sys.path.append("..")

import os
from omegaconf import OmegaConf
from hydra.core.global_hydra import GlobalHydra
from hydra import initialize, initialize_config_module, initialize_config_dir, compose

# initialize hydra config
GlobalHydra.instance().clear()
initialize(config_path="../configs/caption", version_base=None)
config = compose(config_name='coco_config.yaml', overrides=[f"exp.checkpoint={checkpoint}"])

In [4]:
import torch

# model
from models.common.attention import MemoryAttention
from models.caption.detector import build_detector
from models.caption import Transformer, GridFeatureNetwork, CaptionGenerator

# dataset
from PIL import Image
from datasets.caption.field import TextField
from datasets.caption.transforms import get_transform
from engine.utils import nested_tensor_from_tensor_list

device = torch.device(f"cuda:0")

### Build a model

In [5]:
detector = build_detector(config).to(device)

grit_net = GridFeatureNetwork(
    pad_idx=config.model.pad_idx,
    d_in=config.model.grid_feat_dim,
    dropout=config.model.dropout,
    attn_dropout=config.model.attn_dropout,
    attention_module=MemoryAttention,
    **config.model.grit_net,
)
cap_generator = CaptionGenerator(
    vocab_size=config.model.vocab_size,
    max_len=config.model.max_len,
    pad_idx=config.model.pad_idx,
    cfg=config.model.cap_generator,
    dropout=config.model.dropout,
    attn_dropout=config.model.attn_dropout,
    **config.model.cap_generator,
)

model = Transformer(
    grit_net,
    cap_generator,
    detector=detector,
    use_gri_feat=config.model.use_gri_feat,
    use_reg_feat=config.model.use_reg_feat,
    config=config,
)
model = model.to(device)

# load checkpoint
if os.path.exists(config.exp.checkpoint):
    checkpoint = torch.load(config.exp.checkpoint, map_location='cpu')
    missing, unexpected = model.load_state_dict(checkpoint['state_dict'], strict=False)
    print("model missing:", len(missing))
    print("model unexpected:", len(unexpected))
    
model.cached_features = False

# prepare utils
transform = get_transform(config.dataset.transform_cfg)['valid']
text_field = TextField(vocab_path=vocab_path)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


model missing: 0
model unexpected: 0


### Load and Transform An Image

In [6]:
rgb_image = Image.open(img_path).convert('RGB')
image = transform(rgb_image)
images = nested_tensor_from_tensor_list([image]).to(device)
# rgb_image

### Inference and Decode

In [11]:
with torch.no_grad():
    
    out, _ = model(images,                   
                   seq=None,
                   use_beam_search=True,
                   max_len=config.model.beam_len,
                   eos_idx=config.model.eos_idx,
                   beam_size=config.model.beam_size,
                   out_size=1,
                   return_probs=False,
                  )
    caption = text_field.decode(out, join_words=True)[0]
    print(caption)

three sheep standing in the grass near a fence
