In [3]:
import os
import pickle
import numpy as np
import albumentations as A
import torch, torch.nn as nn
import segmentation_models_pytorch as sm
import cv2
import matplotlib.pyplot as plt
from torch.autograd import Variable
from tqdm.notebook import tqdm
from pathlib import Path

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

infer_aug = A.Compose([ 
                        A.Resize ( height=384,
                                   width=384,
                                   interpolation=1,
                                   always_apply=False,
                                   p=1. ),
                        A.Normalize ( mean=(0.485, 0.456, 0.406),
                                      std=(0.229, 0.224, 0.225),
                                      max_pixel_value=255.0,
                                      always_apply=False,
                                      p=1.0 )
                      ])


class UnetSm(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, **kwargs):
        super().__init__()

        self.unet = sm.Unet(in_channels=in_channels, classes=out_channels, **kwargs)

    def forward(self, inputs):
        return self.unet(inputs)

def inverse_infer_aug(height, width):
    return A.Compose([
                       A.Resize ( height=height,
                                  width=width,
                                  interpolation=cv2.INTER_NEAREST,
                                  always_apply=False,
                                  p=1. ) 
                    ])

    
def create_model(model_file):
    model = UnetSm( out_channels=2,
                    encoder_name='efficientnet-b0' )
 
    with open(model_file, 'rb') as f:
        state = pickle.load(f)

    model.load_state_dict(state)
    #model.cuda()

    return model


def torch_float(data, device):
    return Variable(torch.FloatTensor(data)).to(device)

def augmented_load(img, aug):
    auged = aug(image=img)
    aug_img = auged['image']

    aug_img = aug_img.transpose(2, 0, 1)

    return np.array([aug_img])

def infer(model, img):
    
    imgs_batch_ = augmented_load(img, infer_aug)
    imgs_batch = torch_float(imgs_batch_, torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

    logits_batch = model(imgs_batch)

    pred_masks_batch = logits_batch.argmax(dim=1)
    pred_mask = pred_masks_batch.cpu().data.numpy()
    pred_mask = np.reshape(pred_mask, (384, 384,1))

    original_size = img.shape[:2]

    original_pred_mask = inverse_infer_aug(*original_size)(image=pred_mask)['image']
    original_pred_mask = original_pred_mask.astype('uint8') * 255 

    return original_pred_mask

model_checker_path = 'model_checker.bin'

for file in tqdm([*(Path("orig_images")).glob('*.jpg')]):
    img = cv2.imread(str(file))
    img_checker = infer(create_model(model_checker_path),img)
    cv2.imwrite("orig_masks/" + file.stem + ".jpg", img_checker)

  0%|          | 0/1445 [00:00<?, ?it/s]