# Colab-deep-watermark

Original repo: [vinthony/deep-blind-watermark-removal/](https://github.com/vinthony/deep-blind-watermark-removal/)

Original colab: [here](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing)

My fork: [styler00dollar/Colab-deep-watermark](https://github.com/styler00dollar/Colab-deep-watermark)

A more userfriendly version of the official colab.

In [None]:
!nvidia-smi

In [None]:
#@title install
# download the necessary componments
! rm -rf *
! git clone https://github.com/vinthony/deep-blind-watermark-removal.git # get code from github
! gdown https://drive.google.com/uc?id=1KpSJ6385CHN6WlAINqB3CYrJdleQTJBc # get pretrained model
#! gdown https://drive.google.com/uc?id=18HaWfYYZCD34VttSjd2at8b9BKdhgVgU && unzip -q val.zip # get validation dataset (2.31G) of 27kpng
#! gdown https://drive.google.com/uc?id=1it5oQDRqRzBVieX6jKNmOxj1992f63yM && unzip -q natural.zip # get natural images (0.4G) of 27kpng

# rename natural images
from os import listdir
from os.path import isfile, join
import shutil
#filenames = [ shutil.copy(join('./natural', f), join('./natural', f).split('-')[0]+'.jpg') for f in listdir('./natural') if isfile(join('./natural', f)) ]

Input ```/content/image.jpg```. Run the following two cells if you change the image.

In [None]:
#@title remove folders and copy file
!sudo rm -rf /content/val_images/image/
!mkdir /content/val_images/image
!sudo rm -rf /content/val_images/mask
!mkdir /content/val_images/mask
!sudo rm -rf /content/val_images/wm
!mkdir /content/val_images/wm
!sudo rm -rf /content/natural
!mkdir /content/natural

!sudo rm -rf "/content/natural/"
!mkdir "/content/natural/"
!sudo rm -rf "/content/val_images/"
!mkdir "/content/val_images/"
!sudo rm -rf "/content/val_images/image/"
!mkdir "/content/val_images/image/"
!sudo rm -rf "/content/val_images/mask/"
!mkdir "/content/val_images/mask/"
!sudo rm -rf "/content/val_images/wm/"
!mkdir "/content/val_images/wm/"

!cp "/content/image.jpg" "/content/natural/image.jpg.jpg"
!cp "/content/image.jpg" "/content/val_images/image/image.jpg"
!cp "/content/image.jpg" "/content/val_images/mask/image.jpg"
!cp "/content/image.jpg" "/content/val_images/wm/image.jpg"

In [None]:
#@title apply
import os, sys, torch,random
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm

sys.path.append('deep-blind-watermark-removal')
sys.path.insert(0,'deep-blind-watermark-removal')

from scripts.utils.imutils import im_to_numpy
import scripts.models as models
import scripts.datasets as datasets
%matplotlib inline
from PIL import Image, ImageChops

def get_jet():
    colormap_int = np.zeros((256, 3), np.uint8)
 
    for i in range(0, 256, 1):
        colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0))
        colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0))
        colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0))

    return colormap_int

def clamp(num, min_value, max_value):
    return max(min(num, max_value), min_value)

def gray2color(gray_array, color_map):
    
    rows, cols = gray_array.shape
    color_array = np.zeros((rows, cols, 3), np.uint8)
 
    for i in range(0, rows):
        for j in range(0, cols):
#             log(256,2) = 8 , log(1,2) = 0 * 8
            color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)]
    
    return color_array

class objectview(object):
    def __init__(self, *args, **kwargs):
        d = dict(*args, **kwargs)
        self.__dict__ = d

jet_map = get_jet()

resume_path = '27kpng_model_best.pth.tar' # path of pretrained model
samples = [320,1364,1868] #random.sample(range(4000), 1) # show random sample 

data_config  = objectview({'input_size':256,
                            'limited_dataset':0,
                            'normalized_input':False,
                            'data_augumentation':False,
                            'base_dir':'.',
                            'data':'_images'})

val_loader = torch.utils.data.DataLoader(datasets.COCO('val',config=data_config))

print('input          | target              | coarser            | final')
print('----------------------------------------------------------------------------')
print('predicted mask | predicted watermark | coarser difference | final difference')

with torch.no_grad():

      model = models.__dict__['vvv4n']().cuda()
      model.load_state_dict(torch.load(resume_path)['state_dict'])
      model.eval()
      
      for i, batches in enumerate(val_loader):
          
          plt.figure(figsize=(48,12))

          im,mask,target = batches['image'].cuda(),batches['mask'].cuda(),batches['target'].cuda()
              
          imoutput,immask,imwatermark = model(im)
        
          imcoarser,imrefine,imwatermark = imoutput[1]*immask + im*(1-immask),imoutput[0]*immask + im*(1-immask),imwatermark*immask

          ims1 = im_to_numpy(torch.clamp(torch.cat([im,target,imcoarser,imrefine],dim=3)[0]*255,min=0.0,max=255.0)).astype(np.uint8)
          
          imcoarser, imrefine, target  = im_to_numpy((imcoarser[0]*255)).astype(np.uint8), im_to_numpy((imrefine[0]*255)).astype(np.uint8), im_to_numpy((target[0]*255)).astype(np.uint8)
          immask, imwatermark = im_to_numpy((immask.repeat(1,3,1,1)[0]*255)).astype(np.uint8),im_to_numpy((imwatermark[0]*255)).astype(np.uint8)

          coarsenp = gray2color(np.array(ImageChops.difference(Image.fromarray(imcoarser),Image.fromarray(target)).convert('L')),jet_map)
          finenp = gray2color(np.array(ImageChops.difference(Image.fromarray(imrefine),Image.fromarray(target)).convert('L')),jet_map)
          
          imfinal = np.concatenate([ims1,np.concatenate([immask,imwatermark,coarsenp,finenp],axis=1)],axis=0)

          plt.imshow(imfinal,vmin=0.0,vmax=255.0)


# Training

In [None]:
#@title install
!git clone https://github.com/vinthony/deep-blind-watermark-removal
!pip install progress
!pip install tensorboardX
!pip install pytorch_ssim

In [None]:
#@title main.py (forcing coco)
%%writefile /content/deep-blind-watermark-removal/main.py
from __future__ import print_function, absolute_import

import argparse
import torch,time,os

torch.backends.cudnn.benchmark = True

from scripts.utils.misc import save_checkpoint, adjust_learning_rate

import scripts.datasets as datasets
import scripts.machines as machines
from options import Options

def main(args):
    """
    if 'ISTD' in args.base_dir:
        dataset_func = datasets.SR
    elif 'HFlickr' or 'HCOCO' or 'Hday2night' or 'HAdobe5k' in args.base_dir:
        dataset_func = datasets.BIH
    else:
        dataset_func = datasets.COCO
    """
    # forcing coco
    dataset_func = datasets.COCO

    train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    
    val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    lr = args.lr
    data_loaders = (train_loader,val_loader)

    Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)
    print('============================ Initization Finish && Training Start =============================================')

    for epoch in range(Machine.args.start_epoch, Machine.args.epochs):

        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
        lr = adjust_learning_rate(data_loaders, Machine.optimizer, epoch, lr, args)

        Machine.record('lr',lr, epoch)        
        Machine.train(epoch)

        if args.freq < 0:
            Machine.validate(epoch)
            Machine.flush()
            Machine.save_checkpoint()

if __name__ == '__main__':
    parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
    args = parser.parse_args()
    print('==================================== WaterMark Removal =============================================')
    print('==> {:50}: {:<}'.format("Start Time",time.ctime(time.time())))
    print('==> {:50}: {:<}'.format("USE GPU",os.environ['CUDA_VISIBLE_DEVICES']))
    print('==================================== Stable Parameters =============================================')
    for arg in vars(args):
        if type(getattr(args, arg)) == type([]):
            if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]):
                print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
        else:
            if getattr(args, arg) == parser.get_default(arg):
                print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
    print('==================================== Changed Parameters =============================================')
    for arg in vars(args):
        if type(getattr(args, arg)) == type([]):
            if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]):
                print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
        else:
            if getattr(args, arg) != parser.get_default(arg):
                print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
    print('==================================== Start Init Model  ===============================================')
    main(args)
    print('==================================== FINISH WITHOUT ERROR =============================================')

In [None]:
#@title COCO.py (setting my_path)
%%writefile /content/deep-blind-watermark-removal/scripts/datasets/COCO.py
from __future__ import print_function, absolute_import

import os
import csv
import numpy as np
import json
import random
import math
import matplotlib.pyplot as plt
from collections import namedtuple
from os import listdir
from os.path import isfile, join

import torch
import torch.utils.data as data

from scripts.utils.osutils import *
from scripts.utils.imutils import *
from scripts.utils.transforms import *
import torchvision.transforms as transforms
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageFilter
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

class COCO(data.Dataset):
    def __init__(self,train,config=None, sample=[],gan_norm=False):

        self.train = []
        self.anno = []
        self.mask = []
        self.wm = []
        self.input_size = config.input_size
        self.normalized_input = config.normalized_input
        self.base_folder = config.base_dir
        self.dataset = train+config.data

        if config == None:
            self.data_augumentation = False
        else:
            self.data_augumentation = config.data_augumentation

        self.istrain = False if self.dataset.find('train') == -1 else True
        self.sample = sample
        self.gan_norm = gan_norm
        #mypath = join(self.base_folder,self.dataset)
        mypath = '/content/data/'
        file_names = sorted([f for f in listdir(join(mypath,'image')) if isfile(join(mypath,'image', f)) ])

        if config.limited_dataset > 0:
            xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))
            tmp = []
            for x in xtrain:
                # get the file_name by identifier
                tmp.append([y for y in file_names if x in y][0])

            file_names = tmp
        else:
            file_names = file_names

        for file_name in file_names:
            self.train.append(os.path.join(mypath,'image',file_name)) # watermarked
            self.mask.append(os.path.join(mypath,'mask',file_name))
            self.wm.append(os.path.join(mypath,'wm',file_name))
            #self.anno.append(os.path.join(self.base_folder,'natural',file_name.split('-')[0]+'.jpg'))
            self.anno.append(os.path.join(mypath,'natural',file_name)) # not watermarked

        if len(self.sample) > 0 :
            self.train = [ self.train[i] for i in self.sample ] 
            self.mask = [ self.mask[i] for i in self.sample ] 
            self.anno = [ self.anno[i] for i in self.sample ] 

        self.trans = transforms.Compose([
                transforms.Resize((self.input_size,self.input_size)),
                transforms.ToTensor()
            ])

        print('total Dataset of '+self.dataset+' is : ', len(self.train))


    def __getitem__(self, index):
        img = Image.open(self.train[index]).convert('RGB')
        mask = Image.open(self.mask[index]).convert('L')
        anno = Image.open(self.anno[index]).convert('RGB')
        wm = Image.open(self.wm[index]).convert('RGB')


        return {"image": self.trans(img),
                "target": self.trans(anno), 
                "mask": self.trans(mask), 
                "wm": self.trans(wm),
                "name": self.train[index].split('/')[-1],
                "imgurl":self.train[index],
                "maskurl":self.mask[index],
                "targeturl":self.anno[index],
                "wmurl":self.wm[index]
                }

    def __len__(self):

        return len(self.train)

In [None]:
#@title VX.py (deleting ssim, png eval)
%%writefile /content/deep-blind-watermark-removal/scripts/machines/VX.py
import torch
import torch.nn as nn
from progress.bar import Bar
from tqdm import tqdm
import pytorch_ssim
import json
import sys,time,os
import torchvision
from math import log10
import numpy as np
from .BasicMachine import BasicMachine
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.misc import resize_to_match
from torch.autograd import Variable
import torch.nn.functional as F
from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
from scripts.utils.losses import VGGLoss, l1_relative,is_dic
from scripts.utils.imutils import im_to_numpy
import skimage.io
from skimage.measure import compare_psnr,compare_ssim


class Losses(nn.Module):
    def __init__(self, argx, device, norm_func=None, denorm_func=None):
        super(Losses, self).__init__()
        self.args = argx

        if self.args.loss_type == 'l1bl2':
            self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()
        elif self.args.loss_type == 'l2xbl2':
            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()
        elif self.args.loss_type == 'relative' or self.args.loss_type == 'hybrid':
            self.outputLoss, self.attLoss, self.wrloss = l1_relative, nn.BCELoss(), l1_relative
        else: # l2bl2
            self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()

        self.default = nn.L1Loss()

        if self.args.style_loss > 0:
            self.vggloss = VGGLoss(self.args.sltype).to(device)
        
        if self.args.ssim_loss > 0:
            self.ssimloss =  pytorch_ssim.SSIM().to(device)
        
        self.norm = norm_func
        self.denorm = denorm_func


    def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm):
        pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = [0]*5
        pred_ims = pred_ims if is_dic(pred_ims) else [pred_ims]

        # try the loss in the masked region
        if self.args.masked and 'hybrid' in self.args.loss_type: # masked loss
            pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
            pixel_loss += sum([self.default(pred_im*pred_ms,target*mask) for pred_im in pred_ims])
            recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
            wm_loss += self.wrloss(pred_wms, wm, mask)
            wm_loss += self.default(pred_wms*pred_ms, wm*mask)

        elif self.args.masked and 'relative' in self.args.loss_type: # masked loss
            pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
            recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
            wm_loss = self.wrloss(pred_wms, wm, mask)
        elif self.args.masked:
            pixel_loss += sum([self.outputLoss(pred_im*mask, target*mask) for pred_im in pred_ims])
            recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
            wm_loss = self.wrloss(pred_wms*mask, wm*mask)
        else:
            pixel_loss += sum([self.outputLoss(pred_im*pred_ms, target*mask) for pred_im in pred_ims])
            recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
            wm_loss = self.wrloss(pred_wms*pred_ms,wm*mask)

        pixel_loss += sum([self.default(im,target) for im in recov_imgs])

        if self.args.style_loss > 0:
            vgg_loss = sum([self.vggloss(im,target,mask) for im in recov_imgs])

        #if self.args.ssim_loss > 0:
        #    ssim_loss = sum([ 1 - self.ssimloss(im,target) for im in recov_imgs])
        ssim_loss = [0]

        att_loss =  self.attLoss(pred_ms, mask)

        return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss


class VX(BasicMachine):
    def __init__(self,**kwargs):
        BasicMachine.__init__(self,**kwargs)
        self.loss = Losses(self.args, self.device, self.norm, self.denorm)
        self.model.set_optimizers()
        self.optimizer = None
       
    def train(self,epoch):

        self.current_epoch = epoch

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        lossMask = AverageMeter()
        lossWM = AverageMeter()
        lossMX = AverageMeter()
        lossvgg = AverageMeter()
        lossssim = AverageMeter()

        # switch to train mode
        self.model.train()

        end = time.time()
        bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader))

        for i, batches in enumerate(self.train_loader):

            current_index = len(self.train_loader) * epoch + i

            inputs = batches['image'].to(self.device)
            target = batches['target'].to(self.device)
            mask = batches['mask'].to(self.device)
            #print(batches)
            wm =  batches['wm'].to(self.device)

            outputs = self.model(self.norm(inputs))
            
            self.model.zero_grad_all()

            l2_loss,att_loss,wm_loss,style_loss,ssim_loss = self.loss(outputs[0],self.norm(target),outputs[1],mask,outputs[2],self.norm(wm))
            #total_loss = 2*l2_loss + self.args.att_loss * att_loss + wm_loss + self.args.style_loss * style_loss + self.args.ssim_loss * ssim_loss
            total_loss = 2*l2_loss + self.args.att_loss * att_loss + wm_loss + self.args.style_loss * style_loss

            # compute gradient and do SGD step
            total_loss.backward()
            self.model.step_all()

            # measure accuracy and record loss
            losses.update(l2_loss.item(), inputs.size(0))
            lossMask.update(att_loss.item(), inputs.size(0))
            lossWM.update(wm_loss.item(), inputs.size(0))

            if self.args.style_loss > 0 :
                lossvgg.update(style_loss.item(), inputs.size(0))

            #if self.args.ssim_loss > 0 :
            #    lossssim.update(ssim_loss.item(), inputs.size(0))


            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            suffix  = "({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss Mask: {loss_mask:.4f} | loss WM: {loss_wm:.4f} | loss VGG: {loss_vgg:.4f} | loss SSIM: {loss_ssim:.4f}| loss MX: {loss_mx:.4f}".format(
                        batch=i + 1,
                        size=len(self.train_loader),
                        data=data_time.val,
                        bt=batch_time.val,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss_label=losses.avg,
                        loss_mask=lossMask.avg,
                        loss_wm=lossWM.avg,
                        loss_vgg=lossvgg.avg,
                        loss_ssim=lossssim.avg,
                        loss_mx=lossMX.avg
                        )
            if current_index % 1000 == 0:
                print(suffix)

            if self.args.freq > 0 and current_index % self.args.freq == 0:
                self.validate(current_index)
                self.flush()
                self.save_checkpoint()

        self.record('train/loss_L2', losses.avg, epoch)
        self.record('train/loss_Mask', lossMask.avg, epoch)
        self.record('train/loss_WM', lossWM.avg, epoch)
        self.record('train/loss_VGG', lossvgg.avg, epoch)
        self.record('train/loss_SSIM', lossssim.avg, epoch)
        self.record('train/loss_MX', lossMX.avg, epoch)




    def validate(self, epoch):

        self.current_epoch = epoch
        
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        lossMask = AverageMeter()
        psnres = AverageMeter()
        ssimes = AverageMeter()

        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader))
        with torch.no_grad():
            for i, batches in enumerate(self.val_loader):

                current_index = len(self.val_loader) * epoch + i

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)

                outputs = self.model(self.norm(inputs))
                imoutput,immask,imwatermark = outputs
                imoutput = imoutput[0] if is_dic(imoutput) else imoutput

                imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))

                if i % 300 == 0:
                    # save the sample images
                    ims = torch.cat([inputs,target,imfinal,immask.repeat(1,3,1,1)],dim=3)
                    torchvision.utils.save_image(ims,os.path.join(self.args.checkpoint,'%s_%s.png'%(i,epoch)))

                # here two choice: mseLoss or NLLLoss
                psnr = 10 * log10(1 / F.mse_loss(imfinal,target).item())       

                #ssim = pytorch_ssim.ssim(imfinal,target)

                psnres.update(psnr, inputs.size(0))
                #ssimes.update(ssim, inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                bar.suffix  = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | Loss_Mask: {loss_mask:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}'.format(
                            batch=i + 1,
                            size=len(self.val_loader),
                            data=data_time.val,
                            bt=batch_time.val,
                            total=bar.elapsed_td,
                            eta=bar.eta_td,
                            loss_label=losses.avg,
                            loss_mask=lossMask.avg,
                            psnr=psnres.avg,
                            ssim=ssimes.avg
                            )
                bar.next()
        bar.finish()
        
        print("Iter:%s,Losses:%s,PSNR:%.4f,SSIM:%.4f"%(epoch, losses.avg,psnres.avg,ssimes.avg))
        self.record('val/loss_L2', losses.avg, epoch)
        self.record('val/lossMask', lossMask.avg, epoch)
        self.record('val/PSNR', psnres.avg, epoch)
        self.record('val/SSIM', ssimes.avg, epoch)
        self.metric = psnres.avg

        self.model.train()

    def test(self, ):

        # switch to evaluate mode
        self.model.eval()
        print("==> testing VM model ")
        ssimes = AverageMeter()
        psnres = AverageMeter()
        ssimesx = AverageMeter()
        psnresx = AverageMeter()

        with torch.no_grad():
            for i, batches in enumerate(tqdm(self.val_loader)):

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask =batches['mask'].to(self.device)

                # select the outputs by the giving arch
                outputs = self.model(self.norm(inputs))
                imoutput,immask,imwatermark = outputs
                imoutput = imoutput[0] if is_dic(imoutput) else imoutput

                imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))
                psnrx = 10 * log10(1 / F.mse_loss(imfinal,target).item())       
                ssimx = pytorch_ssim.ssim(imfinal,target)
                # recover the image to 255
                imfinal = im_to_numpy(torch.clamp(imfinal[0]*255,min=0.0,max=255.0)).astype(np.uint8)
                target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)

                skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), imfinal)

                psnr = compare_psnr(target,imfinal)
                ssim = compare_ssim(target,imfinal,multichannel=True)

                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim, inputs.size(0))
                psnresx.update(psnrx, inputs.size(0))
                ssimesx.update(ssimx, inputs.size(0))

        print("%s:PSNR:%.5f(%.5f),SSIM:%.5f(%.5f)"%(self.args.checkpoint,psnres.avg,psnresx.avg,ssimes.avg,ssimesx.avg))
        print("DONE.\n")

Data structure
```
/content/data/image/000.png (Masked image (RGBA32)
/content/data/mask/000.png (Mask where watermark is (white is mask) (GRAY8BPP))
/content/data/natural/000.png (Original image (RGB/YUV)
/content/data/wm/000.png (Watermark (RGBA32))

/content/data/image_train.txt (List of filenames of /image folder)
/content/data/image_val.txt (List of filenames of /natural folder)
```

In [None]:
#@title create folders
!mkdir /content/data
!mkdir /content/data/image
!mkdir /content/data/mask
!mkdir /content/data/wm
!mkdir /content/data/natural
!mkdir /content/eval

Create files with filenames

In [None]:
%%writefile /content/data/image_train.txt
/content/data/image/000.png

In [None]:
%%writefile /content/data/image_val.txt
/content/data/natural/000.png

In [None]:
#@title train
!CUDA_VISIBLE_DEVICES=0 python /content/deep-blind-watermark-removal/main.py \
 --epochs 100\
 --schedule 100\
 --lr 1e-3\
 -c /content/eval/ \
 --arch vvv4n\
 --sltype vggx\
 --style-loss 0.025\
 --ssim-loss 0.15\
 --masked True\
 --loss-type hybrid\
 --limited-dataset 1\
 --machine vx\
 --input-size 256\
 --train-batch 1\
 --test-batch 1\
 --base-dir /content/ \
 --data /content/data/image