In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
import numpy as np
from torchvision import transforms
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from utils import bbox_collate
import yaml
import os
import json
import copy
from PIL import Image
import numpy as np
from pycocotools.coco import COCO
from torch.utils.data import Dataset
from utils import *
from torchvision.ops import roi_align
from torchvision.ops.boxes import box_iou
import copy
import numpy as np
import cv2
from make_dloader import make_data
import time
from matplotlib import pyplot as plt
import math

In [2]:
config = yaml.safe_load(open('./config.yaml'))
dataset_means = json.load(open(config['dataset']['mean_file']))

In [3]:
dataset_val = torch.load(f'/data/unagi0/masaoka/val_all1.pt')

In [4]:
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, 
                                num_workers=4, collate_fn=bbox_collate)

In [27]:
import torch
import torch.nn as nn
from torchvision.ops import roi_align, nms
from utils import calc_iou
import torchvision.models as models
import yaml
from torchvision.ops.boxes import box_iou
import torch.nn.functional as F
import time

config = yaml.safe_load(open('./config.yaml'))

class focalloss(nn.Module):
    def forward(self, classifications, regressions, anchors, annotations):
        alpha = 0.25
        gamma = 2.0
        batch_size = classifications.shape[0]
        classification_losses = []
        regression_losses = []

        anchor = anchors[0, :, :]

        anchor_widths  = anchor[:, 2] - anchor[:, 0]
        anchor_heights = anchor[:, 3] - anchor[:, 1]
        anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):

            classification = classifications[j, :, :]
            regression = regressions[j, :, :]

            bbox_annotation = annotations[j][ :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]

            if bbox_annotation.shape[0] == 0:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).float().cuda())
                    classification_losses.append(torch.tensor(0).float().cuda())
                else:
                    regression_losses.append(torch.tensor(0).float())
                    classification_losses.append(torch.tensor(0).float())

                continue

            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

            IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotations

            IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1

            #import pdb
            #pdb.set_trace()

            # compute the loss for classification
            targets = torch.ones(classification.shape) * -1

            if torch.cuda.is_available():
                targets = targets.cuda()

            targets[torch.lt(IoU_max, 0.4), :] = 0

            positive_indices = torch.ge(IoU_max, 0.5)

            num_positive_anchors = positive_indices.sum()

            assigned_annotations = bbox_annotation[IoU_argmax, :]

            targets[positive_indices, :] = 0
            targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1

            if torch.cuda.is_available():
                alpha_factor = torch.ones(targets.shape).cuda() * alpha
            else:
                alpha_factor = torch.ones(targets.shape) * alpha

            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
            focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))

            # cls_loss = focal_weight * torch.pow(bce, gamma)
            cls_loss = focal_weight * bce

            if torch.cuda.is_available():
                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
            else:
                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape))

            classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))

            # compute the loss for regression

            if positive_indices.sum() > 0:
                assigned_annotations = assigned_annotations[positive_indices, :]

                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices]

                gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
                gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
                gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights

                # clip widths to 1
                gt_widths  = torch.clamp(gt_widths, min=1)
                gt_heights = torch.clamp(gt_heights, min=1)

                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
                targets = targets.t()

                if torch.cuda.is_available():
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
                else:
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]])

                negative_indices = 1 + (~positive_indices)

                regression_diff = torch.abs(targets - regression[positive_indices, :])

                regression_loss = torch.where(
                    torch.le(regression_diff, 1.0 / 9.0),
                    0.5 * 9.0 * torch.pow(regression_diff, 2),
                    regression_diff - 0.5 / 9.0
                )
                regression_losses.append(regression_loss.mean().float())
            else:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).float().cuda())
                else:
                    regression_losses.append(torch.tensor(0).float())
            

        return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)


class PyramidFeatures(nn.Module):
    def __init__(self, C3_size, C4_size, C5_size, feature_size=256):
        super(PyramidFeatures, self).__init__()

        # upsample C5 to get P5 from the FPN paper
        self.P5_1 = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P5_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
        self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

        # add P5 elementwise to C4
        self.P4_1 = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P4_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
        self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

        # add P4 elementwise to C3
        self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

        # "P6 is obtained via a 3x3 stride-2 conv on C5"
        self.P6 = nn.Conv2d(C5_size, feature_size, kernel_size=3, stride=2, padding=1)

        # "P7 is computed by applying ReLU followed by a 3x3 stride-2 conv on P6"
        self.P7_1 = nn.ReLU()
        self.P7_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=2, padding=1)

    def forward(self, inputs):
        C3, C4, C5 = inputs

        P5_x = self.P5_1(C5)
        P5_upsampled_x = self.P5_upsampled(P5_x)
        P5_x = self.P5_2(P5_x)

        P4_x = self.P4_1(C4)
        P4_x = P5_upsampled_x + P4_x
        P4_upsampled_x = self.P4_upsampled(P4_x)
        P4_x = self.P4_2(P4_x)

        P3_x = self.P3_1(C3)
        P3_x = P3_x + P4_upsampled_x
        P3_x = self.P3_2(P3_x)

        P6_x = self.P6(C5)

        P7_x = self.P7_1(P6_x)
        P7_x = self.P7_2(P7_x)

        return [P3_x, P4_x, P5_x, P6_x, P7_x]


class RegressionModel(nn.Module):
    def __init__(self, num_features_in, num_anchors=9, feature_size=256):
        super(RegressionModel, self).__init__()

        self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act3 = nn.ReLU()

        self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act4 = nn.ReLU()

        self.output = nn.Conv2d(feature_size, num_anchors * 4, kernel_size=3, padding=1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.act2(out)

        out = self.conv3(out)
        out = self.act3(out)

        out = self.conv4(out)
        out = self.act4(out)

        out = self.output(out)

        # out is B x C x W x H, with C = 4*num_anchors
        out = out.permute(0, 2, 3, 1)

        return out.contiguous().view(out.shape[0], -1, 4)
    
class ClassificationModel(nn.Module):
    def __init__(self, num_features_in, num_anchors=9, num_classes=80, prior=0.01, feature_size=256):
        super(ClassificationModel, self).__init__()

        self.num_classes = num_classes
        self.num_anchors = num_anchors

        self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act3 = nn.ReLU()

        self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act4 = nn.ReLU()

        self.output = nn.Conv2d(feature_size, num_anchors * num_classes, kernel_size=3, padding=1)
        self.output_act = nn.Sigmoid()

    def forward(self, x):
        out = self.conv1(x)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.act2(out)

        out = self.conv3(out)
        out = self.act3(out)

        out = self.conv4(out)
        out = self.act4(out)

        out = self.output(out)
        out = self.output_act(out)

        # out is B x C x W x H, with C = n_classes + n_anchors
        out1 = out.permute(0, 2, 3, 1)

        batch_size, width, height, channels = out1.shape

        out2 = out1.view(batch_size, width, height, self.num_anchors, self.num_classes)

        return out2.contiguous().view(x.shape[0], -1, self.num_classes)
    
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.eps = eps

    def forward(self, input, target):
        y = target #bs*proposal,4
        
        logit = F.softmax(input,dim=-1) #bs*proposal,4
        logit = logit.clamp(self.eps, 1. - self.eps)
        loss = -1 * y * torch.log(logit) # cross entropy
        loss = loss * ((1 - logit) ** self.gamma) # focal loss

        return loss


class ASPP(nn.Module):
    def  __init__(self,size_list=[]):
        super().__init__()
        assert len(size_list)>0
        self.avgpool_list = []
        for size in size_list:
            self.avgpool_list.append(nn.Sequential(nn.AdaptiveAvgPool2d(size),
                                                    nn.Flatten()))
        
    def forward(self,x):
        vec = []
        for avgpool in self.avgpool_list:
            vec.append(avgpool(x))
        vec = torch.cat(vec,1)
        return vec

class _ROIPool(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, inputs, rois):
        #rois: bs,2000,4 
        #output: bs, 2000, ch, h, w
        rois = [r for r in rois]
        h, w = inputs.shape[2], inputs.shape[3]
        res = roi_align(inputs, rois, 7, spatial_scale=w/512)
        return res

class vector_extractor(nn.Module):
    """input: images, proposals
         output: feature vector"""
    def __init__(self):
        super().__init__()
        model = models.resnet18(pretrained=True)
        layers = list(model.children())[:-2]
        self.feature_map = nn.Sequential(*layers)
        self.roi_pool = _ROIPool()
        #self.gap = nn.AdaptiveAvgPool2d(1)
        self.feature_vector = nn.Sequential(ASPP([1,2]),
                                            #nn.AdaptiveAvgPool2d(1),
                                            nn.Flatten(),
                                            nn.Linear(512*(5), 2048),
                                            nn.ReLU(inplace=True),
                                            nn.BatchNorm1d(2048),
                                            nn.Linear(2048, 512),
                                            nn.BatchNorm1d(512),
                                            nn.ReLU(inplace=True))
    def forward(self, inputs, rois):
        f = self.feature_map(inputs)
        f = self.roi_pool(f, rois)
        #f = self.gap(f).view(f.shape[0], f.shape[1]) #batch*proposal, ch
        f = self.feature_vector(f) 
        return f
class vector_extractor_pyramid(nn.Module):
    """input: images, proposals
         output: feature vector"""
    def __init__(self):
        super().__init__()
        model = models.resnet18(pretrained=True)
        self.base = nn.Sequential(*list(model.children())[:4])
        self.layer1 = list(model.children())[4]
        self.layer2 = list(model.children())[5]
        self.layer3 = list(model.children())[6]
        self.layer4 = list(model.children())[7]
        self.roi_pool = _ROIPool()
        #self.gap = nn.AdaptiveAvgPool2d(1)
        self.feature_vector = nn.Sequential(ASPP([1,2]),
                                            #nn.AdaptiveAvgPool2d(1),
                                            nn.Flatten(),
                                            nn.Linear(512*(5), 2048),
                                            nn.ReLU(inplace=True),
                                            nn.BatchNorm1d(2048),
                                            nn.Linear(2048, 512),
                                            nn.BatchNorm1d(512),
                                            nn.ReLU(inplace=True))
    def forward(self, inputs, rois):
        x = self.base(inputs)
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        f = self.roi_pool(x4, rois)
        #f = self.gap(f).view(f.shape[0], f.shape[1]) #batch*proposal, ch
        f = self.feature_vector(f) 
        return f,x2,x3,x4
    
class MIDN(nn.Module):
    """input: feature vector, labels
         output: scores per proposal, loss"""
    def __init__(self):
        super().__init__()
        c_in = 512
        self.layer_c = nn.Linear(c_in, 4)
        self.layer_d = nn.Linear(c_in, 4)
        self.softmax_c = nn.Softmax(dim=2)
        self.softmax_d = nn.Softmax(dim=1)
        self.loss = nn.BCELoss()
        self.upper = 1-1e-8
        self.lower = 1e-8
    def forward(self, inputs, labels, num):
        bs, proposal = inputs.shape[0]//num, num
        x_c = self.layer_c(inputs).view(bs, proposal, -1) #bs, proposal, 4
        x_d = self.layer_d(inputs).view(bs, proposal, -1)
        sigma_c = self.softmax_c(x_c)
        sigma_d = self.softmax_d(x_d)
        x_r = sigma_c * sigma_d #bs, proposal, 4
        phi_c = x_r.sum(dim=1) #bs, 4
        phi_c = torch.clamp(phi_c, self.lower,self.upper)
        scaled = x_r/(torch.max(x_r,1)[0].unsqueeze(1)+1e-8)*phi_c.unsqueeze(1)
        loss = self.loss(phi_c, labels)
        print(loss)
        if not self.training:
            #print(f'scaled = {scaled.data.cpu().numpy()}')
            return x_r#scaled#
        return  scaled, loss#x_r,loss#phi_c.unsqueeze(1), loss #sigma_c,loss#
        

class ICR(nn.Module):
    """input: feature vector (bs*proposal, ch)
                     k-1th proposal scores(bs, proposal,3or4)
                     supervision (label) (bs)
                     ROI proposals
         output: refined proposal scores, loss"""
    def __init__(self):
        super().__init__()
        c_in = 512
        self.I_t = 0.1
        self.fc = nn.Linear(c_in, 4)
        self.softmax = nn.Softmax(dim=2)
        self.fl = FocalLoss(gamma=2)
        self.one = nn.Parameter(torch.tensor([1.])).requires_grad_(False)
        self.zero = nn.Parameter(torch.tensor([0.])).requires_grad_(False)
        self.d = self.zero.device
        
    def forward(self, inputs, pre_score, labels, rois, num):
        bs, proposal = inputs.shape[0]//num, num
        pre_score = pre_score.clone().detach()
        xr_k = self.fc(inputs).view(bs, proposal, -1) #bs, proposal, 4
        logit = self.softmax(xr_k)
        _xr_k = xr_k.view(bs*proposal, -1)
        y_k = torch.zeros(bs, proposal, 4).cuda()
        y_k[:, :, 3] = 1
        
        w = torch.stack([torch.cat([self.zero]*proposal)]*bs,0)
        I = torch.stack([torch.cat([self.zero]*proposal)]*bs,0)
        for batch in range(bs):
            for c in range(4):
                if labels[batch][c]:
                    label = c
                    j_list = (pre_score[batch,:,c]>0.5).nonzero().squeeze()
                    x_list = pre_score[batch,j_list,c]
                    if (j_list).size() == torch.Size([0]) or (j_list).size() == torch.Size([1]) or (j_list).size() == torch.Size([]):
                        m = torch.max(pre_score[batch, :, c], 0)
                        x = m[0].item()
                        j = m[1].item()
                        x_list = [x]
                        j_list =[j]
                    mat = box_iou(rois[batch], rois[batch][j_list])
                    for i,j in enumerate(j_list):
                        _I = mat[:,i]
                        old = I[batch].clone()
                        I[batch] = torch.where(_I > old,_I,old)
                        pre = w[batch].clone()
                        w[batch] = torch.where(_I > old,self.one*x_list[i],pre)
                        p = y_k[batch, :, 3].clone()
                        y_k[batch, :, 3] = torch.where(_I > self.I_t,self.zero,p)
                        q = y_k[batch, :, c].clone()
                        y_k[batch, :, c] = torch.where(_I > self.I_t,self.one,q)
                        
                
        y_k = y_k.view(bs*proposal, -1)
        w = w.view(bs*proposal, 1)
        
        imbalance_list = y_k.float().sum(dim=0)
        bg_num = (w!=0.).sum()-imbalance_list[:-1].sum()
        imbalance_list[-1] = bg_num
        loss = self.fl(_xr_k.float(), y_k)
        loss = (w*loss)/(imbalance_list+1e-7)#bs*proposal,4
        loss = loss.view(bs,proposal,-1)
        w_ = torch.exp(logit)#bs,proposal,4
        lab = labels.unsqueeze(1)#bs,1,4
        
        mask = 1-lab
        w_ = w_*mask+lab
        loss = loss*w_
        loss = loss.sum()/bs
        print(loss)
        if not self.training:
            return logit
        return logit, loss

class OICR(nn.Module):
    def __init__(self):
        super().__init__()
        self.v_extractor = vector_extractor()
        self.midn = MIDN()
        self.icr1 = ICR()
        self.icr2 = ICR()
        self.icr3 = ICR()
        
    def forward(self, inputs, labels, rois, num):
        if self.training:
            labels = labels.squeeze()
            rois = rois.squeeze()
            v = self.v_extractor(inputs, rois)
            x, midn_loss = self.midn(v, labels, num)
            x, loss1 = self.icr1(v, x, labels, rois, num)
            x, loss2 = self.icr2(v, x, labels, rois, num)
            x, loss3 = self.icr3(v, x, labels, rois, num) 
            loss = midn_loss + (loss1 + loss2 + loss3)
            return x, loss.unsqueeze(0),midn_loss.unsqueeze(0),loss1.unsqueeze(0),loss2.unsqueeze(0),loss3.unsqueeze(0)
        else:
            self.three = torch.tensor([3.]).cuda()
            self.inf = torch.tensor([-1.]).cuda()
            #rois = rois.squeeze()
            v = self.v_extractor(inputs, rois)
            x = self.midn(v, labels, num)
            print('---------------------------------------------------')
            print('icr1')
            x = self.icr1(v, x, labels, rois, num) 
            print('---------------------------------------------------')
            print('icr2')
            x = self.icr2(v, x, labels, rois, num) 
            print('---------------------------------------------------')
            print('icr3')
            x = self.icr3(v, x, labels, rois, num) 
            x, rois = x[0], rois[0].cuda()
            rois = rois[torch.max(x,1)[1]!=3]
            x = x[torch.max(x,1)[1]!=3]
            print(f'x={x}')
            if x.size() == torch.Size([0,4]):
                return torch.tensor([]),torch.tensor([]),torch.tensor([])
            #label = torch.max(labels,1)[1]
            #score = x[:,label].squeeze()
            #index = nms(rois,score,0.5)
            #scores = score[index]
           # bboxes = rois[index]
            #labels = torch.cat([label]*(len(index)))
            classes = torch.max(x,1)[1] #proposals
            scores = torch.max(x,1)[0] #proposals
            print(rois.shape,scores.shape,classes.shape)
            index = nms(rois,scores,0.5)
            scores = scores[index]
            labels = classes[index]
            bboxes = rois[index]
            return scores, labels, bboxes

def generate_gt(scores_list,rois_list,labels):
    #scores:bs,proposal,4; rois_list:bs,n,4; labels:bs,4
    gt_list = []
    bs,proposal,_ = scores_list.shape
    for batch in range(bs):
        label = torch.max(labels[batch],0)[1]
        if label == 3:
            gt_list.append(torch.tensor([]))
            continue
        rois = rois_list[batch]
        scores = scores_list[batch]
        rois = rois[torch.max(scores,1)[1]!=3]
        scores = scores[torch.max(scores,1)[1]!=3]
        scores = scores[:,label]
        if len(scores)==0:
            gt_list.append(torch.tensor([]))
            continue
        index = nms(rois,scores,0.5)
        rois = rois[index]
        scores = scores[index]
        for i in range(len(scores)):
            if scores[i]>0.5:
                continue
            else:
                break
        rois = rois[:i]
        scores = scores[:i]
        gt_list.append(rois)
    return gt_list

class Detection(nn.Module):
    def __init__(self):
        super().__init__()
        self.cls = nn.Linear(512,4)
        self.reg = nn.Linear(512,4)
        self.ce = nn.CrossEntropyLoss(reduction='none')
        self.l1 = nn.SmoothL1Loss(reduction='none')
        self.zero = nn.Parameter(torch.tensor(0.)).requires_grad_(False)
    def forward(self,v,gt_list,rois_list,labels,pre_score):
        if not self.training:
            cls = F.softmax(self.cls(v),dim=-1)#proposal,4
            reg = self.reg(v)#proposal,4
            tx,ty,tw,th = reg[:,0],reg[:,1],reg[:,2],reg[:,3]
            r = rois_list[0]
            rx = (r[:,2]+r[:,0])/2
            ry = (r[:,3]+r[:,1])/2
            rw = (r[:,2]-r[:,0])/2
            rh = (r[:,3]-r[:,1])/2
            gx = tx*rw+rx
            gy = ty*rh+ry
            gw = rw*torch.exp(tw)
            gh = rh*torch.exp(th)
            x0 = gx-gw/2
            x1 = gx+gw/2
            y0 = gy-gh/2
            y1 = gy+gh/2
            boxes = torch.stack([x0,y0,x1,y1],1)
            return boxes, cls
        bs = len(gt_list)
        pre_score = pre_score.clone().detach()
        classify = self.cls(v).view(bs,-1,4) #bs,proposal,4
        proposal = classify.shape[1]
        reg = self.reg(v).view(bs,-1,4)
        
        c_list = self.zero.clone()
        l1_list= self.zero.clone()
        i = 0
        for batch in range(bs):
            label = torch.max(labels[batch],0)[1]
            if gt_list[batch].size()==torch.Size([0,4]) or gt_list[batch].size()==torch.Size([0]) :
                target = torch.stack([label]*proposal,0)
                c_loss = self.ce(classify[batch],target).mean()
                c_list += c_loss
                continue
            gt_list[batch] = gt_list[batch].float()
            i+=1
            target = torch.stack([label]*proposal,0)
            c_loss = self.ce(classify[batch],target)*pre_score[batch,:,label] #proposal,4
            ious = box_iou(gt_list[batch],rois_list[batch])
            _ious = ious>0.5
            mask = _ious.any(dim=0)
            c_loss = (c_loss*mask).sum()/(mask.sum()+1e-7)
            index = torch.max(ious,0)[1]
            g = gt_list[batch][index]
            gx = (g[:,2]+g[:,0])/2
            gy = (g[:,3]+g[:,1])/2
            gw = (g[:,2]-g[:,0])/2
            gh = (g[:,3]-g[:,1])/2
            r = rois_list[batch]
            rx = (r[:,2]+r[:,0])/2
            ry = (r[:,3]+r[:,1])/2
            rw = (r[:,2]-r[:,0])/2
            rh = (r[:,3]-r[:,1])/2
            tx = (gx-rx)/(rw+1e-8)
            ty = (gy-ry)/(rh+1e-8)
            tw = torch.log(gw/(rw+1e-8))
            th = torch.log(gh/(rh+1e-8))
            t = torch.stack([tx,ty,tw,th],1)
            l1_loss = (self.l1(reg[batch],t).sum(dim=1)*mask*pre_score[batch,:,label]).sum()/(bs*proposal)#mask.sum()
            c_list+=c_loss
            l1_list+=l1_loss
        if i == 0:
            return c_list+l1_list
        else:
            return c_list/bs+l1_list/i
                
                
class OICRe2e(nn.Module):
    def __init__(self):
        super().__init__()
        self.v_extractor = vector_extractor()
        self.midn = MIDN()
        self.icr1 = ICR()
        self.icr2 = ICR()
        self.icr3 = ICR()
        self.detection = Detection()
        
    def forward(self, inputs, labels, rois, num):
        if self.training:
            labels = labels.squeeze()
            rois = rois.squeeze()
            v = self.v_extractor(inputs, rois)
            x, midn_loss = self.midn(v, labels, num)
            x, loss1 = self.icr1(v, x, labels, rois, num)
            x, loss2 = self.icr2(v, x, labels, rois, num)
            x, loss3 = self.icr3(v, x, labels, rois, num) 
            gt_list = generate_gt(x,rois,labels) #n x 4がbs個 gt_list[0]: n,4            
            loss_detection = self.detection(v,gt_list,rois,labels,x)
            loss = midn_loss + (loss1 + loss2 + loss3)+loss_detection
            return x, loss.unsqueeze(0),midn_loss.unsqueeze(0),loss1.unsqueeze(0),loss2.unsqueeze(0),loss3.unsqueeze(0),loss_detection.unsqueeze(0)
        else:
            self.three = torch.tensor([3.]).cuda()
            self.inf = torch.tensor([-1.]).cuda()
            #rois = rois.squeeze()
            v = self.v_extractor(inputs, rois)
            rois,x = self.detection(v,None,rois,None,None)
            '''x = self.midn(v, labels, num)
            print('---------------------------------------------------')
            print('icr1')
            x = self.icr1(v, x, labels, rois, num) 
            print('---------------------------------------------------')
            print('icr2')
            x = self.icr2(v, x, labels, rois, num) 
            print('---------------------------------------------------')
            print('icr3')
            x = self.icr3(v, x, labels, rois, num) 
            x, rois = x[0], rois[0].cuda()
            rois = rois[torch.max(x,1)[1]!=3]
            x = x[torch.max(x,1)[1]!=3]
            print(f'x={x}')'''
            
            rois = rois[torch.max(x,1)[1]!=3]
            x = x[torch.max(x,1)[1]!=3]
            print(f'x={x}')
            if x.size() == torch.Size([0,4]):
                return torch.tensor([]),torch.tensor([]),torch.tensor([])
            classes = torch.max(x,1)[1] #proposals
            scores = torch.max(x,1)[0] #proposals
            print(rois.shape,scores.shape,classes.shape)
            index = nms(rois,scores,0.5)
            scores = scores[index]
            labels = classes[index]
            bboxes = rois[index]
            return scores, labels, bboxes


class SAV(nn.Module):
    def __init__(self):
        super().__init__()
        self.v_extractor = vector_extractor()
        self.midn = MIDN()
        self.icr1 = ICR()
        self.icr2 = ICR()
        self.icr3 = ICR()
        self.detection = Detection()
        self.mat = nn.Parameter(torch.zeros((512,512))).requires_grad_(False)
        self.zero = nn.Parameter(torch.tensor([0.])).requires_grad_(False)
        
    def forward(self, inputs, labels, rois, num):
        if self.training:
            labels = labels.squeeze()
            rois = rois.squeeze()
            v = self.v_extractor(inputs, rois)
            x, midn_loss = self.midn(v, labels, num)
            x1, loss1 = self.icr1(v, x, labels, rois, num)
            x2, loss2 = self.icr2(v, x1, labels, rois, num)
            x3, loss3 = self.icr3(v, x2, labels, rois, num) 
            #return x1,x2,x3,labels,rois
            gt_list = self.generate_gt_sav(x1,x2,x3,rois,labels) #n x 4がbs個 gt_list[0]: n,4            
            loss_detection = self.detection(v,gt_list,rois,labels,x)
            loss = midn_loss + (loss1 + loss2 + loss3)+loss_detection
            return x, loss.unsqueeze(0),midn_loss.unsqueeze(0),loss1.unsqueeze(0),loss2.unsqueeze(0),loss3.unsqueeze(0),loss_detection.unsqueeze(0)
        else:
            self.three = torch.tensor([3.]).cuda()
            self.inf = torch.tensor([-1.]).cuda()
            #rois = rois.squeeze()
            v = self.v_extractor(inputs, rois)
            rois,x = self.detection(v,None,rois,None,None)
            rois = rois[torch.max(x,1)[1]!=3]
            x = x[torch.max(x,1)[1]!=3]
            print(f'x={x}')
            if x.size() == torch.Size([0,4]):
                return torch.tensor([]),torch.tensor([]),torch.tensor([])
            classes = torch.max(x,1)[1] #proposals
            scores = torch.max(x,1)[0] #proposals
            print(rois.shape,scores.shape,classes.shape)
            index = nms(rois,scores,0.5)
            scores = scores[index]
            labels = classes[index]
            bboxes = rois[index]
            return scores, labels, bboxes

    def generate_gt_sav(self,x1,x2,x3,rois_list,labels_list):
        bs,proposal,_ = x1.shape
        gt_list = []
        for batch in range(bs):
            mat = self.mat.clone()
            label = torch.max(labels_list[batch],0)[1]
            if label == 3:
                gt_list.append(torch.tensor([]))
                continue
            rois = rois_list[batch]
            score_list = []
            
            for score in [x1[batch],x2[batch],x3[batch]]:
                score_list.append(score[:,label])
            score = (score_list[0]+score_list[1]+score_list[2])/3
            index = nms(rois,score,0.5)
            rois = rois[index]
            score = score[index]
            for r in range(len(rois)):
                mat[int(rois[r,1]):int(rois[r,3]),int(rois[r,0]):int(rois[r,2])] += score[r]
            mat = mat/(mat.max()+1e-8)
            mat = torch.where(mat>0.5,mat,self.zero).cpu().detach().numpy()
            heatmap = np.uint8(255*mat)
            LAB = cv2.connectedComponentsWithStats(heatmap)
            n = LAB[0] - 1
            data = np.delete(LAB[2], 0, 0)
            boxes = torch.tensor([])
            for i in range(n):
                X0 = data[i][0]
                Y0 = data[i][1]
                X1 = data[i][0] + data[i][2]
                Y1 = data[i][1] + data[i][3]
                if boxes.shape[0] == 0:
                    boxes = torch.tensor([[X0,Y0,X1,Y1]]).cuda()
                else:
                    boxes = torch.cat((boxes, torch.tensor([[X0,Y0,X1,Y1]]).cuda()), dim=0)
            gt_list.append(boxes)
        return gt_list
    
class SAV_Retina(nn.Module):
    def __init__(self):
        super().__init__()
        self.v_extractor = vector_extractor_pyramid()
        self.midn = MIDN()
        self.icr1 = ICR()
        self.icr2 = ICR()
        self.icr3 = ICR()
        num_classes = 3
        self.detection = Detection()
        self.mat = nn.Parameter(torch.zeros((512,512))).requires_grad_(False)
        self.zero = nn.Parameter(torch.tensor([0.])).requires_grad_(False)
        self.fpn = PyramidFeatures(128, 256, 512)
        self.regressionModel = RegressionModel(256)
        self.classificationModel = ClassificationModel(256, num_classes=num_classes)
        self.anchors = Anchors()

        self.regressBoxes = BBoxTransform()

        self.clipBoxes = ClipBoxes()

        self.focalLoss = focalloss()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        prior = 0.01

        self.classificationModel.output.weight.data.fill_(0)
        self.classificationModel.output.bias.data.fill_(-math.log((1.0 - prior) / prior))

        self.regressionModel.output.weight.data.fill_(0)
        self.regressionModel.output.bias.data.fill_(0)
        
    def forward(self, inputs, labels, rois, num):
        if self.training:
            labels = labels.squeeze()
            rois = rois.squeeze()
            v,f2,f3,f4 = self.v_extractor(inputs, rois)
            x, midn_loss = self.midn(v, labels, num)
            x1, loss1 = self.icr1(v, x, labels, rois, num)
            x2, loss2 = self.icr2(v, x1, labels, rois, num)
            x3, loss3 = self.icr3(v, x2, labels, rois, num) 
            #return x1,x2,x3,labels,rois
            gt_list = self.generate_gt_sav(x1,x2,x3,rois,labels) #n x 4がbs個 gt_list[0]: n,4 
            #return gt_list
            features = self.fpn([f2, f3, f4])

            regression = torch.cat([self.regressionModel(feature) for feature in features], dim=1)

            classification = torch.cat([self.classificationModel(feature) for feature in features], dim=1)

            anchors = self.anchors(inputs)
            c_loss,r_loss = self.focalLoss(classification, regression, anchors, gt_list)
            c_loss = c_loss.mean()
            r_loss = r_loss.mean()
            loss = midn_loss + (loss1 + loss2 + loss3)+c_loss+r_loss
            return x, loss.unsqueeze(0),midn_loss.unsqueeze(0),loss1.unsqueeze(0),loss2.unsqueeze(0),loss3.unsqueeze(0),(c_loss+r_loss).unsqueeze(0)
        else:
            v,f2,f3,f4 = self.v_extractor(inputs, rois)
            features = self.fpn([f2, f3, f4])

            regression = torch.cat([self.regressionModel(feature) for feature in features], dim=1)

            classification = torch.cat([self.classificationModel(feature) for feature in features], dim=1)

            anchors = self.anchors(inputs)
            transformed_anchors = self.regressBoxes(anchors, regression)
            transformed_anchors = self.clipBoxes(transformed_anchors, inputs)

            scores = torch.max(classification, dim=2, keepdim=True)[0]

            scores_over_thresh = (scores > 0.05)[0, :, 0]

            if scores_over_thresh.sum() == 0:
                # no boxes to NMS, just return
                return [torch.zeros(0), torch.zeros(0), torch.zeros(0, 4)]

            classification = classification[:, scores_over_thresh, :]
            transformed_anchors = transformed_anchors[:, scores_over_thresh, :]
            scores = scores[:, scores_over_thresh, :]

            anchors_nms_idx = nms(transformed_anchors[0,:,:], scores[0,:,0], 0.5)

            nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1)

            return [nms_scores, nms_class, transformed_anchors[0, anchors_nms_idx, :]]
        

    def generate_gt_sav(self,x1,x2,x3,rois_list,labels_list):
        bs,proposal,_ = x1.shape
        gt_list = []
        for batch in range(bs):
            mat = self.mat.clone()
            label = torch.max(labels_list[batch],0)[1]
            if label == 3:
                gt_list.append(torch.empty(0,5))
                continue
            rois = rois_list[batch]
            score_list = []
            
            for score in [x1[batch],x2[batch],x3[batch]]:
                score_list.append(score[:,label])
            score = (score_list[0]+score_list[1]+score_list[2])/3
            index = nms(rois,score,0.5)
            rois = rois[index]
            score = score[index]
            for r in range(len(rois)):
                mat[int(rois[r,1]):int(rois[r,3]),int(rois[r,0]):int(rois[r,2])] += score[r]
            mat = mat/(mat.max()+1e-8)
            mat = torch.where(mat>0.5,mat,self.zero).cpu().detach().numpy()
            heatmap = np.uint8(255*mat)
            LAB = cv2.connectedComponentsWithStats(heatmap)
            n = LAB[0] - 1
            data = np.delete(LAB[2], 0, 0)
            boxes = torch.empty(0,5)
            for i in range(n):
                X0 = data[i][0]
                Y0 = data[i][1]
                X1 = data[i][0] + data[i][2]
                Y1 = data[i][1] + data[i][3]
                if boxes.shape[0] == 0:
                    boxes = torch.tensor([[X0,Y0,X1,Y1,label]]).cuda()
                else:
                    boxes = torch.cat((boxes, torch.tensor([[X0,Y0,X1,Y1,label]]).cuda()), dim=0)
            gt_list.append(boxes)
        return gt_list

    


In [28]:
oicr = SAV_Retina()
#oicr.load_state_dict(torch.load("/data/unagi0/masaoka/wsod/model/oicr/OICRe2eflx0.0001_1.pt"))
oicr.cuda()
opt = torch.optim.Adam(oicr.parameters(), lr = 1e-5, weight_decay=1e-5)

In [30]:
for i, data in enumerate(dataloader_val):
    opt.zero_grad()
    labels, n, t, v, u= data2target(data)
    labels = labels.unsqueeze(1).unsqueeze(2).cuda().float() # bs, 1, 1, num_class
    rois = [r.cuda().float() for r in data["p_bboxes"]]
    n = min(list(map(lambda x: x.shape[0], rois)))
    n = min(n,2000)
    print(f'proposal = {n}')
    for ind, tensor in enumerate(rois):
        rois[ind] = rois[ind][:n,:]
    rois = torch.stack(rois, dim=0) 
    rois = rois.unsqueeze(1) #bs, 1, n, 4
    #g = oicr(data["img"].cuda().float(), labels, rois, n)
    #break
    output, loss,m,l1,l2,l3,ld = oicr(data["img"].cuda().float(), labels, rois, n)
    print(f'losses = {loss},{m},{l1},{l2},{l3},{ld}')
    loss = m+l1+l2+l3+ld*i/len(dataloader_val) 
    loss = loss.mean()
    loss.backward()
    opt.step()
    print(loss)
    print(f'{i}/{len(dataloader_val)}, {loss}', end='\r')
    print('------------------------------------------------------------------------------------------------------------------')


proposal = 1452
tensor(0.5547, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.0333, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.0787, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.0475, device='cuda:0', grad_fn=<DivBackward0>)
losses = tensor([0.8485], device='cuda:0', grad_fn=<UnsqueezeBackward0>),tensor([0.5547], device='cuda:0', grad_fn=<UnsqueezeBackward0>),tensor([0.0333], device='cuda:0', grad_fn=<UnsqueezeBackward0>),tensor([0.0787], device='cuda:0', grad_fn=<UnsqueezeBackward0>),tensor([0.0475], device='cuda:0', grad_fn=<UnsqueezeBackward0>),tensor([0.1343], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
tensor(0.7142, device='cuda:0', grad_fn=<MeanBackward0>)
------------------------------------------------------------------------------------------------------------------
proposal = 1413
tensor(0.5491, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.0499, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.0795, device='cuda:0', grad_fn

KeyboardInterrupt: 

In [16]:
g[0][:,:]

tensor([[403,  43, 407,  71],
        [  6,  48, 478, 501],
        [347,  49, 348,  54],
        [202,  52, 203,  59],
        [331,  64, 332,  65],
        [339,  64, 340,  65],
        [347,  64, 353,  66],
        [ 61,  77,  72,  81],
        [160,  78, 162,  85],
        [347,  78, 357,  98],
        [270,  80, 272,  81],
        [339,  80, 340,  81],
        [331,  87, 332,  90],
        [338,  87, 340,  90],
        [160,  91, 162,  92],
        [102,  97, 120,  98],
        [338,  96, 340,  98],
        [347,  99, 353, 102],
        [455, 100, 467, 102],
        [330, 105, 332, 108],
        [337, 105, 340, 108],
        [347, 105, 353, 108],
        [ 13, 111,  22, 119],
        [ 23, 113,  30, 119],
        [338, 113, 340, 114],
        [228, 121, 251, 130],
        [270, 121, 273, 124],
        [485, 120, 497, 125],
        [257, 125, 260, 130],
        [485, 129, 491, 130],
        [ 32, 137,  33, 146],
        [239, 147, 243, 150],
        [469, 178, 470, 179],
        [4

In [19]:
data['annot'][1][:,:]

tensor([], size=(0, 5), dtype=torch.float64)

In [23]:
x = torch.empty(0,5).float()

In [24]:
x

tensor([], size=(0, 5))

In [18]:
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, 
                                num_workers=4, collate_fn=bbox_collate)

In [19]:
for data in dataloader_val:
    #data['bboxes'][0]=torch.empty(0,5)
    print(data['bboxes'])
    break

[tensor([[183.1463, 202.3716, 412.8380, 377.4229]]), tensor([[  1.0119,   6.0711, 220.5850, 235.7628]]), tensor([[  1.0119,   6.0711, 162.9091, 190.2293]]), tensor([[  1.0119,   1.0119, 173.0277, 222.6087]]), tensor([[  1.0119,   1.0119, 240.8221, 266.1186]]), tensor([[ 63.7470, 190.2293, 267.1304, 356.1739]]), tensor([[112.3162, 235.7628, 306.5929, 410.8142]]), tensor([[100.1739, 254.9882, 302.5455, 414.8617]]), tensor([[ 94.1028, 229.6917, 340.9961, 419.9210]]), tensor([[217.5494, 238.7984, 390.5771, 409.8024]]), tensor([[291.4150, 278.2609, 437.1226, 422.9565]]), tensor([[299.5099, 318.7352, 455.3360, 456.3478]]), tensor([[207.4308, 242.8459, 302.5455, 343.0198]]), tensor([[233.7391, 251.9526, 333.9131, 339.9842]]), tensor([[344.0316, 168.9802, 509.9763, 402.7194]]), tensor([[322.7826, 134.5771, 512.0000, 436.1107]])]


In [27]:
x = [copy.copy(dataset_val[i]) for i in range(100)]

In [28]:
dir(dataset_val[0])

['__class__',
 '__contains__',
 '__delattr__',
 '__delitem__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'clear',
 'copy',
 'fromkeys',
 'get',
 'items',
 'keys',
 'pop',
 'popitem',
 'setdefault',
 'update',
 'values']

In [29]:
x = torch.empty(0,8)

In [30]:
x

tensor([], size=(0, 8))

In [31]:
x.numpy()

array([], shape=(0, 8), dtype=float32)