<a href="https://colab.research.google.com/github/olaviinha/NeuralBackgroundRemoval/blob/main/DIS_public.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#<font face="Trebuchet MS" size="6">DIS IS-Net<font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><font color="#999" size="4">Background removal</font><font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><a href="https://github.com/olaviinha/NeuralBackgroundRemoval" target="_blank"><font color="#999" size="4">Github</font></a>

Colab for [Highly Accurate Dichotomous Image Segmentation](https://arxiv.org/pdf/2203.03041.pdf) by Xuebin Qin et al.

<font color="salmon">Note!</font> In light of Google Colab's recent price hike circus, please note that running this notebook is probably entirely tolerable on CPU only runtime (without GPU). Images are processed within seconds on CPU.

### Tips

- All directory and file paths should be relative to your Google Drive root: e.g. if you have a directory called _images_ in your Drive, containing a subdirectory called _churchboats_, then the field value in this notebook should be `images/churchboats`.

- `input` may be a path to an image file or a directory containing image files. All images found from the directory will be individually processed (images in subdirs not included).

- If you provide a `local_models_dir` path, all models will be fetched from there instead of being downloaded. If models are not found from the given path, they will be downloaded there first. Using this may be useful in the future if models are no longer available in their current locations.

- `end_session_when_done` disconnects and deletes runtime upon cell completion. Mostly useful if you are processing a large number of images on a GPU runtime.

In [None]:
#@title #Setup
#@markdown This cell needs to be run only once. It will mount your Google Drive and setup prerequisites.<br>
#@markdown <small>Mounting Drive will enable this notebook to save outputs directly to your Drive. Otherwise you will need to copy/download them manually from this notebook.</small>

force_setup = False
repositories = [
  # 'https://github.com/xuebinqin/U-2-Net',
  'https://github.com/xuebinqin/DIS'
  # 'https://github.com/xuebinqin/DIS'
]
pip_packages = ''
apt_packages = ''
mount_drive = True #@param {type:"boolean"}
local_models_dir = "" #@param {type:"string"}
skip_setup = False #@ param {type:"boolean"}

# Download the repo from Github
import os
from google.colab import output
import warnings
warnings.filterwarnings('ignore')
%cd /content/

# inhagcutils
if not os.path.isfile('/content/inhagcutils.ipynb') and force_setup == False:
  !pip -q install import-ipynb {pip_packages}
  if apt_packages != '':
    !apt-get update && apt-get install {apt_packages}
  !curl -s -O https://raw.githubusercontent.com/olaviinha/inhagcutils/master/inhagcutils.ipynb
import import_ipynb
from inhagcutils import *

# Mount Drive
if mount_drive is True:
  if not os.path.isdir('/content/drive'):
    from google.colab import drive
    drive.mount('/content/drive')
    drive_root = '/content/drive/My Drive'
  if not os.path.isdir('/content/mydrive'):
    os.symlink('/content/drive/My Drive', '/content/mydrive')
    drive_root = '/content/mydrive/'
  drive_root_set = True
else:
  create_dirs(['/content/faux_drive'])
  drive_root = '/content/faux_drive/'

if len(repositories) > 0:
  if skip_setup == False:
    for repo in repositories:
      %cd /content/
      install_dir = fix_path('/content/'+path_leaf(repo).replace('.git', ''))
      repo_name = basename(install_dir)
      repo = repo if '.git' in repo else repo+'.git'
      !git clone {repo}
      if os.path.isfile(install_dir+'setup.py') or os.path.isfile(install_dir+'setup.cfg'):
        !pip install -e ./{install_dir}
      if os.path.isfile(install_dir+'requirements.txt'):
        !pip install -r {install_dir}/requirements.txt
  else:
    install_dir = fix_path('/content/'+path_leaf(repositories[0]).replace('.git', ''))
    repo_name = path_leaf(install_dir)

dir_tmp = '/content/tmp/'
create_dirs([dir_tmp])

if len(repositories) == 1:
  %cd {install_dir}

import time, sys
from datetime import timedelta

# # DO stuff
# if repo_name == 'DIS':
#   !gdown 1jOC2zK0GowBvEt03B7dGugCRDVAoeIqq

def prep_model(model, gid):
  global install_dir, models_dir
  filename = model+'.pth'
  if not os.path.isfile(models_dir+filename):
    !gdown {gid}
  if os.path.isfile(models_dir+filename):
    mdir = install_dir+'saved_models/'
    if not os.path.isdir(mdir): os.mkdir(mdir)
    print( 'From', models_dir+filename ) 
    print(' To', mdir+filename )
    shutil.copy(models_dir+filename, mdir+filename)
  else:
    op(c.fail, 'Failed', filename)

if not os.path.isdir(install_dir+'saved_models/IS-Net'):
  os.mkdir(install_dir+'saved_models/IS-Net')

if local_models_dir != '':
  models_dir = drive_root+fix_path(local_models_dir)
  if not os.path.isdir(models_dir): os.mkdir(models_dir)
  %cd {models_dir}
  prep_model('isnet', '1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn')
  prep_model('isnet-general-use', '1nV57qKuy--d5u1yvkng9aXW1KS4sOpOi')

  %cd {install_dir}
else:
  d = install_dir+'saved_models/IS-Net'
  os.mkdir(d)
  %cd {d}
  # isnet.pth
  !gdown 1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn
  # isnet-general-use.pth
  !gdown 1nV57qKuy--d5u1yvkng9aXW1KS4sOpOi
  

%cd {install_dir}/IS-Net


import numpy as np
from PIL import Image, ImageEnhance
import torch
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
import gdown
import os

import requests
import matplotlib.pyplot as plt
from io import BytesIO

# project imports
from data_loader_cache import normalize, im_reader, im_preprocess 
from models import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

%cd {install_dir}


class GOSNormalize(object):
    '''
    Normalize the Image using torch.transforms
    '''
    def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
        self.mean = mean
        self.std = std

    def __call__(self,image):
        image = normalize(image,self.mean,self.std)
        return image


transform =  transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])

def load_image(im_path, hypar):
    if im_path.startswith("http"):
        im_path = BytesIO(requests.get(im_path).content)

    im = im_reader(im_path)
    im, im_shp = im_preprocess(im, hypar["cache_size"])
    im = torch.divide(im,255.0)
    shape = torch.from_numpy(np.array(im_shp))
    return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape


def build_model(hypar,device):
    net = hypar["model"]#GOSNETINC(3,1)

    # convert to half precision
    if(hypar["model_digit"]=="half"):
        net.half()
        for layer in net.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.float()

    net.to(device)

    if(hypar["restore_model"]!=""):
        net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location=device))
        net.to(device)
    net.eval()  
    return net

    
def predict(net,  inputs_val, shapes_val, hypar, device):
    '''
    Given an Image, predict the mask
    '''
    net.eval()

    if(hypar["model_digit"]=="full"):
        inputs_val = inputs_val.type(torch.FloatTensor)
    else:
        inputs_val = inputs_val.type(torch.HalfTensor)

  
    inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
   
    ds_val = net(inputs_val_v)[0] # list of 6 results

    pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W    # we want the first one which is the most accurate prediction

    ## recover the prediction spatial size to the orignal image size
    pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))

    ma = torch.max(pred_val)
    mi = torch.min(pred_val)
    pred_val = (pred_val-mi)/(ma-mi) # max = 1

    if device == 'cuda': torch.cuda.empty_cache()
    return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need


hypar = {} # paramters for inferencing

hypar["model_path"] ="./saved_models" ## load trained weights from this path
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision

##  choose floating point accuracy --
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
hypar["seed"] = 0

hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size

## data augmentation parameters ---
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation

hypar["model"] = ISNetDIS()



net = build_model(hypar, device)




def save_output(og_img, pred, output_path, contrast=1, save_mask=False):
  im = Image.fromarray(pred).convert('L')
  og_image = Image.open(og_img)
  mask = im.resize(og_image.size, resample=Image.BILINEAR)
  if contrast != 1:
    enhancer = ImageEnhance.Contrast(mask)
    mask = enhancer.enhance(contrast)
  if save_mask is True:
    mask.save(output_path.replace('.png', '_mask.png'), 'PNG')
  im2 = Image.open(og_img)
  im2.putalpha(mask)
  im2.save(output_path, 'PNG')




# output.clear()
op(c.ok, 'Setup finished.')

In [None]:
#@title # Run
include_subdirs = False #@ param {type: "boolean"}
input = "" #@param {type:"string"}
output_dir = "" #@param {type:"string"}
filename_detail = "" #@ param {type:"string"}
# scaler = 0.5 #@param {type:"slider", min:0.25, max:4, step:0.25}
# scalers = "1" #@param {type:"string"}
model = "isnet-general-use" #@param ["isnet", "isnet-general-use"]
disconnect_runtime_when_done = False #@param {type: "boolean"}

contrast = 1
save_mask = False

model_path = install_dir+'saved_models/IS-Net/'+model+'.pth'
model_name = basename(model_path)
uniq_id = gen_id()
trunc = 40
est_per_img = 8

copy_files = False
exclude_images_containing = ['isnet', 'u2net'] # ignore images that contain these strings

if os.path.isfile(drive_root+input):
  inputs = [drive_root+input]
  dir_in = path_dir(drive_root+input)
  copy_files = True
elif input != '' and os.path.isdir(drive_root+input):
  dir_in = drive_root+fix_path(input)
  inputs = list_images(dir_in, exclude_pattern=exclude_images_containing)
  if include_subdirs is True:
    dirlist = glob(dir_in+'/*')
    for item in dirlist:
      if os.path.isdir(item):
        inputs.extend(list_images(item, exclude=exclude_images_containing))
    # inputs = list(set(inputs))
elif os.path.isdir(drive_root+input) and '*' in input:
  dir_in = path_dir(drive_root+input)
  inputs = glob(drive_root+input)
  copy_files = True
else:
  op(c.fail, 'FAIL!', 'Input should be a path to a file or a directory.')
  sys.exit('Input not understood.')

# Output
if output_dir == '':
  dir_out = dir_in
else:
  if not os.path.isdir(drive_root+output_dir):
    os.mkdir(drive_root+output_dir)
  dir_out = drive_root+fix_path(output_dir)
  
timer_start = time.time()
total = len(inputs)

if copy_files == True:
  files = glob(dir_tmp)
  for f in files:
    os.remove(f)
  for input in inputs:
    shutil.copy(input, dir_tmp)
  inputs = list_images(dir_tmp)
  dir_in = dir_tmp



using = 'GPU' if device == 'cuda' else 'CPU'
total = len(inputs) * len(scales)
count = 1

est_time = est_per_img * total

op(c.title, 'RUN ID:', uniq_id, time=True)
op(c.okb, 'Using '+using+' to process '+str(total)+' images out of '+str(len(inputs))+' images', time=True)
# op(c.okb, 'Estimated time:', timedelta(seconds=est_time), time=True)

print()

for input in inputs:
  if output_dir == '' and include_subdirs is True:
    dir_out = dir_in

  ndx_info = str(count)+'/'+str(total)+' '
  op(c.title, ndx_info+'Processing', path_leaf(input), time=True)

  image_tensor, orig_size = load_image(input, hypar) 
  mask = predict(net, image_tensor, orig_size, hypar, device)

  fd = filename_detail+'__' if filename_detail != '' else ''
  file_out = dir_out+fd+slug(basename(input)[:trunc])+'x_'+uniq_id+'_'+model_name+'.png'

  save_output(input, mask, file_out, contrast, save_mask)
  if os.path.isfile(file_out):
    op(c.ok, 'Saved as', path_leaf(dir_out)+'/'+path_leaf(file_out), time=True)
  else:
    op(c.fail, 'ERROR saving', file_out, time=True)

  timer_pc = time.time()
  elapsed = timer_pc-timer_start
  est_remaining = est_time-elapsed
  # op(c.okb, 'Est. time remaining', timedelta(seconds=est_remaining), time=True)
  count += 1
  print()
  
# -- END THINGS --

timer_end = time.time()

print()
op(c.okb, 'Elapsed', timedelta(seconds=timer_end-timer_start), time=True)
op(c.ok, 'FIN.')

if disconnect_runtime_when_done is True: end_session()