ARGAN : Attentive Recurrent Generative Adversarial Network
for Shadow Detection and Removal

# import lib & data

In [None]:
!nvidia-smi

Mon Apr 12 15:33:23 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.67       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8    29W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# import lib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os

import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from PIL import Image
from skimage import color

# mount to drive
from google.colab import drive
drive.mount('/content/drive')

# GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# load images from Google Drive
#!ls '/content/drive/My Drive/KU/4/ISTD_Dataset/train/'
img_path = '/content/drive/My Drive/KU/4/ISTD_Dataset/train/'
test_path = '/content/drive/My Drive/KU/4/ISTD_Dataset/test/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
cuda:0


Functions

In [None]:
# display image
def imshow(image):
  # numpy
  npimage = image.detach().numpy()
  plt.imshow(np.transpose(npimage, (1,2,0)))
  plt.show()

# Save output tensor as image
def save_batch(images, nrow, PATH):
  img = torchvision.utils.make_grid(images, nrow=nrow)
  img_out = np.transpose(img.detach().numpy().astype('float64'), (1,2,0))
  img_out = (255*img_out).astype('uint8')
  img_out = Image.fromarray(img_out)
  img_out.save(PATH)

def save_batch_LAB(images, nrow, PATH):
  img = torchvision.utils.make_grid(images, nrow=nrow)
  img_out = np.transpose(img.detach().numpy().astype('float64'), (1,2,0))
  img_out = color.lab2rgb(img_out)
  img_out = (255*img_out).astype('uint8')
  img_out = Image.fromarray(img_out)
  img_out.save(PATH)

Dataloader

In [None]:
# My dataset loading function
def make_dataset(root, test) -> list:
  dataset = []
  # sub folder names of data set
  if test is True:
    src_dir = 'test_A'
    matt_dir = 'test_B'
    free_dir = 'test_C'
  else:
    src_dir = 'train_A'
    matt_dir = 'train_B'
    free_dir = 'train_C'

  # file names of dataset
  src_fnames = sorted(os.listdir(os.path.join(root, src_dir)))
  matt_fnames = sorted(os.listdir(os.path.join(root,matt_dir)))
  free_fnames = sorted(os.listdir(os.path.join(root,free_dir)))

  # matching datasets by name
  # same fname for triplets
  for src_fname in src_fnames:
    # source image (image with shadow)
    src_path = os.path.join(root,src_dir,src_fname)
    if  src_fname in matt_fnames:
      # shadow matte image
      matt_path = os.path.join(root,matt_dir,src_fname)
      if src_fname in free_fnames:
        # shadow free image
        free_path = os.path.join(root,free_dir,src_fname)
        # if triplets exists append to dataset
        temp = (src_path, matt_path, free_path)
        dataset.append(temp)
      # if one of triplets missing do NOT append to dataset
      else:
        print(free_fname, 'Shadow free file missing')
        continue
    else:
      print(matt_fname, 'Shadow matte file missing')
      continue

  return dataset


class ARGAN_Dataset(torchvision.datasets.vision.VisionDataset):
  # ARGAN dataset class composed of 3 func
  def __init__(self, root, loader=torchvision.datasets.folder.default_loader,
               is_test=False, src_trans=None, matt_trans=None):
    super().__init__(root, transform=src_trans, target_transform=matt_trans)
    self.test = is_test
    # Custom dataset loader for Training
    samples = make_dataset(self.root, test=is_test)
    self.loader = loader
    self.samples = samples
    self.trans2tensor = transforms.ToTensor()
    # train data list
#    self.src_samples = [s[0] for s in samples]
#    self.matt_samples = [s[1] for s in samples]
#    self.free_samples = [s[2] for s in samples]

  # Get single data
  def __getitem__(self, index):
    # load training data
    src_path, matt_path, free_path = self.samples[index]
    src_sample = self.loader(src_path)
    matt_sample = self.loader(matt_path)
    free_sample = self.loader(free_path)

    matt = self.trans2tensor(matt_sample)
    free = self.trans2tensor(free_sample)

    # transform data if required
    if self.transform is not None:
      # transform for RGB image : Shadow image and Shadow free image
      src_sample = self.transform(src_sample)
      free_sample = self.transform(free_sample)
    if self.target_transform is not None:
      # transform for Binary image : Shaode Matte
      matt_sample = self.target_transform(matt_sample)

    if self.test is False:
      return src_sample, matt_sample, free_sample
    else:
      return src_sample, matt_sample, free_sample, matt, free

  # Get dataset length
  def __len__(self):
    return len(self.samples)

Transforms

In [None]:
# image Transforms
# image size 128x128 used for training _ from paper
img2tensor = transforms.Compose([
                                 transforms.Resize(size=(256,256)),
                                 transforms.ToTensor()
                                 # additional tasks
])
matt2tensor = transforms.Compose([
                                  transforms.Resize(size=(256,256)),
                                  transforms.Grayscale(1),
                                  transforms.ToTensor()
                                  # additional tasks
])

# Load images
batch_num = 4
dprow = 2

train_img = ARGAN_Dataset(img_path, src_trans=img2tensor, matt_trans=matt2tensor, is_test=False)
trainloader = torch.utils.data.DataLoader(train_img, batch_size=batch_num, shuffle=True)

test_img = ARGAN_Dataset(test_path, src_trans=img2tensor, matt_trans=matt2tensor, is_test=True)
testloader = torch.utils.data.DataLoader(test_img, batch_size=4, shuffle=False)

# Check Dataset

Loaded Dataset

In [None]:
print(train_img)

#for i, (src,matt) in enumerate(trainloader):
dataiter = iter(trainloader)
print(type(dataiter))
images, mattes, frees = dataiter.next()

print(images.shape)
print(mattes.shape)
print(frees.shape)

imshow(torchvision.utils.make_grid(images, nrow=dprow))
imshow(torchvision.utils.make_grid(mattes, nrow=dprow))
imshow(torchvision.utils.make_grid(frees, nrow=dprow))

#imshow(frees)

NameError: ignored

Test Dataset

In [None]:
print(test_img)

#for i, (src,matt) in enumerate(trainloader):
testiter = iter(testloader)

images, mattes, frees = testiter.next()

print(images.shape)
print(mattes.shape)
print(frees.shape)

imshow(torchvision.utils.make_grid(images, nrow=dprow))
imshow(torchvision.utils.make_grid(mattes, nrow=dprow))
imshow(torchvision.utils.make_grid(frees, nrow=dprow))

#imshow(frees)

NameError: ignored

# Generative Network

Convolutional LSTM

In [None]:
# LSTM layer
class ConvLSTM(nn.Module):
    def __init__(self, inp_dim, hid_dim, out_dim=None):
        super(ConvLSTM, self).__init__()
        self.in_dim = inp_dim + hid_dim
        self.hidden_dim = hid_dim
        # Same output Channel as input Channel
        if out_dim is None:
            self.out_dim = inp_dim
        else:
            self.out_dim = out_dim
        self.conv_i = nn.Conv2d(self.in_dim, self.out_dim, 3, 1, 1)
        self.conv_f = nn.Conv2d(self.in_dim, self.out_dim, 3, 1, 1)
        self.conv_c = nn.Conv2d(self.in_dim, self.out_dim, 3, 1, 1)
        self.conv_o = nn.Conv2d(self.in_dim, self.out_dim, 3, 1, 1)
        self.sig = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x, c_prev, h_prev):
        # input X and hidden state H_prev
        xh = torch.cat((x, h_prev), 1)
        i = self.sig(self.conv_i(xh))
        f = self.sig(self.conv_f(xh))
        c = (f * c_prev) + (i * self.tanh(self.conv_c(xh)))
        o = self.sig(self.conv_o(xh))
        h = o * self.tanh(c)
        # C_next : cell output, H_next : hidden state
        return c, h

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return torch.zeros(batch_size, self.hidden_dim, height, width)

Generator

In [None]:
# Conv + BN + LReLU
class ConvL(nn.Module):
    def __init__(self, inp_ch, out_ch, k=None, s=None, p=None):
        super(ConvL, self).__init__()
        # (3, 1, 1) default
        if k is None:
            k = 3
            s = 1
            p = 1
        # default params
        elif s is None:
            s = 1
        elif p is None:
            p = 0
        self.conv = nn.Sequential(
            nn.Conv2d(inp_ch, out_ch, kernel_size=k, stride=s, padding=p),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU())

    def forward(self, x):
        return self.conv(x)

# ConvT + BN + LReLU
class DConvL(nn.Module):
    def __init__(self, inp_ch, out_ch, k=None, s=None, p=None):
        super(DConvL, self).__init__()
        # (3, 1, 1) keep spatial resolution
        if k is None:
            k = 3
            s = 1
            p = 1
        # default params
        elif s is None:
            s = 1
        elif p is None:
            p = 0
        self.dconv = nn.Sequential(
            nn.ConvTranspose2d(inp_ch, out_ch,
                               kernel_size=k, stride=s, padding=p),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU())

    def forward(self, x):
        return self.dconv(x)

# Attention Detector : 10 (Conv + BN + LReLU) layers
class AttDet(nn.Module):
    def __init__(self):
        super(AttDet, self).__init__()
        self.block = nn.Sequential(
            ConvL(3, 8), ConvL(8, 8), ConvL(8, 16), ConvL(16, 16),
            ConvL(16, 16), ConvL(16, 32), ConvL(32,32),
            ConvL(32, 64), ConvL(64, 64), ConvL(64, 64)
        )

    def forward(self, x):
        out = self.block(x)
        return out

# Removal Encoder : 8 Conv + 8 DConv +3 Conv Layers
class REncoder(nn.Module):
    def __init__(self):
        super(REncoder, self).__init__()
        # CONV LAYERS : extract feature
        self.conv0 = ConvL(3, 64, 3, 2, 3)
        self.conv1 = ConvL(64, 128, 3, 2, 2)
        self.conv2 = ConvL(128, 256, 3, 2, 2)
        self.conv3 = ConvL(256, 512, 3, 2, 2)
        self.conv4 = ConvL(512, 512, 3, 2, 2)
        self.conv5 = ConvL(512, 512, 3, 2, 2)
        self.conv6 = ConvL(512, 512, 3, 2, 2)
        self.conv7 = ConvL(512, 512, 3, 2, 2)

        # DECONV LAYERS : generate image with feature data
        self.dconv0 = DConvL(512, 512, 4, 2, 2)
        self.dconv1 = DConvL(512, 512, 4, 2, 2)
        self.dconv2 = DConvL(512, 512, 4, 2, 2)
        self.dconv3 = DConvL(512, 512, 4, 2, 2)
        self.dconv4 = DConvL(512, 256, 4, 2, 2)
        self.dconv5 = DConvL(256, 128, 4, 2, 2)
        self.dconv6 = DConvL(128, 64, 4, 2, 2)
        self.dconv7 = DConvL(64, 3, 4, 2, 3)

        # Convert to Neg residual
        self.rem0 = ConvL(3, 3)
        self.rem1 = ConvL(3, 3)
        self.rem2 = nn.Sequential(nn.Conv2d(3,3, kernel_size=3,
                                            stride=1, padding=1),
                                  nn.Sigmoid())

    def forward(self, x, att_map):
        x0 = self.conv0(x)
        x1 = self.conv1(x0)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)
        x6 = self.conv6(x5)
        x7 = self.conv7(x6)

        xx = self.dconv0(x7)
        xx += x6
        xx = self.dconv1(xx)
        xx += x5
        xx = self.dconv2(xx)
        xx += x4
        xx = self.dconv3(xx)
        xx += x3
        xx = self.dconv4(xx)
        xx += x2
        xx = self.dconv5(xx)
        xx += x1
        xx = self.dconv6(xx)
        xx += x0
        xx = self.dconv7(xx)

        xx = self.rem0(xx)
        xx = self.rem1(xx)
        xx = self.rem2(xx)

        res = xx * att_map
        out = res + x

        return out

In [None]:
# Generative Network
class Gen(nn.Module):
    def __init__(self, batch_size=None, step_num=None):
        super(Gen, self).__init__()
        self.batch_size = batch_size
        self.step = step_num
        # Attention Detector
        self.attL = []
        self.remE = []
        for i in range(self.step):
          self.attL.append(AttDet())
          self.remE.append(REncoder())
        self.attL = nn.ModuleList(self.attL)
        self.remE = nn.ModuleList(self.remE)
        #        self.attL = AttDet()
        # Convolutional LSTM cell
        self.lstm = ConvLSTM(inp_dim=64, hid_dim=64)
        # init hidden state
        self.hidden = self.lstm.init_hidden(self.batch_size, (256,256))
        # Attention Map
        self.attM = nn.Sequential(
                  nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
                  nn.Sigmoid()
        )
        # Removal Encoder
#        self.remE = REncoder()

    def init_h(self):
      self.hidden = self.lstm.init_hidden(self.batch_size, (256,256))

    def forward(self, x):
      in_batch = x.shape[0]
      if in_batch != self.batch_size:
        self.batch_size = in_batch
      self.hidden = self.lstm.init_hidden(self.batch_size, (256,256))
      with torch.autograd.set_detect_anomaly(True):
        # attention map & output tensor
        att_map = torch.empty(self.step, self.batch_size, 1, 256, 256).to(device)
        out = torch.empty(self.step , self.batch_size, 3, 256, 256).to(device)
        lstm_out = torch.zeros(self.batch_size, 64, 256, 256).to(device)
        self.hidden = self.hidden.to(device)

        # for N progressive steps
        for i in range(self.step):
            # attention detector
            lstm_in = self.attL[i](x)
            # LSTM Layer
            lstm_out, self.hidden = self.lstm(lstm_in, lstm_out, self.hidden)
            # Generate attention map
            temp = self.attM(lstm_in)
            # removal encoder
            res = self.remE[i](x, temp)
            # append to output
            att_map[i] = temp
            out[i] = res
            x = res

        # output to tensor
#        att_map = torch.FloatTensor(att_map)
#        out = torch.FloatTensor(out)

        return att_map, out

# Discriminative Network

In [None]:
# Discriminator
class Disc(nn.Module):
    def __init__(self, batch_size=None):
        super(Disc, self).__init__()
        self.batch_size = batch_size
        self.conv0 = ConvL(3, 64, 4, 2, 1)
        self.conv1 = ConvL(64, 128, 4, 2, 1)
        self.conv2 = ConvL(128, 256, 4, 2, 1)
        self.conv3 = ConvL(256, 512, 4, 2, 1)
        self.conv4 = ConvL(512, 256, 4, 2, 1)
        self.fc = nn.Sequential(nn.Linear(256*64, 1),
                                nn.Sigmoid())

    def forward(self, inp):
      with torch.autograd.set_detect_anomaly(True):
        self.batch_size = inp.shape[0]
        x = self.conv0(inp)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = torch.flatten(x, 1)
        out = self.fc(x)

        return out

# Model Summary

In [None]:
# Generatior Model Summary
from torchsummary import summary
summary(gen_net, (3, 256, 256), device='cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 8, 256, 256]             224
       BatchNorm2d-2          [-1, 8, 256, 256]              16
         LeakyReLU-3          [-1, 8, 256, 256]               0
             ConvL-4          [-1, 8, 256, 256]               0
            Conv2d-5          [-1, 8, 256, 256]             584
       BatchNorm2d-6          [-1, 8, 256, 256]              16
         LeakyReLU-7          [-1, 8, 256, 256]               0
             ConvL-8          [-1, 8, 256, 256]               0
            Conv2d-9         [-1, 16, 256, 256]           1,168
      BatchNorm2d-10         [-1, 16, 256, 256]              32
        LeakyReLU-11         [-1, 16, 256, 256]               0
            ConvL-12         [-1, 16, 256, 256]               0
           Conv2d-13         [-1, 16, 256, 256]           2,320
      BatchNorm2d-14         [-1, 16, 2

In [None]:
# Discriminator Model Summary
from torchsummary import summary
summary(dis_net, (3, 128, 128), device='cpu')

# Training Network

In [None]:
# Perceptual loss by VGG16
class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
        for bl in blocks:
            for p in bl:
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks)
        self.transform = torch.nn.functional.interpolate
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
        self.resize = resize

    def forward(self, input, target):
        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        loss = 0.0
        x = input
        y = target
        for block in self.blocks:
            x = block(x)
            y = block(y)
            loss += torch.nn.functional.mse_loss(x, y)
        return loss

In [None]:
# learning parameters
steps = 3   # Number of progressive step
beta = 0.7  # weight for MSE step
lamb = 0.7  # weight for Semi-Supervised learning
l_rate_g = 0.0002
#l_rate_d = 0.0002

gen_net = Gen(batch_size=batch_num, step_num=steps)
#dis_net = Disc(batch_size=batch_num)

# learning loss
MSE = nn.MSELoss()    # Mean Square Error
VGG = VGGPerceptualLoss(resize=True).to(device)
ADV = nn.BCELoss()    # Adversarial Loss : Binary CE

# optimizer
gen_optim = torch.optim.SGD(gen_net.parameters(), lr=l_rate_g, momentum=0.9)
#dis_optim = torch.optim.Adam(dis_net.parameters(), lr=l_rate_d)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))




In [None]:
# Load model params
PATH = '/content/drive/My Drive/KU/4/'
SAVE_PATH = PATH + 'ARGAN_BN/'
gen_PATH = PATH + 'ARGAN256_gen_temp.pth'
#dis_PATH = PATH + 'ARGAN256_dis_temp.pth'


trained = 0
if os.path.isfile(gen_PATH):
  gen_net.load_state_dict(torch.load(gen_PATH))
  trained = 0
"""
if os.path.isfile(dis_PATH):
  dis_net.load_state_dict(torch.load(dis_PATH))
  trained = 0
"""
if torch.cuda.is_available:
  print('CUDA available')
  device = "cuda:0"
gen_net.train()
#dis_net.train()

for epoch in range(10):
#    print('======================[%d epoch] running====================='
#          %(epoch+trained+1))

    # loss per epoch
    det_loss = 0.0
    rem_loss = 0.0
    adv_loss = 0.0

    for i, datas in enumerate(trainloader):

        det_err = 0.0
        rem_err = 0.0

        # from dataset
        image, matte, free = datas
        image = image.to(device)
        matte = matte.to(device)
        free = free.to(device)
        if torch.cuda.is_available:
          gen_net.cuda()

        # Generator Output
        mattes, frees = gen_net(image)

        with torch.autograd.set_detect_anomaly(True):
              """
              # train Discriminative Net
              dis_optim.zero_grad()
              # real data Error
              real_out = dis_net(free)
              real_label = torch.ones(free.shape[0],1).to(device)
              real_err = ADV(real_out, real_label)
              real_err.backward()
              # fake data Error
              fake_out = dis_net(frees[steps-1].detach())
              fake_label = torch.zeros(frees[steps-1].shape[0], 1).to(device)
              fake_err = ADV(fake_out, fake_label)
              fake_err.backward()

              dis_err = real_err + fake_err
              dis_optim.step()

#              real_loss += real_err
#              fake_loss += fake_err
              """

              # train Generative Network
              gen_optim.zero_grad()
              # for N steps
              for n in range(steps):
                  # detector loss : MSE
                  det_err += pow(beta, steps-n) * MSE(matte, mattes[n])
                  # removal loss : acc loss + perceptual loss
                  rem_err += pow(beta, steps-n) * MSE(free, frees[n])
                  rem_err += VGG(free, frees[n]) / 10

              # Adversarial loss
#              out = dis_net(frees[steps-1])
              #adv_err = ADV(out, real_label)

              total_loss = det_err + rem_err
              total_loss.backward()
              gen_optim.step()

              # loss per epoch
              det_loss += det_err
              rem_loss += rem_err
              #adv_loss += adv_err

              """ SAVE every batch
              img_fname = temp_path + str(i+1) + "_img.jpg"
              matt_fname = temp_path + str(i+1) + '_matt.jpg'
              fre_fname = temp_path + str(i+1) + '_free.jpg'

              img_out = image.cpu()
              save_batch(img_out, dprow, img_fname)
              matt_out = mattes[steps-1].cpu()
              save_batch(matt_out, dprow, matt_fname)
              free_out = frees[steps-1].cpu()
              save_batch(free_out, dprow, fre_fname)
              """

    # 1 epoch finished
    total = det_loss + rem_loss + adv_loss
    print('[%d epoch]\t det : %f, rem : %f, adv : %f, total : %f'
            %(epoch+trained+1, det_loss, rem_loss, adv_loss, total))
#    print('\t\t total loss = %f' %(det_loss + rem_loss + adv_loss))
    torch.save(gen_net.state_dict(), gen_PATH)
#    torch.save(dis_net.state_dict(), dis_PATH)

    img_fname = SAVE_PATH + str(epoch+trained+1) + "_img.jpg"
    mat_fname = SAVE_PATH + str(epoch+trained+1) + "_matt.jpg"
    fre_fname = SAVE_PATH + str(epoch+trained+1) + "_free.jpg"

    # data out : to 'cpu'
    img_out = image.cpu()
    save_batch(img_out, dprow, img_fname)
    matt_out = mattes[steps-1].cpu()
    save_batch(matt_out, dprow, mat_fname)
    free_out = frees[steps-1].cpu()
    save_batch(free_out, dprow, fre_fname)


CUDA available
[1 epoch]	 det : 93.487061, rem : 305.857819, adv : 0.000000, total : 399.344879
[2 epoch]	 det : 90.005196, rem : 266.892517, adv : 0.000000, total : 356.897705
[3 epoch]	 det : 85.225288, rem : 262.741516, adv : 0.000000, total : 347.966797


KeyboardInterrupt: ignored

7M for 1 epoch

# Test Network

In [None]:
# Model parameters
steps = 3   # Number of progressive step

gen_net = Gen(batch_size=batch_num, step_num=steps)
#dis_net = Disc(batch_size=batch_num)

Metrics

In [None]:
# Balanced Error Rate
class BERScore(nn.Module):
  def __init__(self, thresh=None):
    super().__init__()
    self.thr = thresh

  def forward(self, est, gt):
    est_b = (est>self.thr).float()

    conf_mat = est_b / gt

    tp = torch.sum(conf_mat == 1).item()
    fp = torch.sum(conf_mat == float('inf')).item()
    tn = torch.sum(torch.isnan(conf_mat)).item()
    fn = torch.sum(conf_mat == 0).item()

    sensit = tp / (tp+fn)
    specif = tn / (tn+fp)
    return (1 - (sensit + specif)/2 )

# Balanced Error Rate & sensitivity & specificity
class BERScores(nn.Module):
  def __init__(self, thresh=None):
    super().__init__()
    self.thr = thresh

  def forward(self, est, gt):
    est_b = (est>self.thr).float()

    conf_mat = est_b / gt

    tp = torch.sum(conf_mat == 1).item()
    fp = torch.sum(conf_mat == float('inf')).item()
    tn = torch.sum(torch.isnan(conf_mat)).item()
    fn = torch.sum(conf_mat == 0).item()

    sensit = tp / (tp+fn)
    specif = tn / (tn+fp)
    return sensit, specif, (1 - (sensit + specif)/2 )

# Root Mean Square
class RMSEScore(nn.Module):
  def __init__(self):
    super().__init__()
    self.mse = nn.MSELoss()

  def forward(self, est, gt):
    score = torch.sqrt(self.mse(est, gt))
    return float(score)

BER = BERScore(thresh=0.5)
BERs = BERScores(thresh=0.5)
RMSE = RMSEScore()

In [None]:
# image transforms
def rgb2lab(image):
  temp = image.cpu()
  out = torch.empty_like(temp)
  temp = np.transpose(temp, (0,2,3,1))
  for i in range(temp.shape[0]):
    out[i] = torch.FloatTensor(np.transpose(color.rgb2lab(np.array(temp[i])),
                                          (2,0,1)))
  out = out.to(device)
  return out

matt2src = transforms.Compose([
                               transforms.Resize(size=(480,640)),
])
free2src = transforms.Compose([
                               transforms.Resize(size=(480,640)),
                               transforms.Lambda(rgb2lab)
])


Network Evaluation with batch_size 4

In [None]:
# Generate Testset Output
if torch.cuda.is_available:
#  torch.cuda.empty_cache()
#  dis_net.cuda()
  gen_net.cuda()
  device = "cuda:0"

gen_net.eval()
#dis_net.eval()
for m in gen_net.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.track_running_stats = False

PATH = '/content/drive/My Drive/KU/4/'
gen_PATH = PATH + 'ARGAN256_gen_BN.pth'
#dis_PATH = PATH + 'ARGAN256_dis_net.pth'
SAVE_PATH = PATH + 'ARGAN256_test/zzz/'

if os.path.isfile(gen_PATH):
#  gen_net.load_state_dict(torch.load(gen_PATH, map_location=device))
  gen_net.load_state_dict(torch.load(gen_PATH))
else:
  print("No trained Gen Net")
"""
if os.path.isfile(dis_PATH):
  dis_net.load_state_dict(torch.load(dis_PATH, map_location=device))
else:
  print("No trained Disc Net")
"""
total_BER = []
total_RMSE = []
total_SRMSE = []

for i, datas in enumerate(testloader):
  with torch.no_grad():
    image, matte, free, src_matte, src_free = datas
    image = image.to(device)
    matte = matte.to(device)
    free = free.to(device)
    src_matte = src_matte.to(device)
    src_free = src_free.to(device)

    mattes, frees = gen_net(image)

    image_ = matt2src(image)
    mattes_ = matt2src(mattes[steps-1])
    frees_ = free2src(frees[steps-1])
    src_free_lab = free2src(src_free)

    ber = BER(mattes_, src_matte)
    rmse = RMSE(frees_, src_free_lab)
    s_rmse = RMSE(frees_ * src_matte, src_free_lab * src_matte)
    print("[%d batch]\tber : %f, rmse : %f, s_rmse : %f" %(i+1, ber, rmse, s_rmse))

    total_BER.append(ber)
    total_RMSE.append(rmse)
    total_SRMSE.append(s_rmse)

    if i%10 == 10:
      img_fname = SAVE_PATH + str(i) + "_img.jpg"
      mat_fname = SAVE_PATH + str(i) + "_matt.jpg"
      fre_fname = SAVE_PATH + str(i) + "_free.jpg"
      # data out : to 'cpu'
      img_out = image_.cpu()
      save_batch(img_out, dprow, img_fname)
      matt_out = mattes_.cpu()
      save_batch(matt_out, dprow, mat_fname)
      free_out = frees_.cpu()
      save_batch_LAB(free_out, dprow, fre_fname)

    torch.cuda.empty_cache()

avg_BER = np.mean(total_BER)
avg_RMSE = np.mean(total_RMSE)
avg_SRMSE = np.mean(total_SRMSE)

print("DONE============")
print("BER : %f" %(avg_BER))
print("RMSE : %f" %(avg_RMSE))
print("Shadow region RMSE : %f" %(avg_SRMSE))

[1 batch]	ber : 0.031697, rmse : 5.336680, s_rmse : 3.165593
[2 batch]	ber : 0.029675, rmse : 2.893386, s_rmse : 1.594625
[3 batch]	ber : 0.069604, rmse : 7.194751, s_rmse : 3.547878
[4 batch]	ber : 0.039225, rmse : 4.845188, s_rmse : 1.494015
[5 batch]	ber : 0.037525, rmse : 4.811147, s_rmse : 1.863109
[6 batch]	ber : 0.019533, rmse : 3.613688, s_rmse : 1.509352
[7 batch]	ber : 0.035284, rmse : 4.533610, s_rmse : 2.036932
[8 batch]	ber : 0.015640, rmse : 3.757565, s_rmse : 1.795089
[9 batch]	ber : 0.076878, rmse : 3.740987, s_rmse : 1.958181
[10 batch]	ber : 0.014604, rmse : 3.873204, s_rmse : 1.770228
[11 batch]	ber : 0.026468, rmse : 4.225317, s_rmse : 2.313511
[12 batch]	ber : 0.018267, rmse : 7.145534, s_rmse : 3.883172
[13 batch]	ber : 0.059398, rmse : 5.831346, s_rmse : 2.613397
[14 batch]	ber : 0.045773, rmse : 6.937132, s_rmse : 3.446935
[15 batch]	ber : 0.027680, rmse : 9.896238, s_rmse : 4.600909
[16 batch]	ber : 0.011324, rmse : 7.861331, s_rmse : 4.200420
[17 batch]	ber : 

Network Evaluation with batch_size 1

In [None]:
batch_num = 1
dprow = 1

test_img = ARGAN_Dataset(test_path, src_trans=img2tensor, matt_trans=matt2tensor, is_test=True)
testloader = torch.utils.data.DataLoader(test_img, batch_size=1, shuffle=False)

In [None]:
# Generate Testset Output
if torch.cuda.is_available:
#  torch.cuda.empty_cache()
#  dis_net.cuda()
  gen_net.cuda()
  device = "cuda:0"

gen_net.eval()
#dis_net.eval()
for m in gen_net.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.track_running_stats = False

PATH = '/content/drive/My Drive/KU/4/'
gen_PATH = PATH + 'ARGAN256_gen_BN.pth'
#dis_PATH = PATH + 'ARGAN256_dis_net.pth'
SAVE_PATH = PATH + 'ARGAN256_test/zzz/'

if os.path.isfile(gen_PATH):
#  gen_net.load_state_dict(torch.load(gen_PATH, map_location=device))
  gen_net.load_state_dict(torch.load(gen_PATH))
else:
  print("No trained Gen Net")
"""
if os.path.isfile(dis_PATH):
  dis_net.load_state_dict(torch.load(dis_PATH, map_location=device))
else:
  print("No trained Disc Net")
"""
total_BER = []
total_RMSE = []

for i, datas in enumerate(testloader):
  with torch.no_grad():
    image, matte, free, src_matte, src_free = datas
    image = image.to(device)
    matte = matte.to(device)
    free = free.to(device)
    src_matte = src_matte.to(device)
    src_free = src_free.to(device)

    mattes, frees = gen_net(image)

    image_ = matt2src(image)
    mattes_ = matt2src(mattes[steps-1])
    frees_ = free2src(frees[steps-1])
    src_free_lab = free2src(src_free)

    ber = BER(mattes_, src_matte)
    rmse = RMSE(frees_, src_free_lab)
    print("[%d image]\tber : %f, rmse : %f" %(i+1, ber, rmse))

    total_BER.append(ber)
    total_RMSE.append(rmse)


    img_fname = SAVE_PATH + str(i) + "_img.jpg"
    mat_fname = SAVE_PATH + str(i) + "_matt.jpg"
    fre_fname = SAVE_PATH + str(i) + "_free.jpg"
    gt_m_fname = SAVE_PATH + str(i) + "_matt_gt.jpg"
    gt_f_fname = SAVE_PATH + str(i) + "_free_gt.jpg"

    # data out : to 'cpu'
    img_out = image_.cpu()
    save_batch(img_out, dprow, img_fname)
    matt_out = mattes_.cpu()
    save_batch(matt_out, dprow, mat_fname)
    free_out = frees_.cpu()
    save_batch_LAB(free_out, dprow, fre_fname)
    gt_m_out = src_matte.cpu()
    save_batch(gt_m_out, dprow, gt_m_fname)
    gt_f_out = src_free.cpu()
    save_batch(gt_f_out, dprow, gt_f_fname)

    torch.cuda.empty_cache()

avg_BER = np.mean(total_BER)
avg_RMSE = np.mean(total_RMSE)

print("DONE============")
print("BER : %f" %(avg_BER))
print("RMSE : %f" %(avg_RMSE))

[1 image]	ber : 0.026574, rmse : 4.532673
[2 image]	ber : 0.020301, rmse : 7.038940
[3 image]	ber : 0.042702, rmse : 5.012364
[4 image]	ber : 0.036504, rmse : 4.324929
[5 image]	ber : 0.033626, rmse : 2.895464
[6 image]	ber : 0.019955, rmse : 2.931724
[7 image]	ber : 0.024172, rmse : 2.855391
[8 image]	ber : 0.040451, rmse : 2.890459
[9 image]	ber : 0.024752, rmse : 5.514907
[10 image]	ber : 0.081063, rmse : 6.431958
[11 image]	ber : 0.103735, rmse : 10.223772
[12 image]	ber : 0.075263, rmse : 5.545083
[13 image]	ber : 0.040020, rmse : 4.391103
[14 image]	ber : 0.041718, rmse : 4.524572
[15 image]	ber : 0.039346, rmse : 4.404221
[16 image]	ber : 0.037194, rmse : 5.895140
[17 image]	ber : 0.035893, rmse : 5.775198
[18 image]	ber : 0.041946, rmse : 4.461034
[19 image]	ber : 0.038854, rmse : 4.504193
[20 image]	ber : 0.033634, rmse : 4.364290
[21 image]	ber : 0.006324, rmse : 4.265287
[22 image]	ber : 0.034213, rmse : 3.063166
[23 image]	ber : 0.037318, rmse : 3.506435
[24 image]	ber : 0.

# For test

In [None]:
# Evaluation with size recovering
# No images saved
if torch.cuda.is_available:
  gen_net.cuda()
  device = "cuda:0"

gen_net.eval()
for m in gen_net.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.track_running_stats = False

PATH = '/content/drive/My Drive/KU/4/'
#gen_PATH = PATH + 'ARGAN256_gen_BL.pth'
gen_PATH = PATH + 'ARGAN256_gen_BL.pth'

if os.path.isfile(gen_PATH):
  gen_net.load_state_dict(torch.load(gen_PATH))
else:
  print("No trained Gen Net")

total_BER = []
total_sens = []
total_spec = []
total_RMSE = []

for i, datas in enumerate(testloader):
  with torch.no_grad():
    image, matte, free, src_matte, src_free = datas
    image = image.to(device)
    matte = matte.to(device)
    free = free.to(device)
    src_matte = src_matte.to(device)
    src_free = src_free.to(device)

    mattes, frees = gen_net(image)

    image_ = matt2src(image)
    mattes_ = matt2src(mattes[steps-1])
    frees_ = free2src(frees[steps-1])
    src_free_lab = free2src(src_free)

    sens, spec, ber = BERs(mattes_, src_matte)
    rmse = RMSE(frees_, src_free_lab)
    print("[%d image]\tber : %f (%f / %f), rmse : %f" %(i+1, ber, sens, spec, rmse))

    total_BER.append(ber)
    total_sens.append(sens)
    total_spec.append(spec)
    total_RMSE.append(rmse)

    torch.cuda.empty_cache()

avg_BER = np.mean(total_BER)
avg_sens = np.mean(total_sens)
avg_spec = np.mean(total_spec)
avg_RMSE = np.mean(total_RMSE)

print("DONE============")
print("BER : %f" %(avg_BER))
print("sens : %f" %(avg_sens))
print("spec : %f" %(avg_spec))
print("RMSE : %f" %(avg_RMSE))

[1 image]	ber : 0.006055 (0.997673 / 0.990217), rmse : 4.008686
[2 image]	ber : 0.018797 (0.981029 / 0.981376), rmse : 6.761482
[3 image]	ber : 0.034071 (0.942675 / 0.989183), rmse : 4.594495
[4 image]	ber : 0.015643 (0.989570 / 0.979145), rmse : 3.413022
[5 image]	ber : 0.021569 (0.976200 / 0.980662), rmse : 2.496001
[6 image]	ber : 0.016291 (0.974692 / 0.992726), rmse : 2.792156
[7 image]	ber : 0.023387 (0.965468 / 0.987758), rmse : 2.625678
[8 image]	ber : 0.022112 (0.973510 / 0.982265), rmse : 2.478117
[9 image]	ber : 0.001413 (0.998541 / 0.998633), rmse : 4.359787
[10 image]	ber : 0.020153 (0.999948 / 0.959746), rmse : 4.880087
[11 image]	ber : 0.039770 (0.934414 / 0.986046), rmse : 9.831825
[12 image]	ber : 0.022752 (0.990128 / 0.964369), rmse : 4.109680
[13 image]	ber : 0.033081 (0.985491 / 0.948346), rmse : 4.084582
[14 image]	ber : 0.028034 (0.992741 / 0.951190), rmse : 4.086577
[15 image]	ber : 0.028017 (0.992450 / 0.951515), rmse : 3.988919
[16 image]	ber : 0.023324 (0.99867