### Download the trained model from google drive

In [1]:
! gdown -O saved_models/ https://drive.google.com/uc?id=1HBSGXbWw5Vorj82buF-gCi6S2DpF4mFL 

Downloading...
From: https://drive.google.com/uc?id=1HBSGXbWw5Vorj82buF-gCi6S2DpF4mFL
To: /home/maruf/pipeline/saved_models/Trained_model_SM.pth
100%|█████████████████████████████████████████| 113M/113M [00:01<00:00, 104MB/s]


# Imports

In [2]:
%matplotlib inline

import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

import torch
import segmentation_models_pytorch as smp
import albumentations as albu

import seaborn as sns
import pylab as py
import pandas as pd
import torch
from torchvision import transforms

from scripts.helper import *

# Preprocessing

In [3]:
# input image directory
input_dir= 'input_images/'
input_img_names = os.listdir(input_dir)
input_img_names = [items for items in input_img_names if items[-4:]=='.png' or items[-4:]=='.jpg']
if len(input_img_names) == 0:
    print('Please put the input images in input_images/ directory!')
    
for idx, image_name in enumerate(input_img_names):
    img = Image.open('input_images/'+image_name)
    r_width = 800
    r_height = 320
    r_img = transforms.Resize((r_height, r_width))(img)
    r_img.save('transformed_images/'+image_name)

# Load the trained model

In [4]:
model, preprocessing_fn, CLASSES = load_pretrained_model()
DEVICE = ('cuda:4' if torch.cuda.is_available() else 'cpu')

In [5]:
model = torch.load('saved_models/Trained_model_SM.pth').to(DEVICE)

In [6]:
test_dataset = Dataset(
    images_dir = 'transformed_images/',
    masks_dir=None,
    augmentation = get_validation_augmentation(),
    preprocessing = get_preprocessing_unlabeled(preprocessing_fn),
    classes = CLASSES,
)

test_dataset_viz = Dataset(
    images_dir = 'transformed_images/',
    masks_dir=None,
    classes = CLASSES,
)

# Prediction

In [7]:
output_img_names = os.listdir('transformed_images/')
output_img_names = [items for items in input_img_names if items[-4:]=='.png' or items[-4:]=='.jpg']

In [8]:
# to plot the legends
legend_img = Image.open('Legends.png')
w_legend, h_legend = legend_img.size
ar = w_legend/h_legend

In [9]:
for test_no, image_name in enumerate(output_img_names):
    image_viz, mask_viz = test_dataset_viz[test_no]
    image, mask = test_dataset[test_no]
    img_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pred_mask = model.predict(img_tensor)
    pred_mask = pred_mask.squeeze().cpu().numpy().round()
    input_img_viz = Image.fromarray(image_viz)
    new_h = input_img_viz.size[1]
    legend = legend_img.resize((int(new_h*ar), new_h), Image.ANTIALIAS)
    np.save('segmented_outputs/'+image_name[:-4]+'.npy', pred_mask)
    get_color_img(pred_mask, normal=False).save('segmented_images/'+image_name)
    get_concat_h(
        get_concat_h(
            Image.fromarray(image_viz),
            get_color_img(pred_mask, normal=False)
        ), 
        legend
    ).save('joint_images/'+image_name)