In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
import numpy as np
from torchvision import transforms
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from utils import bbox_collate
import yaml
import os
import json
import copy
from PIL import Image
import numpy as np
from pycocotools.coco import COCO
from torch.utils.data import Dataset
from utils import data2target
from torchvision.ops import roi_align
import copy
import numpy as np
import cv2
from imgaug import BoundingBoxesOnImage
from imgaug import augmenters as iaa
from imgaug import parameters as iap

In [2]:
config = yaml.safe_load(open('./config.yaml'))
dataset_means = json.load(open(config['dataset']['mean_file']))

In [3]:
class ROIPool(nn.Module):
    def __init__(self):
        super().__init__()
        self.max_pool = nn.AdaptiveMaxPool2d(7)
        
    def forward(self, inputs, rois):
        #rois: [torch(2000,4) x bs] 
        #output: bs, 2000, ch, h, w
        n = min(list(map(lambda x: x.shape[0], rois)))
        for i, tensor in enumerate(rois):
            rois[i] = rois[i][:n,:]
            rois[i].cuda()
        rois = torch.stack(rois, dim=0) #tensor (bs, n, 4)
        bs = rois.shape[0]
        x1 = rois[:,:,0]
        x2 = rois[:,:,2]
        y1 = rois[:,:,1]
        y2 = rois[:,:,3]
        h, w = inputs.shape[2], inputs.shape[3]
        x1 = np.floor(x1/512*w).type(torch.uint8)
        x2 = np.ceil(x2/512*w).type(torch.uint8)
        y1 = np.floor(y1/512*h).type(torch.uint8)
        y2 = np.ceil(y2/512*h).type(torch.uint8)
        
        res = []
        for batch in range(bs):
            for i in range(n):
                inp = inputs[batch, :, y1[batch,i]:y2[batch,i], x1[batch,i]:x2[batch,i]].unsqueeze(0) # 1, ch, h',w'
                inp = self.max_pool(inp)
                res.append(inp)
        res = torch.cat(res, dim=0) #batch*dim, ch, 7, 7
        return res

In [4]:
class _ROIPool(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, inputs, rois):
        #rois: [torch(2000,4) x bs] 
        #output: bs, 2000, ch, h, w
        n = min(list(map(lambda x: x.shape[0], rois)))
        r = []
        for i, tensor in enumerate(rois):
            rois[i] = rois[i][:n,:]
            tmp = torch.zeros(n, 1)
            tmp[:] = i
            tmp = torch.cat([tmp, rois[i]], dim=1)
            r.append(tmp)
        r = torch.cat(r, dim=0)
        print(r.shape)
        r.cuda()
            
            
        h, w = inputs.shape[2], inputs.shape[3]
        res = roi_align(inputs, r, 7, spatial_scale=w/512)
        
        print(res.shape) #batch*dim, ch, 7, 7
        return res

In [5]:
class vector_extractor(nn.Module):
    """input: images, proposals
         output: feature vector"""
    def __init__(self):
        super().__init__()
        self.feature_map = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-2])
        self.roi_pool = ROIPool()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.feature_vector = nn.Sequential(nn.Linear(2048, 1000),
                                                                       nn.ReLU(inplace=True),
                                                                       nn.Linear(1000, 500),
                                                                       nn.ReLU(inplace=True))
    def forward(self, inputs, rois):
        f = self.feature_map(inputs)
        f = self.roi_pool(f, rois)
        f = self.gap(f).view(f.shape[0], f.shape[1]) #batch*proposal, ch
        f = self.feature_vector(f) 
        return f
    

In [6]:
class MIDN(nn.Module):
    """input: feature vector, labels
         output: scores per proposal, loss"""
    def __init__(self, bs):
        super().__init__()
        c_in = 500
        self.bs = bs
        self.layer_c = nn.Linear(c_in, 3)
        self.layer_d = nn.Linear(c_in, 3)
        self.softmax_c = nn.Softmax(dim=2)
        self.softmax_d = nn.Softmax(dim=1)
        self.loss = nn.BCELoss()
        
    def forward(self, inputs, labels):
        bs, proposal = self.bs, inputs.shape[0]//self.bs
        x_c = self.layer_c(inputs).view(bs, proposal, -1) #bs, proposal, 3
        x_d = self.layer_d(inputs).view(bs, proposal, -1)
        sigma_c = self.softmax_c(x_c)
        sigma_d = self.softmax_d(x_d)
        x_r = sigma_c * sigma_d #bs, proposal, 3
        phi_c = x_r.sum(dim=1) #bs, 3
        loss = self.loss(phi_c, labels)
        
        return x_r, loss
        
        
        

In [7]:
class ICR(nn.Module):
    """input: feature vector (bs*proposal, ch)
                     k-1th proposal scores(bs, proposal,3or4)
                     supervision (label) (bs)
                     ROI proposals
         output: refined proposal scores, loss"""
    def __init__(self, bs):
        super().__init__()
        c_in = 500
        self.I_t = 0.5
        self.bs = bs
        self.fc = nn.Linear(c_in, 4)
        self.softmax = nn.Softmax(dim=2)
        self.loss = nn.CrossEntropyLoss(reduction="none")
        """self.y_k = torch.zeros(bs, proposal, 4).cuda()
        self.y_k[:, :, 3] = 1
        self.w = torch.zeros(bs, proposal).cuda()"""
        
    def forward(self, inputs, pre_score, labels, rois):
        bs, proposal = self.bs, inputs.shape[0]//self.bs
        xr_k = self.fc(inputs).view(bs, proposal, -1) #bs, proposal, 4
        xr_k = self.softmax(xr_k)
        if pre_score == None:
            return xr_k
        _xr_k = xr_k.view(bs*proposal, -1)
        self.y_k = torch.zeros(bs, proposal, 4).cuda()
        self.y_k[:, :, 3] = 1
        self.w = torch.zeros(bs, proposal).cuda()
        I = torch.zeros(bs, proposal)
        for batch in range(bs):
            for c in range(3):
                if labels[batch][c]:
                    m = torch.max(pre_score[batch, :, c], 0)
                    x = m[0].item()
                    j = m[1].item()
                    for r in range(proposal):
                        _I = calc_iou(rois[batch][r], rois[batch][j])
                        if _I > I[batch, r]:
                            I[batch, r] = _I
                            self.w[batch, r] = x
                            if _I > self.I_t:
                                self.y_k[batch, r, c] = 1
                                self.y_k[batch, r, 3] = 0
        self.y_k = self.y_k.view(bs*proposal, -1)
        self.w = self.w.view(bs*proposal, 1)
        loss = self.loss(_xr_k.cuda().float(), torch.max(self.y_k, 1)[1])
        loss = torch.sum(self.w*loss)
        return xr_k, loss

In [8]:
def calc_iou(a, b):
    area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])

    iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
    ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])

    iw = torch.clamp(iw, min=0)
    ih = torch.clamp(ih, min=0)

    ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih

    ua = torch.clamp(ua, min=1e-8)

    intersection = iw * ih

    IoU = intersection / ua

    return IoU

In [9]:
#TODO pseudo label を読み込むようにdataset と　dataloaderを改良
#TODO 各モジュールが正常に動くか逐次チェック
#TODO main()を作成してtrain

In [10]:
class MedicalBboxDataset(Dataset):
    def __init__(self, annotation, data_path, pseudo_path=None, transform=None):
        if isinstance(annotation, dict):
            self.coco = COCO()
            self.coco.dataset = annotation
            self.coco.createIndex()
        else:
            self.coco = COCO(annotation)
        self.data_path = data_path
        self.imgids = self.coco.getImgIds()
        self.set_transform(transform)
        self.load_classes()
        self.p_path = pseudo_path
        if pseudo_path != None:
            with open(pseudo_path, "r") as json_open:
                self.p_file = json.load(json_open)
            

    def load_classes(self):
        # load class names (name -> label)
        categories = self.coco.loadCats(self.coco.getCatIds())
        categories.sort(key=lambda x: x['id'])

        self.classes             = {}
        self.coco_labels         = {}
        self.coco_labels_inverse = {}
        for c in categories:
            self.coco_labels[len(self.classes)] = c['id']
            self.coco_labels_inverse[c['id']] = len(self.classes)
            self.classes[c['name']] = len(self.classes)

        # also load the reverse (label -> name)
        self.labels = {}
        for key, value in self.classes.items():
            self.labels[value] = key

    def __len__(self):
        """
        :return: Number of images
        :rtype: int
        """
        return len(self.imgids)
    
    def __getitem__(self, i):
        
        if isinstance(i, slice):
            imgids = self.imgids[i]
            return self.split_by_imgids(imgids)
        else:
            return self.transform({
                'img': self.load_image(i),
                **self.load_annotations(i)
            })
    
    def load_image(self, i):
        '''
        Args:
            i (int): Index of image
        
        Returns:
            numpy.ndarray: Selected image
        '''
        imgid = self.imgids[i]
        img_info = self.coco.loadImgs(imgid)[0]
        img_path = os.path.join(self.data_path, img_info['file_name'])
        img = Image.open(img_path)
        return np.array(img)
    
    def load_annotations(self, i):
        '''
        Args:
            i (int): Index of image
        
        Returns:
            dict: Annotation of the selected image
        '''
        imgid = self.imgids[i]
        annids = self.coco.getAnnIds(imgIds=imgid)
        anno_info = self.coco.loadAnns(annids)
        annotations     = np.zeros((0, 5))
        bboxes, labels = [], []
        
        for anno in anno_info:
            bboxes.append(anno['bbox'])
            label = self.coco.getCatIds().index(anno['category_id'])
            labels.append(label)

            if anno['bbox'][2] < 1 or anno['bbox'][3] < 1:
                continue

            annotation        = np.zeros((1, 5))
            annotation[0, :4] = anno['bbox']
            annotation[0, 4]  = self.coco_label_to_label(anno['category_id'])
            annotations       = np.append(annotations, annotation, axis=0)

        # transform from [x, y, w, h] to [x1, y1, x2, y2]
        annotations[:, 2] = annotations[:, 0] + annotations[:, 2]
        annotations[:, 3] = annotations[:, 1] + annotations[:, 3]
        
        bboxes = np.array(bboxes, dtype=np.float32).reshape(-1, 4)
        bboxes[:, 2:] += bboxes[:, :2]  # xywh -> xyxy
        labels = np.array(labels, dtype=np.int)
        p_bboxes = []
        if self.p_path != None:
            p_bboxes = self.p_file["pseudo_annotations"][0]["p_bbox2"] #[f"p_bbox{imgid}"]
        
        return {
            'annot' : annotations,
            'bboxes': bboxes,
            'labels': labels,
            'p_bboxes':p_bboxes
        }
    
    def set_transform(self, transform):
        '''
        Args:
            transform (function): Function to transform
        '''
        self.transform = transform if transform else lambda x: x
    
    def split(self, split, split_path):
        if not isinstance(split, (tuple, list, set)):
            split = split,
        split_data = json.load(open(split_path))

        imgids = []
        for s in split:
            imgids += split_data['image_id'][s]
        
        return self.split_by_imgids(imgids)
        
    def split_by_imgids(self, imgids):
        coco_format = {
            'info': self.coco.dataset['info'],
            'categories': self.coco.dataset['categories'],
            'images': self.coco.loadImgs(imgids),
            'annotations': self.coco.loadAnns(self.coco.getAnnIds(imgIds=imgids))
        }
        return MedicalBboxDataset(coco_format, self.data_path, self.p_path, self.transform)

    def integrate_classes(self, new_cats, idmap):
        annotations = copy.deepcopy(self.coco.dataset['annotations'])
        for anno in annotations:
            anno['category_id'] = idmap[anno['category_id']]

        coco_format = {
            'info': self.coco.dataset['info'],
            'categories': new_cats,
            'images': self.coco.dataset['images'],
            'annotations': annotations
        }
        return MedicalBboxDataset(coco_format, self.data_path, self.p_path, self.transform)

    def with_annotation_imgids(self):
        imgids = []
        for catid in self.coco.getCatIds():
            imgids += self.coco.getImgIds(catIds=catid)
        return imgids
    
    def with_annotation(self):
        imgids = self.with_annotation_imgids()
        return self.split_by_imgids(imgids)
    
    def without_annotation(self):
        imgids = list(set(self.imgids) - set(self.with_annotation_imgids()))
        return self.split_by_imgids(imgids)
    
    def get_coco(self):
        return self.coco

    def get_category_names(self):
        catids = self.coco.getCatIds()
        categories = self.coco.loadCats(catids)
        return [cat['name'] for cat in categories]
    def coco_label_to_label(self, coco_label):
        return self.coco_labels_inverse[coco_label]


    def label_to_coco_label(self, label):
        return self.coco_labels[label]

    def image_aspect_ratio(self, image_index):
        image = self.coco.loadImgs(self.imgids[image_index])[0]
        return float(image['width']) / float(image['height'])

In [11]:
class ToFixedSize:
    def __init__(self, size):
        self.size = size
    
    def __call__(self, data):
        data = copy.copy(data)
        
        raw_h, raw_w = data['img'].shape[:2]
        mag = min(self.size[0]/raw_h, self.size[1]/raw_w)
        h = round(raw_h * mag)
        w = round(raw_w * mag)
        
        image = data['img']
        data['img'] = np.zeros([*self.size, data['img'].shape[2]])
        data['img'][:h, :w] = cv2.resize(image, (w, h), interpolation=cv2.INTER_CUBIC)

        if data["p_bboxes"].size > 0:
            data['p_bboxes'][:, 0::2] *= w / raw_w
            data['p_bboxes'][:, 1::2] *= h / raw_h
            data['p_bboxes'][:, 0:2] = np.floor(data['p_bboxes'][:, 0:2])
            data['p_bboxes'][:, 2:4] = np.ceil(data['p_bboxes'][:, 2:4])
        else:
            data['bboxes'][:, 0::2] *= w / raw_w
            data['bboxes'][:, 1::2] *= h / raw_h
        
        
        return data


class Augmentation:
    def __init__(self, settings):
        seq = []
        
        def active(t):
            return t in settings and settings[t] != False
        
        # 回転, 左右反転, 上下反転で8パターン
        if active('flip'):
            seq += [
                iaa.Affine(rotate=iap.Binomial(0.5)*90),
                iaa.Fliplr(0.5),
                iaa.Flipud(0.5)
            ]
        if active('rotate_flip_shear'):
            seq += [
                iaa.Affine(rotate=iap.DiscreteUniform(-179,180),shear=(-10, 10)),
                iaa.Fliplr(0.5),
                iaa.Flipud(0.5)
            ]
        
        if active('gamma_per_channel'):
            low, high = settings['gamma_per_channel']
            seq.append(iaa.GammaContrast([low, high], per_channel=True))
            
        if active('gamma'):
            low, high = settings['gamma']
            seq.append(iaa.GammaContrast([low, high]))
        
        if active('gaussnoise'):
            intensity = settings['gaussnoise']
            seq.append(
                iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, intensity), per_channel=1)
            )
        
        self.seq = iaa.Sequential(seq)
    
    def __call__(self, data):
        data = copy.copy(data)
        seq = self.seq.to_deterministic()
        image = data['img']
        data["p_bboxes"] = np.array(data["p_bboxes"])
        if data["p_bboxes"].size > 0:
            bboxes = BoundingBoxesOnImage.from_xyxy_array(data['p_bboxes'], shape=image.shape)
            image, bboxes = seq(image=image, bounding_boxes=bboxes)
            bboxes = bboxes.clip_out_of_image()
            data['p_bboxes'] = BoundingBoxesOnImage.to_xyxy_array(bboxes)
        else:
            bboxes = BoundingBoxesOnImage.from_xyxy_array(data['bboxes'], shape=image.shape)
            image, bboxes = seq(image=image, bounding_boxes=bboxes)
            bboxes = bboxes.clip_out_of_image()
            data['bboxes'] = BoundingBoxesOnImage.to_xyxy_array(bboxes)
        data['img'] = image
        

        return data


class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def __call__(self, data):
        data = copy.copy(data)
        data['img'] = (data['img'] - self.mean) / self.std
        return data

class UnNormalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    
    def __call__(self, data):
        data = copy.copy(data)
        data['img'] = data['img'].squeeze().permute(1,2,0).numpy()
        data['img'] = (data['img'] * self.std) + self.mean
        data['img'] = data['img'].astype(np.uint8)
        #data['img'] = data['img'].transpose(2,0,1)
        return data


class HWCToCHW:
    def __call__(self, data):
        data = copy.copy(data)
        data['img'] = data['img'].transpose(2, 0, 1)
        return data


In [12]:
class InfiniteSampler:
    '''
    与えられたLength内に収まる数値を返すIterator
    '''
    def __init__(self, length, random=True, generator=None):
        self.length = length
        self.random = random
        if random:
            self.generator = torch.Generator() if generator is None else generator
        self.stock = []
        
    def __iter__(self):
        while True:
            yield self.get(1)[0]
    
    def get(self, n):
        while len(self.stock) < n:
            self.extend_stock()
        
        indices = self.stock[:n]
        self.stock = self.stock[n:]
        
        return indices
        
    def extend_stock(self):
        if self.random:
            self.stock += torch.randperm(self.length, generator=self.generator).numpy().tolist()
        else:
            self.stock += list(range(self.length))


class MixedRandomSampler(torch.utils.data.sampler.Sampler):
    '''
    複数のデータセットを一定の比で混ぜながら、指定した長さだけIterationするSampler
    '''
    def __init__(self, datasets, length, ratio=None, generator=None):
        self.catdataset = torch.utils.data.ConcatDataset(datasets)
        self.length = length
        
        self.generator = torch.Generator() if generator is None else generator
        
        self.dataset_lengths = [len(dataset) for dataset in datasets]
        if ratio is None:
            self.ratio = torch.tensor(self.dataset_lengths, dtype=torch.float)
        else:
            self.ratio = torch.tensor(ratio, dtype=torch.float)
            
        self.samplers = [InfiniteSampler(l, generator=self.generator) for l in self.dataset_lengths]
    
    def __iter__(self):
        start_with = torch.cumsum(torch.tensor([0] + self.dataset_lengths), dim=0)
        selected = self.random_choice(self.ratio, self.length)
        
        indices = torch.empty(self.length, dtype=torch.int)
        
        for i in range(len(self.ratio)):
            mask = selected == i
            n_selected = mask.sum().item()
            indices[mask] = torch.tensor(self.samplers[i].get(n_selected), dtype=torch.int) + start_with[i]
        
        indices = indices.numpy().tolist()[0::1]
        
        return iter(indices)
    
    def __len__(self):
        return int(self.length)
    
    def get_concatenated_dataset(self):
        return self.catdataset
    
    def random_choice(self, p, size):
        random = torch.rand(size, generator=self.generator)
        bins = torch.cumsum(p / p.sum(), dim=0)
        choice = torch.zeros(size, dtype=torch.int)

        for i in range(len(p) - 1):
            choice[random > bins[i]] = i + 1

        return choice

In [13]:
p_path = "/data/unagi0/masaoka/endoscopy/annotations/pseudo_annotations.json"
config["batchsize"]=2

In [14]:
transform = Compose([
        Augmentation(config['augmentation']),
        ToFixedSize([config['inputsize']] * 2),  # inputsize x inputsizeの画像に変換
        Normalize(dataset_means['mean'], dataset_means['std']),
        HWCToCHW()
        ])

In [15]:
dataset_all = MedicalBboxDataset(
    config['dataset']['annotation_file'],
    config['dataset']['image_root'],
    pseudo_path=p_path)
if 'class_integration' in config['dataset']:
    dataset_all = dataset_all.integrate_classes(
        config['dataset']['class_integration']['new'],
        config['dataset']['class_integration']['map'])

loading annotations into memory...
Done (t=0.48s)
creating index...
index created!
creating index...
index created!


In [16]:
train_all = dataset_all.split(config['dataset']['train'], config['dataset']['split_file'])
train_all.set_transform(transform)
train_normal = train_all.without_annotation()
train_anomaly = train_all.with_annotation()
n_fg_class = len(dataset_all.get_category_names()) 

generator = torch.Generator()
generator.manual_seed(0)
sampler = MixedRandomSampler(
    [train_normal, train_anomaly],
    config["n_iteration"]*config["batchsize"] ,
    ratio=[config['negative_ratio'], 1],
    generator=generator)
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, config["batchsize"] , drop_last=False)

dataloader_train = DataLoader(
    sampler.get_concatenated_dataset(),
    num_workers=8,
    batch_sampler=batch_sampler,
    collate_fn=bbox_collate)

creating index...
index created!
creating index...
index created!
creating index...
index created!


In [17]:
for i in dataloader_train:
    print(i["p_bboxes"])
    break

[tensor([[  0.,   0., 512., 512.],
        [318., 474., 363., 512.],
        [126.,   1., 512., 491.],
        ...,
        [189., 168., 209., 185.],
        [203.,   0., 226.,  10.],
        [334., 294., 372., 335.]]), tensor([[  0.,   0., 512., 512.],
        [351.,   0., 392.,  51.],
        [138.,  45., 512., 512.],
        ...,
        [180., 320., 200., 336.],
        [171., 496., 193., 512.],
        [342., 191., 377., 230.]])]


In [18]:
i["p_bboxes"][0].shape

torch.Size([1988, 4])

In [19]:
i["p_bboxes"][1].shape

torch.Size([1970, 4])

In [20]:
v = vector_extractor()
v.cuda()
print("clear")

clear


In [21]:
f = v(i["img"].cuda().float(),i["p_bboxes"])

In [20]:
labels, n, t, v, u= data2target(i)

In [24]:
m = MIDN(config["batchsize"])
m.cuda()
print("clear")

clear


In [25]:
g = m(f,labels.cuda().float())

In [27]:
icr1 = ICR(config["batchsize"])
icr1.cuda()
print("clear")

clear


In [28]:
h = icr1(f, g[0], labels.cuda(), i["p_bboxes"])

In [29]:
icr2 = ICR(config["batchsize"])
icr2.cuda()
print("clear")

clear


In [30]:
j = icr2(f, h[0], labels.cuda(), i["p_bboxes"])

In [18]:
class OICR(nn.Module):
    def __init__(self,bs):
        super().__init__()
        self.v_extractor = vector_extractor()
        self.midn = MIDN(bs)
        self.icr1 = ICR(bs)
        self.icr2 = ICR(bs)
        self.icr3 = ICR(bs)
    
    def forward(self, inputs, labels, rois):
        if self.training:
            v = self.v_extractor(inputs, rois)
            x, midn_loss = self.midn(v, labels)
            x, loss1 = self.icr1(v, x, labels, rois)
            x, loss2 = self.icr2(v, x, labels, rois)
            x, loss3 = self.icr3(v, x, labels, rois) 
            loss = midn_loss + loss1 + loss2 + loss3
            return x, loss  
        else:
            v = self.v_extractor(inputs, rois)
            x = self.icr3(v, None, labels, rois) 
            x, rois = x[0], rois[0].cuda()
            s, i = torch.max(x, 1)
            sort = torch.argsort(s, descending=True)
            s, i = s.view(-1,1), i.view(-1,1).cuda().float()
            print(s.is_cuda, i.is_cuda, rois.is_cuda)
            cat = torch.cat([s, i ,rois], dim=1)
            cat = cat[sort, :]
            scores = cat[:, 0]
            labels = cat[:, 1]
            bboxes = cat[:, 2:]
            
            return scores, labels, bboxes

In [19]:
oicr = OICR(config["batchsize"])
oicr.cuda()
labels, n, t, v, u= data2target(i)

In [23]:
output = oicr(i["img"].cuda().float(),labels.cuda().float(), i["p_bboxes"])

In [20]:
oicr.eval()
print("clear")

clear


In [21]:
score, label, box = oicr(i["img"].cuda().float(),labels.cuda().float(), i["p_bboxes"])

True True True


In [25]:
box.shape

torch.Size([1964, 4])