In [1]:
from dataset import SteelTestDataset
import torch
import pandas as pd
import numpy as np
import cv2
from tqdm import tqdm
from torch.utils.data import DataLoader
import os
import glob
import json
import csv

In [2]:
def post_process(probability, threshold=0.2, min_size=3500):
    '''Post processing of each predicted mask, components with lesser number of pixels
    than `min_size` are ignored'''
    mask = cv2.threshold(probability, threshold, 1, cv2.THRESH_BINARY)[1]
    num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
    predictions = np.zeros((256, 1600), np.float32)
    num = 0
    for c in range(1, num_component):
        p = (component == c)
        if p.sum() > min_size:
            predictions[p] = 1
            num += 1
    return predictions, num

In [3]:
from utils import mask2rle

In [4]:
from model_no_cat import *

In [5]:
mdl = torch.load('model_no_cat.pt').cpu().eval()

In [6]:
testset = DataLoader(
    SteelTestDataset(),
    batch_size=1,
    shuffle=False,
    pin_memory=True
)

1801


In [7]:
outfile = 'submission_no_cat_r.csv'

In [8]:
threshold=0.2
maxs = []
encpx = []
outfs = []
with open(outfile, 'w+') as of:
    of.write('ImageId_ClassId,EncodedPixels'+'\n')
    with torch.no_grad():
        for f, inp in tqdm(testset):
            image_id = f[0].split('.')[0]
            outp = mdl(inp).squeeze()
            maxs.append(outp.max())
            outp[outp >= threshold] = 1.0
            outp[outp < threshold] = 0
            for i in range(4):
                ext = '.jpg_' + str(i+1)
                k = image_id + ext
                outfs.append(k)
                y = outp[i]
                try:
                    encoded_pixels = mask2rle(y)
                    encpx.append(encoded_pixels)
        #                     print(y.std())
                except:
                    print('here')
                    encoded_pixels = ''
                    encpx.append(encoded_pixels)
                of.write(k + ',' + encoded_pixels + '\n')

100%|██████████| 1801/1801 [31:12<00:00,  1.04s/it]


In [12]:
res = zip(outfs, encpx)


In [13]:
preds = list(res)

In [16]:
df = pd.DataFrame(preds, columns=['ImageId_ClassId', 'EncodedPixels'])

In [19]:
df.to_csv('submission_final.csv', index=False)