-
Notifications
You must be signed in to change notification settings - Fork 111
/
iphone.py
70 lines (62 loc) 路 2.71 KB
/
iphone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import PIL
import imageio
from easydict import EasyDict as edict
import json
import pickle
from . import base
import camera
from util import log,debug
class Dataset(base.Dataset):
def __init__(self,opt,split="train",subset=None):
self.raw_H,self.raw_W = 1080,1920
super().__init__(opt,split)
self.root = opt.data.root or "data/iphone"
self.path = "{}/{}".format(self.root,opt.data.scene)
self.path_image = "{}/images".format(self.path)
self.list = sorted(os.listdir(self.path_image),key=lambda f: int(f.split(".")[0]))
# manually split train/val subsets
num_val_split = int(len(self)*opt.data.val_ratio)
self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:]
if subset: self.list = self.list[:subset]
# preload dataset
if opt.data.preload:
self.images = self.preload_threading(opt,self.get_image)
self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")
def prefetch_all_data(self,opt):
assert(not opt.data.augment)
# pre-iterate through all samples and group together
self.all = torch.utils.data._utils.collate.default_collate([s for s in self])
def get_all_camera_poses(self,opt):
# poses are unknown, so just return some dummy poses (identity transform)
return camera.pose(t=torch.zeros(len(self),3))
def __getitem__(self,idx):
opt = self.opt
sample = dict(idx=idx)
aug = self.generate_augmentation(opt) if self.augment else None
image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
image = self.preprocess_image(opt,image,aug=aug)
intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
sample.update(
image=image,
intr=intr,
pose=pose,
)
return sample
def get_image(self,opt,idx):
image_fname = "{}/{}".format(self.path_image,self.list[idx])
image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
return image
def get_camera(self,opt,idx):
self.focal = self.raw_W*4.2/(12.8/2.55)
intr = torch.tensor([[self.focal,0,self.raw_W/2],
[0,self.focal,self.raw_H/2],
[0,0,1]]).float()
pose = camera.pose(t=torch.zeros(3)) # dummy pose, won't be used
return intr,pose