In [8]:
#coding:utf8
from config_point import opt
import os
import torch as t
import models
from utils import box_util
from data_point.dataset import KittiPoint
from torch.utils.data import DataLoader
from torch.autograd import Variable
from utils.train_util import *
from utils.visualize import Visualizer
from utils.data_util import *
from utils.train_util import *
from tqdm import tqdm
from torchvision import transforms as T
from tensorboardX import SummaryWriter
import math
    
def train(**kwargs):
    opt.parse(kwargs)
    dev = ToDevice(opt) # init
    writer = SummaryWriter(opt.env+'/train')
    writer_val = SummaryWriter(opt.env+'/val')

    # step1: data
    train_data = KittiPoint(root=opt.root,sets_type='val',white_list=opt.white_list,
                               rotate_to_center=True, random_flip=True, random_shift=True, one_hot=True)
    val_data = KittiPoint(root=opt.root,sets_type='val',white_list=opt.white_list,
                          rotate_to_center=True, one_hot=True)
    train_dataloader = DataLoader(train_data,opt.batch_size,shuffle=True,num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,opt.batch_size_val,shuffle=False,num_workers=opt.num_workers)
    
    # step2: configure model
    model = getattr(models, opt.model)() # test_img_only!!! experiment_no_depth ; (num_out=2)
    model = dev.trans(model)[0]
    if opt.load_model_path:
        model.load(opt.load_model_path)
    
    # step3: criterion and optimizer
    crossEp = t.nn.CrossEntropyLoss()
    smoothL1 = t.nn.SmoothL1Loss() # !!! loss need to be determined match output & label
    optimizer = t.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
        
    # step4: meters；lm; lm_ry_cls,lm_ry_res,cm_ry； lm_size_cls,lm_size_res,cm_size; lm_loc
    previous_loss = 1e100
    lm = LossMeter(total=True,ry=True,size=True,loc=True)
    
    # train
    for epoch in range(opt.max_epoch):
        # init meter
        lm.reset(total=True,ry=True,size=True,loc=True) # init meter ; train_data)/opt.batch_size
        batch_num = len(train_dataloader)
        # each batch
        for ii,(in_pc,in_center,in_rot,in_vec, loc,ry_cls,ry_res,size_cls,size_res) \
                in tqdm(enumerate(train_dataloader),total=batch_num):
            # 记得改train的dataset 从val变回train
            [in_pc,in_center,in_rot,in_vec] = dev.trans(in_pc,in_center,in_rot,in_vec) # BCHW
            # 分类gt必须为1维  torch.Size([bsize]) 而不是torch.Size([bsize，1])
            [loc,ry_cls,ry_res,size_cls,size_res] = \
                dev.trans(loc,ry_cls.squeeze(1),ry_res,size_cls.squeeze(1),size_res)# label
            # train
            optimizer.zero_grad()
            out_loc,out_ry_cls,out_ry_res,out_size_cls,out_size_res = model(in_pc,in_vec)
            
            loss_loc = smoothL1(out_loc+in_center,loc)
            loss_ry_cls = crossEp(out_ry_cls,ry_cls) # output(BxNum)不需要加softmax, label(B)不需要为onehot
            loss_ry_res = smoothL1(out_ry_res,ry_res)
            loss_size_cls = crossEp(out_size_cls,size_cls)
            loss_size_res = smoothL1(out_size_res,size_res)
            loss = loss_loc + loss_ry_cls + loss_ry_res + loss_size_cls + loss_size_res 
            loss.backward()
            optimizer.step()

            # meters update and visualize; lm = LossMeter
            lm.lm_add(loss, loss_ry_cls,loss_ry_res, loss_size_cls,loss_size_res, loss_loc) # train_loss
            # print('out_ry_cls',out_ry_cls,ry_cls)
            out_ry_cls = t.argmax(out_ry_cls,1)
            out_size_cls = t.argmax(out_size_cls,1)
            lm.cm_add(out_ry_cls,ry_cls, out_size_cls,size_cls) #可以均为NxK维，也可均为一维真值
            # print('cfm',out_ry_cls, ry_cls, sep='\n')
            if ii%opt.print_freq==opt.print_freq-1:
                lm.print_log(loss, loss_ry_cls,loss_ry_res, loss_size_cls,loss_ry_res, loss_loc)
            if ii%opt.plot_freq==opt.plot_freq-1:
                niter = epoch * batch_num + ii # training step; batch_num is always same
                lm.plot(writer,niter, total=True,ry=True,size=True,loc=True)
                
        # validate and visualize; lossMeter_val, benchmark_val
        lm_val, bm_val = val(model,val_dataloader,device,writer_val,niter,epoch)
        # aos_val, cm_ry_val, iou_val, iou_bev_val 
        if opt.save_model and (bm_val.aos > 0.9 or bm_val.iou_3d>0.7):
            model.save()
        # benchmark & lossmeter
        lm_val.plot(writer_val,niter, total=True,ry=True,size=True,loc=True) # curve obout val
        bm_val.plot(writer_val,niter, aos=True, iou_3d=True, iou_bev=True)
        bm_val.print_log(aos=True, iou_3d=True, iou_bev=True)
        # write text; print confusion_matrix:ry & size
        print_epoch_info(writer,epoch,niter, lr, lm,lm_val, bm_val)
        # update learning rate
        if lm.value()[0].item() > previous_loss: # .item change tensor to float         
            lr = lr * opt.lr_decay
            # 第二种降低学习率的方法:不会有moment等信息的丢失
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = lm.value()[0].item()
        # if epoch < 5:
            # lr = lr * 0.5
            # for param_group in optimizer.param_groups:
                # param_group['lr'] = lr
    # model.save()
    print('Finishe training & save model')

def val(model,dataloader,dev,writer_val,niter,epoch):
    '''
    计算模型在验证集上的准确率等信息
    '''
    print('Evaluation of epoch:{} with batch_num:{}'.format(epoch,len(dataloader))
    model.eval()
    lm_val = LossMeter(total=True,ry=True,size=True,loc=True)
    bm_val = Benchmark(aos=True, iou_3d=True, iou_bev=True)
    
    crossEp = t.nn.CrossEntropyLoss()
    smoothL1 = t.nn.SmoothL1Loss() # !!! loss need to be determined match output & label
          
    for ii, (in_pc,in_center,in_rot,in_vec, loc,ry_cls,ry_res,size_cls,size_res) in enumerate(dataloader):
        with t.no_grad():
            [in_pc,in_center, in_rot, in_vec] = dev.trans(in_pc,in_center, in_rot, in_vec) # BCHW
            [loc, ry_cls,ry_res, size_cls,size_res] = \
                dev.trans(loc, ry_cls.squeeze(1),ry_res, size_cls.squeeze(1),size_res)# label
            
            out_loc, out_ry_cls,out_ry_res, out_size_cls,out_size_res = model(in_pc,in_vec)
            
            loss_loc = smoothL1(out_loc+in_center,loc)
            loss_ry_cls = crossEp(out_ry_cls,ry_cls) # output(BxNum)不需要加softmax, label(B)不需要为onehot
            loss_ry_res = smoothL1(out_ry_res,ry_res)
            loss_size_cls = crossEp(out_size_cls,size_cls)
            loss_size_res = smoothL1(out_size_res,size_res)
            loss = loss_loc + loss_ry_cls + loss_ry_res + loss_size_cls + loss_size_res 
            
            lm_val.lm_add(loss, loss_ry_cls,loss_ry_res, loss_size_cls,loss_size_res, loss_loc) # train_loss
            out_ry_cls = t.argmax(out_ry_cls,1)
            out_size_cls = t.argmax(out_size_cls,1)
            lm_val.cm_add(out_ry_cls, ry_cls, out_size_cls,size_cls) #可以均为NxK维，也可均为一维真值
            
            bsize = loc.shape[0] # num
            for i in range(bsize):
                out_ry = class2angle(out_ry_cls[i].item(), out_ry_res[i].item())
                ry = class2angle(ry_cls[i].item(), ry_res[i].item())
                out_size = class2size(out_size_cls[i].itemm(), out_size_res[i].item())
                size = class2size(size_cls[i].item(), size_res[i],item())
                # size = hwl
                size_p, ry_p, loc_p = tuple(out_size[i].tolist()), out_ry, tuple(out_loc[i].tolist()) # predict
                size_g, ry_g, loc_g = tuple(size[i].tolist()), ry, tuple(loc[i].tolist()) # groundtruth
                iou_3d, iou_bev = box_util.call_iou_compute_3d((size_g, ry_g, loc_g),(size_p, ry_p, loc_p))
                
                aos = (1+math.cos(out_ry-ry)) / 2
                bm_val.add(aos, iou_3d,iou_bev) # recall of aos =1; bev=bird eye's view
    model.train()
    return lm_val, bm_val
    

def help():
    '''
    打印帮助的信息： python file.py help
    '''
    print('''
    usage : python file.py <function> [--args=value]
    <function> := train | test | help
    example: 
            python {0} train --env='env0701' --lr=0.01
            python {0} test --dataset='path/to/dataset/root/'
            python {0} help
    avaiable args:'''.format(__file__))
    # from inspect import getsource
    # source = (getsource(opt.__class__))
    # print(source)

    
if __name__=='__main__':
#     import fire
#     fire.Fire()
    pass

SyntaxError: invalid syntax (<ipython-input-8-88b83a0deb95>, line 115)