In [7]:
import sys
import os
sys.path.append('./dof')
import torch
from dof.datasets_lmdb import LMDBDataLoaderAugmenter

In [9]:
import os

import numpy as np
import torch
from easydict import EasyDict
from torch.nn import MSELoss


class Config(EasyDict):
    def __init__(self, args):
        # workspace configuration
        self.prefix = args.prefix
        self.work_path = os.path.join(args.workspace, self.prefix)
        self.model_path = os.path.join(self.work_path, "models")
        try:
            self.create_path(self.model_path)
        except Exception as e:
            print(e)

        self.log_path = os.path.join(self.work_path, "log")
        try:
            self.create_path(self.log_path)
        except Exception as e:
            print(e)

        self.frequency_log = 20

        # training/validation configuration
        self.train_source = args.train_source
        self.val_source = args.val_source

        # network and training parameters
        self.pose_loss = MSELoss(reduction="sum")
        self.pose_mean = np.load(args.pose_mean)
        self.pose_stddev = np.load(args.pose_stddev)
        self.depth = args.depth
        self.lr = args.lr
        self.lr_plateau = args.lr_plateau
        self.early_stop = args.early_stop
        self.batch_size = args.batch_size
        self.workers = args.workers
        self.epochs = args.epochs
        self.min_size = args.min_size
        self.max_size = args.max_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.weight_decay = 5e-4
        self.momentum = 0.9
        self.pin_memory = True

        # resume from or load pretrained weights
        self.pretrained_path = args.pretrained_path
        self.resume_path = args.resume_path

        # online augmentation
        self.noise_augmentation = args.noise_augmentation
        self.contrast_augmentation = args.contrast_augmentation
        self.random_flip = args.random_flip
        self.random_crop = args.random_crop

        # 3d reference points to compute pose
        self.threed_5_points = args.threed_5_points
        self.threed_68_points = args.threed_68_points

        # distributed
        self.distributed = args.distributed
        if not args.distributed:
            self.gpu = 0
        else:
            self.gpu = args.gpu

        self.num_gpus = args.world_size

    def create_path(self, file_path):
        if not os.path.exists(file_path):
            os.makedirs(file_path)


In [40]:
def parse_args():
    parser = EasyDict()
    parser.min_size = '640'
    parser.max_size = '1400'
    parser.epochs = 100
    parser.batch_size=2
    parser.pose_mean='./dataset/wider_lmdb/WIDER_train_annotations_pose_mean.npy'
    parser.pose_stddev='./dataset/wider_lmdb/WIDER_train_annotations_pose_stddev.npy'
    parser.workspace='./workspace'
    parser.train_source='./dataset/wider_lmdb/WIDER_train_annotations.lmdb'
    parser.val_source='./dataset/wider_lmdb/WIDER_val_annotations.lmdb'
    parser.prefix='trial_1'
    parser.noise_augmentation=True
    parser.contrast_augmentation=True
    parser.random_flip=True
    parser.random_crop=True
    parser.world_size=1
    parser.dist_url = "env://"
    parser.distributed=False
    parser.threed_5_points='./dof/pose_references/reference_3d_5_points_trans.npy'
    parser.threed_68_points='./dof/pose_references/reference_3d_68_points_trans.npy'
    args = parser    
    args.min_size = [int(item) for item in args.min_size.split(",")]
    args.depth=18
    args.lr = 0.01
    args.lr_plateau = False
    args.early_stop = False
    args.workers = 4
    args.pretrained_path=False
    args.resume_path=False
    return args


In [41]:
args = parse_args()

In [42]:
config = Config(args)

In [43]:
lmdbloader = LMDBDataLoaderAugmenter(config, lmdb_path=config.train_source)

In [44]:
lmdbloader

<dof.datasets_lmdb.LMDBDataLoaderAugmenter at 0x7f97a03b3550>

In [45]:
for data, target in lmdbloader:
    print(1)
    break

1


In [48]:
data[1].shape

torch.Size([3, 682, 909])

In [49]:
data[1]

tensor([[[0.7412, 0.7412, 0.7412,  ..., 0.2510, 0.2706, 0.2627],
         [0.7451, 0.7451, 0.7451,  ..., 0.2235, 0.2235, 0.2314],
         [0.7451, 0.7490, 0.7490,  ..., 0.1922, 0.1804, 0.1882],
         ...,
         [0.2196, 0.2196, 0.2196,  ..., 0.5216, 0.5294, 0.5333],
         [0.2157, 0.2157, 0.2157,  ..., 0.5333, 0.5137, 0.5176],
         [0.2039, 0.2078, 0.2078,  ..., 0.5176, 0.5176, 0.5216]],

        [[0.8784, 0.8784, 0.8784,  ..., 0.3686, 0.3765, 0.3686],
         [0.8824, 0.8824, 0.8824,  ..., 0.3412, 0.3373, 0.3373],
         [0.8824, 0.8863, 0.8863,  ..., 0.3098, 0.2980, 0.2941],
         ...,
         [0.2549, 0.2549, 0.2549,  ..., 0.4118, 0.4196, 0.4235],
         [0.2588, 0.2588, 0.2588,  ..., 0.4235, 0.4039, 0.4078],
         [0.2471, 0.2510, 0.2510,  ..., 0.4078, 0.4078, 0.4118]],

        [[0.9020, 0.9020, 0.9020,  ..., 0.2353, 0.2549, 0.2471],
         [0.9059, 0.9059, 0.9059,  ..., 0.2078, 0.2157, 0.2157],
         [0.9059, 0.9098, 0.9098,  ..., 0.1765, 0.1647, 0.