In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
plt.style.use("seaborn-white")
import seaborn as sns
sns.set_style("white")
from torch import nn
from tqdm import tqdm
from toolz import compose
import datetime
import numpy as np
import pandas as pd
import torch
from torch.utils import data
import logging
import random
import os

In [2]:
from data import (
    test_images_path,
    load_images_as_arrays,
    TGSSaltDataset,
    prepare_test_data,
)
from model import model_path, predict_tta
from data import rle_encode
from resnet34_unet_hyper import UNetResNetSCSE
from config import load_config
from image_processing import upsample, downsample

In [3]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [4]:
now = datetime.datetime.now()

In [5]:
config = load_config()["EvaluateModel"]
logger.info(f"Loading config {config}")

INFO:config:Loading config ../configs/config.json
INFO:__main__:Loading config {'base_channels': 32, 'batch_size': 128, 'id': '99e6d379-569f-4554-bf37-afa6d1d2f4e4', 'img_target_size': 128, 'initial_model_filename': 'best_lovasz_model.pth', 'name': 'EvaluateModel', 'num_workers': 0, 'seed': 42, 'threshold': 0.43332206261113065}


In [6]:
locals().update(config)

In [7]:
torch.backends.cudnn.benchmark = True
logger.info(f"Started {now}")

INFO:__main__:Started 2018-10-10 13:24:45.985977


In [8]:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [9]:
model = UNetResNetSCSE()

In [10]:
device = torch.device("cuda:0")
model = nn.DataParallel(model)
model.to(device)

DataParallel(
  (module): UNetResNetSCSE(
    (resnet): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm

In [11]:
model_dir = os.path.join(model_path(), f"{id}")
filename = os.path.join(model_dir, initial_model_filename)
checkpoint = torch.load(filename)
model.load_state_dict(checkpoint["state_dict"])

In [12]:
test_df = prepare_test_data()
x_test = load_images_as_arrays(test_df.index, test_images_path())

100%|██████████| 18000/18000 [00:14<00:00, 1211.49it/s]


In [13]:
upsample_to = upsample(101, img_target_size)

In [14]:
x_test = np.asarray(list(map(upsample_to, x_test)))

In [15]:
x_test = x_test.reshape(-1, 1, img_target_size, img_target_size)

In [16]:
dataset_test = TGSSaltDataset(x_test, is_test=True)

In [17]:
test_data_loader = data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=False,
)

In [18]:
model.eval()
predictions = [predict_tta(model, image) for image in tqdm(test_data_loader)]

100%|██████████| 141/141 [00:50<00:00,  1.02s/it]


In [19]:
preds_test = np.concatenate(predictions, axis=0).squeeze()

In [20]:
transform = compose(rle_encode, 
                    np.round, 
                    lambda x: x > threshold,
                    downsample(img_target_size, 101))

In [21]:
pred_dict = {
    idx: transform(preds_test[i]) for i, idx in enumerate(tqdm(test_df.index.values))
}

100%|██████████| 18000/18000 [00:23<00:00, 758.98it/s]


In [22]:
sub = pd.DataFrame.from_dict(pred_dict, orient="index")
sub.index.names = ["id"]
sub.columns = ["rle_mask"]
filename = os.path.join(model_dir, f"submission_{now:%d%b%Y_%H}.csv")
sub.to_csv(filename)