In [1]:
import os
import cv2
import torch
import timeit
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm

from utils.generator import claim_generator
from utils.iou import compute_mask_iou
from utils.losses import Loss

from model.resnet import UNet_ResNet
from model.senet import UNet_SENet
from model.resnext import UNet_SeResnext
from model.deep_lab.deeplab import *

In [2]:
abs_path = '/home/kaichou/ssd/course'
test_path = os.path.join(abs_path, 'test')

In [3]:
test_loader = claim_generator(test_path, batch_size = 50, workers = 10, side_size=513, mode = 'test')

iou_nets = pd.DataFrame(columns = ['IoU@0.3', 'IoU@0.5', 'IoU@0.7', 'IoU@0.9', 'Dice_loss', 'Inf_Time'])

In [4]:
iou_nets = pd.read_csv('nets_test_info.csv')

In [5]:
model = DeepLab()

In [6]:
model.load_state_dict(torch.load('weights/deeplab/model/DeepLab_model129.pth', map_location = 'cpu'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [7]:
device = torch.device('cuda')

In [8]:
model = model.to(device)

In [9]:
criterion = Loss(0, 1)

In [10]:
times = []

In [11]:
ious = {
    0.3: 0,
    0.5: 0,
    0.7: 0,
    0.9: 0
}

In [12]:
loss = 0
schet = 0
whole_loss = 0
with torch.no_grad():
    for x, y in tqdm(test_loader):
        schet += 1
        st = timeit.default_timer()
        x = x.to(device)
        y = y.to(device)
        mask_pred = model(x)
        times.append(timeit.default_timer() - st)
        loss = criterion(mask_pred, y)
        whole_loss += loss.item()
        for t in [0.3, 0.5, 0.7, 0.9]:
            ious[t] += compute_mask_iou(y.cpu().squeeze(1).numpy(), (torch.sigmoid(mask_pred) > t).cpu().squeeze(1).numpy().astype(np.float))

HBox(children=(IntProgress(value=0, max=44), HTML(value='')))




In [13]:
iou_nets.loc[3] = [ious[0.3] / schet, ious[0.5] / schet, ious[0.7] / schet, ious[0.9] / schet, whole_loss / schet, np.mean(times)]

In [14]:
iou_nets.to_csv('nets_test_info.csv', index = False)

In [15]:
iou_nets

Unnamed: 0,IoU@0.3,IoU@0.5,IoU@0.7,IoU@0.9,Dice_loss,Inf_Time
0,0.863049,0.867387,0.866115,0.857214,0.088416,0.023123
1,0.873832,0.874391,0.874118,0.8682,0.077424,0.131036
2,0.912831,0.913522,0.912954,0.906816,0.041356,0.767237
3,0.871474,0.872698,0.870205,0.856783,0.063681,0.08708
