In [1]:
import argparse
import collections
from tqdm import tqdm
import numpy as np
from dataset import MedicalBboxDataset
import torch
import torch.optim as optim
from torchvision import transforms
from torchvision.transforms import Compose
import transform as transf
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import yaml
import json
from PIL import Image
from torchvision.ops.boxes import box_iou
from torchvision.ops import nms
import torchvision
import time
import os
import copy
import pdb
import sys
import cv2
from torch.utils.data import Dataset
from torchvision import datasets, models
from make_dloader import make_data
from utils import *

In [2]:
config = yaml.safe_load(open('./config.yaml'))
dataset_means = json.load(open(config['dataset']['mean_file']))
#_, _, _, dataset_val, _ = make_data()
dataset_val = torch.load(f'/data/unagi0/masaoka/val_all1.pt')
dataloader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, 
                                num_workers=4, collate_fn=bbox_collate)
unnormalize = transf.UnNormalize(dataset_means['mean'], dataset_means['std'])

In [186]:
import torch
import torch.nn as nn
from torchvision.ops import roi_align, nms
from utils import calc_iou
import torchvision.models as models
import yaml
from torchvision.ops.boxes import box_iou
import torch.nn.functional as F

config = yaml.safe_load(open('./config.yaml'))

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, eps=1e-7,reduction='none'):
        super().__init__()
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction

    def forward(self, input, target):
        logit = F.softmax(input, dim=1)
        logit = logit.clamp(self.eps, 1. - self.eps)
        logit_ls = torch.log(logit)
        loss = F.nll_loss(logit_ls, target, reduction="none")
        view = target.size() + (1,)
        index = target.view(*view)
        loss = loss * (1 - logit.gather(1, index).squeeze(1)) ** self.gamma # focal loss
        if self.reduction=='none':
            return loss

        return loss.sum()


class ASPP(nn.Module):
    def  __init__(self,size_list=[]):
        super().__init__()
        assert len(size_list)>0
        self.avgpool_list = []
        for size in size_list:
            self.avgpool_list.append(nn.Sequential(nn.AdaptiveAvgPool2d(size),
                                                    nn.Flatten()))
        
    def forward(self,x):
        vec = []
        for avgpool in self.avgpool_list:
            vec.append(avgpool(x))
        vec = torch.cat(vec,1)
        return vec

class _ROIPool(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, inputs, rois):
        #rois: bs,2000,4 
        #output: bs, 2000, ch, h, w
        rois = [r for r in rois]
        h, w = inputs.shape[2], inputs.shape[3]
        res = roi_align(inputs, rois, 7, spatial_scale=w/512)
        return res

class vector_extractor(nn.Module):
    """input: images, proposals
         output: feature vector"""
    def __init__(self):
        super().__init__()
        model = models.resnet18(pretrained=True)
        layers = list(model.children())[:-2]
        self.feature_map = nn.Sequential(*layers)
        self.roi_pool = _ROIPool()
        #self.gap = nn.AdaptiveAvgPool2d(1)
        self.feature_vector = nn.Sequential(ASPP([1,2]),
                                            #nn.AdaptiveAvgPool2d(1),
                                            nn.Flatten(),
                                            nn.Linear(512*(5), 500),
                                            nn.ReLU(inplace=True),
                                            nn.BatchNorm1d(500),
                                            nn.Linear(500, 500),
                                            nn.BatchNorm1d(500),
                                            nn.ReLU(inplace=True))
    def forward(self, inputs, rois):
        f = self.feature_map(inputs)
        f = self.roi_pool(f, rois)
        #f = self.gap(f).view(f.shape[0], f.shape[1]) #batch*proposal, ch
        f = self.feature_vector(f) 
        return f
    
class MIDN(nn.Module):
    """input: feature vector, labels
         output: scores per proposal, loss"""
    def __init__(self):
        super().__init__()
        c_in = 500
        self.layer_c = nn.Linear(c_in, 4)
        self.layer_d = nn.Linear(c_in, 4)
        self.softmax_c = nn.Softmax(dim=2)
        self.softmax_d = nn.Softmax(dim=1)
        self.loss = nn.CrossEntropyLoss()
    def forward(self, inputs, labels, num):
        bs, proposal = inputs.shape[0]//num, num
        x_c = self.layer_c(inputs).view(bs, proposal, -1) #bs, proposal, 4
        x_d = self.layer_d(inputs).view(bs, proposal, -1)
        sigma_c = self.softmax_c(x_c)
        sigma_d = self.softmax_d(x_d)
        x_r = sigma_c * sigma_d #bs, proposal, 4
        phi_c = x_r.sum(dim=1) #bs, 4
        print((torch.max(x_r,1))[0].shape,phi_c.shape,x_r.shape)
        scaled = x_r/torch.max(x_r,1)[0]*phi_c
        if not self.training:
            print(scaled)
            return scaled
        loss = self.loss(phi_c, torch.max(labels,1)[1])
        return   scaled, loss#phi_c.unsqueeze(1), loss # sigma_c,loss#x_r,loss
        

class ICR(nn.Module):
    """input: feature vector (bs*proposal, ch)
                     k-1th proposal scores(bs, proposal,3or4)
                     supervision (label) (bs)
                     ROI proposals
         output: refined proposal scores, loss"""
    def __init__(self):
        super().__init__()
        c_in = 500
        self.I_t = 0.5
        self.fc = nn.Linear(c_in, 4)
        self.softmax = nn.Softmax(dim=2)
        #self.loss = nn.CrossEntropyLoss(reduction="none")
        self.loss = FocalLoss(gamma=2)
        """self.y_k = torch.zeros(bs, proposal, 4).cuda()
        self.y_k[:, :, 3] = 1
        self.w = torch.zeros(bs, proposal).cuda()"""
        
    def forward(self, inputs, pre_score, labels, rois, num):
        bs, proposal = inputs.shape[0]//num, num
        xr_k = self.fc(inputs).view(bs, proposal, -1) #bs, proposal, 4
        xr_k = self.softmax(xr_k)
        
        _xr_k = xr_k.view(bs*proposal, -1)
        self.y_k = torch.zeros(bs, proposal, 4).cuda()
        self.y_k[:, :, 3] = 1
        self.w = torch.zeros(bs, proposal).cuda()
        I = torch.zeros(bs, proposal)
        for batch in range(bs):
            for c in range(3):
                if labels[batch][c]:
                    #print(f'label{c}')
                    #m = torch.max(pre_score[batch, :, c], 0)
                    #x = m[0].item()
                    #j = m[1].item()
                    j_list = (pre_score[batch,:,c]>0.5).nonzero().squeeze()
                    x_list = pre_score[batch,j_list,c]
                    if (j_list).size() == torch.Size([0]) or (j_list).size() == torch.Size([1]) or (j_list).size() == torch.Size([]):
                        m = torch.max(pre_score[batch, :, c], 0)
                        x = m[0].item()
                        j = m[1].item()
                        x_list = [x]
                        j_list =[j]
                    mat = box_iou(rois[batch], rois[batch][j_list])
                    for i,j in enumerate(j_list):
                        for r in range(proposal):
                            _I = mat[r][i]
                            if _I > I[batch, r]:
                                I[batch, r] = _I
                                self.w[batch, r] = x_list[i]
                                if _I > self.I_t:
                                    #print(f'next supervision index{r}')
                                    self.y_k[batch, r, c] = 1
                                    self.y_k[batch, r, 3] = 0
        self.y_k = self.y_k.view(bs*proposal, -1)
        self.w = self.w.view(bs*proposal, 1)
        loss = self.loss(_xr_k.cuda().float(), torch.max(self.y_k, 1)[1])
        loss = torch.mean(self.w*loss)
        if not self.training:
            return xr_k
        return xr_k, loss

class OICR(nn.Module):
    def __init__(self):
        super().__init__()
        self.v_extractor = vector_extractor()
        self.midn = MIDN()
        self.icr1 = ICR()
        self.icr2 = ICR()
        self.icr3 = ICR()
        
    def forward(self, inputs, labels, rois, num):
        if self.training:
            labels = labels.squeeze()
            rois = rois.squeeze()
            v = self.v_extractor(inputs, rois)
            x, midn_loss = self.midn(v, labels, num)
            x, loss1 = self.icr1(v, x, labels, rois, num)
            x, loss2 = self.icr2(v, x, labels, rois, num)
            x, loss3 = self.icr3(v, x, labels, rois, num) 
            print(midn_loss,loss1,loss2,loss3)
            loss = midn_loss + loss1 + loss2 + loss3
            return x, loss.unsqueeze(0)  
        else:
            self.three = torch.tensor([3.]).cuda()
            self.inf = torch.tensor([-1.]).cuda()
            #rois = rois.squeeze()
            v = self.v_extractor(inputs, rois)
            x = self.midn(v, labels, num)
            x = self.icr1(v, x, labels, rois, num) 
            x, rois = x[0], rois[0].cuda()
            h = 0
            while(1):
                if torch.max(x[h],0)[1]==3:
                    x = torch.cat([x[:h],x[h+1:]])
                    rois = torch.cat([rois[:h],rois[h+1:]])
                else:
                    h+1
                if len(x)==h:
                    break
            print(x.size())
            if x.size() == torch.Size([0,4]):
                return [],[],[]
            s, i = torch.max(x, 1)
            s = torch.where(i==self.three,self.inf,s)
            sort = torch.argsort(s, descending=True)
            s, i = s.view(-1,1), i.view(-1,1).cuda().float()
            #print(s.shape, i.shape, rois.shape)
            cat = torch.cat([s, i ,rois], dim=1)
            cat = cat[sort, :]
            scores = cat[:, 0]
            labels = cat[:, 1]
            bboxes = cat[:, 2:]
            return scores, labels, bboxes


In [187]:
oicr = OICR()
oicr.load_state_dict(torch.load("/data/unagi0/masaoka/wsod/model/oicr/ResNet18aspplight0.0002_1.pt"))
oicr.cuda()
oicr.eval()

OICR(
  (v_extractor): vector_extractor(
    (feature_map): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d

In [188]:
def draw_caption(image, box, caption):
    b = np.array(box).astype(int)
    cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2)
    cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1)

In [189]:
x = 0
for idx, data in enumerate(dataloader_val):
    with torch.no_grad():
        st = time.time()
        labels, n, t, v, u= data2target(data)
        rois = data["p_bboxes"][0][:2000].unsqueeze(0).cuda().float()
        scores, classification, transformed_anchors = oicr(data['img'].cuda().float(), labels,rois,2000)
        if len(scores)==0:
            continue
        print('Elapsed time: {}'.format(time.time()-st))
        idxs = np.where(scores.cpu()>0.01)
        print(idxs)
        data['img'] = data['img'].squeeze(dim=0)
        
        print(scores)
        print(classification)
        img = unnormalize(data)['img'].copy() 
        img[img<0] = 0
        img[img>255] = 255
        for j in range(idxs[0].shape[0]):
            if int(classification[idxs[0][j]]) == 3:
                continue
            bbox = transformed_anchors[idxs[0][j], :]
            x1 = int(bbox[0])
            y1 = int(bbox[1])
            x2 = int(bbox[2])
            y2 = int(bbox[3])
            label_name_ = dataset_val.labels[int(classification[idxs[0][j]])]
            draw_caption(img, (x1, y1, x2, y2), label_name_)
            if label_name_ == "ulcer":
                cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 255, 255), thickness=2) #黄色
            elif label_name_ == "torose lesion":
                cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 255, 255), thickness=2) #青
            else:
                cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 255, 255), thickness=2) #赤
                
        iou = box_iou(rois[0], rois[0][1].unsqueeze(0))
        print(iou.shape)
        for i in range(len(data["bboxes"])):
            x1 = int(data["bboxes"][i][0][0])
            y1 = int(data["bboxes"][i][0][1])
            x2 = int(data["bboxes"][i][0][2])
            y2 = int(data["bboxes"][i][0][3])
            print(x1,y1,x2,y2)
            label_name = dataset_val.labels[int(data["labels"][i])]
            draw_caption(img, (x1, y1, x2, y2), label_name)
            if label_name == "ulcer":
                cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 255, 0), thickness=6) #黄色
            elif label_name == "torose lesion":
                cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 0, 255), thickness=6) #青
            else:
                cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=6)  #赤
            print(label_name)
        plt.imshow(img)
        plt.show()
        x+=1
        if x == 50:
            break
        
        


torch.Size([1, 4]) torch.Size([1, 4]) torch.Size([1, 2000, 4])
tensor([[[0.0468, 0.0096, 0.0359, 0.0464],
         [0.0409, 0.0175, 0.0341, 0.1122],
         [0.0381, 0.0163, 0.0327, 0.1264],
         ...,
         [0.0541, 0.0141, 0.0416, 0.0480],
         [0.0224, 0.0063, 0.0225, 0.1533],
         [0.1102, 0.0115, 0.0413, 0.0369]]], device='cuda:0')
torch.Size([0, 4])
torch.Size([1, 4]) torch.Size([1, 4]) torch.Size([1, 2000, 4])
tensor([[[2.1748e-01, 3.9069e-03, 1.8839e-03, 5.6039e-05],
         [1.1188e-03, 6.1148e-04, 1.5149e-04, 1.2430e-03],
         [2.2029e-02, 3.7546e-02, 1.2798e-03, 2.9888e-04],
         ...,
         [2.3085e-01, 2.7855e-03, 2.0236e-03, 5.3442e-05],
         [3.4687e-01, 9.9974e-05, 8.0871e-05, 1.1474e-05],
         [3.6545e-01, 4.6365e-05, 4.6072e-05, 7.8483e-06]]], device='cuda:0')
torch.Size([0, 4])
torch.Size([1, 4]) torch.Size([1, 4]) torch.Size([1, 2000, 4])
tensor([[[0.0427, 0.0039, 0.0131, 0.0037],
         [0.0528, 0.0046, 0.0152, 0.0035],
         

KeyboardInterrupt: 

In [None]:
np.set_printoptions(threshold=3000)

In [137]:
x = torch.tensor([1,9])

In [138]:
x.size()

torch.Size([2])

In [125]:
((x>3)).nonzero().squeeze()

tensor([3, 4])

In [134]:
x[(x*(x>3)).nonzero()]

tensor([[4],
        [5]])

In [110]:
len(x)

0

In [35]:
x = torch.randn(3,5,10)

In [41]:
x[0,0,[1]]

tensor([0.7372])

In [7]:
len(torch.tensor([]))

0

In [None]:
s = torch.randn(scores.shape).cuda()
t = transformed_anchors
s

In [None]:
nms(t,s,0.5)