# Prepare proxy ground truth for training data

In [None]:
# Estiamte the size of excavator for traning set

import yaml
from easydict import EasyDict
from pathlib import Path
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from core.model import PoseEstimator
from core.dataset import build_dataloader, RealDataset
from utils.utils import setup_logger, filtered_mean, serialize_data
from utils.train_utils import load_checkpoint
from utils.eval_utils import estimate_size


def main(cfg_path, data_path, ckpt, out_path):

    with open(cfg_path, 'r') as f:
        try:
            yaml_config = yaml.safe_load(f, Loader=yaml.FullLoader)
        except:
            yaml_config = yaml.safe_load(f)
    config = EasyDict(yaml_config)

    logger = setup_logger()
    dataset = RealDataset('train', root_path=data_path, need_label=False, logger=logger)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2, collate_fn=dataset.collate_batch)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = PoseEstimator(in_channels=3, cfg=config.model, device=device).to(device)

    checkpoint = load_checkpoint(ckpt)
    model.load_state_dict(checkpoint['state_dict'])

    logger.info(f'Start running model: {config.model_name}')
    size_dict, all_dict = estimate_size(model, dataloader, device, logger)
    with open(out_path / 'estimated_size_train.json', 'w') as f:
        json.dump(serialize_data(size_dict), f)
        
    logger.info(f'Save finished')
    return all_dict


data_path = Path('data/realsite')
out_path = data_path
cfg_path = 'cfg/model.yaml'
ckpt = 'output/base/ckpt/best_model.pth.tar'

all_dict = main(cfg_path, data_path, ckpt, out_path)


In [None]:
# Estimate the labels (rotation, segmentation) for the traning set

import yaml
from easydict import EasyDict
from pathlib import Path
import json
import os
import torch
from torch.utils.data import DataLoader
from core.model import PoseEstimator
from core.dataset import RealDataset
from utils.utils import setup_logger, serialize_data
from utils.train_utils import load_checkpoint
from utils.eval_utils import estimate_labels


def estiamte_and_save_labels(cfg_path, data_path, ckpt, out_path):

    with open(cfg_path, 'r') as f:
        try:
            yaml_config = yaml.safe_load(f, Loader=yaml.FullLoader)
        except:
            yaml_config = yaml.safe_load(f)
    config = EasyDict(yaml_config)

    if not out_path.exists():
        os.makedirs(out_path, exist_ok=True)

    logger = setup_logger()
    dataset = RealDataset('train', root_path=data_path, need_label=False, logger=logger)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2, collate_fn=dataset.collate_batch)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = PoseEstimator(in_channels=3, cfg=config.model, device=device).to(device)

    checkpoint = load_checkpoint(ckpt)
    model.load_state_dict(checkpoint['state_dict'])

    logger.info(f'Start running model: {config.model_name}')

    label_dict = estimate_labels(model, dataloader, device, logger)
    for token, label in label_dict.items():
        filename = out_path / f'{token}.json'
        with open(filename, 'w') as f:
            json.dump(serialize_data(label), f)

    logger.info(f'Save finished')
    return label_dict


data_path = Path('data/realsite')
out_path = data_path / 'train_labels'
cfg_path = 'cfg/model.yaml'
ckpt = 'output/base/ckpt/best_model.pth.tar'

label_dict = estiamte_and_save_labels(cfg_path, data_path, ckpt, out_path)