Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
199 lines (142 sloc) 6.56 KB
from utils.common import AverageMeter
import numpy as np
import time
import torch
from utils.common import quaternion_angular_error
# train function
def train(train_loader, model, criterion, optimizer, epoch, max_epoch, log_freq=1, print_sum=True,
poses_mean=None, poses_std=None, device=None, stereo=True):
# switch model to training
model.train()
losses = AverageMeter()
epoch_time = time.time()
gt_poses = np.empty((0, 7))
pred_poses = np.empty((0, 7))
end = time.time()
for idx, (batch_images, batch_poses) in enumerate(train_loader):
data_time = (time.time() - end)
if stereo:
batch_images = [x.to(device) for x in batch_images]
batch_poses = [x.to(device) for x in batch_poses]
else:
batch_images = batch_images.to(device)
batch_poses = batch_poses.to(device)
out = model(batch_images)
loss = criterion(out, batch_poses)
# print('loss = {}'.format(loss))
# Make an optimization step
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss.data[0], len(batch_images) * batch_images[0].size(0) if stereo
else batch_images.size(0))
# move data to cpu & numpy
if stereo:
bp = [x.detach().cpu().numpy() for x in batch_poses]
outp = [x.detach().cpu().numpy() for x in out]
gt_poses = np.vstack((gt_poses, *bp))
pred_poses = np.vstack((pred_poses, *outp))
else:
bp = batch_poses.detach().cpu().numpy()
outp = out.detach().cpu().numpy()
gt_poses = np.vstack((gt_poses, bp))
pred_poses = np.vstack((pred_poses, outp))
batch_time = (time.time() - end)
end = time.time()
if log_freq != 0 and idx % log_freq == 0:
print('Epoch: [{}/{}]\tBatch: [{}/{}]\t'
'Time: {batch_time:.3f}\t'
'Data Time: {data_time:.3f}\t'
'Loss: {losses.val:.3f}\t'
'Avg Loss: {losses.avg:.3f}\t'.format(
epoch, max_epoch - 1, idx, len(train_loader) - 1,
batch_time=batch_time, data_time=data_time, losses=losses))
# if idx == 0:
# break
# un-normalize translation
unnorm = (poses_mean is not None) and (poses_std is not None)
if unnorm:
gt_poses[:, :3] = gt_poses[:, :3] * poses_std + poses_mean
pred_poses[:, :3] = pred_poses[:, :3] * poses_std + poses_mean
t_loss = np.asarray([np.linalg.norm(p - t) for p, t in zip(pred_poses[:, :3], gt_poses[:, :3])])
q_loss = np.asarray([quaternion_angular_error(p, t) for p, t in zip(pred_poses[:, 3:], gt_poses[:, 3:])])
# if unnorm:
# print('poses_std = {:.3f}'.format(np.linalg.norm(poses_std)))
# print('T: median = {:.3f}, mean = {:.3f}'.format(np.median(t_loss), np.mean(t_loss)))
# print('R: median = {:.3f}, mean = {:.3f}'.format(np.median(q_loss), np.mean(q_loss)))
if print_sum:
print('Ep: [{}/{}]\tTrain Loss: {:.3f}\tTe: {:.3f}\tRe: {:.3f}\t Et: {:.2f}s\t\
{criterion_sx:.5f}:{criterion_sq:.5f}'.format(
epoch, max_epoch - 1, losses.avg, np.mean(t_loss), np.mean(q_loss),
(time.time() - epoch_time), criterion_sx=criterion.sx.data[0],
criterion_sq=criterion.sq.data[0]))
# return losses.avg
def validate(val_loader, model, criterion, epoch, log_freq=1, print_sum=True, device=None, stereo=True):
losses = AverageMeter()
# set model to evaluation
model.eval()
with torch.no_grad():
epoch_time = time.time()
end = time.time()
for idx, (batch_images, batch_poses) in enumerate(val_loader):
data_time = time.time() - end
if stereo:
batch_images = [x.to(device) for x in batch_images]
batch_poses = [x.to(device) for x in batch_poses]
else:
batch_images = batch_images.to(device)
batch_poses = batch_poses.to(device)
# compute model output
out = model(batch_images)
loss = criterion(out, batch_poses)
losses.update(loss.data[0], len(batch_images) * batch_images[0].size(0) if stereo
else batch_images.size(0))
batch_time = time.time() - end
end = time.time()
if log_freq != 0 and idx % log_freq == 0:
print('Val Epoch: {}\t'
'Time: {batch_time:.3f}\t'
'Data Time: {data_time:.3f}\t'
'Loss: {losses.val:.3f}\t'
'Avg Loss: {losses.avg:.3f}'.format(
epoch, batch_time=batch_time, data_time=data_time, losses=losses))
# if idx == 0:
# break
if print_sum:
print('Epoch: [{}]\tValidation Loss: {:.3f}\tEpoch time: {:.3f}'.format(epoch, losses.avg,
(time.time() - epoch_time)))
# return losses.avg
def model_results_pred_gt(model, dataloader, poses_mean=None, poses_std=None, device=None, stereo=True):
model.eval()
gt_poses = np.empty((0, 7))
pred_poses = np.empty((0, 7))
for idx, (batch_images, batch_poses) in enumerate(dataloader):
# TODO: Stereo=False mode support
if stereo:
batch_images = [x.to(device) for x in batch_images]
batch_poses = [x.to(device) for x in batch_poses]
else:
batch_images = batch_images.to(device)
batch_poses = batch_poses.to(device)
out = model(batch_images)
# loss = criterion(out, batch_poses)
# print('loss = {}'.format(loss))
# move data to cpu & numpy
if stereo:
batch_poses = [x.detach().cpu().numpy() for x in batch_poses]
out = [x.detach().cpu().numpy() for x in out]
gt_poses = np.vstack((gt_poses, *batch_poses))
pred_poses = np.vstack((pred_poses, *out))
else:
bp = batch_poses.detach().cpu().numpy()
outp = out.detach().cpu().numpy()
gt_poses = np.vstack((gt_poses, bp))
pred_poses = np.vstack((pred_poses, outp))
# if idx == 0:
# break
# un-normalize translation
unnorm = (poses_mean is not None) and (poses_std is not None)
if unnorm:
gt_poses[:, :3] = gt_poses[:, :3] * poses_std + poses_mean
pred_poses[:, :3] = pred_poses[:, :3] * poses_std + poses_mean
return pred_poses, gt_poses
You can’t perform that action at this time.