In [0]:
#http://dl.yf.io/lsun/scenes/
"""
bridge_train_lmdb.zip                              12-Mar-2017 23:01     15G
church_outdoor_train_lmdb.zip                      12-Mar-2017 23:03      2G
classroom_train_lmdb.zip                           12-Mar-2017 23:04      3G
conference_room_train_lmdb.zip                     12-Mar-2017 23:06      4G
dining_room_train_lmdb.zip                         12-Mar-2017 23:12     11G
restaurant_train_lmdb.zip                          12-Mar-2017 23:46     13G
tower_train_lmdb.zip                               12-Mar-2017 23:52     11G
"""
%mkdir data
!git clone https://github.com/fyu/lsun
%cd lsun
!python download.py -c conference_room -o ../data/ 
%cd ../data
!unzip conference_room_train_lmdb.zip && rm conference_room_train_lmdb.zip
%cd ..

In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

import matplotlib.pyplot as plt
import numpy as np
import random
import math

import os

In [0]:
data_PATH = "/content/data/"
log_PATH = os.path.join("/gdrive","My Drive","notebooks", "logs","wgan-gp")
modelName = "WGAN-gp_DCGAN_ln_conference_lsun"

batch_size =64
workers = 2
epochs = 15

latent_size=100

gf_dim = 64
df_dim = 64

in_h = 64
in_w =64
c_dim = 3

n_critic = 5 # the number of iterations of the critic per generator iteration

learning_rate = 0.0001
beta1=0.5
beta2=0.9

gp_lambda = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

manualSeed = 3734
print("Random Seed: ",manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [0]:
transform = transforms.Compose(
    [
     transforms.Resize((in_h,in_w)),
     transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
     ]
)
def transform_inverse (y):
  t= None
  if y.size()[0]==1:#1-dim
    t=torchvision.transforms.Normalize((-1,),(2,))
  else :#3-dim
    t=torchvision.transforms.Normalize((-1,-1,-1),(2,2,2))
  return t(y)

def batch_transform_inverse(y):
  x = y.new(*y.size())
  if y.size()[1]==1:
    x[:, 0, :, :] = y[:, 0, :, :] * 2 - 1
  else:
    x[:, 0, :, :] = y[:, 0, :, :] * 2 - 1
    x[:, 1, :, :] = y[:, 1, :, :] * 2 - 1 
    x[:, 2, :, :] = y[:, 2, :, :] * 2 - 1
  return x

In [0]:
lsun_dataset=torchvision.datasets.LSUN(root = data_PATH,
                                             classes= ['conference_room_train'],
                                 transform=transform)
train_loader = torch.utils.data.DataLoader(lsun_dataset,batch_size=batch_size,
                                          shuffle =True, num_workers=workers)
print(lsun_dataset)

In [0]:
real_batch = next(iter(train_loader))
plt.figure(figsize=(8,8))
plt.axis('off')
plt.title('Training Images')
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device),padding=2, normalize=True).cpu(),(1,2,0)))

print(real_batch[0].size())

In [0]:
import torch.nn as nn
def norm_layer(out_shape, mode='bn'):
  if len(out_shape)==4:
    out_shape = out_shape[1:]
  if mode=='bn':
    return nn.BatchNorm2d(out_shape[0],momentum=0.1,eps=1e-5)
  elif mode == 'ln':
    return nn.LayerNorm(out_shape[:],eps=1e-5)
  elif mode == 'in':
    return nn.InstanceNorm2d(output_shape[0],momentum=0.1,eps=1e-5)
  else:
    raise NameError("'%s' is not valid normalization type"%mode)


def conv_bn_layer(in_channels,out_channels,kernel_size,stride=1,padding=0):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,bias=False),
        nn.BatchNorm2d(out_channels,momentum=0.1,eps=1e-5),
    )


def tconv_bn_layer(in_channels,out_channels,kernel_size,stride=1,padding=0):
  return nn.Sequential(
      nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,bias=False),
      nn.BatchNorm2d(out_channels,momentum=0.1,eps=1e-5),
  )
def tconv_layer(in_channels,out_channels,kernel_size,stride=1,padding=0):
  return nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding)

def conv_layer(in_channels,out_channels,kernel_size,stride=1,padding=0):
    return nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding)

def fc_layer(in_features,out_features):
  return nn.Linear(in_features,out_features)

def fc_bn_layer(in_features,out_features):
  return nn.Sequential(
      nn.Linear(in_features,out_features,bias=False),
      nn.BatchNorm1d(out_features)
  )

In [0]:
def conv_out_size_same(size, stride):
  return int(math.ceil(float(size) / float(stride)))
s_h, s_w = in_h, in_w
s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)

In [0]:
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    self.fc_bn_layer1 = fc_bn_layer(latent_size,s_h16*s_w16*gf_dim*8)
    self.up_sample_layer2 = tconv_bn_layer(gf_dim*8,gf_dim*4,4,stride=2,padding=1)
    self.up_sample_layer3 = tconv_bn_layer(gf_dim*4,gf_dim*2,4,stride=2,padding=1)
    self.up_sample_layer4 = tconv_bn_layer(gf_dim*2,gf_dim,4,stride=2,padding=1)
    self.up_sample_layer5 = tconv_layer(gf_dim,c_dim,4,stride=2,padding=1)
    self.tanh = nn.Tanh()

  def forward(self, x):
    x = F.relu(self.fc_bn_layer1(x))
    x = x.view(-1,gf_dim*8,s_h16,s_w16)
    x = F.relu(self.up_sample_layer2(x))
    x = F.relu(self.up_sample_layer3(x))
    x = F.relu(self.up_sample_layer4(x))
    x = self.tanh(self.up_sample_layer5(x))
    return x

In [0]:
"""
batch normalization changes the form of the discriminator's problem from mapping a single input to a single output to mapping from an entire batch of inputs to a batch of outputs.
Since we penalize the norm of the critic's gradient with respect to each input independently, and not the entire batch, our penalized training objective is no longer valid in this setting.
"""
class Critic(nn.Module):
  def __init__(self):
    super(Critic,self).__init__()
    self.down_sample_layer1 = conv_layer(c_dim,df_dim,4,stride=2,padding=1)
    self.down_sample_layer2 = conv_layer(df_dim,df_dim*2,4,stride=2,padding=1)
    self.norm_layer2 = norm_layer([df_dim*2,s_h4,s_w4],mode='ln')
    self.down_sample_layer3 = conv_layer(df_dim*2,df_dim*4,4,stride=2,padding=1)
    self.norm_layer3 = norm_layer([df_dim*4,s_h8,s_w8],mode='ln')
    self.down_sample_layer4 = conv_layer(df_dim*4,df_dim*8,4,stride=2,padding=1)
    self.norm_layer4 = norm_layer([df_dim*8,s_h16,s_w16],mode='ln')
    self.fc_layer5 = fc_layer(df_dim*8*s_h16*s_w16,1)


  def forward(self, x):
    x = F.leaky_relu(self.down_sample_layer1(x),0.2)
    x = F.leaky_relu(self.norm_layer2(self.down_sample_layer2(x)),0.2)
    x = F.leaky_relu(self.norm_layer3(self.down_sample_layer3(x)),0.2)
    x = F.leaky_relu(self.norm_layer4(self.down_sample_layer4(x)),0.2)
    x = x.flatten(1)
    x = self.fc_layer5(x)
    return x


In [0]:
import torch.optim as optim

G = Generator().to(device)
fw = Critic().to(device)

G_optimizer = optim.Adam(G.parameters(),lr=learning_rate,betas=(beta1,beta2),weight_decay=1e-3) # when using layer(or batch) normalization, weight decaying is recommended
critic_optimizer = optim.Adam(fw.parameters(),lr=learning_rate,betas=(beta1,beta2),weight_decay=1e-3)

L2_criterion = nn.MSELoss()

fixed_noise = torch.randn(batch_size, latent_size,device=device)

In [0]:
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv')!=-1:
    nn.init.normal_(m.weight.data,0.0,0.02)
  elif classname.find('Norm')!=-1:
    nn.init.normal_(m.weight.data,1.0,0.02)

print(G.apply(weights_init))
print(fw.apply(weights_init))

In [0]:
with torch.no_grad():
  fake_batch=G(fixed_noise)
plt.figure(figsize=(8,8))
plt.axis('off')
plt.title('fake Images Before learning')
plt.imshow(np.transpose(vutils.make_grid(fake_batch.to(device).view(64,c_dim,in_h,in_w),padding=2, normalize=True).cpu(),(1,2,0)))

In [0]:
img_list = []

G_losses = []
critic_losses = []

w_losses = []
gp_losses =[]

iter_per_plot = 500
plot_per_eps=(int(len(train_loader)/iter_per_plot))

transform_PIL=transforms.ToPILImage()

In [0]:
import pickle
def log_list_save(l,file_name):
  with open(os.path.join(log_PATH ,file_name+".logs"), "wb") as fp:
    pickle.dump(l, fp)

def log_list_load(file_name):
  with open(os.path.join(log_PATH ,file_name+".logs"), "rb") as fp:
    return pickle.load(fp)

In [0]:
import torch.autograd as autograd
from torch.autograd import Variable

for ep in range(epochs):
  for i, (real_data, _) in enumerate(train_loader):
    b_size=real_data.shape[0]
    real_data = real_data.to(device)
    z = torch.randn(b_size,latent_size).to(device)
    fake_data = G(z)
    alpha = torch.rand([b_size,1,1,1],device=device)#for sampling_distribution
    alpha = alpha.expand(real_data.size())

    #Train critic function
    fw.zero_grad()

    # calculate gradient penalty

    interpolates = alpha * real_data.data  + (1-alpha)*fake_data.data
    interpolates=Variable(interpolates,requires_grad=True)

    disc_interpolates= fw(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones(disc_interpolates.size(),device=device),
                                  create_graph=True)[0].view(interpolates.size(0),-1)
    slopes = gradients.norm(2,dim=1)
    lipschitz_gradient_norm=torch.ones(slopes.size(),device=device)
    
    loss_gp = gp_lambda*L2_criterion(slopes,lipschitz_gradient_norm)
    
    real_critic = fwx = fw(real_data)
    fake_critic = fwg = fw(fake_data.detach())
    loss_w = -(fwx.mean()-fwg.mean())

    loss_critic = loss_w + loss_gp
    loss_critic.backward()
    critic_optimizer.step()

    #Train G
    if (i+1)%n_critic==0:
      G.zero_grad()
      fwg = fw(fake_data)
      
      loss_G = -fwg.mean()

      loss_G.backward()
      G_optimizer.step()

    if (i+1)%iter_per_plot == 0:
      print('Epoch [{}/{}], Step [{}/{}], critic_loss: {:.4f}, g_loss: {:.4f}, fw(x): {:.4f}, fw(G(z)): {:.4f}, gp : {:.4f} ' 
            .format(ep, epochs, i+1, len(train_loader), loss_critic.item(),  fake_critic.mean().item(), 
                    real_critic.mean().item(), fake_critic.mean().item(),loss_gp.item()))
      G_losses.append(fake_critic.mean().item())
      critic_losses.append(loss_critic.item())

      w_losses.append(loss_w.item())
      gp_losses.append(loss_gp.item())

      with torch.no_grad():
        G.eval()
        fake = G(fixed_noise).detach().cpu()
        img_list.append(vutils.make_grid(torch.reshape(fake,(b_size,c_dim,in_h,in_w))[:64], padding=2, normalize=True))
        transform_PIL(img_list[-1]).save(os.path.join(log_PATH,str(ep)+modelName+"_Last.png"))
        G.train()

      log_list_save(G_losses,os.path.join(log_PATH,"G_losses"))
      log_list_save(critic_losses,os.path.join(log_PATH,"critic_losses"))
      log_list_save(w_losses,os.path.join(log_PATH,"w_losses"))
      log_list_save(gp_losses,os.path.join(log_PATH,"gp_losses"))

      torch.save(G.state_dict(),os.path.join(log_PATH,"G"+modelName+".pth"))
      torch.save(fw.state_dict(),os.path.join(log_PATH,"fw"+modelName+".pth"))

In [0]:
with torch.no_grad():
  G.eval()
  fake_batch=G(fixed_noise)
  G.train()
plt.figure(figsize=(8,8))
plt.axis('off')
plt.title('fake Images')
plt.imshow(np.transpose(vutils.make_grid(fake_batch.to(device)[:64].view(64,c_dim,in_h,in_w),padding=2, normalize=True).cpu(),(1,2,0)))

In [0]:
real_batch = next(iter(train_loader))

# Plot the real images
fig=plt.figure(figsize=(15,8))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))

fig.savefig(os.path.join(log_PATH,"Compare"))


In [0]:
import matplotlib.animation as animation
from IPython.display import HTML
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in short_img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [0]:
from PIL import Image
transform_PIL=transforms.ToPILImage()

p_img_list = [transform_PIL(p_image) for p_image in short_img_list]
p_img_list[0].save(os.path.join(log_PATH,modelName+'s.gif'), save_all=True,append_images=p_img_list[1:], optimize=False, duration=0.5, loop=0)
p_img_list[-1].save(os.path.join(log_PATH,modelName+"_last_result.png"))

In [0]:
with torch.no_grad():
  G.eval()
  random_noise=torch.randn(batch_size,latent_size,device=device)
  fake=G(random_noise)
  G.train()
fake = fake.squeeze().cpu()
fake_image=transform_inverse(fake[0])
plt.axis("off")
plt.imshow(np.transpose(fake_image,(1,2,0)))

In [0]:
plt.title("Losses")
epsilon = 1/plot_per_eps
X = np.array(range(plot_per_eps*50))/plot_per_eps

plt.plot(X,G_losses,label="G loss")
plt.plot(X,i_critic_losses,label="critic loss")
plt.legend(loc=2)
plt.xticks(np.arange(0,50+1,5)) 
plt.ylabel("loss")
plt.xlabel("Epochs")
#plt.show()
# plt.savefig(os.path.join(modelName+"_loss_figure.png"))

In [0]:
plt.title("Losses")
epsilon = 1/plot_per_eps
X = np.array(range(plot_per_eps*50))/plot_per_eps

plt.plot(X,i_critic_losses,label="critic loss")
plt.plot(X,gp_losses,label="gp loss")
plt.legend(loc=2)
plt.xticks(np.arange(0,50+1,5)) 
plt.ylabel("loss")
plt.xlabel("Epochs")
#plt.show()
plt.savefig(os.path.join(modelName+"_loss_figure_gp,critic.png"))

In [0]:
plt.title("Losses")
epsilon = 1/plot_per_eps
X = np.array(range(plot_per_eps*50))/plot_per_eps

plt.plot(X,G_losses,label="G loss")
plt.legend(loc=2)
plt.xticks(np.arange(0,50+1,5)) 
plt.ylabel("loss")
plt.xlabel("Epochs")
#plt.show()
plt.savefig(os.path.join(modelName+"G_loss_figure.png"))

In [0]:
plt.title("Losses")
epsilon = 1/plot_per_eps
X = np.array(range(plot_per_eps*50))/plot_per_eps
plt.plot(X,i_critic_losses,label="critic loss")

plt.legend(loc=2)
plt.xticks(np.arange(0,50+1,5)) 
plt.ylabel("loss")
plt.xlabel("Epochs")
#plt.show()
plt.savefig(os.path.join(modelName+"critic_loss_figure.png"))

In [0]:
plt.title("Losses")
epsilon = 1/plot_per_eps
X = np.array(range(plot_per_eps*50))/plot_per_eps
plt.plot(X,gp_losses,label="critic loss")

plt.legend(loc=2)
plt.xticks(np.arange(0,50+1,5)) 
plt.ylabel("loss")
plt.xlabel("Epochs")
#plt.show()
plt.savefig(os.path.join(modelName+"gp_loss_figure.png"))

In [0]:
torch.save(G.state_dict(),os.path.join(log_PATH,"./G_"+modelName+".pth"))
G.load_state_dict(torch.load(os.path.join(log_PATH,("G_"+modelName+".pth"))))

In [0]:
torch.save(fw.state_dict(),os.path.join(log_PATH,"./critic_"+modelName+".pth"))
fw.load_state_dict(torch.load(os.path.join(log_PATH,("critic_"+modelName+".pth"))))