In [1]:
import gc
import os
import torch
import pandas as pd
import numpy as np

from tqdm import tqdm
from dataset.Dataset import TGS_Dataset, do_horizontal_flip
from model.Models import UnetSeNext50
from losses.Evaluation import do_length_decode, do_length_encode

In [2]:
def load_net_and_predict(net, test_path, load_paths, batch_size=32, tta_transform=None, threshold=0.5, min_size=0):
    sum_ = np.zeros((18000, 1, 101, 101))
    test_dataset = TGS_Dataset(test_path)
    test_loader, test_ids = test_dataset.yield_dataloader(data='test', num_workers=11, batch_size=batch_size)
    # predict
    for i in tqdm(range(len(load_paths))):
        net.load_model(load_paths[i])
        p = net.do_predict(test_loader, threshold=0, tta_transform=tta_transform)
        sum_ += p['pred']
    del test_dataset, test_loader
    
    return sum_, p

In [3]:
def tta_transform(images, mode):
    out = []
    if mode == 'out':
        images = images[0]
    images = images.transpose((0, 2, 3, 1))
    tta = []
    for i in range(len(images)):
        t = np.fliplr(images[i])
        tta.append(t)
    tta = np.transpose(tta, (0, 3, 1, 2))
    out.append(tta)
    return np.asarray(out)

In [4]:
TEST_PATH = './test/'
net = UnetSeNext50()
NET_NAME = type(net).__name__
THRESHOLD = 0.45
MIN_SIZE = 0
BATCH_SIZE = 128

In [5]:
filelists = ['./Saves/UnetSeNext50/2018-10-20 03:57_Fold1_Epoach12_reset0_val0.847',
            './Saves/UnetSeNext50/2018-10-20 04:24_Fold1_Epoach30_reset1_val0.864']
sum_, p = load_net_and_predict(net, TEST_PATH, filelists,
                            tta_transform=tta_transform,
                            batch_size=BATCH_SIZE,
                            threshold=THRESHOLD,min_size=MIN_SIZE)

100%|██████████| 2/2 [09:08<00:00, 270.51s/it]


In [6]:
avg = sum_ / (len(filelists))
pred = avg > THRESHOLD

In [7]:
rle = []
for i in range(len(pred)):
    rle.append(do_length_encode(pred[i]))
# create sub
df = pd.DataFrame(dict(id=p['id'], rle_mask=rle))

In [8]:
# You can modify the filename of the .csv file
df.to_csv(os.path.join(
        './results/',
        'submit.csv'),
        index=False)