In [1]:
#!pip install segmentation_models_pytorch
#!pip install --upgrade albumentations
#!pip uninstall opencv-python-headless==4.5.5.62
#!pip install opencv-python-headless==4.5.2.52

In [2]:
import json, os

import pandas as pd
import seaborn as sns

import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

import albumentations as albu

import torch
import numpy as np
from scipy.stats import gaussian_kde
import segmentation_models_pytorch as smp

In [3]:
def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.Resize(640, 640, always_apply=True)
    ]
    return albu.Compose(test_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [4]:
class SegmDataset(BaseDataset):
    def __init__(self, data_dir, ann_folder='ann', img_folder='img', mask_folder='masks_machine',
                augmentation=None, preprocessing=None):
        self.data_dir = data_dir
        self.ann_folder = ann_folder
        self.img_folder = img_folder
        self.mask_folder = mask_folder
        self.photo_ids = self.extract_ids()
        
        self.images = [os.path.join(data_dir, img_folder, img + '.png') for img in self.photo_ids]
        self.masks = [os.path.join(data_dir, mask_folder, img + '.png') for img in self.photo_ids]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        
    def extract_ids(self):
        marked_files = []
        ann_files = os.listdir(os.path.join(self.data_dir, self.ann_folder))
        for fname in ann_files:
            with open(os.path.join(self.data_dir, self.ann_folder, fname), 'r') as json_file:
                data = json.load(json_file)
                if data['objects']:
                    marked_files.append(fname.split('.')[0])
        return marked_files
    
    def __len__(self):
        return len(self.photo_ids)
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks[i], 0)
        mask = mask[:, :, np.newaxis]
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask        

In [5]:
train_ann_folder = 'ann'
test_ann_folder = train_ann_folder
ENCODER = 'resnet34' # resnet18
ENCODER_WEIGHTS = 'imagenet'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [11]:
MODEL_PATH = './td_model/_ocr_model.pth'
model = torch.load(MODEL_PATH, map_location=torch.device(DEVICE))

In [10]:
INPUT_DATASET = './cropped/good'
PREDICTED = './cropped/good/pred'

p = 0.5

dataset_predict = SegmDataset(
    INPUT_DATASET,
    test_ann_folder, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn)
)

for img_path in dataset_predict.images:
    img = preprocessing_fn(plt.imread(img_path))
    image = get_validation_augmentation()(image=img)['image']
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    x_tensor = x_tensor.permute(0, 3, 1,2)
    x_tensor = x_tensor.to(torch.float32)

    pr_mask = model.predict(x_tensor)
    pr_mask = (pr_mask > p).squeeze().cpu().numpy()
    file_name = os.path.join(PREDICTED, img_path.split('/')[-1])
    cv2.imwrite(file_name, np.float32(pr_mask))

FileNotFoundError: [Errno 2] No such file or directory: './cropped/good/ann'