Skip to content

Commit

Permalink
Training scannet, script
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Oct 11, 2019
1 parent f166209 commit 4085407
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 105 deletions.
9 changes: 4 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ def add_argument_group(name):
'--threads', type=int, default=1, help='num threads for train/test dataloader')
data_arg.add_argument('--val_threads', type=int, default=1, help='num threads for val dataloader')
data_arg.add_argument('--ignore_label', type=int, default=255)
data_arg.add_argument('--train_elastic_distortion', type=str2bool, default=True)
data_arg.add_argument('--test_elastic_distortion', type=str2bool, default=False)
data_arg.add_argument('--return_transformation', type=str2bool, default=False)
data_arg.add_argument('--ignore_duplicate_class', type=str2bool, default=False)
data_arg.add_argument('--partial_crop', type=float, default=0.)
Expand Down Expand Up @@ -198,15 +196,16 @@ def add_argument_group(name):
data_aug_arg.add_argument('--normalize_color', type=str2bool, default=True)
data_aug_arg.add_argument('--data_aug_scale_min', type=float, default=0.9)
data_aug_arg.add_argument('--data_aug_scale_max', type=float, default=1.1)
data_aug_arg.add_argument(
'--data_aug_hue_max', type=float, default=0.5, help='Hue translation range. [0, 1]')
data_aug_arg.add_argument(
'--data_aug_saturation_max', type=float, default=0.20, help='Saturation translation range, [0, 1]')

# Test
test_arg = add_argument_group('Test')
test_arg.add_argument('--visualize', type=str2bool, default=False)
test_arg.add_argument('--test_temporal_average', type=str2bool, default=False)
test_arg.add_argument('--visualize_path', type=str, default='outputs/visualize')
test_arg.add_argument('--test_rotation', type=int, default=-1)
test_arg.add_argument('--test_rotation_save', type=str2bool, default=False)
test_arg.add_argument('--test_rotation_save_dir', type=str, default='outputs/rotation_fulleval')
test_arg.add_argument('--save_prediction', type=str2bool, default=False)
test_arg.add_argument('--save_pred_dir', type=str, default='outputs/pred')
test_arg.add_argument('--test_phase', type=str, default='test', help='Dataset for test')
Expand Down
50 changes: 14 additions & 36 deletions lib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class DictDataset(Dataset, ABC):

def __init__(self,
data_paths,
prevoxel_transform=None,
input_transform=None,
target_transform=None,
cache=False,
Expand All @@ -91,6 +92,8 @@ def __init__(self,

self.data_root = data_root
self.data_paths = sorted(data_paths)

self.prevoxel_transform = prevoxel_transform
self.input_transform = input_transform
self.target_transform = target_transform

Expand Down Expand Up @@ -141,28 +144,27 @@ class VoxelizationDatasetBase(DictDataset, ABC):

def __init__(self,
data_paths,
prevoxel_transform=None,
input_transform=None,
target_transform=None,
cache=False,
data_root='/',
explicit_rotation=-1,
ignore_mask=255,
return_transformation=False,
**kwargs):
"""
ignore_mask: label value for ignore class. It will not be used as a class in the loss or evaluation.
explicit_rotation: # of discretization of 360 degree. # data would be num_data * explicit_rotation
"""
DictDataset.__init__(
self,
data_paths,
prevoxel_transform=prevoxel_transform,
input_transform=input_transform,
target_transform=target_transform,
cache=cache,
data_root=data_root)

self.ignore_mask = ignore_mask
self.explicit_rotation = explicit_rotation
self.return_transformation = return_transformation

def __getitem__(self, index):
Expand All @@ -174,8 +176,6 @@ def load_ply(self, index):

def __len__(self):
num_data = len(self.data_paths)
if self.explicit_rotation > 1:
return num_data * self.explicit_rotation
return num_data


Expand All @@ -202,7 +202,6 @@ def __init__(self,
input_transform=None,
target_transform=None,
data_root='/',
explicit_rotation=-1,
ignore_label=255,
return_transformation=False,
augment_data=False,
Expand All @@ -214,11 +213,11 @@ def __init__(self,
VoxelizationDatasetBase.__init__(
self,
data_paths,
prevoxel_transform=prevoxel_transform,
input_transform=input_transform,
target_transform=target_transform,
cache=cache,
data_root=data_root,
explicit_rotation=config.test_rotation,
ignore_mask=ignore_label,
return_transformation=return_transformation)

Expand Down Expand Up @@ -250,13 +249,6 @@ def convert_mat2cfl(self, mat):
return mat[:, :3], mat[:, 3:-1], mat[:, -1]

def __getitem__(self, index):
if self.explicit_rotation > 1:
rotation_space = np.linspace(-np.pi, np.pi, self.explicit_rotation + 1)
rotation_angle = rotation_space[index % self.explicit_rotation]
index //= self.explicit_rotation
else:
rotation_angle = None

pointcloud, center = self.load_ply(index)

# Downsample the pointcloud with finer voxel size before transformation for memory and speed
Expand All @@ -269,19 +261,8 @@ def __getitem__(self, index):
pointcloud = self.prevoxel_transform(pointcloud)

coords, feats, labels = self.convert_mat2cfl(pointcloud)
outs = self.voxelizer.voxelize(
coords,
feats,
labels,
center=center,
rotation_angle=rotation_angle,
return_transformation=self.return_transformation)

if self.return_transformation:
coords, feats, labels, transformation = outs
transformation = np.expand_dims(transformation, 0)
else:
coords, feats, labels = outs
coords, feats, labels, transformation = self.voxelizer.voxelize(
coords, feats, labels, center=center)

# map labels not used for evaluation to ignore_label
if self.input_transform is not None:
Expand All @@ -296,12 +277,6 @@ def __getitem__(self, index):
return_args.extend([pointcloud.astype(np.float32), transformation.astype(np.float32)])
return tuple(return_args)

def __len__(self):
num_data = sum(self.numels)
if self.explicit_rotation > 1:
return num_data * self.explicit_rotation
return num_data


class TemporalVoxelizationDataset(VoxelizationDataset):

Expand All @@ -313,7 +288,6 @@ def __init__(self,
input_transform=None,
target_transform=None,
data_root='/',
explicit_rotation=-1,
ignore_label=255,
temporal_dilation=1,
temporal_numseq=3,
Expand All @@ -322,8 +296,8 @@ def __init__(self,
config=None,
**kwargs):
VoxelizationDataset.__init__(self, data_paths, input_transform, target_transform, data_root,
explicit_rotation, ignore_label, return_transformation,
augment_data, config, **kwargs)
ignore_label, return_transformation, augment_data, config,
**kwargs)
self.temporal_dilation = temporal_dilation
self.temporal_numseq = temporal_numseq
temporal_window = temporal_dilation * (temporal_numseq - 1) + 1
Expand Down Expand Up @@ -406,6 +380,10 @@ def __getitem__(self, index):
return_args.extend([pointclouds.astype(np.float32), transformations.astype(np.float32)])
return tuple(return_args)

def __len__(self):
num_data = sum(self.numels)
return num_data


def initialize_data_loader(DatasetClass,
config,
Expand Down
1 change: 1 addition & 0 deletions lib/datasets/scannet.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(self,
super().__init__(
data_paths,
data_root=data_root,
prevoxel_transform=prevoxel_transform,
input_transform=input_transform,
target_transform=target_transform,
ignore_label=config.ignore_label,
Expand Down
4 changes: 1 addition & 3 deletions lib/datasets/stanford.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from lib.utils import read_txt, fast_hist, per_class_iu
from lib.dataset import VoxelizationDataset, DatasetPhase, str2datasetphase_type
import lib.transforms as t
from lib.datasets.preprocessing.stanford_3d import Stanford3DDatasetConverter


class StanfordVoxelizationDatasetBase:
Expand Down Expand Up @@ -69,8 +68,7 @@ def test_pointcloud(self, pred_dir):
ious = []
print('Per class IoU:')
for i, iou in enumerate(per_class_iu(hist) * 100):
unmasked_idx = self.label2masked.tolist().index(i)
result_str = f'\t{Stanford3DDatasetConverter.CLASSES[unmasked_idx]}:\t'
result_str = ''
if hist.sum(1)[i]:
result_str += f'{iou}'
ious.append(iou)
Expand Down
43 changes: 2 additions & 41 deletions lib/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
data_iter = data_loader.__iter__()
max_iter = len(data_loader)
max_iter_unique = max_iter
if config.test_rotation > 1:
if config.test_rotation_save:
logging.info('Saving rotation pointcloud prediction at ' + config.test_rotation_save_dir)
os.makedirs(config.test_rotation_save_dir, exist_ok=True)
max_iter_unique //= config.test_rotation

# Fix batch normalization running mean and std
model.eval()
Expand Down Expand Up @@ -106,7 +101,7 @@ def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
iter_timer.tic()

if config.wrapper_type != 'None':
color = input[:, :3].int()
color = input[:, :3].int()
if config.normalize_color:
input[:, :3] = input[:, :3] / 255. - 0.5
sinput = SparseTensor(input, coords).to(device)
Expand All @@ -119,45 +114,13 @@ def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
pred = get_prediction(dataset, output, target).int()
iter_time = iter_timer.toc(False)

# Get mapping between input and output space
if np.prod(np.array(model.OUT_PIXEL_DIST)) > 1:
permutation = model.get_permutation(model.OUT_PIXEL_DIST, 1).long()
upsampled_pred = pred[permutation].cpu().numpy()
else:
upsampled_pred = pred.cpu().numpy()

if config.save_prediction or config.test_original_pointcloud:
save_predictions(coords, upsampled_pred, transformation, dataset, config, iteration,
save_pred_dir)

# Visualize prediction
if config.visualize:
# Do not save all predictions in rotation-augmented test.
if config.test_rotation < 1 or iteration % config.test_rotation == 0:
visualize_results(coords, input, target, upsampled_pred, config, iteration)
save_predictions(coords, pred, transformation, dataset, config, iteration, save_pred_dir)

if has_gt:
if config.eval_upsample:
# Upscale the target and predication to the original voxel space
output = output[permutation]
pred = get_prediction(dataset, output, target).int()

if config.evaluate_original_pointcloud:
output, pred, target = permute_pointcloud(coords, pointcloud, transformation,
dataset.label_map, output, pred)
if config.test_rotation > 1:
if iteration % config.test_rotation == 0:
output_rotation = output
else:
output_rotation += output
if iteration % config.test_rotation != config.test_rotation - 1:
continue
iteration //= config.test_rotation
output = output_rotation
pred = get_prediction(dataset, output, target).int()
if config.test_rotation_save:
save_rotation_pred(iteration,
pred.cpu().numpy(), dataset, config.test_rotation_save_dir)

target_np = target.numpy()

Expand Down Expand Up @@ -221,8 +184,6 @@ def test(model, data_loader, config, transform_data_fn=None, has_gt=True):
if config.test_original_pointcloud:
logging.info('===> Start testing on original pointcloud space.')
dataset.test_pointcloud(save_pred_dir)
if not config.save_prediction:
shutil.rmtree(save_pred_dir)

logging.info("Finished test. Elapsed time: {:.4f}".format(global_time))

Expand Down
4 changes: 2 additions & 2 deletions lib/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ def elastic_distortion(self, pointcloud, granularity, magnitude):
return pointcloud

def __call__(self, pointcloud):
if self.distortion_param is not None:
if self.distortion_params is not None:
if random.random() < 0.95:
for granularity, magnitude in self.distortion_param:
for granularity, magnitude in self.distortion_params:
pointcloud = self.elastic_distortion(pointcloud, granularity, magnitude)
return pointcloud

Expand Down
5 changes: 1 addition & 4 deletions lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ def wrapper(*args, **kwargs):


def get_prediction(dataset, output, target):
if dataset.NEED_PRED_POSTPROCESSING:
return dataset.get_prediction(output, target)
else:
return output.max(1)[1]
return output.max(1)[1]


def count_parameters(model):
Expand Down
14 changes: 3 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,9 @@ def main():
if not config.return_transformation:
raise ValueError('Pointcloud evaluation requires config.return_transformation=true.')

if config.test_rotation > 1:
if config.is_train:
raise ValueError('Rotation evaluation should not be used for training.')
if not (config.return_transformation and config.evaluate_original_pointcloud):
raise ValueError('Rotation evaluation requires config.evaluate_original_pointcloud=true and '
'config.return_transformation=true.')
if config.test_original_pointcloud:
raise ValueError('Cannot run rotation evaluation and KD-tree evaluation together.')
if (config.return_transformation ^ config.evaluate_original_pointcloud):
raise ValueError('Rotation evaluation requires config.evaluate_original_pointcloud=true and '
'config.return_transformation=true.')

logging.info('===> Initializing dataloader')
if config.is_train:
Expand All @@ -75,7 +70,6 @@ def main():
phase=config.train_phase,
threads=config.threads,
augment_data=True,
elastic_distortion=config.train_elastic_distortion,
shuffle=True,
repeat=True,
batch_size=config.batch_size,
Expand All @@ -87,7 +81,6 @@ def main():
threads=config.val_threads,
phase=config.val_phase,
augment_data=False,
elastic_distortion=config.test_elastic_distortion,
shuffle=True,
repeat=False,
batch_size=config.val_batch_size,
Expand All @@ -105,7 +98,6 @@ def main():
threads=config.threads,
phase=config.test_phase,
augment_data=False,
elastic_distortion=config.test_elastic_distortion,
shuffle=False,
repeat=False,
batch_size=config.test_batch_size,
Expand Down
12 changes: 9 additions & 3 deletions models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,15 @@ def space_n_time_m(n, m):

def weight_initialization(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if isinstance(m, ME.MinkowskiConvolution):
ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')

if isinstance(m, ME.MinkowskiConvolutionTranspose):
ME.utils.kaiming_normal_(m.kernel, mode='fan_in', nonlinearity='relu')

if isinstance(m, ME.MinkowskiBatchNorm):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)

def _make_layer(self,
block,
Expand Down
Loading

0 comments on commit 4085407

Please sign in to comment.