# Imports

In [None]:
%matplotlib inline
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debugging GPU stuff
import time, random

import torch
import torch.optim as optim
from tensorboardX import SummaryWriter
import numpy as np

from importlib import reload

In [None]:
# Config file
cfg_file = '/home/chrisxie/local_installations/PointGroup/config/pointgroup_TOD.yaml'
from util.config import get_parser_notebook
get_parser_notebook(cfg_file=cfg_file, pretrain_path=None)

from util.config import cfg
from util.log import logger
import util.utils as utils

# Training function definitions

In [None]:
def init():
    # copy important files to backup
    backup_dir = os.path.join(cfg.exp_path, 'backup_files')
    os.makedirs(backup_dir, exist_ok=True)
    os.system('cp train.py {}'.format(backup_dir))
    os.system('cp {} {}'.format(cfg.model_dir, backup_dir))
    os.system('cp {} {}'.format(cfg.dataset_dir, backup_dir))
    os.system('cp {} {}'.format(cfg.config, backup_dir))

    # log the config
    logger.info(cfg)

    # summary writer
    global writer
    writer = SummaryWriter(cfg.exp_path)

    # random seed
    random.seed(cfg.manual_seed)
    np.random.seed(cfg.manual_seed)
    torch.manual_seed(cfg.manual_seed)
    torch.cuda.manual_seed_all(cfg.manual_seed)

In [None]:
def train_epoch(train_loader, model, model_fn, optimizer, epoch):
    iter_time = utils.AverageMeter()
    data_time = utils.AverageMeter()
    am_dict = {}

    current_iter = (epoch - 1) * len(train_loader)
    
    model.train()
    start_epoch = time.time()
    end = time.time()
    for i, batch in enumerate(train_loader):
        data_time.update(time.time() - end)
        torch.cuda.empty_cache()

        ##### Debug
#         print(batch['v2p_map'].shape,
#               batch['v2p_map'].numpy().size,
#               batch['spatial_shape'],
#               batch['id'])
        
        if current_iter >= cfg.max_iters:
            break
        
        ##### adjust learning rate
        utils.step_learning_rate(optimizer, cfg.lr, epoch - 1, cfg.step_epoch, cfg.multiplier)

        ##### prepare input and forward
        loss, _, visual_dict, meter_dict = model_fn(batch, model, current_iter)

        ##### meter_dict
        for k, v in meter_dict.items():
            if k not in am_dict.keys():
                am_dict[k] = utils.AverageMeter()
            am_dict[k].update(v[0], v[1])

        ##### backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ##### time and print
        current_iter += 1
        remain_iter = cfg.max_iters - current_iter

        iter_time.update(time.time() - end)
        end = time.time()

        remain_time = remain_iter * iter_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))

        if current_iter % cfg.i_log == 0:
            sys.stdout.write(
                "epoch: {}/{} iter: {}/{} loss: {:.4f}({:.4f}) data_time: {:.2f}({:.2f}) iter_time: {:.2f}({:.2f}) remain_time: {remain_time}\n".format
                (epoch, cfg.max_epochs, i + 1, len(train_loader), am_dict['total_loss'].val, am_dict['total_loss'].avg,
                 data_time.val, data_time.avg, iter_time.val, iter_time.avg, remain_time=remain_time))
            
        if current_iter % cfg.i_log == 0:
            for k in am_dict.keys():
                if k in visual_dict.keys():
                    writer.add_scalar('Train_Loss/'+k, am_dict[k].avg, current_iter)
            
    logger.info("epoch: {}/{}, train loss: {:.4f}, time: {}s".format(epoch, cfg.max_epochs, am_dict['total_loss'].avg, time.time() - start_epoch))

    utils.checkpoint_save(model, cfg.exp_path, cfg.config.split('/')[-1][:-5], epoch, cfg.save_freq, use_cuda)
    

# Training script

In [None]:
##### init
init()

In [None]:
##### get model version and data version
exp_name = cfg.config.split('/')[-1][:-5]
print(exp_name)
model_name = exp_name.split('_')[0]
print(model_name)
data_name = exp_name.split('_')[-1]
print(data_name)

In [None]:
##### model
logger.info('=> creating model ...')
if model_name == 'pointgroup':
    from model.pointgroup.pointgroup import PointGroup as Network
    from model.pointgroup.pointgroup import model_fn_decorator
else:
    print("Error: no model - " + model_name)
    exit(0)
model = Network(cfg)

use_cuda = torch.cuda.is_available()
logger.info('cuda available: {}'.format(use_cuda))
assert use_cuda
model = model.cuda()

# logger.info(model)
logger.info('#classifier parameters: {}'.format(sum([x.nelement() for x in model.parameters()])))

In [None]:
##### optimizer
logger.info('=> creating optimizer ...')
if cfg.optim == 'Adam':
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr)
elif cfg.optim == 'SGD':
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)

In [None]:
##### model_fn (criterion)
model_fn = model_fn_decorator()

In [None]:
##### dataset
if cfg.dataset == 'TOD':
    if data_name == 'TOD':
        import data.TOD
        data.TOD = reload(data.TOD)
        dataset = data.TOD.Dataset()
        dataset.trainLoader()
    else:
        print("Error: no data loader - " + data_name)
        exit(0)

In [None]:
##### resume
start_epoch = utils.checkpoint_restore(model,
                                       cfg.exp_path,
                                       cfg.config.split('/')[-1][:-5],
                                       use_cuda)
# resume from the latest epoch, or specify the epoch to restore
print(f'Start epoch: {start_epoch}')

In [None]:
##### train and val
for epoch in range(start_epoch, cfg.max_epochs + 1):
    train_epoch(dataset.train_data_loader, model, model_fn, optimizer, epoch)

# Random Testing

In [None]:
for i in np.random.permutation(len(dataset.label_filenames))[:100]:
    temp = dataset.train_collate_fn([32])
    print(i, temp['v2p_map'].shape, temp['v2p_map'].numpy().size)

In [None]:
temp = dataset.train_collate_fn([178780])
print(temp['v2p_map'].shape,
      temp['v2p_map'].numpy().size,
      temp['spatial_shape'],
     )

In [None]:
temp.keys()

In [None]:
temp['v2p_map'][0]

In [None]:
temp['v2p_map'].shape

In [None]:
temp['p2v_map'].max()

In [None]:
print(temp['v2p_map'].shape, temp['v2p_map'].numpy().size)

In [None]:
temp['locs_float'].reshape(480,640,3)[10:13, 100:106,2]

In [None]:
temp_ = temp['locs_float'].reshape(480,640,3)[...,2].numpy().copy()
np.where(temp_ == temp_.max())

In [None]:
np.unique(temp['labels'].numpy())

In [None]:
np.unique(temp['instance_labels'].numpy())

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(1, figsize=(15,5))

plt.subplot(1,3,1)
temp_ = temp['labels'].numpy().copy().reshape(480, 640)
# temp_[temp_ == -100] = 0
plt.imshow(temp_)
plt.title('Seg Labels')

plt.subplot(1,3,2)
temp_ = temp['instance_labels'].numpy().copy().reshape(480, 640)
temp_ = temp_ + 1
temp_[temp_ == -100+1] = 0
plt.imshow(temp_)
plt.title('Instance Labels')

plt.subplot(1,3,3)
temp_ = temp['feats'].numpy().copy().reshape(480, 640, 3)
temp_ = temp_ * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
plt.imshow(temp_)
plt.title('RGB')

In [None]:
temp['locs'].shape

In [None]:
temp['voxel_locs'].shape

In [None]:
temp['locs_float'].shape

In [None]:
temp['feats'].shape

In [None]:
torch.is_tensor(temp['feats'])

In [None]:
for key in temp:
    print(key, temp[key].shape if torch.is_tensor(temp[key]) else type(temp[key]))