# Examples 

In [1]:
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
import torchvision.models
import torchvision.transforms as transforms
from options.train_options import TrainOptions
import extract_gan as E
import sys; sys.argv=['']; del sys

In [2]:
# Options, basically copied from CycleGAN, thus very verbose. Only a small subset is useful.
opt = TrainOptions().parse()

----------------- Options ---------------
               batch_size: 1                             
                    beta1: 0.5                           
          checkpoints_dir: ./checkpoints                 
           continue_train: False                         
             dataset_mode: unaligned                     
                direction: AtoB                          
              display_env: main                          
             display_freq: 400                           
               display_id: 1                             
            display_ncols: 4                             
             display_port: 8097                          
           display_server: http://localhost              
          display_winsize: 256                           
                    epoch: latest                        
              epoch_count: 1                             
                 fineSize: 256                           
                  gpu_ids: -1 

In [3]:
# Build a model
model = E.ExtractGANModel(opt)

VGG parameters loaded.
Encoder weights initialized using xavier.
Decoder weights initialized using xavier.
StyleExtractor weights initialized using xavier.
StyleWhitener weights initialized using xavier.
Generator build success!
Discriminator weights initialized using xavier.
New ExtractGAN model initialized!


In [4]:
# Save parameters in ./checkpoints/
model.save_networks('test')

In [5]:
# It uses only one pretrained vgg in both G and D 
model.D.vgg16 is model.G.style_extractor.vgg16

True

In [6]:
# Generate some random inputs
a = torch.randn(2,3,224,224)
b = torch.randn(2,3,224,224)

In [7]:
# G works well
model.G(a,b)

tensor([[[[0.4510, 0.4510, 0.4510,  ..., 0.4510, 0.4510, 0.4510],
          [0.4510, 0.4218, 0.3493,  ..., 0.4669, 0.5434, 0.4510],
          [0.4510, 0.4917, 0.2532,  ..., 0.2827, 0.4728, 0.4510],
          ...,
          [0.4510, 0.6842, 0.5601,  ..., 0.2391, 0.3749, 0.4510],
          [0.4510, 0.6204, 0.6330,  ..., 0.4565, 0.4432, 0.4510],
          [0.4510, 0.4510, 0.4510,  ..., 0.4510, 0.4510, 0.4510]],

         [[0.5186, 0.5186, 0.5186,  ..., 0.5186, 0.5186, 0.5186],
          [0.5186, 0.4924, 0.6310,  ..., 0.7007, 0.5704, 0.5186],
          [0.5186, 0.5589, 0.7337,  ..., 0.6737, 0.5053, 0.5186],
          ...,
          [0.5186, 0.6174, 0.6588,  ..., 0.4037, 0.3831, 0.5186],
          [0.5186, 0.5526, 0.5579,  ..., 0.4520, 0.4467, 0.5186],
          [0.5186, 0.5186, 0.5186,  ..., 0.5186, 0.5186, 0.5186]],

         [[0.4046, 0.4046, 0.4046,  ..., 0.4046, 0.4046, 0.4046],
          [0.4046, 0.3830, 0.2957,  ..., 0.3928, 0.4889, 0.4046],
          [0.4046, 0.3415, 0.2427,  ..., 0

In [8]:
# D works well too.
model.D(a,b)

tensor([[0.5000],
        [0.5000]], grad_fn=<SigmoidBackward>)

In [9]:
print(model.D)

Discriminator(
  (vgg16): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size

In [10]:
print(model.G)

Generator(
  (encoder): Encoder(
    (model): Sequential(
      (0): ZeroPad2d(padding=(1, 1, 1, 1), value=0)
      (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
      (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (3): ReLU(inplace)
      (4): ZeroPad2d(padding=(1, 1, 1, 1), value=0)
      (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (6): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (7): ReLU(inplace)
    )
  )
  (decoder): Decoder(
    (model): Sequential(
      (0): ZeroPad2d(padding=(1, 1, 1, 1), value=0)
      (1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (3): ReLU()
      (4): ZeroPad2d(padding=(1, 1, 1, 1), value=0)
      (5): ConvTranspose2d(64, 3, kernel_size=(3, 3), stride=(1, 1))
      (6): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine