In [None]:
import torch, detectron2
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)

In [None]:
import os
import sys
import logging
import argparse
os.environ["DATASET"] = "../datasets"

pth = '/'.join(sys.path[0].split('/')[:-1])
sys.path.insert(0, pth)

from pprint import pprint
import numpy as np
np.random.seed(1)

home_dir = os.path.abspath(os.getcwd()+"/../")
sys.path.append(home_dir)
print(home_dir)

import warnings
warnings.filterwarnings(action='ignore')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

from hdecoder.BaseModel import BaseModel
from hdecoder import build_model
from utils.distributed import init_distributed
from utils.arguments import load_opt_from_config_files, load_config_dict_to_opt

In [None]:
parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.')
parser.add_argument('--command', default="evaluate", help='Command: train/evaluate/train-and-evaluate')
parser.add_argument('--conf_files', nargs='+', help='Path(s) to the config file(s).')
parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.')
parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER)

cmdline_args = parser.parse_args('')
# cmdline_args.conf_files = [os.path.join(home_dir, "configs/xdecoder/svlp_focalt_lang.yaml")]
cmdline_args.conf_files = [os.path.join(home_dir, "configs/hdecoder/vcoco.yaml")]
cmdline_args.overrides = ['WEIGHT', 'true', 'RESUME_FROM', '../checkpoints/xdecoder_focalt_best_openseg.pt'] 

opt = load_opt_from_config_files(cmdline_args.conf_files)
opt["base_path"] = "../"

keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]
vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]
print(keys, vals)
vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]
types = []
for key in keys:
    key = key.split('.')
    ele = opt.copy()
    while len(key) > 0:
        ele = ele[key.pop(0)]
    types.append(type(ele))

config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}

load_config_dict_to_opt(opt, config_dict)
for key, val in cmdline_args.__dict__.items():
    if val is not None:
        opt[key] = val
opt = init_distributed(opt)

In [None]:
pretrained_pth = os.path.join(opt['RESUME_FROM'])
print(pretrained_pth)
opt["FP16"] = True

In [None]:
model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()

In [None]:
from datasets import build_evaluator, build_eval_dataloader
dataloaders = build_eval_dataloader(opt)
dataset_names = opt["DATASETS"]["TEST"]
opt['GRADIENT_ACCUMULATE_STEP'] = int(opt.get('GRADIENT_ACCUMULATE_STEP', 1))
opt['LR_SCHEDULER_PARAMS'] = opt.get('LR_SCHEDULER_PARAMS', {})

In [None]:
def get_dataloaders(
    dataset_label: str,
    is_evaluation: bool
):
    if is_evaluation:

        dataloaders = build_eval_dataloader(opt)
        dataloader = dataloaders[0]

        # temp solution for lr scheduler
        steps_total = len(dataloader)
        steps_acc = opt['GRADIENT_ACCUMULATE_STEP']
        steps_update = steps_total // steps_acc
        opt["LR_SCHEDULER_PARAMS"]["steps_update_per_epoch"] = steps_update
    return dataloader

In [None]:
from trainer import HDecoder_Trainer as Trainer
if cmdline_args.user_dir:
    absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
    opt['base_path'] = absolute_user_dir
trainer = Trainer(opt)
trainer.eval()