In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%time
import zipfile

!cp -r drive/MyDrive/Signate-OffroadSegmentation/data/precision_test_images precision_test_images

!pip install git+https://github.com/rwightman/pytorch-image-models.git 
!pip install -U git+https://github.com/albu/albumentations --no-cache-dir
!pip install pytorch-lightning
!pip install segmentation-models-pytorch 

In [None]:
import os
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import glob

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

import torch
import numpy as np

import albumentations as albu
from tqdm import tqdm_notebook as tqdm


import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.metrics.functional import accuracy
from sklearn.model_selection import StratifiedKFold

In [None]:
import segmentation_models_pytorch as smp

class TestDataset(BaseDataset):
    
    def __init__(
            self, 
            df, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.images_fps = df['file_name'].values
                
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image)
            image = sample['image']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']
            
        return image
        
    def __len__(self):
        return len(self.images_fps)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    IMAGE_SIZE = [1080, 1920]
    test_transform = [
        albu.Resize(*[1056, 1920]),
        albu.PadIfNeeded(1056, 1920)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """


    if preprocessing_fn:
        _transform = [
            albu.Lambda(image=preprocessing_fn),
            albu.Lambda(image=to_tensor),
        ]
    else:
        _transform = [
            albu.Lambda(image=to_tensor),
        ]

    return albu.Compose(_transform)

ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
df_test = pd.read_csv('/content//drive/MyDrive/Signate-OffroadSegmentation/data/test_2stage_binary.csv')

test_dataset_resnet18 = TestDataset(
    df_test, 
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(smp.encoders.get_preprocessing_fn('resnet18', ENCODER_WEIGHTS)),
)

test_dataset_resnet34 = TestDataset(
    df_test, 
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(smp.encoders.get_preprocessing_fn('resnet34', ENCODER_WEIGHTS)),
)

In [None]:
best_models_resnet34 = [
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/2_resnet34_Resizemix_FocalLoss_fold0.pth').eval(),
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/2_resnet34_Resizemix_FocalLoss_fold1.pth').eval(),
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/2_resnet34_Resizemix_FocalLoss_fold2.pth').eval(),
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/2_resnet34_Resizemix_FocalLoss_fold3.pth').eval(),
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/2_resnet34_Resizemix_FocalLoss_fold4.pth').eval(),
]

best_models_resnet18 = [
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/3_resnet18_Resizemix_FocalLoss_fold0.pth').eval(),
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/3_resnet18_Resizemix_FocalLoss_fold1.pth').eval(),
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/3_resnet18_Resizemix_FocalLoss_fold2.pth').eval(),
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/3_resnet18_Resizemix_FocalLoss_fold3.pth').eval(),
              torch.load(f'/content//drive/MyDrive/Signate-OffroadSegmentation/model/3_resnet18_Resizemix_FocalLoss_fold4.pth').eval(),
]

model_other = [torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_2_OtherObstacle_resnet34_Resizemix_BCELoss_fold0.pth').eval(),
               torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_2_OtherObstacle_resnet34_Resizemix_BCELoss_fold1.pth').eval(),
               torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_2_OtherObstacle_resnet34_Resizemix_BCELoss_fold2.pth').eval(),
               torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_2_OtherObstacle_resnet34_Resizemix_BCELoss_fold3.pth').eval(),
               torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_2_OtherObstacle_resnet34_Resizemix_BCELoss_fold4.pth').eval(),
               ]

model_road_hasRoad = [torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_3_road_hasroad_resnet34_Resizemix_BCELoss_fold0.pth').eval(),
                      torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_3_road_hasroad_resnet34_Resizemix_BCELoss_fold1.pth').eval(),
                      torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_3_road_hasroad_resnet34_Resizemix_BCELoss_fold2.pth').eval(),
                      torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_3_road_hasroad_resnet34_Resizemix_BCELoss_fold3.pth').eval(),
                      torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_3_road_hasroad_resnet34_Resizemix_BCELoss_fold4.pth').eval(),
                      ]

model_road_NothasRoad = [torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_4_road_NotHasroad_resnet34_Resizemix_BCELoss_fold0.pth').eval(),
                         torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_4_road_NotHasroad_resnet34_Resizemix_BCELoss_fold1.pth').eval(),
                         torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_4_road_NotHasroad_resnet34_Resizemix_BCELoss_fold2.pth').eval(),
                         torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_4_road_NotHasroad_resnet34_Resizemix_BCELoss_fold3.pth').eval(),
                         torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_4_road_NotHasroad_resnet34_Resizemix_BCELoss_fold4.pth').eval(),
                         ]

model_dirtroad_hasRoad = [torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_5_dirtroad_hasroad_resnet34_Resizemix_BCELoss_fold0.pth').eval(),
                          torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_5_dirtroad_hasroad_resnet34_Resizemix_BCELoss_fold1.pth').eval(),
                          torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_5_dirtroad_hasroad_resnet34_Resizemix_BCELoss_fold2.pth').eval(),
                          torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_5_dirtroad_hasroad_resnet34_Resizemix_BCELoss_fold3.pth').eval(),
                          torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_5_dirtroad_hasroad_resnet34_Resizemix_BCELoss_fold4.pth').eval(),
                          ]



model_dirtroad_NothasRoad = [torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_6_dirtroad_NotHasroad_resnet34_Resizemix_BCELoss_fold0.pth').eval(),
                             torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_6_dirtroad_NotHasroad_resnet34_Resizemix_BCELoss_fold1.pth').eval(),
                             torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_6_dirtroad_NotHasroad_resnet34_Resizemix_BCELoss_fold2.pth').eval(),
                             torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_6_dirtroad_NotHasroad_resnet34_Resizemix_BCELoss_fold3.pth').eval(),
                             torch.load('/content//drive/MyDrive/Signate-OffroadSegmentation/model/4_6_dirtroad_NotHasroad_resnet34_Resizemix_BCELoss_fold4.pth').eval(),
                             ]



In [None]:
%%time

th = 0.5
DEVICE = 'cuda'
json_data = {}
for i, row in tqdm(df_test_path.iterrows()): 
    image = test_dataset[i]        
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)

    png_name = row['png_name']
    json_data[png_name] = {}

    image = test_dataset[i]
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask_18 = 0
    for best_model in best_models_resnet18:
        pr_mask_tmp = best_model.predict(x_tensor)
        pr_mask_tmp = pr_mask_tmp.squeeze().cpu().numpy()
        pr_mask_tmp = pr_mask_tmp.transpose(1,2,0)
        pr_mask_tmp = cv2.resize(pr_mask_tmp, (1920, 1080))
        pr_mask_18 += pr_mask_tmp / 5
    
    image_34 = test_dataset_resnet34[i]
    x_tensor_34 = torch.from_numpy(image_34).to(DEVICE).unsqueeze(0)
    pr_mask_34 = 0
    for best_model in best_models_resnet34:
        pr_mask_tmp = best_model.predict(x_tensor_34)
        pr_mask_tmp = pr_mask_tmp.squeeze().cpu().numpy()
        pr_mask_tmp = pr_mask_tmp.transpose(1,2,0)
        pr_mask_tmp = cv2.resize(pr_mask_tmp, (1920, 1080))
        pr_mask_34 += pr_mask_tmp / 5

    pr_mask_pseudo = np.zeros_like(pr_mask_18)
    for i, category in enumerate(['road', 'dirt road', 'other obstacle']):
        pr_mask_tmp = 0
        if category == 'road':
            if row['pred_has_road'] > 0.9:
                for m in model_road_hasRoad:
                    pr_mask_tmp += m.predict(x_tensor_34) / 5
            else:
                for m in model_road_NothasRoad:
                    pr_mask_tmp += m.predict(x_tensor_34) / 5

        elif category == 'dirt road':
            if row['pred_has_road'] > 0.9:
                for m in model_dirtroad_hasRoad:
                    pr_mask_tmp += m.predict(x_tensor_34) / 5
            else:
                for m in model_dirtroad_NothasRoad:
                    pr_mask_tmp += m.predict(x_tensor_34) / 5

        elif category == 'other obstacle':
            for m in model_other:
                pr_mask_tmp += m.predict(x_tensor_34) / len(model_other)

        pr_mask_1c = pr_mask_tmp.squeeze().cpu().numpy()
        pr_mask_1c = cv2.resize(pr_mask_1c, (1920, 1080))

        pr_mask_em = pr_mask_18[:,:,i] * (4/10) + pr_mask_34[:,:,i] * (4/10) + pr_mask_1c * (2/10)

        pr_mask_pseudo[:, :, i] = pr_mask_em

        

        x, y = np.where(pr_mask_em > th)
        category_pix = {}
        category_segments = {}
        for i,j in zip(x,y):
            if i not in category_pix:
                category_pix[i]=[]
            category_pix[i].append(j)
        for l in category_pix:
            segments = []
            num_segments = 0
            for i,v in enumerate(sorted(category_pix[l])):
                if i==0:
                    start=v
                    end=v
                else:
                    if v==end+1:
                        end = v
                    else:
                        segments.append([int(start),int(end)])
                        start = v
                        end = v
                        num_segments+=1
            segments.append([int(start),int(end)])
            category_segments[int(l)]=segments

        if len(category_segments):
            json_data[png_name][category] = category_segments    

        

import json
result_json_name = f'/content//drive/MyDrive/OffroadSegmentation/result/{filename}_weight2.json'
with open(result_json_name, 'w') as f:
    json.dump(json_data,f)