<a href="https://colab.research.google.com/github/olaviinha/NeuralBackgroundRemoval/blob/main/u2net_public.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#<font face="Trebuchet MS" size="6">U²-Net<font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><font color="#999" size="4">Background removal</font><font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><a href="https://github.com/olaviinha/NeuralBackgroundRemoval" target="_blank"><font color="#999" size="4">Github</font></a>

Colab for [U²-Net: Going Deeper with Nested U-Structure for Salient Object Detection](https://arxiv.org/pdf/2005.09007.pdf) by Xuebin Qin et al.

<font color="salmon">Note!</font> In light of Google Colab's recent price hike circus, please note that running this notebook is probably entirely tolerable on CPU only runtime (without GPU). Images are processed within seconds on CPU.

### Tips

- All directory and file paths should be relative to your Google Drive root: e.g. if you have a directory called _images_ in your Drive, containing a subdirectory called _churchboats_, then the field value in this notebook should be `images/churchboats`.

- `input` may be a path to an image file or a directory containing image files. All images found from the directory will be individually processed (images in subdirs not included).

- If you provide a `local_models_dir` path, all models will be fetched from there instead of being downloaded. If models are not found from the given path, they will be downloaded there first. Using this may be useful in the future if models are no longer available in their current locations.

- `scalers` (~0-1) are factors for image resizing prior to processing. You may think of it as _background removal rate_ or  _sensitivity_ of the bg removal, where a higher number typically removes more stuff from the image. You may give a single value or a comma separated list of values (e.g. `0.5, 0.75`). Each given value will produce a new image.

- `end_session_when_done` disconnects and deletes runtime upon cell completion. Mostly useful if you are processing a large number of images on a GPU runtime.

In [None]:
#@title #Setup
#@markdown This cell needs to be run only once. It will mount your Google Drive and setup prerequisites.<br>
#@markdown <small>Mounting Drive will enable this notebook to save outputs directly to your Drive. Otherwise you will need to copy/download them manually from this notebook.</small>

force_setup = False
repositories = [
  'https://github.com/xuebinqin/U-2-Net',
  # 'https://github.com/xuebinqin/DIS'
]
pip_packages = ''
apt_packages = ''
mount_drive = True #@param {type:"boolean"}
local_models_dir = "" #@param {type:"string"}
skip_setup = False #@ param {type:"boolean"}

# Download the repo from Github
import os
from google.colab import output
import warnings
warnings.filterwarnings('ignore')
%cd /content/

# inhagcutils
if not os.path.isfile('/content/inhagcutils.ipynb') and force_setup == False:
  !pip -q install import-ipynb {pip_packages}
  if apt_packages != '':
    !apt-get update && apt-get install {apt_packages}
  !curl -s -O https://raw.githubusercontent.com/olaviinha/inhagcutils/master/inhagcutils.ipynb
import import_ipynb
from inhagcutils import *

# Mount Drive
if mount_drive is True:
  if not os.path.isdir('/content/drive'):
    from google.colab import drive
    drive.mount('/content/drive')
    drive_root = '/content/drive/My Drive'
  if not os.path.isdir('/content/mydrive'):
    os.symlink('/content/drive/My Drive', '/content/mydrive')
    drive_root = '/content/mydrive/'
  drive_root_set = True
else:
  create_dirs(['/content/faux_drive'])
  drive_root = '/content/faux_drive/'

if len(repositories) > 0:
  if skip_setup == False:
    for repo in repositories:
      %cd /content/
      install_dir = fix_path('/content/'+path_leaf(repo).replace('.git', ''))
      repo_name = basename(install_dir)
      repo = repo if '.git' in repo else repo+'.git'
      !git clone {repo}
      if os.path.isfile(install_dir+'setup.py') or os.path.isfile(install_dir+'setup.cfg'):
        !pip install -e ./{install_dir}
      if os.path.isfile(install_dir+'requirements.txt'):
        !pip install -r {install_dir}/requirements.txt
  else:
    install_dir = fix_path('/content/'+path_leaf(repositories[0]).replace('.git', ''))
    repo_name = path_leaf(install_dir)

dir_tmp = '/content/tmp/'
create_dirs([dir_tmp])

if len(repositories) == 1:
  %cd {install_dir}

import time, sys
from datetime import timedelta

# # DO stuff
# if repo_name == 'DIS':
#   !gdown 1jOC2zK0GowBvEt03B7dGugCRDVAoeIqq

def prep_model(model, gid):
  global install_dir, models_dir
  filename = model+'.pth'
  if not os.path.isfile(models_dir+filename):
    !gdown {gid}
  if os.path.isfile(models_dir+filename):
    mdir = install_dir+'saved_models/'+fix_path(basename(filename))
    if not os.path.isdir(mdir): os.mkdir(mdir)
    shutil.copy(models_dir+filename, mdir+filename)
  else:
    op(c.fail, 'Failed', filename)

if local_models_dir != '':
  models_dir = drive_root+fix_path(local_models_dir)
  if not os.path.isdir(models_dir): os.mkdir(models_dir)
  %cd {models_dir}
  prep_model('u2net', '1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ')
  prep_model('u2netp', '1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy')
  prep_model('u2net_human_seg', '1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P')
  prep_model('u2net_portrait', '1IG3HdpcRiDoWNookbncQjeaPN28t90yW')
  %cd {install_dir}
else:
  # u2net
  d = install_dir+'saved_models/u2net'
  os.mkdir(d)
  %cd {d}
  !gdown 1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ
  # u2netp
  d = install_dir+'saved_models/u2netp'
  os.mkdir(d)
  %cd {d}
  !gdown 1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy
  # human seg
  d = install_dir+'saved_models/u2net_human_seg'
  os.mkdir(d)
  %cd {d}
  !gdown 1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P
  # portrait
  d = install_dir+'saved_models/u2net_portrait'
  os.mkdir(d)
  %cd {d}
  !gdown 1IG3HdpcRiDoWNookbncQjeaPN28t90yW

import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image, ImageEnhance
from glob import glob

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB

# normalize the predicted SOD probability map
def normPRED(d):
  ma = torch.max(d)
  mi = torch.min(d)
  dn = (d-mi)/(ma-mi)
  return dn

def save_output(og_img, pred, output_path, contrast=1, save_mask=False):
  predict = pred
  predict = predict.squeeze()
  predict_np = predict.cpu().data.numpy()
  im = Image.fromarray(predict_np*255).convert('L')
  og_image = io.imread(og_img)
  mask = im.resize((og_image.shape[1],og_image.shape[0]), resample=Image.BILINEAR)
  if contrast != 1:
    enhancer = ImageEnhance.Contrast(mask)
    mask = enhancer.enhance(contrast)
  if save_mask is True:
    mask.save(output_path.replace('.png', '_mask.png'), 'PNG')
  im2 = Image.open(og_img)
  im2.putalpha(mask)
  im2.save(output_path, 'PNG')

%cd {install_dir}

output.clear()
op(c.ok, 'Setup finished.')

In [None]:
#@title # Run
include_subdirs = False #@ param {type: "boolean"}
input = "" #@param {type:"string"}
output_dir = "" #@param {type:"string"}

filename_detail = "" # @param {type:"string"}

scalers = "0.5" #@param {type:"string"}
model = "u2net" #@param ["u2net", "u2netp", "portrait", "human"]
disconnect_runtime_when_done = False #@param {type: "boolean"}

contrast = 1
save_mask = False

model_path = install_dir+'saved_models/'+model+'/'+model+'.pth'
model_name = basename(model_path)
uniq_id = gen_id()
trunc = 40
est_per_img = 10

copy_files = False
exclude_images_containing = ['u2net', 'isnet']

if os.path.isfile(drive_root+input):
  inputs = [drive_root+input]
  dir_in = path_dir(drive_root+input)
  copy_files = True
elif input != '' and os.path.isdir(drive_root+input):
  dir_in = drive_root+fix_path(input)
  inputs = list_images(dir_in, exclude_pattern=exclude_images_containing)
  if include_subdirs is True:
    dirlist = glob(dir_in+'/*')
    for item in dirlist:
      if os.path.isdir(item):
        inputs.extend(list_images(item, exclude_pattern=exclude_images_containing))
    # inputs = list(set(inputs))
elif os.path.isdir(drive_root+input) and '*' in input:
  dir_in = path_dir(drive_root+input)
  inputs = glob(drive_root+input)
  copy_files = True
else:
  op(c.fail, 'FAIL!', 'Input should be a path to a file or a directory.')
  sys.exit('Input not understood.')

# Output
if output_dir == '':
  dir_out = dir_in
else:
  if not os.path.isdir(drive_root+output_dir):
    os.mkdir(drive_root+output_dir)
  dir_out = drive_root+fix_path(output_dir)
  
timer_start = time.time()
total = len(inputs)

if copy_files == True:
  files = glob(dir_tmp)
  for f in files:
    os.remove(f)
  for input in inputs:
    shutil.copy(input, dir_tmp)
  inputs = list_images(dir_tmp)
  dir_in = dir_tmp

scales = [float(scale.strip()) for scale in scalers.split(',')]

# -- DO THINGS --

if(model_name=='u2net'):
  net = U2NET(3,1)
elif(model_name=='u2netp'):
  net = U2NETP(3,1)

using = 'CPU'
if torch.cuda.is_available():
  net.load_state_dict(torch.load(model_path))
  net.cuda()
  using = 'GPU'
else:
  net.load_state_dict(torch.load(model_path, map_location='cpu'))
net.eval()

total = len(inputs) * len(scales)
count = 1
est_time = est_per_img * total

op(c.title, 'RUN ID:', uniq_id, time=True)
op(c.okb, 'Using '+using+' to process '+str(total)+' images out of '+str(len(inputs))+' images', time=True)
# op(c.okb, 'Estimated time:', timedelta(seconds=est_time), time=True)
print()

for input in inputs:
  if output_dir == '' and include_subdirs is True:
    dir_out = dir_in
  for scale in scales:
    ndx_info = str(count)+'/'+str(total)+' '
    op(c.title, ndx_info+'Processing', path_leaf(input), time=True)
    inp = Image.open(input)
    w, h = inp.size
    max_side = int(max([w, h])*scale)
    # print( max_side )
    fd = filename_detail+'__' if filename_detail != '' else ''
    salobj_dataset = SalObjDataset(img_name_list=[input], lbl_name_list=[], transform=transforms.Compose([RescaleT(max_side), ToTensorLab(flag=0)]) )
    salobj_dataloader = DataLoader(salobj_dataset, batch_size=1, shuffle=False, num_workers=1)
    for i, data in enumerate(salobj_dataloader):
      file_out = dir_out+fd+slug(basename(input)[:trunc])+'x_'+uniq_id+'_'+model_name+'_'+str(scale).replace('.','')+'.png'
      # print( 'save as', file_out)
      inputs_test = data['image']
      inputs_test = inputs_test.type(torch.FloatTensor)
      if torch.cuda.is_available():
        inputs_test = Variable(inputs_test.cuda())
      else:
        inputs_test = Variable(inputs_test)
      d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
      pred = d1[:,0,:,:]
      pred = normPRED(pred)
      save_output(input, pred, file_out, contrast, save_mask)
      if os.path.isfile(file_out):
        op(c.ok, 'Saved as', path_leaf(dir_out)+'/'+path_leaf(file_out), time=True)
      else:
        op(c.fail, 'ERROR saving', file_out, time=True)
      del d1,d2,d3,d4,d5,d6,d7,pred
      timer_pc = time.time()
      elapsed = timer_pc-timer_start
      est_remaining = est_time-elapsed
      # op(c.okb, 'Est. time remaining', timedelta(seconds=est_remaining), time=True)
    del salobj_dataset, salobj_dataloader, inp
    count += 1
    print()




  
# -- END THINGS --

timer_end = time.time()

print()
op(c.okb, 'Elapsed', timedelta(seconds=timer_end-timer_start), time=True)
op(c.ok, 'FIN.')

if disconnect_runtime_when_done is True: end_session()