In [1]:
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torchvision import models, transforms
from make_dloader import make_data
from torch.utils.data import DataLoader
from utils import bbox_collate, MixedRandomSampler
import transform as transf
import yaml
import json
from matplotlib import pyplot as plt
import numbers
import torchvision
import torch.optim as optim
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug import parameters as iap
from model import ResNet50
import time
from scipy.ndimage import gaussian_filter
import random

config = yaml.safe_load(open('./config.yaml'))
dataset_means = json.load(open(config['dataset']['mean_file']))

In [2]:
class FeatureExtractor():
    """ Class for extracting activations and 
    registering gradients from targetted intermediate layers """

    def __init__(self, model, target_layers):
        self.model = model
        self.target_layers = target_layers
        self.gradients = []

    def save_gradient(self, grad):
        self.gradients.append(grad)

    def __call__(self, x):
        outputs = []
        self.gradients = []
        for name, module in self.model._modules.items():
            x = module(x)
            if name in self.target_layers:
                x.register_hook(self.save_gradient)
                outputs += [x]
        return outputs, x

In [3]:
class ModelOutputs():

    def __init__(self, model, feature_module, target_layers):
        self.model = model
        self.feature_module = feature_module
        self.feature_extractor = FeatureExtractor(self.feature_module, target_layers)

    def get_gradients(self):
        return self.feature_extractor.gradients

    def __call__(self, x):
        target_activations = []
        for name, module in self.model._modules.items():
            if module == self.feature_module:
                target_activations, x = self.feature_extractor(x)
            elif "avgpool" in name.lower():
                x = module(x)
                x = x.view(x.size(0),-1)
            else:
                x = module(x)
        
        return target_activations, x


def cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * (1-mask)), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)/255
    cam = cam / np.max(cam)
    return cam
   
def thresh(narray, threshold = 0.15, binary = False):
    if binary:
        return np.where(narray>threshold*np.max(narray), 1, 0)
    return np.where(narray>threshold*np.max(narray), narray, 0)

In [4]:
class GradCam:
    def __init__(self, model, feature_module, target_layer_names, use_cuda):
        self.model = model
        self.feature_module = feature_module
        self.model.eval()
        self.cuda = use_cuda
        if self.cuda:
            self.model = model.cuda()

        self.extractor = ModelOutputs(self.model, self.feature_module, target_layer_names)

    def forward(self, input):
        return self.model(input)

    def __call__(self, input, index=None):
        if self.cuda:
            features, output = self.extractor(input.cuda())
        else:
            features, output = self.extractor(input)
        o = torch.sigmoid(output)
        o = o.cpu().data.numpy()
        if index == None:
            o = np.where(o>0.5, 1., 0.)
        label = o.sum(axis = 0)
        #print(label)
        label = np.where(label>o.shape[0]/2, 1., 0.)
        cam_list = []
        for i in range(len(label)):
            if label[i] == 0:
                cam_list.append(0)
                continue
                
            one_hot = np.zeros_like(label)
            one_hot[i] = 1.
            one_hot = torch.from_numpy(one_hot).requires_grad_(True)
            if self.cuda:
                one_hot = torch.sum(one_hot.cuda() * output)
            else:
                one_hot = torch.sum(one_hot * output)
            
            self.feature_module.zero_grad()
            self.model.zero_grad()
            one_hot.backward(retain_graph=True)

            grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()

            target = features[-1]
            target = target.cpu().data.numpy() # 40,2048,16,16

            weights = np.mean(grads_val, axis=(2, 3)) #40,2048
            weights = weights[:,:,np.newaxis,np.newaxis] #40,2048,1,1
            cam = np.zeros((target.shape[0], target.shape[2], target.shape[3]), dtype=np.float32) #40,16,16
            target =  weights * target #40,2048,16,16
            target = target.sum(axis=1) #40,16,16
            target = np.maximum(target, 0)
            T = np.zeros((input.shape[0],input.shape[2],input.shape[3]))
            for b in range(input.shape[0]):
                cam = cv2.resize(target[b], input.shape[2:])
                cam = cam - np.min(cam)
                cam = cam / np.max(cam)
                T[b] = cam
            cam_list.append(T)
    
        return label, cam_list


In [5]:
#model = models.resnet50(pretrained=False)
#model.fc = nn.Sequential( nn.Linear(in_features=2048, out_features=4, bias=True) )
model = ResNet50()
model.cuda()
model.load_state_dict(torch.load("/data/unagi0/masaoka/wsod/model/resnet50_classify1.pt"))

<All keys matched successfully>

In [6]:
grad_cam = GradCam(model=model.resnet50, feature_module=model.resnet50.layer4, target_layer_names=["2"], use_cuda=True)

In [7]:
_, dataset_val, _ = make_data()

loading annotations into memory...
Done (t=0.20s)
creating index...
index created!
creating index...
index created!
creating index...
index created!
creating index...
index created!
creating index...
index created!
creating index...
index created!


In [15]:
config
size = config["inputsize"]
val = config["dataset"]["val"]
val.sort()
val = ''.join(map(str,val))

In [16]:
val_anomaly = dataset_val.with_annotation()
#dataset_val = val_anomaly
dataloader_val = DataLoader(dataset_val, num_workers=4, collate_fn=bbox_collate)
unnormalize = transf.UnNormalize(dataset_means['mean'], dataset_means['std'])
normalize = transf.Normalize(dataset_means['mean'], dataset_means['std'])

creating index...
index created!


In [17]:
class loss_r(nn.Module):
    
    def __init__(self):
        super(loss_r, self).__init__()
        
    def forward(self, i, target, rho = 1e-2):
        loss = (target-i)**2
        loss = loss.sum()
        return 0.5*loss*rho
    
class loss_l2(nn.Module):
    def __init__(self):
        super(loss_l2, self).__init__()
    
    def forward(self, i, rho = 10):
        loss = i**2
        loss = loss.sum()
        return 0.5*rho*loss
    
class loss_tv(nn.Module):
    
    def __init__(self):
        super(loss_tv, self).__init__()
    
    def forward(self, i, target, rho = 20):
        w, h = i.shape[0], i.shape[1]
        lx = i[1:, :h-1] - target[:, :w-1, :h-1]
        ly = i[:w-1, 1:] - target[:, :w-1, :h-1]
        lx, ly = abs(lx), abs(ly)
        loss = lx.sum()+ly.sum()
        return rho*loss

In [18]:
class Loss(nn.Module):
    def __init__(self):
        super(Loss, self).__init__()
        self.l2loss = loss_l2()
        self.TVloss = loss_tv()
        self.r_loss = loss_r()
        
    def forward(self, i, t):
        #l2loss = loss_l2(i)
        #TVloss = loss_tv(i,t)
        #r_loss = loss_r(i,t)
       # s = time.time()
        l2loss = self.l2loss(i)
        TVloss = self.TVloss(i,t)
        r_loss = self.r_loss(i,t)
        #e = time.time()
        #print(f"calc_loss {e-s}")
        loss = l2loss + TVloss + r_loss
        return TVloss, r_loss, l2loss

In [19]:
def converge_map(masks):
    #init 512,512 ndarray ; mask b,512,512 ndarray
    init = masks[0]
    seq1 = iaa.Sequential([
                    iaa.Affine(
        rotate=iap.DiscreteUniform(-180,179)*(-1)
                    )])
    ia.seed(0)
    masks = seq1(images=masks)
    mask_reconvert = torch.from_numpy(masks) #b,512,512 tensor
    """for m in mask_reconvert:
            plt.imshow(m.squeeze().numpy())
            plt.show()"""
    mp = torch.from_numpy(init).requires_grad_(True) #512,512 tensor
    calc_loss = Loss()
    calc_loss.cuda()
    optimizer = optim.Adam([mp], lr = 5e-3)
    losses = []
    for i in range(200):
        optimizer.zero_grad()
        loss, conf_loss, aux = calc_loss(mp.cuda(), mask_reconvert.cuda())
        loss = loss/masks.shape[0]/(size**2)
        conf_loss = conf_loss/masks.shape[0]/(size**2)
        loss.backward()
        optimizer.step()
        losses.append(loss)
    """plt.figure()
    plt.plot(range(len(losses)), losses, linestyle = '-', color = 'red', label = 'loss')
    plt.xlabel('times')
    plt.ylabel('loss')
    plt.legend()
    plt.show()"""
    mp = mp.data.numpy()
    print(conf_loss)
    return mp, 1/conf_loss #512,512 ndarray

In [20]:
def augmented_grad_cam(gcam, image):           #gcamはheatmapとlabelを出力するクラス
    #img.shape=B,C,H,W　tensor
    img = image.clone()
    img = img.squeeze().numpy().transpose(1,2,0)  #512,512,3
    b = 10
    img = np.tile(img,(b,1,1,1)) #b,512,512,3
    seq = iaa.Sequential([
                    iaa.Affine(
        rotate=iap.DiscreteUniform(-180, 179)
                    )])
    ia.seed(0)
    img = seq(images=img)
    img = torch.from_numpy(img)
    labels, masks = grad_cam(img.permute(0,3,1,2), None) #label [1,0,0,0] masks [(40,512,512), 0, (40,512,512), 0]
    maps = []
    for i, label in enumerate(labels):
        if label == 0:
            maps.append(0)
            continue
        elif i == len(labels)-1:
            continue
        else:
            mp, _ = converge_map(masks[i])
            maps.append(mp[np.newaxis,:,:])
    
    return maps
    

In [21]:
class Conf(nn.Module):
    def __init__(self):
        super(Conf, self).__init__()
    def forward(self, x):
        converged, conf = converge_map(x.numpy())
        conf = conf/converged.sum()*converged.mean()
        return converged, conf

In [22]:
def calc_conf(masks):
    reconvert = torch.from_numpy(masks) 
    #信頼度の計算をするクラスConf
    conf = Conf() 
    conf.cuda()
    mask, conf = conf(reconvert)
    return mask, conf

In [23]:
def high_conf_maps(gcam, image):           #gcamはheatmapとlabelを出力するクラス
    #img.shape=B,C,H,W　tensor
    img = image.clone()
    img = img.squeeze().numpy().transpose(1,2,0)  #512,512,3
    b = 10
    img = np.tile(img,(b,1,1,1)) #b,512,512,3
    seq = iaa.Sequential([
                    iaa.Affine(
        rotate=iap.DiscreteUniform(-180, 179)
                    )])
    ia.seed(0)
    img = seq(images=img)
    img = torch.from_numpy(img)
    labels, masks = grad_cam(img.permute(0,3,1,2), None) #label [1,0,1] masks [(40,512,512), 0, (40,512,512)]

    maps = []
    eps = 0.18
    for i, label in enumerate(labels):
        if label == 0:
            maps.append(0)
            continue
        else:
            pseudo_map, conf = calc_conf(masks[i])
            print(f"conf:{conf}")
            if conf > eps:
                maps.append(pseudo_map)
            else:
                maps.append(0)
    return maps

In [24]:
def heatmap2box(heatmap, img=0, threshold = 0.5):
    # img 512,512,3 ndarray,      heatmap  1,512,512 ndarray
    if not isinstance(img, numbers.Number):
        image = img.copy()
    heatmap = thresh(heatmap, threshold = threshold)
    heatmap = heatmap[0]
    heatmap = np.uint8(255*heatmap)
    label = cv2.connectedComponentsWithStats(heatmap)
    n = label[0] - 1
    data = np.delete(label[2], 0, 0)
    boxes = torch.tensor([])
    for i in range(n):
    # 各オブジェクトの外接矩形を赤枠で表示
        x0 = data[i][0]
        y0 = data[i][1]
        x1 = data[i][0] + data[i][2]
        y1 = data[i][1] + data[i][3]
        if boxes.shape[0] == 0:
            boxes = torch.tensor([[x0,y0,x1,y1]])
        else:
            torch.cat((boxes, torch.tensor([[x0,y0,x1,y1]])), dim=0)
        score = threshold
        if not isinstance(img, numbers.Number):
            cv2.rectangle(image, (x0, y0), (x1, y1), (255, 255, 0), thickness = 4)
    #if not isinstance(img, numbers.Number):
    #    plt.imshow(image)
    #    plt.show()
    return boxes, score

In [None]:
#通常の評価　マスクなし
def draw_caption(image, box, caption):
    b = np.array(box).astype(int)
    cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2)
    cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1)
threshold= .5
gt, tpa, fpa, tna = np.zeros(3),np.zeros(3),np.zeros(3),np.zeros(3)
ids = []
for idx, data in enumerate(dataloader_val): 
    image = data["img"].clone()
    print(f'{idx}/{len(dataloader_val)}', end = '\n')
    for i in range(len(data["bboxes"][0])):
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        """print("----------------")
        print(f'{label_name}')
        print("----------------")"""
    maps = high_conf_maps(grad_cam, image)
    image = data["img"].clone()
    for m in maps:
        if not isinstance(m, numbers.Number):
            """mask_thresh = thresh(m, threshold = threshold)
            m = cam_on_image(image.squeeze().numpy().transpose(1,2,0), mask_thresh)
            plt.imshow(m)
            plt.show()"""
            ids.append(idx)
            break
    masks = grad_cam(image, None)
    target = np.zeros(3)
    if len(data["bboxes"][0]) != 0:
        target[int(data["labels"][0][0])] = 1
    output = masks[0]
    
    
    gt += target
    tp = (output * target)
    fp = (output * (1 - target))
    tn=(1-output)*(1-target)
    tpa += tp
    fpa += fp
    tna += tn
    #break
    #if isinstance(masks[2],numbers.Number):
    #        continue
    #masksa = augmented_grad_cam(grad_cam, data["img"]) 
    masks = masks[1]
    img = unnormalize(data)['img'].copy() 
    img[img<0] = 0
    img[img>255] = 255
    flag = 0
    for num, mask in enumerate(masks):
        if isinstance(mask,numbers.Number):
            flag += 1
            if flag == 3:
                print("Normal")
            continue
        print("----------------------------------------")
        print("Grad-CAM")
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        #mask 1,512,512
        heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad1 = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad1, threshold = threshold)
        """plt.imshow(grad1)
        plt.show()"""
        grad1 = cv2.cvtColor(grad1, cv2.COLOR_BGR2RGB)
        #cv2.imwrite(f"/data/unagi0/masaoka/w_o_mask_grad_cam/{idx}_{num}.png", np.uint8(255*grad1))
    #print("------------------------------------")
    """print("Augmented Grad-CAM")
    for num, mask in enumerate(masksa):
        if isinstance(mask,numbers.Number):
            continue
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        else:
            print("Normal")
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad, threshold = threshold)
        plt.imshow(grad)
        plt.show()"""
    #print("-------------------------------------")
    #print("Ground Truth")
    #plt.imshow(img)
    #plt.show()
    label = 3
    for i in range(len(data["bboxes"][0])):
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        draw_caption(img, (x1, y1, x2, y2), label_name)
        if label_name == "ulcer":
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) #緑
        else:
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2) #赤
        label = int(data["labels"][0][i])
        #print(label_name)
        
    """plt.imshow(img)
    plt.show()"""
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #cv2.imwrite(f"/data/unagi0/masaoka/w_o_mask_normal/{idx}_{label}.png", img)
    
path = f'/data/unagi0/masaoka/wsod/text/val{val}.txt'
with open(path, mode='w') as f:
    f.write('w/o masks, searching confidential pseudo labels\n')
    f.write(f'id for pseudo: {ids}\n')
    f.write(f'TP: {tpa}, FN: {gt-tpa}, FP: {fpa},TN: {tna}\n')
    f.write(f'Precision: {tpa/(tpa+fpa+1e-10)}, Recall: {tpa/(gt+1e-10)}, Spec: {tna/(tna+fpa+1e-10)}\n')

0/32260
tensor(2.8307e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<DivBackward0>)
conf:0.13476060342641108
----------------------------------------
Grad-CAM
torose lesion
1/32260
tensor(2.7425e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<DivBackward0>)
conf:0.13909619480321816
----------------------------------------
Grad-CAM
torose lesion
2/32260
tensor(3.2993e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<DivBackward0>)
conf:0.11562079450270218
----------------------------------------
Grad-CAM
torose lesion
3/32260
tensor(5.1678e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<DivBackward0>)
conf:0.0738167385935206
----------------------------------------
Grad-CAM
torose lesion
4/32260
tensor(5.1572e-05, device='cuda:0', dtype=torch.float64,
       grad_fn=<DivBackward0>)
conf:0.07396879175131293
----------------------------------------
Grad-CAM
torose lesion
5/32260
tensor(5.0275e-05, device='cuda:0', dtype=torch.float64,
       grad

In [None]:
print(tpa,gt-tpa,fpa,tna)
print(tpa/(gt), tna/(tna+fpa))

In [None]:
#通常の評価　d_val_conf のみ

gt, tpa, fpa, tna = np.zeros(3),np.zeros(3),np.zeros(3),np.zeros(3)
for idx, data in enumerate(dataloader_val): 
    if idx > ids[-1]:
        break
    if not (idx in ids):
        continue
    print(f'{idx}/{len(dataloader_val)}', end = '\r')
    for i in range(len(data["bboxes"][0])):
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        #print(label_name)
    image = data["img"].clone()
    masks = grad_cam(image, None)
    target = np.zeros(3)
    if len(data["bboxes"][0]) != 0:
        target[int(data["labels"][0][0])] = 1
    output = masks[0]
    image = data["img"].clone()
    
    gt += target
    tp = (output * target)
    fp = (output * (1 - target))
    tn = (1 - output) * (1 - target)
    tpa += tp
    fpa += fp
    tna += tn
    #break
    #if isinstance(masks[2],numbers.Number):
    #        continue
    #masksa = augmented_grad_cam(grad_cam, data["img"]) 
    masks = masks[1]
    img = unnormalize(data)['img'].copy() 
    img[img<0] = 0
    img[img>255] = 255
    flag = 0
    for num, mask in enumerate(masks):
        if isinstance(mask,numbers.Number):
            flag += 1
            """if flag == 3:
                print("Normal")"""
            continue
        print("----------------------------------------")
        print("Grad-CAM")
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        #mask 1,512,512
        #heatmap2box(mask, img=img, threshold = threshold)
        """mask_thresh = thresh(mask, threshold = threshold)
        grad1 = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad1, threshold = threshold)
        plt.imshow(grad1)
        plt.show()
        grad1 = cv2.cvtColor(grad1, cv2.COLOR_BGR2RGB)"""
        #cv2.imwrite(f"/data/unagi0/masaoka/w_o_mask_grad_cam/{idx}_{num}.png", np.uint8(255*grad1))
    #print("------------------------------------")
    """print("Augmented Grad-CAM")
    for num, mask in enumerate(masksa):
        if isinstance(mask,numbers.Number):
            continue
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        else:
            print("Normal")
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad, threshold = threshold)
        plt.imshow(grad)
        plt.show()"""
    #print("-------------------------------------")
    #print("Ground Truth")
    #plt.imshow(img)
    #plt.show()
    label = 3
    for i in range(len(data["bboxes"][0])):
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        draw_caption(img, (x1, y1, x2, y2), label_name)
        if label_name == "ulcer":
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) #緑
        else:
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2) #赤
        label = int(data["labels"][0][i])
        #print(label_name)
        
    """plt.imshow(img)
    plt.show()"""
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #cv2.imwrite(f"/data/unagi0/masaoka/w_o_mask_normal/{idx}_{label}.png", img)
    
print(tpa,gt-tpa,fpa,tna)
print(tpa/(gt), tna/(tna+fpa))
path = f'/data/unagi0/masaoka/wsod/text/val{val}.txt'
with open(path, mode='a') as f:
    f.write('only pseudo labels, evaluating\n')
    f.write(f'TP: {tpa},FN: {gt-tpa},FP: {fpa},TN: {tna}\n')
    f.write(f'Precision: {tpa/(tpa+fpa+1e-10)}, Recall: {tpa/(gt+1e-10)}, Spec: {tna/(tna+fpa+1e-10)}\n')

In [None]:
print(tpa,gt-tpa,fpa,tna)
print(tpa/(gt), tna/(tna+fpa))

In [None]:
#腫瘍にマスク、正常画像はランダムにマスク
gt, tpa, fpa, tna = np.zeros(3),np.zeros(3),np.zeros(3),np.zeros(3)
all_masks = []
size = config["inputsize"]
for idx, data in enumerate(dataloader_val): 
    image = data["img"].clone()
    print(f'{idx}/{len(dataloader_val)}', end = '\r')
    mask = torch.ones((size,size))
    flag = -1
    for i in range(len(data["bboxes"][0])):
        flag = 1
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        mask[y1:y2,x1:x2] = 0
        inv_mask = 1 - mask
        if (y2-y1)*(x2-x1) < size*size/4:
            all_masks.append(mask)
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        #print(label_name)
    if flag != 1:
        mask = random.choice(all_masks)
        inv_mask = 1 - mask
        
    blur = gaussian_filter(image*inv_mask,10)
    image_mask = image*mask
    image = image_mask+ blur
    masks = grad_cam(image, None)
    target = np.zeros(3)
    if len(data["bboxes"][0]) != 0:
        target[int(data["labels"][0][0])] = 1
    output = masks[0]
    
    
    gt += target
    tp = (output * target)
    fp = (output * (1 - target))
    tn = (1 - output) * (1 - target)
    tpa += tp
    fpa += fp
    tna += tn
    #break
    #if isinstance(masks[2],numbers.Number):
    #        continue
    #masksa = augmented_grad_cam(grad_cam, data["img"]) 
    masks = masks[1]
    img = unnormalize(data)['img'].copy() 
    img[img<0] = 0
    img[img>255] = 255
    flag = 0
    for num, mask in enumerate(masks):
        if isinstance(mask,numbers.Number):
            flag += 1
            """if flag == 3:
                print("Normal")"""
            continue
        #print("----------------------------------------")
        #print("Grad-CAM")
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        #mask 1,512,512
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad1 = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad1, threshold = threshold)
        grad1 = cv2.cvtColor(grad1, cv2.COLOR_BGR2RGB)
        #cv2.imwrite(f"/data/unagi0/masaoka/w_mask_grad_cam/{idx}_{num}.png", np.uint8(255*grad1))
    #print("------------------------------------")
    """print("Augmented Grad-CAM")
    for num, mask in enumerate(masksa):
        if isinstance(mask,numbers.Number):
            continue
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        else:
            print("Normal")
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad, threshold = threshold)
        plt.imshow(grad)
        plt.show()"""
    #print("-------------------------------------")
    #print("Ground Truth")
    #plt.imshow(img)
    #plt.show()
    label = 3
    for i in range(len(data["bboxes"][0])):
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        draw_caption(img, (x1, y1, x2, y2), label_name)
        if label_name == "ulcer":
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) #緑
        else:
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2) #赤
        label = int(data["labels"][0][i])
        
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #cv2.imwrite(f"/data/unagi0/masaoka/w_mask_normal/{idx}_{label}.png", img)
    
print(tpa,gt-tpa,fpa,tna)
path = f'/data/unagi0/masaoka/wsod/text/val{val}.txt'
with open(path, mode='a') as f:
    f.write('all images, masking anomaly\n')
    f.write(f'TP: {tpa},FN: {gt-tpa},FP: {fpa},TN: {tna}\n')
    f.write(f'Precision: {tpa/(tpa+fpa+1e-10)}, Recall: {tpa/(gt+1e-10)}, Spec: {tna/(tna+fpa+1e-10)}\n')

In [None]:
#腫瘍にマスク、正常画像はランダムにマスク
gt, tpa, fpa, tna = np.zeros(3),np.zeros(3),np.zeros(3),np.zeros(3)
all_masks = []
size = config["inputsize"]
for idx, data in enumerate(dataloader_val): 
    if idx > ids[-1]:
        break
    if not (idx in ids):
        continue
    image = data["img"].clone()
    print(f'{idx}/{len(dataloader_val)}', end = '\r')
    mask = torch.ones((size,size))
    flag = -1
    for i in range(len(data["bboxes"][0])):
        flag = 1
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        mask[y1:y2,x1:x2] = 0
        inv_mask = 1 - mask
        if (y2-y1)*(x2-x1) < size*size/4:
            all_masks.append(mask)
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        #print(label_name)
    if flag != 1:
        mask = random.choice(all_masks)
        inv_mask = 1 - mask
        
    blur = gaussian_filter(image*inv_mask,10)
    image_mask = image*mask
    image = image_mask+ blur
    masks = grad_cam(image, None)
    target = np.zeros(3)
    if len(data["bboxes"][0]) != 0:
        target[int(data["labels"][0][0])] = 1
    output = masks[0]
    
    
    gt += target
    tp = (output * target)
    fp = (output * (1 - target))
    tn = (1 - output) * (1 - target)
    tpa += tp
    fpa += fp
    tna += tn
    #break
    #if isinstance(masks[2],numbers.Number):
    #        continue
    #masksa = augmented_grad_cam(grad_cam, data["img"]) 
    masks = masks[1]
    img = unnormalize(data)['img'].copy() 
    img[img<0] = 0
    img[img>255] = 255
    flag = 0
    for num, mask in enumerate(masks):
        if isinstance(mask,numbers.Number):
            flag += 1
            """if flag == 3:
                print("Normal")"""
            continue
        #print("----------------------------------------")
        #print("Grad-CAM")
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        #mask 1,512,512
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad1 = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad1, threshold = threshold)
        grad1 = cv2.cvtColor(grad1, cv2.COLOR_BGR2RGB)
        #cv2.imwrite(f"/data/unagi0/masaoka/w_mask_grad_cam/{idx}_{num}.png", np.uint8(255*grad1))
    #print("------------------------------------")
    """print("Augmented Grad-CAM")
    for num, mask in enumerate(masksa):
        if isinstance(mask,numbers.Number):
            continue
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        else:
            print("Normal")
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad, threshold = threshold)
        plt.imshow(grad)
        plt.show()"""
    #print("-------------------------------------")
    #print("Ground Truth")
    #plt.imshow(img)
    #plt.show()
    label = 3
    for i in range(len(data["bboxes"][0])):
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        draw_caption(img, (x1, y1, x2, y2), label_name)
        if label_name == "ulcer":
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) #緑
        else:
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2) #赤
        label = int(data["labels"][0][i])
        
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #cv2.imwrite(f"/data/unagi0/masaoka/w_mask_normal/{idx}_{label}.png", img)
    
print(tpa,gt-tpa,fpa,tna)
path = f'/data/unagi0/masaoka/wsod/text/val{val}.txt'
with open(path, mode='a') as f:
    f.write('only confidential images, masking anomaly\n')
    f.write(f'TP: {tpa},FN: {gt-tpa},FP: {fpa},TN: {tna}\n')
    f.write(f'Precision: {tpa/(tpa+fpa+1e-10)}, Recall: {tpa/(gt+1e-10)}, Spec: {tna/(tna+fpa+1e-10)}\n')

In [None]:
#腫瘍を切り出し表示、正常画像はランダムに切り出し表示

gt, tpa, fpa, tna = np.zeros(3),np.zeros(3),np.zeros(3),np.zeros(3)
all_masks = []
for idx, data in enumerate(dataloader_val): 
    image = data["img"].clone()
    print(f'{idx}/{len(dataloader_val)}', end = '\n')
    mask = torch.zeros((size, size))
    flag = 0
    for i in range(len(data["bboxes"][0])):
        flag = 1
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        mask[y1:y2,x1:x2] = 1
        inv_mask = 1 - mask
        if (y2-y1)*(x2-x1) < size*size/4:
            all_masks.append(mask) 
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        #print(label_name)
    if flag ==0:
        mask = random.choice(all_masks)
        inv_mask = 1 - mask
    blur = gaussian_filter(image*inv_mask,10)
    image = image*mask
    image = image + blur
    masks = grad_cam(image, None)
    target = np.zeros(3)
    if len(data["bboxes"][0]) != 0:
        target[int(data["labels"][0][0])] = 1
    output = masks[0]
    
    
    gt += target
    tp = (output * target)
    fp = (output * (1 - target))
    tn = (1 - output) * (1 - target)
    tpa += tp
    fpa += fp
    tna += tn
    #break
    #if isinstance(masks[2],numbers.Number):
    #        continue
    #masksa = augmented_grad_cam(grad_cam, data["img"]) 
    masks = masks[1]
    img = unnormalize(data)['img'].copy() 
    img[img<0] = 0
    img[img>255] = 255
    flag = 0
    for num, mask in enumerate(masks):
        if isinstance(mask,numbers.Number):
            flag += 1
            """if flag == 3:
                print("Normal")"""
            continue
        """print("----------------------------------------")
        print("Grad-CAM")
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")"""
        #mask 1,512,512
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad1 = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad1, threshold = threshold)
        grad1 = cv2.cvtColor(grad1, cv2.COLOR_BGR2RGB)
        #cv2.imwrite(f"/data/unagi0/masaoka/tumor_only_grad_cam/{idx}_{num}.png", np.uint8(255*grad1))
    #print("------------------------------------")
    """print("Augmented Grad-CAM")
    for num, mask in enumerate(masksa):
        if isinstance(mask,numbers.Number):
            continue
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        else:
            print("Normal")
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad, threshold = threshold)
        plt.imshow(grad)
        plt.show()"""
    #print("-------------------------------------")
    #print("Ground Truth")
    #plt.imshow(img)
    #plt.show()
    label = 3
    for i in range(len(data["bboxes"][0])):
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        draw_caption(img, (x1, y1, x2, y2), label_name)
        if label_name == "ulcer":
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) #緑
        else:
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2) #赤
        label = int(data["labels"][0][i])
        
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #cv2.imwrite(f"/data/unagi0/masaoka/tumor_only_normal/{idx}_{label}.png", img)
    
  
    
print(tpa,gt-tpa,fpa,tna)
path = f'/data/unagi0/masaoka/wsod/text/val{val}.txt'
with open(path, mode='a') as f:
    f.write('all images, emphasizing anomaly\n')
    f.write(f'TP: {tpa},FN: {gt-tpa},FP: {fpa},TN: {tna}\n')
    f.write(f'Precision: {tpa/(tpa+fpa+1e-10)}, Recall: {tpa/(gt+1e-10)}, Spec: {tna/(tna+fpa+1e-10)}\n')

In [None]:
#腫瘍を切り出し表示、正常画像はランダムに切り出し表示

gt, tpa, fpa, tna = np.zeros(3),np.zeros(3),np.zeros(3),np.zeros(3)
all_masks = []
for idx, data in enumerate(dataloader_val): 
    if not (idx in ids):
        continue
    image = data["img"].clone()
    print(f'{idx}/{len(dataloader_val)}', end = '\n')
    mask = torch.zeros((size, size))
    flag = 0
    for i in range(len(data["bboxes"][0])):
        flag = 1
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        mask[y1:y2,x1:x2] = 1
        inv_mask = 1 - mask
        if (y2-y1)*(x2-x1) < size*size/4:
            all_masks.append(mask) 
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        #print(label_name)
    if flag ==0:
        mask = random.choice(all_masks)
        inv_mask = 1 - mask
    blur = gaussian_filter(image*inv_mask,10)
    image = image*mask
    image = image + blur
    masks = grad_cam(image, None)
    target = np.zeros(3)
    if len(data["bboxes"][0]) != 0:
        target[int(data["labels"][0][0])] = 1
    output = masks[0]
    
    
    gt += target
    tp = (output * target)
    fp = (output * (1 - target))
    tn = (1 - output) * (1 - target)
    tpa += tp
    fpa += fp
    tna += tn
    #break
    #if isinstance(masks[2],numbers.Number):
    #        continue
    #masksa = augmented_grad_cam(grad_cam, data["img"]) 
    masks = masks[1]
    img = unnormalize(data)['img'].copy() 
    img[img<0] = 0
    img[img>255] = 255
    flag = 0
    for num, mask in enumerate(masks):
        if isinstance(mask,numbers.Number):
            flag += 1
            """if flag == 3:
                print("Normal")"""
            continue
        """print("----------------------------------------")
        print("Grad-CAM")
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")"""
        #mask 1,512,512
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad1 = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad1, threshold = threshold)
        grad1 = cv2.cvtColor(grad1, cv2.COLOR_BGR2RGB)
        #cv2.imwrite(f"/data/unagi0/masaoka/tumor_only_grad_cam/{idx}_{num}.png", np.uint8(255*grad1))
    #print("------------------------------------")
    """print("Augmented Grad-CAM")
    for num, mask in enumerate(masksa):
        if isinstance(mask,numbers.Number):
            continue
        if num == 0:
            print("torose lesion")
        elif num == 1:
            print("vascular lesion")
        elif num == 2:
            print("ulcer")
        else:
            print("Normal")
        #heatmap2box(mask, img=img, threshold = threshold)
        mask_thresh = thresh(mask, threshold = threshold)
        grad = cam_on_image(img, mask_thresh[0])
        heatmap2box(mask, img=grad, threshold = threshold)
        plt.imshow(grad)
        plt.show()"""
    #print("-------------------------------------")
    #print("Ground Truth")
    #plt.imshow(img)
    #plt.show()
    label = 3
    for i in range(len(data["bboxes"][0])):
        x1 = int(data["bboxes"][0][i][0])
        y1 = int(data["bboxes"][0][i][1])
        x2 = int(data["bboxes"][0][i][2])
        y2 = int(data["bboxes"][0][i][3])
        label_name = dataset_val.labels[int(data["labels"][0][i])]
        draw_caption(img, (x1, y1, x2, y2), label_name)
        if label_name == "ulcer":
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) #緑
        else:
            cv2.rectangle(img, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2) #赤
        label = int(data["labels"][0][i])
        
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #cv2.imwrite(f"/data/unagi0/masaoka/tumor_only_normal/{idx}_{label}.png", img)
    
print(tpa,gt-tpa,fpa,tna)
path = f'/data/unagi0/masaoka/wsod/text/val{val}.txt'
with open(path, mode='a') as f:
    f.write('only confidential images, emphasizing anomaly\n')
    f.write(f'TP: {tpa},FN: {gt-tpa},FP: {fpa},TN: {tna}\n')
    f.write(f'Precision: {tpa/(tpa+fpa+1e-10)}, Recall: {tpa/(gt+1e-10)}, Spec: {tna/(tna+fpa+1e-10)}\n')