In [1]:
import sys,os,time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import scipy.ndimage
from skimage.io import imread, imsave
from skimage.transform import rotate

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset

from utils import LARGE_CHIP_SIZE, CHIP_SIZE,CROP_POINT, NUM_WORKERS,MixedLoss, joint_transform, mixed_loss, get_mask
from tqdm import tqdm

from dataloader import AirbusShipPatchDataset, AirbusShipDataset
from streaming_dataloader import StreamingShipDataset, StreamingShipValTestDataset
import joblib

import rasterio
import fiona
import shapely.geometry
import cv2
import rasterio.features
from PIL import Image
import segmentation_models_pytorch as smp
import glob

In [2]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [3]:
ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
# preprocessing_fn = None

# #  Test Loader

# streaming_test_dataset = StreamingShipValTestDataset("./data/test_df.csv", "./data/train_v2/", 
#     large_chip_size=LARGE_CHIP_SIZE, chip_size=CHIP_SIZE, transform=joint_transform, preprocessing_fn=preprocessing_fn,
#     rotation_augmentation=False, only_ships=True)

# streaming_test_aug_dataset = StreamingShipValTestDataset("./data/test_df.csv", "./data/train_v2/", 
#     large_chip_size=LARGE_CHIP_SIZE, chip_size=CHIP_SIZE, transform=joint_transform, preprocessing_fn=preprocessing_fn,
#     rotation_augmentation=True, only_ships=True)

# test_loader = DataLoader(dataset=streaming_test_dataset, batch_size = 1, num_workers=1)

# test_aug_loader = DataLoader(dataset=streaming_test_aug_dataset, batch_size = 1, num_workers=1)

In [4]:
device = torch.device("cuda:%d" % 0)
aug_model = torch.load('./best_model_aug.pth')
aug_model = aug_model.to(device)

non_aug_model = torch.load('./best_model_non_aug.pth')
non_aug_model = non_aug_model.to(device)

In [5]:
np_aug_model = torch.load('./best_model_aug_np.pth')
np_aug_model = np_aug_model.to(device)

In [6]:
aug_model.eval()
non_aug_model.eval()
np_aug_model.eval()

loss = MixedLoss(10.0,2.0)
loss.__name__ = "MixedLoss"

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

In [7]:
class ShipTestDataset(Dataset):

    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.image_fns = glob.glob(self.file_path + "img/*")
        self.transform = transform

    def __len__(self):
        return len(self.image_fns)

    def __getitem__(self, idx):
        
        fn = self.image_fns[idx].split('/')[-1]
        
        mask_fn = os.path.join(self.file_path, "mask",fn.replace("jpg", "png"))
        
        # Read image
        img = imread(self.image_fns[idx])
        mask = imread(mask_fn)
        
        if self.transform != None:
            img = self.transform(img)
        else:
            img = img / 255.0
            
        p_img = np.rollaxis(img, 2, 0).astype(np.float32)
        p_img = torch.from_numpy(p_img).squeeze()

        p_mask = mask.astype(np.int64)
        p_mask = torch.from_numpy(p_mask).unsqueeze(0)
        
        return p_img, p_mask

In [20]:
test_epoch_aug_Unet = smp.utils.train.ValidEpoch(
    aug_model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

test_epoch_non_aug_Unet = smp.utils.train.ValidEpoch(
    non_aug_model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

test_epoch_aug_np_Unet = smp.utils.train.ValidEpoch(
    np_aug_model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [34]:
test_epoch_aug_Unet.run(aug_test_loader_pl)

valid:  21%|██        | 145899/687960 [28:57<1:47:34, 83.98it/s, MixedLoss - 10.64, iou_score - 0.5149] 


KeyboardInterrupt: 

In [None]:
test_epoch_aug_Unet.run(test_loader_pl)

# Iterative

In [25]:
def IoU(pred, targs):
    pred = (pred>0).float()
    intersection = (pred*targs).sum()
    return intersection / ((pred+targs).sum() - intersection + 1.0)

In [48]:
aug_test_ds = ShipTestDataset('./data/test_set_rotation_aug/', transform=preprocessing_fn)
aug_test_loader_pl = DataLoader(dataset=aug_test_ds, batch_size = 1, num_workers=8)

test_ds = ShipTestDataset('./data/test_set/', transform=preprocessing_fn)
test_loader_pl = DataLoader(dataset=test_ds, batch_size = 1, num_workers=8)

In [25]:
np_aug_model = torch.load('./aug_models_np/model_aug_9.pth')
np_aug_model = np_aug_model.to(device)

In [51]:
device = torch.device("cuda:%d" % 0)

aug_model = torch.load('./aug_models/model_aug_9.pth')
aug_model = torch.load('./best_model_aug.pth')
aug_model = aug_model.to(device)

non_aug_model = torch.load('./best_model_non_aug.pth')
non_aug_model = non_aug_model.to(device)

In [None]:
sum_iou = 0
count = 0
for i, (img, mask) in tqdm(enumerate(aug_test_loader_pl)):
#     if i == 1500: break
    
    if mask.sum() != 0:
        pred = aug_model(img.cuda())

        pred = pred.detach().cpu().double()
        
        pred[pred >= 0.5] = 1
        pred[pred < 0.5] =0
        
        sum_iou += (IoU(pred.squeeze(), mask.squeeze()))
        count += 1

#     if i != 0 and i % 500 == 0:
#         print(sum_iou / count)

#         visualize(
#             image=img.squeeze().permute(1,2,0),
#             mask=mask.squeeze(),
#             pred = pred.squeeze()
#         )

0it [00:00, ?it/s]Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb167122ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7fb167122ee0>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb167122ee0><function _MultiProcessingDataLoaderIter.__del__ at 0x7fb167122ee0>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fb167122ee0>


Traceback (most recent call last):
Traceback (most recent call last):

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/jason/anaconda3/envs/ai4e/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1324, in __del__
  File "/home/jason/anaconda3/envs/ai4e/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1324, in __del__
Exception ignored in: Traceback (most recent call last):
  File "/home/jason/anaconda3/envs/ai4e/lib/python3.8/site-packages/torch/utils/data/data

In [41]:
sum_iou/count

tensor(0.2867)