In [1]:
import time
import cv2
import datetime
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
import torchvision
from torchvision import models
from torchvision.models.vgg import VGG
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import sys
if sys.version_info[0] == 2:
    import xml.etree.cElementTree as ET
else:
    import xml.etree.ElementTree as ET

In [2]:
def one_hot(x, bins):
    '''
    Convert tensor x to one-hot tensor
    '''
    x = x.numpy()
    idxs = np.digitize(x, bins, right=True)
    idxs = idxs.reshape(-1,1)
    z = torch.zeros(len(x), len(bins)+1).scatter_(1, torch.tensor(idxs), 1)
    return z

In [3]:
class Conv(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=4, stride=2, padding=1, dropout=0.0, bn=True):
        super(Conv, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if bn:
            layers.append(nn.BatchNorm2d(out_size))
        #layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class DeConv(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=4, stride=2, padding=1, dropout=0.0):
        super(DeConv, self).__init__()
        layers = [  nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(out_size),
                    nn.ReLU(inplace=True)]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

    '''def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x'''

In [4]:
n_encode = 64
n_gen = 64
latent_s = 100

In [5]:
class Encoder(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(Encoder, self).__init__()
        self.n_encode = 32
        self.conv1 = Conv(in_channels, n_encode, 5, 2, 1)
        self.conv2 = Conv(n_encode, 2*n_encode, 5, 2, 1)
        self.conv3 = Conv(2*n_encode, 4*n_encode, 5, 2, 1)
        self.conv4 = Conv(4*n_encode, 8*n_encode, 5, 2, 1, dropout=0.5)
        self.fc1 = nn.Sequential(nn.Linear(8*n_encode*8*8, latent_s))
        self.fc2 = nn.Sequential(nn.Linear(4096, 100))
    def forward(self, x):
        c1 = F.relu(self.conv1(x))
        c2 = F.relu(self.conv2(c1))
        c3 = F.relu(self.conv3(c2))
        c4 = F.relu(self.conv4(c3))
        c = c4.view(-1, 8*n_encode*8*8)
        f1 = self.fc1(c)
        out = f1
        '''f = d5.view(d5.size(0), -1)
        l1 = F.relu(self.linear1(f))
        l2 = F.relu(self.linear2(l1))
        y = l2.unsqueeze(-1).unsqueeze(-1)'''
        '''print("x : ", x.shape)
        print("c1: ", c1.shape)
        print("c2: ", c2.shape)
        print("c3: ", c3.shape)
        print("c4: ", c4.shape)
        print("c : ", c.shape)
        print("f1: ", f1.shape)
        print("out: ", out.shape)'''
        return out

x = torch.zeros([1, 3, 128, 128])
x.shape
e = Encoder()
y = e(x)
print(y.shape)

cls = torch.tensor(2).unsqueeze(0)
cls = one_hot(cls, [0, 1])
print(cls, cls.shape)

z = torch.cat((y, cls), 1)
z.shape

In [6]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=2):
        super(Generator, self).__init__()
        #in_channels, out_channels, kernel_size, stride, padding
        self.up1 = DeConv(8*n_gen, 4*n_gen, 4, 2, 0, 0.0)
        self.up2 = DeConv(4*n_gen, 2*n_gen, 4, 2, 1, 0.0)
        self.up3 = DeConv(2*n_gen, 1*n_gen, 4, 2, 1, 0.0)
        self.up4 = DeConv(1*n_gen, out_channels, 4, 2, 1, 0.5)
        self.fc1 = nn.Sequential(nn.Linear(latent_s + 3, 8*8*n_gen*8))
        
    def forward(self, x, z):
        z = torch.tensor(z).unsqueeze(0).cpu()
        z = one_hot(z, [0, 1])
        if cuda:
            z = z.cuda()
        x = torch.cat((x, z), 1)
        fc = self.fc1(x).view(-1, 8*n_gen,8,8)
        u1 = F.relu(self.up1(fc))
        u2 = F.relu(self.up2(u1))
        u3 = F.relu(self.up3(u2))
        u4 = F.relu(self.up4(u3))
        y = u4
        '''print("x : ", x.shape)
        print("fc1: ", fc.shape)
        print("u1: ", u1.shape)
        print("u2: ", u2.shape)
        print("u3: ", u3.shape)
        print("u4: ", u4.shape)
        print("u5: ", u5.shape)
        print("y: ", y.shape)'''
        return y

G = Generator()
#print(z.shape)
g = G(y, 3)
print(g.shape)

In [7]:
class Discriminator(nn.Module):
    def __init__(self, y_size, conv_dim=32):
        super(Discriminator, self).__init__()
        self.conv_dim = conv_dim
        self.y_size = y_size
        self.conv1 = Conv(2 + 3 + y_size, conv_dim, 4)
        self.conv2 = Conv(conv_dim, conv_dim*2, 4)
        self.conv3 = Conv(conv_dim*2, conv_dim*4, 4)
        self.conv4 = Conv(conv_dim*4, conv_dim*8, 4, bn=False)
        self.fc1 = nn.Sequential(
            nn.Linear(conv_dim*8*8*8, 1024),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 1),
            nn.Softmax()
        )
            
    def forward(self, img_A, img_B, z):
        img_input = torch.cat((img_A, img_B), 1)
        z = torch.tensor(z).unsqueeze(0).cpu()
        z = one_hot(z, [0, 1])
        if cuda:
            z = z.cuda()
        z = z.view(-1,z.size()[-1],1,1)
        z = z.expand(-1,-1,img_input.size()[-2], img_input.size()[-1])
        x = torch.cat((img_input, z), 1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(1, -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

D = Discriminator(3)

img = torch.zeros([1, 2, 128, 128])
d = D(img, img, 2)
print(d.shape)
print(d)

In [8]:
class ImageDataset(Dataset):
    def __init__(self, root = "/media/arg_ws3/5E703E3A703E18EB/data/subt_all/", mode='train'):
        transforms_ = [ transforms.Resize((128, 128), Image.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]

        self.transform_ = transforms.Compose(transforms_)
        self.root = root
        self.files = []
        self.obj_class = ['bb_extinguisher', 'bb_drill', 'bb_backpack']
        for line in open(os.path.join(root, mode + '.txt')):
            self.files.append(line.strip())

    def __getitem__(self, index):
        idx = index % len(self.files)
        ann_path = self.root + 'Annotations/' + self.files[idx] + '.xml'
        mask_path = self.root + 'mask/' + self.files[idx] + '.png'
        rgb_path = self.root + 'image/' + self.files[idx] + '.jpg'

        bbx = self.get_ann(ann_path)

        img_mask = Image.open(mask_path).crop(bbx).resize((128, 128), Image.ANTIALIAS)
        img_rgb = Image.open(rgb_path).crop(bbx).resize((128, 128), Image.ANTIALIAS)

        img_mask = img_mask.convert('L')

        img_mask = np.array(img_mask)

        img_rgb = np.array(img_rgb)

        if np.random.random() < 0.5:
            img_mask = np.array(img_mask)[::-1, :]
            img_rgb = np.array(img_rgb)[::-1, :]

        # 1 channel to 2 channel classifier (one hot encoder)
        n_class = 2
        h, w = img_mask.shape
        target = torch.zeros(n_class, h, w)
        clss = img_mask[img_mask!=0]
        cls_counts = np.bincount(clss)
        label = np.argmax(cls_counts)
        img_mask[img_mask!=0] = 1 # 255 to 1
        img_mask = torch.from_numpy(img_mask.copy()).long()
        for i in range(n_class):
            target[i][img_mask == i] = 1

        img_rgb = self.transform_(Image.fromarray(img_rgb))

        return {'A': target, 'B': img_rgb, 'C': label}

    def get_ann(self, ann_path):
        target = ET.parse(ann_path).getroot()
        res = []
        for obj in target.iter('object'):
            name = obj.find('name').text.lower().strip()
            if name not in self.obj_class:
                continue
            bbox = obj.find('bndbox')
            if bbox is not None:
                pts = ['xmin', 'ymin', 'xmax', 'ymax']
                bndbox = []
                for i, pt in enumerate(pts):
                    cur_pt = int(bbox.find(pt).text) - 1
                    # scale height or width
                    #cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                    bndbox.append(cur_pt)
                res += [bndbox]  # [xmin, ymin, xmax, ymax]
            else: # For LabelMe tool
                polygons = obj.find('polygon')
                x = []
                y = []
                bndbox = []
                for polygon in polygons.iter('pt'):
                    # scale height or width
                    x.append(int(polygon.find('x').text))
                    y.append(int(polygon.find('y').text))
                bndbox.append(min(x))
                bndbox.append(min(y))
                bndbox.append(max(x))
                bndbox.append(max(y))
                res += [bndbox] # [xmin, ymin, xmax, ymax]
        return res[0]

    def __len__(self):
        return len(self.files)

In [9]:
dataloader = DataLoader(ImageDataset("/media/arg_ws3/5E703E3A703E18EB/data/subt_all/"),
                        batch_size=1, shuffle=True, num_workers=4)

val_dataloader = DataLoader(ImageDataset("/media/arg_ws3/5E703E3A703E18EB/data/subt_all/", mode='val'),
                            batch_size=1, shuffle=True, num_workers=1)

In [10]:
dataset_name = 'test1'
root = '/media/arg_ws3/5E703E3A703E18EB/research/InstaceGAN/'
os.makedirs(root + 'images/%s' % dataset_name, exist_ok=True)
os.makedirs(root + 'saved_models/%s' % dataset_name, exist_ok=True)

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
print("CUDA: ", cuda)

# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Initialize generator and discriminator
G = Generator()
E = Encoder()
D = Discriminator(3)

if cuda:
    G = G.cuda()
    E = E.cuda()
    D = D.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()
    
# Optimizers
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_E = torch.optim.Adam(E.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

CUDA:  True


In [43]:
def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    imgs = next(iter(val_dataloader))
    r_rgb = Variable(imgs['B'].type(Tensor)) # rgb
    r_mask = Variable(imgs['A'].type(Tensor)) # mask
    r_label = Variable(imgs['C'].type(Tensor)) # label
    
    y = E(r_rgb)
    
    f_mask = G(y, r_label)
    
    f_mask = f_mask.data.cpu().numpy()
    N, _, h, w = f_mask.shape
    f_mask = f_mask.transpose(0, 2, 3, 1).reshape(-1, 2).argmax(axis = 1).reshape(N, h, w)
    
    #fake_B_d = f_mask[:,1,:,:].unsqueeze(1).data
    fake_B_d = torch.tensor(f_mask, dtype=torch.float).unsqueeze(1).cuda()
    fake_B_d = torch.cat((fake_B_d, fake_B_d, fake_B_d), 1)
    real_B_d = r_mask[:,1,:,:].unsqueeze(1).data
    real_B_d = torch.cat((real_B_d, real_B_d, real_B_d), 1)
    img_sample = torch.cat((r_rgb.data, fake_B_d, real_B_d), 0)
    save_image(img_sample, root + 'images/%s/%s.png' % (dataset_name, batches_done), nrow=5, normalize=False)

sample_images(0)

  if sys.path[0] == '':


In [12]:
prev_time = time.time()
n_epochs = 50
for epoch in range(n_epochs):
    for i, batch in enumerate(dataloader):
        # Model inputs
        r_rgb = Variable(batch['B'].type(Tensor)) # rgb
        r_mask = Variable(batch['A'].type(Tensor)) # mask
        r_label = Variable(batch['C'].type(Tensor)) # label
        #print(r_rgb.shape, r_mask.shape, r_label.shape)
        # Adversarial ground truths
        #valid = Variable(Tensor(np.ones((real_B.size(0), *patch))), requires_grad=False)
        #fake = Variable(Tensor(np.zeros((real_B.size(0), *patch))), requires_grad=False)
        valid = Variable(Tensor(np.ones(1)), requires_grad=False) #1
        fake = Variable(Tensor(np.zeros(1)), requires_grad=False) #0
        
        # ------------------
        #  Train Generators
        # ------------------
        optimizer_G.zero_grad()
        optimizer_E.zero_grad()

        y = E(r_rgb)
        # GAN loss
        f_mask = G(y, r_label)
        f_pred = D(f_mask, r_rgb, r_label)
        loss_GAN = criterion_GAN(f_pred, valid)
        # Pixel-wise loss
        
        loss_pixel = criterion_pixelwise(f_mask, r_mask)

        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel
        #loss_G = loss_GAN

        loss_G.backward()

        optimizer_E.step()
        optimizer_G.step()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Real loss
        r_pred = D(r_mask, r_rgb, r_label)
        loss_real = criterion_GAN(r_pred, valid)

        # Fake loss
        f_pred = D(f_mask.detach(), r_rgb, r_label)
        loss_fake = criterion_GAN(f_pred, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        if i % 1000 == 0:
            sample_images(batches_done)
        if i % 500 == 0:
            print("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s" %
                                (epoch, n_epochs,
                                i, len(dataloader),
                                loss_D.item(), loss_G.item(),
                                time_left))
    if epoch % 2 == 0 and epoch != 0:
        torch.save(G.state_dict(), root + 'saved_models/%s/G_%d.pth' % (dataset_name, epoch))
        torch.save(E.state_dict(), root + 'saved_models/%s/E_%d.pth' % (dataset_name, epoch))
        torch.save(D.state_dict(), root + 'saved_models/%s/D_%d.pth' % (dataset_name, epoch))

  if sys.path[0] == '':
  input = module(input)


[Epoch 0/50] [Batch 0/3405] [D loss: 0.500000] [G loss: 57.224880] ETA: 6:37:38.682275
[Epoch 0/50] [Batch 500/3405] [D loss: 0.500000] [G loss: 36.090855] ETA: 1:17:14.114325
[Epoch 0/50] [Batch 1000/3405] [D loss: 0.500000] [G loss: 38.851040] ETA: 1:19:54.867337
[Epoch 0/50] [Batch 1500/3405] [D loss: 0.500000] [G loss: 33.046883] ETA: 1:19:38.207839
[Epoch 0/50] [Batch 2000/3405] [D loss: 0.500000] [G loss: 32.442081] ETA: 1:19:27.981350
[Epoch 0/50] [Batch 2500/3405] [D loss: 0.500000] [G loss: 32.264893] ETA: 1:20:30.201924
[Epoch 0/50] [Batch 3000/3405] [D loss: 0.500000] [G loss: 32.167072] ETA: 1:18:58.127410
[Epoch 1/50] [Batch 0/3405] [D loss: 0.500000] [G loss: 32.698483] ETA: 7:38:37.803183
[Epoch 1/50] [Batch 500/3405] [D loss: 0.500000] [G loss: 32.411083] ETA: 1:18:57.831686
[Epoch 1/50] [Batch 1000/3405] [D loss: 0.500000] [G loss: 31.828892] ETA: 1:18:12.511849
[Epoch 1/50] [Batch 1500/3405] [D loss: 0.500000] [G loss: 32.739037] ETA: 1:19:36.878446
[Epoch 1/50] [Batc

[Epoch 13/50] [Batch 500/3405] [D loss: 0.500000] [G loss: 28.264120] ETA: 0:59:07.940413
[Epoch 13/50] [Batch 1000/3405] [D loss: 0.500000] [G loss: 28.014576] ETA: 0:58:47.307376
[Epoch 13/50] [Batch 1500/3405] [D loss: 0.500000] [G loss: 26.628527] ETA: 0:59:27.391287
[Epoch 13/50] [Batch 2000/3405] [D loss: 0.500000] [G loss: 26.947868] ETA: 1:00:01.748548
[Epoch 13/50] [Batch 2500/3405] [D loss: 0.500000] [G loss: 28.153154] ETA: 0:59:22.875806
[Epoch 13/50] [Batch 3000/3405] [D loss: 0.500000] [G loss: 29.096348] ETA: 0:58:56.104932
[Epoch 14/50] [Batch 0/3405] [D loss: 0.500000] [G loss: 29.339569] ETA: 5:12:41.739120
[Epoch 14/50] [Batch 500/3405] [D loss: 0.500000] [G loss: 28.914993] ETA: 0:58:27.930222
[Epoch 14/50] [Batch 1000/3405] [D loss: 0.500000] [G loss: 27.159119] ETA: 0:57:46.170230
[Epoch 14/50] [Batch 1500/3405] [D loss: 0.500000] [G loss: 29.475838] ETA: 0:58:19.345207
[Epoch 14/50] [Batch 2000/3405] [D loss: 0.500000] [G loss: 27.299934] ETA: 0:56:57.479396
[Epo

[Epoch 26/50] [Batch 500/3405] [D loss: 0.500000] [G loss: 28.279284] ETA: 0:38:18.587904
[Epoch 26/50] [Batch 1000/3405] [D loss: 0.500000] [G loss: 29.080725] ETA: 0:38:14.060097
[Epoch 26/50] [Batch 1500/3405] [D loss: 0.500000] [G loss: 29.016769] ETA: 0:38:07.997761
[Epoch 26/50] [Batch 2000/3405] [D loss: 0.500000] [G loss: 26.881742] ETA: 0:38:43.230515
[Epoch 26/50] [Batch 2500/3405] [D loss: 0.500000] [G loss: 28.785700] ETA: 0:40:25.006447
[Epoch 26/50] [Batch 3000/3405] [D loss: 0.500000] [G loss: 27.328199] ETA: 0:37:09.938049
[Epoch 27/50] [Batch 0/3405] [D loss: 0.500000] [G loss: 27.030176] ETA: 15:25:02.116005
[Epoch 27/50] [Batch 500/3405] [D loss: 0.500000] [G loss: 28.343899] ETA: 0:38:12.611792
[Epoch 27/50] [Batch 1000/3405] [D loss: 0.500000] [G loss: 26.711241] ETA: 0:36:48.368527
[Epoch 27/50] [Batch 1500/3405] [D loss: 0.500000] [G loss: 28.961622] ETA: 0:36:30.460700
[Epoch 27/50] [Batch 2000/3405] [D loss: 0.500000] [G loss: 28.840330] ETA: 0:36:22.698257
[Ep

[Epoch 39/50] [Batch 500/3405] [D loss: 0.500000] [G loss: 26.991034] ETA: 0:18:11.326991
[Epoch 39/50] [Batch 1000/3405] [D loss: 0.500000] [G loss: 26.512754] ETA: 0:17:15.528567
[Epoch 39/50] [Batch 1500/3405] [D loss: 0.500000] [G loss: 29.509199] ETA: 0:17:00.014166
[Epoch 39/50] [Batch 2000/3405] [D loss: 0.500000] [G loss: 27.034027] ETA: 0:16:44.198117
[Epoch 39/50] [Batch 2500/3405] [D loss: 0.500000] [G loss: 28.653898] ETA: 0:16:35.478581
[Epoch 39/50] [Batch 3000/3405] [D loss: 0.500000] [G loss: 28.139019] ETA: 0:16:52.340051
[Epoch 40/50] [Batch 0/3405] [D loss: 0.500000] [G loss: 27.259836] ETA: 1:30:40.794003
[Epoch 40/50] [Batch 500/3405] [D loss: 0.500000] [G loss: 28.470612] ETA: 0:16:16.167047
[Epoch 40/50] [Batch 1000/3405] [D loss: 0.500000] [G loss: 28.027111] ETA: 0:15:40.722048
[Epoch 40/50] [Batch 1500/3405] [D loss: 0.500000] [G loss: 27.685541] ETA: 0:15:46.023488
[Epoch 40/50] [Batch 2000/3405] [D loss: 0.500000] [G loss: 28.177256] ETA: 0:15:51.886308
[Epo