In [None]:
import pickle
import torch.utils.data
from LSUV_pytorch.LSUV import LSUVinit
from config import cfg
from model.model import PPModel,PPScatter
from model.loss import PPLoss
from data.dataset import PPDataset
from tqdm import tqdm_notebook
import pdb
import pathlib
import os.path as osp
from evaluate import evaluate_single,box_nms,make_pred_boxes
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches, patheffects
from pyquaternion import Quaternion
from utils.box_utils import boxes_to_image_space
from importlib import reload

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from lyft_dataset_sdk.lyftdataset import LyftDataset
from lyft_dataset_sdk.utils.data_classes import LidarPointCloud,Box

In [None]:
fn_in = cfg.NET.FEATURE_NET_IN
fn_out = cfg.NET.FEATURE_NET_OUT
cls_channels = len(cfg.DATA.ANCHOR_DIMS)*cfg.DATA.NUM_CLASSES
reg_channels = len(cfg.DATA.ANCHOR_DIMS)*cfg.DATA.REG_DIMS

In [None]:
ddfp = osp.join(cfg.DATA.LIDAR_TRAIN_DIR,'data_dict.pkl')
boxfp = osp.join(cfg.DATA.ANCHOR_DIR,'anchor_boxes.pkl')
crnfp = osp.join(cfg.DATA.ANCHOR_DIR,'anchor_corners.pkl')
cenfp = osp.join(cfg.DATA.ANCHOR_DIR,'anchor_centers.pkl')
xyfp = osp.join(cfg.DATA.ANCHOR_DIR,'anchor_xy.pkl')
token_fp = osp.join(cfg.DATA.TOKEN_TRAIN_DIR,'token_list.pkl')

In [None]:
data_dict = pickle.load(open(ddfp,'rb'))
anchor_boxes = pickle.load(open(boxfp,'rb'))
anchor_corners = pickle.load(open(crnfp,'rb'))
anchor_centers = pickle.load(open(cenfp,'rb'))
anchor_xy = pickle.load(open(xyfp,'rb'))
data_mean = pickle.load(open('pillar_means.pkl','rb'))
token_list = pickle.load(open(token_fp,'rb'))

In [None]:
pp_dataset = PPDataset(token_list,data_dict,anchor_boxes,
                      anchor_corners,anchor_centers,data_mean,training=True)

In [None]:
cls_channels = len(cfg.DATA.ANCHOR_DIMS)*cfg.DATA.NUM_CLASSES
reg_channels = len(cfg.DATA.ANCHOR_DIMS)*cfg.DATA.REG_DIMS

In [None]:
pp_model = PPModel(fn_in,fn_out,cls_channels,reg_channels,'cpu')
pp_loss  = PPLoss(cfg.NET.B_ORT,cfg.NET.B_REG,cfg.NET.B_CLS,cfg.NET.GAMMA,'cpu')

In [None]:
p0,i0,c0,r0 = pp_dataset[0]
p0.unsqueeze_(dim=0)
i0.unsqueeze_(dim=0)

In [None]:
pp_model.feature_net = LSUVinit(pp_model.feature_net,p0,needed_std=1,std_tol=0.1,max_attempts=10,do_orthonorm=False)
feature_out = pp_model.feature_net(p0)
scatter_out = pp_model.scatter(feature_out,i0)
print('here')
pp_model.backbone = LSUVinit(pp_model.backbone,scatter_out,needed_std=1,std_tol=0.1,max_attempts=10,do_orthonorm=False)
backbone_out = pp_model.backbone(scatter_out)
pp_model.det_head = LSUVinit(pp_model.det_head,backbone_out,needed_std=1,std_tol=0.1,max_attempts=10,do_orthonorm=False)


In [None]:
pi = 0.01
pp_model.det_head.cls.bias.data.fill_(-np.log((1-pi)/pi))

In [None]:
lr = cfg.NET.LEARNING_RATE
wd = cfg.NET.WEIGHT_DECAY
params = list(pp_model.parameters())
optim  = torch.optim.Adam(params,lr=lr,weight_decay=wd)

In [None]:
dataloader= torch.utils.data.DataLoader(pp_dataset,3,shuffle=False,num_workers=0)

In [None]:
i=0
num_epochs = 10
for e in range(num_epochs):
    for i,(pill,ind,c_targ,r_targ) in enumerate(dataloader):
        print(i)
        print(torch.mean(pill))
        cls_tensor,reg_tensor = pp_model(pill,ind)
        c_loss,r_loss,batch_loss = pp_loss(cls_tensor,reg_tensor,c_targ,r_targ)
        optim.zero_grad()
        batch_loss.backward()
        optim.step()
        print('total: ',batch_loss)
        print('cls: ',c_loss)
        print('reg: ',r_loss)

In [None]:
def draw_outline(o, lw):
    o.set_path_effects([patheffects.Stroke(
        linewidth=lw, foreground='black'), patheffects.Normal()])

def draw_rect(ax, b, color='black'):
    patch = ax.add_patch(patches.Polygon(b,closed=True,fill=False, edgecolor=color, lw=2))
    draw_outline(patch, 4)


def plot_box(box):
    fig,ax = plt.subplots()
    crnrs = box.bottom_corners().transpose([1,0])[:,:2]
    ax.scatter(crnrs[:,0],crnrs[:,1])

def vis_pillars_corners(p,corners):
    fig,ax = plt.subplots()
    ax.imshow(p)
    for i in range(len(corners)):
        draw_rect(ax,corners[i])

def vis_pillars(p,boxes):
    fig,ax = plt.subplots()
    ax.imshow(p)
    for box in boxes:
        crnrs = box.bottom_corners()[:2,:].transpose([1,0])
        draw_rect(ax,crnrs)