## Stage I : SeamFormer

Goal : The purpose of this notebook is for users to perform quick inference on their custom datasets without finetuning/training. We offer multiple model checkpoints for you to try out !

Inputs to the notebook :  

* Image Folder

Pointers on how to choose the optimal model weight :

- If your custom input images are more like Palm Leaf Manuscripts ( Balineese/Sundaneese/Khmer ) , you can opt for the checkpoint : BKS.pt

- If your custom input images are mode dense in nature and are closely related to Indic documents , you may experiment with : I2.pt


# Library Imports

In [None]:
# Requires installations
!pip install vit_pytorch==0.24.3
!pip install empatches
!pip install gdown
!pip install plantcv==3.14.1

# Library Imports
import copy
import cv2
import os
import pathlib
import sys
import numpy as np
from scipy.interpolate import interp1d
from vit_pytorch import ViT
from einops import rearrange
from empatches import EMPatches
from plantcv import plantcv as pcv
pcv.params.debug = None


# Torch Imports
import torch
from torch import nn
import torch.nn.functional as F
from vit_pytorch.vit import Transformer
from google.colab.patches import cv2_imshow




# Global Settings

In [None]:
# Default Inputs
# THRESHOLD = 0.3 ## binarization threshold after the model output
# SPLITSIZE =  256  ## your image will be divided into patches of 256x256 pixels
# setting = "base"  ## choose the desired model size [small, base or large], depending on the model you want to use
# patch_size = 8 ## choose your desired patch size [8 or 16], depending on the model you want to use
# image_size =  (SPLITSIZE,SPLITSIZE)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


settings={
    "dataset_code":"BKS",
    "experiment_base":"scribbleGen",
    "wid" : "BKS_L2_Loss_V3_EncFreeze_MorePatches_th7",
    "data_path":"/scratch/nv/BKS/",
    "model_weights_path":"/scratch/nv/weights_BKS_Exp/",
    "visualisation_folder":"/scratch/nv/vis_BKS_Exp/",
    "learning_rate":0.005,
    "vit_model_size":"base",
    "imgsize":256,
    "patchsize":8,
    "split_size":256,
    "vit_patch_size":8,
    "encoder_freeze":"False",
    "encoder_layers":6,
    "encoder_heads":8,
    "encoder_dims":768,
    "batch_size":8,
    "num_epochs":25,
    "train_scribble":False,
    "train_binary":False,
    "vis_results":"True",
    "scribble_weights_path":"/content/drive/MyDrive/Stable_SF/network-Public_BKS_E3_V3-9.pt",
    "binary_weights_path":'/content/drive/MyDrive/Stable_SF/BKS_correct_final.pt',
    "threshold":0.3
}


# Helper Functions

In [None]:
def downloadWeights(modelType):
  if modelType =='I2':
    if not os.path.exists('I2.pt') :
      !gdown 1O_CtJToNUPrQzbMN38FsOJwEdxCDXqHh
    else:
      print('I2.pt is already existing .. Skipping download !')
  elif modelType == 'BKS':
    if not os.path.exists('BKS.pt') :
      !gdown 1nro1UjYRSlMIaYUwkMTrfZzrE_kz0QDF
    else:
      print('BKS.pt is already existing .. Skipping download !')
  else :
    print('Invalid Model Checkpoint')

# Network Instantiation

In [None]:
class SeamFormer(nn.Module):
    def __init__(
        self,
        *,
        encoder,
        decoder_dim,
        decoder_depth = 1,
        decoder_heads = 8,
        decoder_dim_head = 64,
        patch_size =8):

        super().__init__()
        # extract hyperparameters and functions from the ViT encoder.
        self.encoder = encoder
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
        self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
        # pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
        pixel_values_per_patch = patch_size * patch_size

        # Binary Decoder
        self.enc_to_dec_bin = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
        self.mask_token_bin = nn.Parameter(torch.randn(decoder_dim))
        self.decoder_bin = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
        self.decoder_pos_emb_bin = nn.Embedding(num_patches, decoder_dim)
        self.to_pixels_bin = nn.Linear(decoder_dim, pixel_values_per_patch)

        # Scribble Decoder
        self.enc_to_dec_scr = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
        self.mask_token_scr = nn.Parameter(torch.randn(decoder_dim))
        self.decoder_scr = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
        self.decoder_pos_emb_scr = nn.Embedding(num_patches, decoder_dim)
        self.to_pixels_scr = nn.Linear(decoder_dim, pixel_values_per_patch)


    def forward(self,img,gt_bin_img=None,gt_scr_img=None,criterion=None,strain=True,btrain=True,mode='train'):
        scribbleloss=None
        gt_scr_patches=None
        binaryloss=None
        gt_bin_patches=None

        # get patches and their number
        patches = self.to_patch(img)
        _, num_patches, *_ = patches.shape
        # project pixel patches to tokens and add positions
        tokens = self.patch_to_emb(patches)
        tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]
        # encode tokens by the encoder
        encoded_tokens = self.encoder.transformer(tokens)

        if btrain:
            decoder_tokens_bin = self.enc_to_dec_bin(encoded_tokens)
            # decode tokens with decoder
            decoded_tokens_bin = self.decoder_bin(decoder_tokens_bin)
            # project tokens to pixels
            pred_pixel_values_bin = self.to_pixels_bin(decoded_tokens_bin)
            ## --- Focal Loss ---
            if mode == 'train':
                # calculate the loss with gt
                if gt_bin_img is not None:
                    gt_bin_patches = self.to_patch(gt_bin_img)
                ## ---  Weighted BCE Loss ---
                binaryloss = criterion(pred_pixel_values_bin,gt_bin_patches)
                # pt = torch.exp(-binaryloss)
                # binaryloss = ((1-pt)**2) * binaryloss
                # binaryloss = torch.mean(binaryloss)
                return binaryloss,gt_bin_patches,pred_pixel_values_bin

        if strain:
            decoder_tokens_scr = self.enc_to_dec_scr(encoded_tokens)
            # decode tokens with decoder
            decoded_tokens_scr = self.decoder_scr(decoder_tokens_scr)
            # project tokens to pixels
            pred_pixel_values_scr = self.to_pixels_scr(decoded_tokens_scr)
            ## --- Focal Loss ---
            if mode == 'train':
                # calculate the loss with gt
                if gt_scr_img is not None:
                    gt_scr_patches = self.to_patch(gt_scr_img)
                ## ---  Weighted BCE Loss ---
                scribbleloss = criterion(pred_pixel_values_scr,gt_scr_patches)
                # pt = torch.exp(-scribbleloss)
                # scribbleloss = ((1-pt)**2) *scribbleloss
                # scribbleloss = torch.mean(scribbleloss)
                return scribbleloss,gt_scr_patches,pred_pixel_values_scr

        if mode=='test':
            return pred_pixel_values_bin,pred_pixel_values_scr



# Network Instantiation

# Encoder settings
encoder_layers = settings['encoder_layers']
encoder_heads = settings['encoder_heads']
encoder_dim = settings['encoder_dims']

# Encoder
v = ViT(
    image_size = settings['imgsize'],
    patch_size =  settings['patchsize'],
    num_classes = 1000,
    dim = encoder_dim,
    depth = encoder_layers,
    heads = encoder_heads,
    mlp_dim = 2048)

# Full model
network = SeamFormer(encoder = v,
    decoder_dim = encoder_dim,
    decoder_depth = encoder_layers,
    decoder_heads = encoder_heads).to(device)


# Helper Function : Reading Inputs

In [None]:
def preprocess(deg_img):
    deg_img = (np.array(deg_img) /255).astype('float32')
    # normalize data
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    out_deg_img = np.zeros([3, *deg_img.shape[:-1]])
    for i in range(3):
        out_deg_img[i] = (deg_img[:,:,i] - mean[i]) / std[i]
    return out_deg_img

def deformat(listofpoints):
    # Input : [[[x1,y1],[[x2,y2]],[[x3,y3]]....]
    # Output : [ [x1,y1], [x2,y2],[x3,y3]....]
    output = [ pt[0].tolist() for pt in listofpoints ]
    return output

# Supply the raw image here
def cleanImageFindContours(patch,threshold = 0.20):
    patch = np.uint8(patch)
    #   patch = cv2.cvtColor(patch,cv2.COLOR_BGR2GRAY)
    contours, hierarchy = cv2.findContours(patch,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if len(contours)<1:
        print('No contours in the raw image!')
        return patch
    # Else sort them
    cntsSorted = sorted(contours, key=lambda x: cv2.contourArea(x),reverse=True)
    areaList = [cv2.contourArea(c) for c in cntsSorted]
    maxArea = max(areaList)
    sortedContours = [deformat(c) for c in cntsSorted if cv2.contourArea(c)>threshold*maxArea]
    # Draw on canvas
    canvas=np.zeros(patch.shape)
    for i,cnt in enumerate(sortedContours):
        canvas= cv2.fillPoly(canvas,np.int32([cnt]),color=(255,255,255))
    return sortedContours,canvas



def readFullImage(path,PDIM=256,DIM=256,OVERLAP=0.25):
    input_patches=[]
    emp = EMPatches()
    try:
        img = cv2.imread(path)
        img = preprocess(img)
        img = np.transpose(img)
        img_patches, indices = emp.extract_patches(img,patchsize=PDIM,overlap=OVERLAP)
        for i,patch in enumerate(img_patches):
              resized=[DIM,DIM]
              if patch.shape[0]!= DIM or patch.shape[1]!= DIM :
                  resized=[patch.shape[0],patch.shape[1]]
                  patch = cv2.resize(patch,(DIM,DIM),interpolation = cv2.INTER_AREA)
              # cv2_imshow(patch)
              patch = np.asarray(patch,dtype=np.float32)
              patch =  np.transpose(patch)
              patch= np.expand_dims(patch,axis=0)
              sample={'img':patch,'resized':resized}
              input_patches.append(sample)

    except Exception as exp :
        print('ImageReading Error ! :{}'.format(exp))
        return None
    return input_patches,indices

def stack_images_vertically(image_a, image_b,image_c):
    """
    Stacks two images vertically on top of each other.

    Args:
        image_a: The first image (numpy array).
        image_b: The second image (numpy array).

    Returns:
        The vertically stacked image (numpy array).

    Raises:

        ValueError: If the images have different widths.
    """
    # Check if the dimensions of the two images match
    if image_a.shape[1] != image_b.shape[1] or image_a.shape[1] != image_c.shape[1]:
        raise ValueError("The images must have the same width")

    # Stack the images vertically
    stacked_image = np.vstack((image_a, image_b,image_c))
    return stacked_image


def reconstruct(pred_pixel_values,patch_size,target_shape,image_size):
    rec_patches = copy.deepcopy(pred_pixel_values)
    output_image = rearrange(rec_patches, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)',p1 = patch_size, p2 = patch_size,  h=image_size[0]//patch_size)
    output_image = output_image.cpu().numpy().squeeze()
    output_image =  output_image.T
    # Resizing to get desired output
    output_image = cv2.resize(output_image,target_shape, interpolation = cv2.INTER_AREA)
    # Basic Thresholding
    output_image[np.where( output_image>1)] = 1
    output_image[np.where( output_image<0)] = 0
    return output_image


def inferenceNetwork(network,path,PDIM=256,DIM=256,OVERLAP=0.25,THRESHOLD=0.3,save=True):
    parentImage=cv2.imread(path)
    patch_size = 8
    emp = EMPatches()
    if not os.path.exists(path):
        print('Invalid File Path ! Skipping Inference')
        sys.exit()
    else:
        input_patches , indices = readFullImage(path,PDIM,DIM,OVERLAP)
        soutput_patches=[]
        boutput_patches=[]
        # Iterate through the resulting patches
        weight = torch.tensor(1)
        for i,sample in enumerate(input_patches):
            p = sample['img']
            target_shape = (sample['resized'][1],sample['resized'][0])
            with torch.no_grad():
                inputs =torch.from_numpy(p).to(device)
                # Pass through model
                # loss, patches, pred_pixel_values = model(inputs,inputs,criterion,train=False)
                loss_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weight, reduction='none')
                pred_pixel_values_bin,pred_pixel_values_scr=network(inputs,gt_bin_img=inputs,gt_scr_img=inputs,criterion=loss_criterion,strain=True,btrain=True,mode='test')

                # Send them to .cpu
                pred_pixel_values_bin = pred_pixel_values_bin.cpu()
                pred_pixel_values_scr = pred_pixel_values_scr.cpu()

                bpatch=reconstruct(pred_pixel_values_bin,patch_size,target_shape,(DIM,DIM))
                spatch=reconstruct(pred_pixel_values_scr,patch_size,target_shape,(DIM,DIM))

                # binarize the predicted image taking 0.5 as threshold
                bpatch = ( bpatch>THRESHOLD)*1
                spatch = ( spatch>THRESHOLD)*1

                # Append the net processed patch
                soutput_patches.append(255*spatch)
                boutput_patches.append(255*bpatch)

        assert len(boutput_patches)==len(soutput_patches)==len(input_patches),"Error in patch count!"

        # Restich the image
        soutput = emp.merge_patches(soutput_patches,indices,mode='max')
        boutput = emp.merge_patches(boutput_patches,indices,mode='max')

        # Binary Done
        binaryOutput=np.transpose(boutput,(1,0))

        # Scribble Done
        soutput=np.transpose(soutput,(1,0))
        contours,scribbleOutput=cleanImageFindContours(patch=soutput.astype(np.uint8),threshold = 0.15)

        # Sharing both
        res= stack_images_vertically(image_a=np.asarray(parentImage[:,:,0],dtype=np.uint8).squeeze(),image_b=scribbleOutput,image_c=binaryOutput)
        return binaryOutput,scribbleOutput,res

# Utils
def imageCombiner(imgs):
  imgs_comb = np.hstack([i for i in imgs])
  return imgs_comb


# Polygon to Distance Mask
def polygon_to_distance_mask(polygon_mask,threshold=60):
    # Read the polygon mask image as a binary image
    # polygon_mask = cv2.cvtColor(pmask,cv2.COLOR_BGR2GRAY)

    # Ensure that the mask is binary (0 or 255 values)
    _, polygon_mask = cv2.threshold(polygon_mask,100,255, cv2.THRESH_BINARY)

    # Compute the distance transform
    distance_mask = cv2.distanceTransform(polygon_mask, cv2.DIST_L2, cv2.DIST_MASK_5)

    # Normalize the distance values to 0-255 range
    distance_mask = cv2.normalize(distance_mask, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)

    # Threshold the image
    src = copy.deepcopy(distance_mask)
    src[src<threshold]=0
    src[src>=threshold]=255
    src = np.uint8(src)
    return src


# Helper Function : Post Processing

In [None]:
'''
Post Processing Function
'''

def find_corner_points(points):
    if not points:
        return None, None
    # Sort the points based on their x-coordinate
    sorted_points = sorted(points, key=lambda point: point[0])
    leftmost_point = list(sorted_points[0])
    rightmost_point = list(sorted_points[-1])
    return leftmost_point, rightmost_point

def hullNise(polygons):
  hulls = []
  if len(polygons)==0:
    return hulls
  else:
    for p in polygons:
      p = np.asarray(p,dtype=np.int32)
      hull = cv2.convexHull(p)
      hull =  np.asarray(hull,dtype=np.int32).reshape((-1,2))
      hulls.append(hull.tolist())
  return hulls

def uniformly_sampled_line(points,T=50):
    num_points = min(len(points),T)
    # Separate x and y coordinates from the given points
    x_coords, y_coords = zip(*points)

    # Calculate the cumulative distance along the original line
    distances = np.cumsum(np.sqrt(np.diff(x_coords) ** 2 + np.diff(y_coords) ** 2))
    distances = np.insert(distances, 0, 0)  # Add the initial point (0, 0) distance

    # Create a linear interpolation function for x and y coordinates
    interpolate_x = interp1d(distances, x_coords, kind='linear')
    interpolate_y = interp1d(distances, y_coords, kind='linear')

    # Calculate new uniformly spaced distances
    new_distances = np.linspace(0, distances[-1], num_points)

    # Interpolate new x and y coordinates using the uniformly spaced distances
    new_x_coords = interpolate_x(new_distances)
    new_y_coords = interpolate_y(new_distances)

    # Create a list of new points
    new_points = [[np.int32(new_x_coords[i]), np.int32(new_y_coords[i])] for i in range(num_points)]
    return new_points

# Scribble Generation
def generateScribble(H,W,polygon):
    # Generate Canvas
    canvas = np.zeros((H,W))
    # Mark the polygon on the canvas
    leftmost_point, rightmost_point = find_corner_points(polygon)
    poly_arr = np.asarray(polygon,dtype=np.int32).reshape((-1,1,2))
    canvas = cv2.fillPoly(canvas,[poly_arr],(255,255,255))
    # Scribble generation
    skeleton = pcv.morphology.skeletonize(canvas)
    pruned_skeleton,_,segment_objects = pcv.morphology.prune(skel_img=skeleton,size=100)
    scribble = np.asarray(segment_objects[0],dtype=np.int32).reshape((-1,2))
    scribble=scribble.tolist()
    # scribble = collect_mask_points(pruned_skeleton)
    scribble = uniformly_sampled_line(scribble,1000)
    if leftmost_point is not None and rightmost_point is not None:
      scribble.append(leftmost_point)
      scribble.append(rightmost_point)
      scribble = sorted(scribble, key=lambda point: point[0])
    return scribble

# Text Dilation
def text_dilate(image, kernel_size, iterations=1):
    # Create a structuring element (kernel) for dilation
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    # Perform dilation
    dilated_image = cv2.dilate(image, kernel, iterations=iterations)
    return dilated_image

# Horizontal Dilation
def horizontal_dilation(image, kernel_width=5,iterations=1):
    # Create a horizontal kernel for dilation
    kernel = np.ones((1, kernel_width), np.uint8)
    # Perform dilation
    dilated_image = cv2.dilate(image, kernel, iterations)
    return dilated_image


def average_coordinates(hull):
    # Calculate the average x and y coordinates of all points in the hull/contour.
    # Format : [[x1,y1], [x2,y2],[x3,y3]...[xn,yn]]
    num_points = len(hull)
    avg_x = sum(pt[0][0] for pt in hull) / num_points
    avg_y = sum(pt[0][1] for pt in hull) / num_points
    return avg_x, avg_y

# You send set of clean contours to this function
# you obtain a list of hulls and merge the ones on the same horizontal level.
def combine_hulls_on_same_level(contours, threshold=45):
    combined_hulls=[]
    hulls = hullNise(contours)

    # Sort the hulls by the average y-coordinate of all points
    sorted_hulls = sorted(hulls, key=lambda hull: average_coordinates(hull)[1])

    current_combined_hull = sorted_hulls[0]
    for hull in sorted_hulls[1:]:
        # Check if the current hull is on the same horizontal level as the combined hull
        if abs(average_coordinates(hull)[1] - average_coordinates(current_combined_hull)[1]) < threshold:
            # Merge the hulls by extending the current_combined_hull with hull
            current_combined_hull = np.vstack((current_combined_hull, hull))
        else:
            # Hull is on a different level, add the current combined hull to the result
            combined_hulls.append(current_combined_hull)
            current_combined_hull = hull

    # Add the last combined hull
    combined_hulls.append(current_combined_hull)
    nethulls = [cv2.convexHull(np.array(contour)) for contour in combined_hulls]
    return nethulls

def postProcess(scribbleImage,binaryImage,binaryThreshold=50,rectangularKernel=50):
    bin_ = binaryImage.astype(np.uint8)
    scr = scribbleImage.astype(np.uint8)
    # print('PP @ BIN SHAPE : {} SCRIBBLE SHAPE : {}'.format(scribbleImage.shape,binaryImage.shape))
    # bin_ = cv2.cvtColor(bin_,cv2.COLOR_BGR2GRAY)
    H,W = bin_.shape

    # Threshold it
    bin_[bin_>=binaryThreshold]=255
    bin_[bin_<binaryThreshold]=0
    scr[scr>=binaryThreshold]=255
    scr[scr<binaryThreshold]=0

    # We apply distance transform to thin the output polygon
    scr = polygon_to_distance_mask(scr,threshold=50)

    # Bitwise AND of the textual region and polygon region ( only cut off letters will be highlighted)
    scr_ = cv2.bitwise_and(bin_/255,scr/255)
    # Dilate the existing text content
    scr_ = text_dilate(scr_,kernel_size=3,iterations=3) # SD = 3,3
    # Dilate it horizontally to fill the gaps within the text region
    scr_ = horizontal_dilation(scr_,rectangularKernel,3) # SD - 50 ,3

    # Extract the final contours
    contours = cleanImageFindContours(np.uint8(scr_),threshold = 0.10)

    print('Length of contours : {}'.format(len(contours)))
    print(contours[0])

    # Combine the hulls that are on the same horizontal level
    new_hulls = combine_hulls_on_same_level(contours,threshold=20)
    # Scribble Generation
    predictedScribbles=[]
    for hull in new_hulls:
        hull = np.asarray(hull,dtype=np.int32).reshape((-1,2)).tolist()
        scr_ = generateScribble(H,W,hull)
        if scr_ is not None:
            predictedScribbles.append(scr_)
    return predictedScribbles



# User Inputs

In [None]:
# Please enter valid values
IMAGE_FOLDER = input("Please provide the full path to the folder : ")
MODEL_CHECKPOINT_TYPE = input(" Choose the model checkpoint (BKS / I2 ) : ")
OUTPUT_FOLDER = input("Please provide the full path to the output folder : ")
downloadWeights(modelType=MODEL_CHECKPOINT_TYPE)


Please provide the full path to the folder : /content/
 Choose the model checkpoint (BKS / I2 ) : BKS
Please provide the full path to the output folder : BKS_OPT
BKS.pt is already existing .. Skipping download !


# Visual Inference

In [None]:
def visualInference(folderPath,outputPath):

  # Output Folder Creation
  os.makedirs(outputPath,exist_ok = True)
  os.makedirs(os.path.join(outputPath,'binaryImages'),exist_ok = True)
  os.makedirs(os.path.join(outputPath,'scribbleImages'),exist_ok = True)

  if os.path.exists(folderPath):
    fileNames =  [filename for filename in os.listdir(folderPath) if filename.endswith('.jpg') or filename.endswith('.png')]
  # Iterating through the samples and generating the results
  for f in fileNames :
    path = os.path.join(IMAGE_FOLDER,f)
    img = cv2.imread(path)
    binaryOutput,scribbleOutput,res = inferenceNetwork(network,path,PDIM=256,DIM=256,OVERLAP=0.50,THRESHOLD=0.5,save=True)
    print('Scribble Output : ')
    cv2_imshow(scribbleOutput)
    print('Binary Output : ')
    cv2_imshow(binaryOutput)
    # Scribbles
    binaryOutput=np.uint8(binaryOutput)
    scribbleOutput=np.uint8(scribbleOutput)
    scribbles = postProcess(scribbleOutput,binaryOutput,binaryThreshold=50,rectangularKernel=50)
    img2 = copy.deepcopy(img)
    for p in scribbles:
        p = np.asarray(p,dtype=np.int32).reshape((-1,1,2))
        img2 = cv2.polylines(img2, [p],False, (0,255,0),2)


    # Writing to the directory ..
    cv2.imwrite(os.path.join(outputPath,'binaryImages/{}'.format(f)),binaryOutput)
    cv2.imwrite(os.path.join(outputPath,'scribbleImages/{}'.format(f)),scribbleOutput)

    print('Image with Scribbles Overlaid :')
    cv2_imshow(img2)

# Network Weights Loading..
network.load_state_dict(torch.load(MODEL_CHECKPOINT_TYPE+'.pt',map_location=device),strict=True)
visualInference(IMAGE_FOLDER,OUTPUT_FOLDER)

Scribble Output : 
