In [1]:
import torch
import random
import numpy as np

from parsing.config import cfg
from parsing.utils.comm import to_device
from parsing.dataset import build_train_dataset
from parsing.detector import WireframeDetector
from parsing.solver import make_lr_scheduler, make_optimizer
from parsing.utils.logger import setup_logger
from parsing.utils.metric_logger import MetricLogger
from parsing.utils.miscellaneous import save_config
from parsing.utils.checkpoint import DetectronCheckpointer
import os
import time
import datetime
import argparse
import logging

class LossReducer(object):
    def __init__(self,cfg):
        # self.loss_keys = cfg.MODEL.LOSS_WEIGHTS.keys()
        self.loss_weights = dict(cfg.MODEL.LOSS_WEIGHTS)
    
    def __call__(self, loss_dict):
        total_loss = sum([self.loss_weights[k]*loss_dict[k] 
        for k in self.loss_weights.keys()])
        
        return total_loss

In [2]:
config_file_loc = '../config-files/hawp.yaml'
cfg.merge_from_file(config_file_loc)
# cfg.merge_from_list(None)
cfg.freeze()

output_dir = cfg.OUTPUT_DIR
if output_dir:
    if os.path.isdir(output_dir) and False:
        import shutil
        shutil.rmtree(output_dir)
    os.makedirs(output_dir, exist_ok=True)
logger = setup_logger('hawp', output_dir, out_file='train.log')
# logger.info(args)
logger.info("Loaded configuration file {}".format(config_file_loc))

# with open(config_file_loc,"r") as cf:
#     config_str = "\n" + cf.read()
#     logger.info(config_str)

logger.info("Running with config:\n{}".format(cfg))
output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml')
# logger.info("Saving config into: {}".format(output_config_path))

save_config(cfg, output_config_path)

random.seed(2)
np.random.seed(2)
torch.manual_seed(2)

if torch.cuda.is_available():
    torch.cuda.manual_seed(2)

2023-07-03 14:32:09,045 hawp INFO: Loaded configuration file ../config-files/hawp.yaml
2023-07-03 14:32:09,047 hawp INFO: Running with config:
DATALOADER:
  NUM_WORKERS: 8
DATASETS:
  DISTANCE_TH: 0.02
  IMAGE:
    HEIGHT: 512
    PIXEL_MEAN: [109.73, 103.832, 98.681]
    PIXEL_STD: [22.275, 22.124, 23.229]
    TO_255: True
    WIDTH: 512
  NUM_STATIC_NEGATIVE_LINES: 40
  NUM_STATIC_POSITIVE_LINES: 300
  TARGET:
    HEIGHT: 128
    WIDTH: 128
  TEST: ('wireframe_test',)
  TRAIN: ('wireframe_train',)
  VAL: ('wireframe_test',)
ENCODER:
  ANG_TH: 0.1
  DIS_TH: 5
  NUM_STATIC_NEG_LINES: 40
  NUM_STATIC_POS_LINES: 300
MODEL:
  DEVICE: cpu
  HEAD_SIZE: [[3], [1], [1], [2], [2]]
  HGNETS:
    DEPTH: 4
    INPLANES: 64
    NUM_BLOCKS: 1
    NUM_FEATS: 128
    NUM_STACKS: 2
  LOSS_WEIGHTS:
    loss_dis: 1.0
    loss_jloc: 8.0
    loss_joff: 0.25
    loss_md: 1.0
    loss_neg: 1.0
    loss_pos: 1.0
    loss_res: 1.0
  NAME: Hourglass
  OUT_FEATURE_CHANNELS: 256
  PARSING_HEAD:
    DIM_FC: 1024


In [3]:
from parsing.detector import get_hawp_model

logger = logging.getLogger("hawp.trainer")
device = cfg.MODEL.DEVICE
# model = WireframeDetector(cfg)
model = get_hawp_model(pretrained=True)

model = model.to(device)

train_dataset = build_train_dataset(cfg)

optimizer = make_optimizer(cfg,model)
scheduler = make_lr_scheduler(cfg,optimizer)

loss_reducer = LossReducer(cfg)

arguments = {}
arguments["epoch"] = 0
max_epoch = cfg.SOLVER.MAX_EPOCH
arguments["max_epoch"] = max_epoch

checkpointer = DetectronCheckpointer(cfg,
                                     model,
                                     optimizer,
                                     save_dir=cfg.OUTPUT_DIR,
                                     save_to_disk=True,
                                     logger=logger)

_ = checkpointer.load()
start_training_time = time.time()
end = time.time()

start_epoch = arguments['epoch']
epoch_size = len(train_dataset)

global_iteration = epoch_size*start_epoch

wireframe_train Hi ('wireframe_train',)
Hi 2
Hi 3 8
2023-07-03 14:32:19,867 hawp.trainer INFO: No checkpoint found. Initializing model from scratch


In [12]:
for epoch in range(start_epoch+1, arguments['max_epoch']+1):
    meters = MetricLogger(" ")
    model.train()
    arguments['epoch'] = epoch

    for it, (images, annotations) in enumerate(train_dataset):
        print('Hi')
        ann=annotations
        data_time = time.time() - end
        images = images.to(device)
        annotations = to_device(annotations,device)
        loss_dict, _ = model(images,annotations)
        total_loss = loss_reducer(loss_dict)

        with torch.no_grad():
            loss_dict_reduced = {k:v.item() for k,v in loss_dict.items()}
            loss_reduced = total_loss.item()
            meters.update(loss=loss_reduced, **loss_dict_reduced)
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        global_iteration +=1

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_batch = epoch_size*(max_epoch-epoch+1) - it +1
        eta_seconds = meters.time.global_avg*eta_batch
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if it % 20 == 0 or it+1 == len(train_dataset):
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "epoch: {epoch}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}\n",
                    ]
                ).format(
                    eta=eta_string,
                    epoch=epoch,
                    iter=it,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                )
            )


    checkpointer.save('model_{:05d}'.format(epoch))
    scheduler.step()

total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))

logger.info(
    "Total training time: {} ({:.4f} s / epoch)".format(
        total_time_str, total_training_time / (max_epoch)
    )
)

Hi


NameError: name 'lmap' is not defined

In [15]:
ann[0].keys()

dict_keys(['junctions', 'height', 'filename', 'width', 'edges_negative', 'edges_positive', 'reminder'])

In [24]:
ann[0]['edges_negative'] #4000

tensor([[ 61,  85],
        [ 61,  99],
        [ 64,  99],
        ...,
        [ 99, 136],
        [ 79,  91],
        [117, 148]])

In [25]:
ann[0]['edges_positive'] #132

tensor([[  0,   1],
        [  2,   3],
        [  3,   4],
        [  5,   6],
        [  6,   7],
        [  8,   9],
        [  8,  10],
        [ 10,   7],
        [ 11,  12],
        [ 12,  13],
        [ 13,  14],
        [ 15,  16],
        [ 16,  17],
        [ 17,  18],
        [ 15,  18],
        [ 19,  20],
        [ 20,  21],
        [ 21,  22],
        [ 22,  23],
        [ 24,  25],
        [ 25,  26],
        [ 27,  28],
        [ 29,  30],
        [ 30,  31],
        [ 31,  32],
        [ 32,  33],
        [ 33,  34],
        [ 34,  35],
        [ 32,  36],
        [ 37,  38],
        [ 38,  39],
        [ 40,  41],
        [ 42,  40],
        [ 43,  41],
        [ 44,  45],
        [ 46,  47],
        [ 48,  49],
        [ 50,  51],
        [ 51,  52],
        [ 52,  53],
        [ 50,  53],
        [ 54,  55],
        [ 56,  57],
        [ 58,  59],
        [ 60,  61],
        [ 61,  62],
        [ 62,  63],
        [ 61,  64],
        [ 64,  65],
        [ 66,  67],


In [38]:
ann[0]['junctions'] #149

tensor([[1.0691e+02, 1.6537e-01],
        [1.0493e+02, 1.1249e+02],
        [1.1320e+02, 6.6590e-01],
        [1.0900e+02, 7.8388e+00],
        [1.0572e+02, 1.2731e+02],
        [1.2208e+02, 6.6590e-01],
        [1.1237e+02, 1.1819e+01],
        [1.1259e+02, 4.0264e+01],
        [1.2785e+02, 4.2173e+00],
        [1.1251e+02, 1.8195e+01],
        [1.2667e+02, 3.8351e+01],
        [1.1259e+02, 4.3697e+01],
        [1.2622e+02, 4.2297e+01],
        [1.2519e+02, 7.4654e+01],
        [1.1251e+02, 6.5864e+01],
        [1.1202e+02, 6.8538e+01],
        [1.2489e+02, 7.8798e+01],
        [1.2384e+02, 1.0922e+02],
        [1.1112e+02, 8.9502e+01],
        [1.2282e+02, 1.2733e+02],
        [1.2356e+02, 1.1372e+02],
        [1.1128e+02, 9.3201e+01],
        [1.1039e+02, 1.1234e+02],
        [1.1675e+02, 1.2773e+02],
        [1.1483e+02, 1.2793e+02],
        [1.1039e+02, 1.1569e+02],
        [1.0965e+02, 1.2733e+02],
        [9.5149e+01, 7.3988e-02],
        [1.0156e-04, 1.0646e+01],
        [2.582

In [43]:
ann[0]['edges_negative'].shape

torch.Size([4000, 2])

In [44]:
torch.rand(10)

tensor([0.7244, 0.0467, 0.8237, 0.3383, 0.5662, 0.7242, 0.8872, 0.2052, 0.3333,
        0.4157])

In [41]:
torch.cat((ann[0]['junctions'][ann[0]['edges_positive'][:,0]], ann[0]['junctions'][ann[0]['edges_positive'][:,1]]),dim=-1)

tensor([[1.0691e+02, 1.6537e-01, 1.0493e+02, 1.1249e+02],
        [1.1320e+02, 6.6590e-01, 1.0900e+02, 7.8388e+00],
        [1.0900e+02, 7.8388e+00, 1.0572e+02, 1.2731e+02],
        [1.2208e+02, 6.6590e-01, 1.1237e+02, 1.1819e+01],
        [1.1237e+02, 1.1819e+01, 1.1259e+02, 4.0264e+01],
        [1.2785e+02, 4.2173e+00, 1.1251e+02, 1.8195e+01],
        [1.2785e+02, 4.2173e+00, 1.2667e+02, 3.8351e+01],
        [1.2667e+02, 3.8351e+01, 1.1259e+02, 4.0264e+01],
        [1.1259e+02, 4.3697e+01, 1.2622e+02, 4.2297e+01],
        [1.2622e+02, 4.2297e+01, 1.2519e+02, 7.4654e+01],
        [1.2519e+02, 7.4654e+01, 1.1251e+02, 6.5864e+01],
        [1.1202e+02, 6.8538e+01, 1.2489e+02, 7.8798e+01],
        [1.2489e+02, 7.8798e+01, 1.2384e+02, 1.0922e+02],
        [1.2384e+02, 1.0922e+02, 1.1112e+02, 8.9502e+01],
        [1.1202e+02, 6.8538e+01, 1.1112e+02, 8.9502e+01],
        [1.2282e+02, 1.2733e+02, 1.2356e+02, 1.1372e+02],
        [1.2356e+02, 1.1372e+02, 1.1128e+02, 9.3201e+01],
        [1.112

In [31]:
ann[2]['edges_negative'].shape

torch.Size([4000, 2])

In [32]:
ann[2]['edges_positive'].shape

torch.Size([98, 2])

In [33]:
ann[2]['edges_negative'].max()

tensor(122)

In [34]:
ann[2]['junctions'].shape

torch.Size([123, 2])

In [35]:
def _process_per_image(self,ann):
        junctions = ann['junctions']
        device = junctions.device
        height, width = ann['height'], ann['width']
        jmap = torch.zeros((height,width),device=device)
        joff = torch.zeros((2,height,width),device=device,dtype=torch.float32)
        # junctions[:,0] = junctions[:,0].clamp(min=0,max=width-1)
        # junctions[:,1] = junctions[:,1].clamp(min=0,max=height-1)
        xint,yint = junctions[:,0].long(), junctions[:,1].long()
        off_x = junctions[:,0] - xint.float()-0.5
        off_y = junctions[:,1] - yint.float()-0.5

        jmap[yint,xint] = 1
        joff[0,yint,xint] = off_x
        joff[1,yint,xint] = off_y
        edges_positive = ann['edges_positive']
        edges_negative = ann['edges_negative']
        pos_mat = self.adjacent_matrix(junctions.size(0),edges_positive,device)
        neg_mat = self.adjacent_matrix(junctions.size(0),edges_negative,device)        
        lines = torch.cat((junctions[edges_positive[:,0]], junctions[edges_positive[:,1]]),dim=-1)
        lines_neg = torch.cat((junctions[edges_negative[:2000,0]],junctions[edges_negative[:2000,1]]),dim=-1)
#         lmap, _, _ = _C.encodels(lines,height,width,height,width,lines.size(0))
        # lmap, _, _ = encodels(lines,height,width,height,width,lines.size(0))


        lpos = np.random.permutation(lines.cpu().numpy())[:self.num_static_pos_lines]
        lneg = np.random.permutation(lines_neg.cpu().numpy())[:self.num_static_neg_lines]
        # lpos = lines[torch.randperm(lines.size(0),device=device)][:self.num_static_pos_lines]
        # lneg = lines_neg[torch.randperm(lines_neg.size(0),device=device)][:self.num_static_neg_lines]
        lpos = torch.from_numpy(lpos).to(device)
        lneg = torch.from_numpy(lneg).to(device)
        
        lpre = torch.cat((lpos,lneg),dim=0)
        _swap = (torch.rand(lpre.size(0))>0.5).to(device)
        lpre[_swap] = lpre[_swap][:,[2,3,0,1]]
        lpre_label = torch.cat(
            [
                torch.ones(lpos.size(0),device=device),
                torch.zeros(lneg.size(0),device=device)
             ])

        meta = {
            'junc': junctions,
            'Lpos':   pos_mat,
            'Lneg':   neg_mat,
            'lpre':      lpre,
            'lpre_label': lpre_label,
            'lines':     lines,
        }


#         dismap = torch.sqrt(lmap[0]**2+lmap[1]**2)[None]
        def _normalize(inp):
            mag = torch.sqrt(inp[0]*inp[0]+inp[1]*inp[1])
            return inp/(mag+1e-6)
#         md_map = _normalize(lmap[:2])
#         st_map = _normalize(lmap[2:4])
#         ed_map = _normalize(lmap[4:])

#         md_ = md_map.reshape(2,-1).t()
#         st_ = st_map.reshape(2,-1).t()
#         ed_ = ed_map.reshape(2,-1).t()
#         Rt = torch.cat(
#                 (torch.cat((md_[:,None,None,0],md_[:,None,None,1]),dim=2),
#                  torch.cat((-md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)
#         R = torch.cat(
#                 (torch.cat((md_[:,None,None,0], -md_[:,None,None,1]),dim=2),
#                  torch.cat((md_[:,None,None,1], md_[:,None,None,0]),dim=2)),dim=1)

#         Rtst_ = torch.matmul(Rt, st_[:,:,None]).squeeze(-1).t()
#         Rted_ = torch.matmul(Rt, ed_[:,:,None]).squeeze(-1).t()
#         swap_mask = (Rtst_[1]<0)*(Rted_[1]>0)
#         pos_ = Rtst_.clone()
#         neg_ = Rted_.clone()
#         temp = pos_[:,swap_mask]
#         pos_[:,swap_mask] = neg_[:,swap_mask]
#         neg_[:,swap_mask] = temp

#         pos_[0] = pos_[0].clamp(min=1e-9)
#         pos_[1] = pos_[1].clamp(min=1e-9)
#         neg_[0] = neg_[0].clamp(min=1e-9)
#         neg_[1] = neg_[1].clamp(max=-1e-9)

#         mask = ((pos_[1]>self.ang_th)*(neg_[1]<-self.ang_th)*(dismap.view(-1)<=self.dis_th)).float()

#         pos_map = pos_.reshape(-1,height,width)
#         neg_map = neg_.reshape(-1,height,width)

#         md_angle  = torch.atan2(md_map[1], md_map[0])
#         pos_angle = torch.atan2(pos_map[1],pos_map[0])
#         neg_angle = torch.atan2(neg_map[1],neg_map[0])

#         pos_angle_n = pos_angle/(np.pi/2)
#         neg_angle_n = -neg_angle/(np.pi/2)
#         md_angle_n  = md_angle/(np.pi*2) + 0.5
#         mask    = mask.reshape(height,width)

#         hafm_ang = torch.cat((md_angle_n[None],pos_angle_n[None],neg_angle_n[None],),dim=0)
#         hafm_dis   = dismap.clamp(max=self.dis_th)/self.dis_th
#         mask = mask[None]
#         target = {'jloc':jmap[None],
#                 'joff':joff,
#                 'md': hafm_ang,
#                 'dis': hafm_dis,
#                 'mask': mask
#                }
        return meta

In [None]:
_process_per_image

In [None]:
    def forward_train(self, images, annotations = None):
        device = images.device

        targets , metas = self.hafm_encoder(annotations)

        self.train_step += 1

        outputs, features = self.backbone(images)

        loss_dict = {
            'loss_md': 0.0,
            'loss_dis': 0.0,
            'loss_res': 0.0,
            'loss_jloc': 0.0,
            'loss_joff': 0.0,
            'loss_pos': 0.0,
            'loss_neg': 0.0,
        }


        mask = targets['mask']
        if targets is not None:
            for nstack, output in enumerate(outputs):
                loss_map = torch.mean(F.l1_loss(output[:,:3].sigmoid(), targets['md'],reduction='none'),dim=1,keepdim=True)
                loss_dict['loss_md']  += torch.mean(loss_map*mask) / torch.mean(mask)
                loss_map = F.l1_loss(output[:,3:4].sigmoid(), targets['dis'], reduction='none')
                loss_dict['loss_dis'] += torch.mean(loss_map*mask) /torch.mean(mask)
                loss_residual_map = F.l1_loss(output[:,4:5].sigmoid(), loss_map, reduction='none')
                loss_dict['loss_res'] += torch.mean(loss_residual_map*mask)/torch.mean(mask)
                loss_dict['loss_jloc'] += cross_entropy_loss_for_junction(output[:,5:7], targets['jloc'])
                loss_dict['loss_joff'] += sigmoid_l1_loss(output[:,7:9], targets['joff'], -0.5, targets['jloc'])

        loi_features = self.fc1(features)
        output = outputs[0]
        md_pred = output[:,:3].sigmoid()
        dis_pred = output[:,3:4].sigmoid()
        res_pred = output[:,4:5].sigmoid()
        jloc_pred= output[:,5:7].softmax(1)[:,1:]
        joff_pred= output[:,7:9].sigmoid() - 0.5

        lines_batch = []
        extra_info = {
        }

        batch_size = md_pred.size(0)

        for i, (md_pred_per_im, dis_pred_per_im,res_pred_per_im,meta) in enumerate(zip(md_pred, dis_pred,res_pred,metas)):
            lines_pred = []
            if self.use_residual:
                for scale in [-1.0,0.0,1.0]:
                    _ = self.proposal_lines(md_pred_per_im, dis_pred_per_im+scale*res_pred_per_im).view(-1, 4)
                    lines_pred.append(_)
            else:
                lines_pred.append(self.proposal_lines(md_pred_per_im, dis_pred_per_im).view(-1, 4))
            lines_pred = torch.cat(lines_pred)
            junction_gt = meta['junc']
            N = junction_gt.size(0)

            juncs_pred, _ = get_junctions(non_maximum_suppression(jloc_pred[i]),joff_pred[i], topk=min(N*2+2,self.n_dyn_junc))
            dis_junc_to_end1, idx_junc_to_end1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1).min(0)
            dis_junc_to_end2, idx_junc_to_end2 = torch.sum((lines_pred[:, 2:] - juncs_pred[:, None]) ** 2, dim=-1).min(0)

            idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2)
            idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2)
            iskeep = idx_junc_to_end_min<idx_junc_to_end_max
            idx_lines_for_junctions = torch.cat((idx_junc_to_end_min[iskeep,None],idx_junc_to_end_max[iskeep,None]),dim=1).unique(dim=0)
            idx_lines_for_junctions_mirror = torch.cat((idx_lines_for_junctions[:,1,None],idx_lines_for_junctions[:,0,None]),dim=1)
            idx_lines_for_junctions = torch.cat((idx_lines_for_junctions, idx_lines_for_junctions_mirror))
            lines_adjusted = torch.cat((juncs_pred[idx_lines_for_junctions[:,0]], juncs_pred[idx_lines_for_junctions[:,1]]),dim=1)

            cost_, match_ = torch.sum((juncs_pred-junction_gt[:,None])**2,dim=-1).min(0)
            match_[cost_>1.5*1.5] = N
            Lpos = meta['Lpos']
            Lneg = meta['Lneg']
            labels = Lpos[match_[idx_lines_for_junctions[:,0]],match_[idx_lines_for_junctions[:,1]]]

            iskeep = torch.zeros_like(labels, dtype= torch.bool)
            cdx = labels.nonzero().flatten()

            if len(cdx) > self.n_dyn_posl:
                perm = torch.randperm(len(cdx),device=device)[:self.n_dyn_posl]
                cdx = cdx[perm]

            iskeep[cdx] = 1

            if self.n_dyn_negl > 0:
                cdx = Lneg[match_[idx_lines_for_junctions[:,0]],match_[idx_lines_for_junctions[:,1]]].nonzero().flatten()

                if len(cdx) > self.n_dyn_negl:
                    perm = torch.randperm(len(cdx), device=device)[:self.n_dyn_negl]
                    cdx = cdx[perm]

                iskeep[cdx] = 1

            if self.n_dyn_othr > 0:
                cdx = torch.randint(len(iskeep), (self.n_dyn_othr,), device=device)
                iskeep[cdx] = 1

            if self.n_dyn_othr2 >0 :
                cdx = (labels==0).nonzero().flatten()
                if len(cdx) > self.n_dyn_othr2:
                    perm = torch.randperm(len(cdx), device=device)[:self.n_dyn_othr2]
                    cdx = cdx[perm]
                iskeep[cdx] = 1

            lines_selected = lines_adjusted[iskeep]
            labels_selected = labels[iskeep]

            lines_for_train = torch.cat((lines_selected,meta['lpre']))
            labels_for_train = torch.cat((labels_selected.float(),meta['lpre_label']))

            logits = self.pooling(loi_features[i],lines_for_train)

            loss_ = self.loss(logits, labels_for_train)

            loss_positive = loss_[labels_for_train==1].mean()
            loss_negative = loss_[labels_for_train==0].mean()

            loss_dict['loss_pos'] += loss_positive/batch_size
            loss_dict['loss_neg'] += loss_negative/batch_size

        return loss_dict, extra_info


In [25]:
i=0
for it, (images, annotations) in enumerate(train_dataset):
    print(i)
    data_time = time.time() - end
    images = images.to(device)
    annotations = to_device(annotations,device)
    loss_dict, _ = model(images,annotations)
    total_loss = loss_reducer(loss_dict)


KeyboardInterrupt: 