In [2]:
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import os
import multiprocessing as mp
import seaborn as sns

from code.models import basicunet, resnetunet
from code.datasets import TGSTTADataset
from code.configs import *
from code.train import *
from code.losses import FocalRobustLoss
from code.metrics import *
from code.augmentations import *
from code.utils import *
from torch.utils.data import DataLoader
from IPython.display import clear_output
from code.inference import *

np.warnings.filterwarnings("ignore")

%matplotlib inline

In [2]:
def test_postproc(image, mask):
    image, mask = do_resize2(image, mask, 202, 202)
    image, mask = do_center_pad_to_factor2(image, mask, factor=64)
    return image, mask

In [3]:
class UNetResNet34Wrapped(resnetunet.UNetResNet34):
    
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.to(device)
        
    def __call__(self, image, **kwargs):
        logits = super().__call__(image)
        return {"logits": logits[:,0]}

In [4]:
test_ds = TGSTTADataset(postproc=test_postproc, paths=[PATH_TO_TEST], 
                        path_to_depths=PATH_TO_DEPTHS, progress_bar=True)
test_dl = DataLoader(test_ds, batch_size=16, num_workers=2)

In [5]:
device = torch.device("cuda")
model = UNetResNet34Wrapped(device)
mAP_cp = BestLastCheckpointer("mAP1")
mAP_cp.load("best", model=model)
predict_and_save(model, test_dl, ["logits"], "id", "test1", verbose=1)

In [6]:
device = torch.device("cuda")
model = UNetResNet34Wrapped(device)
mAP_cp = BestLastCheckpointer("mAP2")
mAP_cp.load("best", model=model)
predict_and_save(model, test_dl, ["logits"], "id", "test2", verbose=1)

In [7]:
device = torch.device("cuda")
model = UNetResNet34Wrapped(device)
mAP_cp = BestLastCheckpointer("mAP3")
mAP_cp.load("best", model=model)
predict_and_save(model, test_dl, ["logits"], "id", "test3", verbose=1)

In [8]:
device = torch.device("cuda")
model = UNetResNet34Wrapped(device)
mAP_cp = BestLastCheckpointer("mAP4")
mAP_cp.load("best", model=model)
predict_and_save(model, test_dl, ["logits"], "id", "test4", verbose=1)

In [9]:
device = torch.device("cuda")
model = UNetResNet34Wrapped(device)
mAP_cp = BestLastCheckpointer("mAP5")
mAP_cp.load("best", model=model)
predict_and_save(model, test_dl, ["logits"], "id", "test5", verbose=1)

In [10]:
dirs = ["test1/logits/", "test2/logits/", "test3/logits/", "test4/logits/", "test5/logits/"]
logits, ids = [], []

for name in os.listdir(dirs[0]):
    name = name.split(".")[0]
    id_ = name.split("_")[0]
    ids.append(id_)
    
ids = list(set(ids))
    
for id_ in tqdm_notebook(ids):
    logit_list = []
    for dirname in dirs:
        suffix = ""
        name = id_ + suffix + ".npy"
        path = os.path.join(dirname, name)
        logit = np.load(path)
        logit_list.append(logit)
        
        suffix = "_flipped"
        name = id_ + suffix + ".npy"
        path = os.path.join(dirname, name)
        logit = np.load(path)
        logit = logit[:,::-1]
        logit_list.append(logit)
        
    logit = np.mean(np.array(logit_list), axis=0)
    logit = logit[27:-27,27:-27]
    logit = cv2.resize(logit, dsize=(101, 101))
    logits.append(logit)
logits = np.array(logits)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))




In [11]:
import pickle
np.save("logits", logits)
with open("ids.pkl", "wb") as f:
    pickle.dump(ids, f)

In [12]:
logits.shape, len(ids)

((10, 101, 101), 18000)

In [13]:
preds = (logits > 0).astype(int)

In [14]:
preds = force_zero_empty(os.path.join(PATH_TO_TEST, "images"), ids, preds, verbose=1)

In [15]:
prepare_submit(preds, ids, "smbt41.csv")

In [16]:
!kaggle competitions submit -c tgs-salt-identification-challenge -f smbt41.csv -m "ResNet34 Lovasz 5 fold flip tta pl"

In [17]:
sum = preds.sum(axis=(1, 2))

In [18]:
preds2 = preds
preds2[sum < 10] = 0
prepare_submit(preds2, ids, "smbt42.csv")

In [19]:
!kaggle competitions submit -c tgs-salt-identification-challenge -f smbt42.csv -m "ResNet34 Lovasz 5 fold flip tta pl; less 10 to zero;"