In [1]:
from transformers import CLIPProcessor, CLIPModel
import torch
import torchvision
from torchvision.models import resnet50
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import clip
from PIL import Image
import requests
import torch.hub
import time
import pickle
import math
import torch.nn as nn

from utils import matching, stats, proggan, nethook, zdataset

In [2]:
# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()

#         self.init_size = 32 // 4
#         self.l1 = nn.Sequential(nn.Linear(100, 128 * self.init_size ** 2))

#         self.conv_blocks = nn.Sequential(
#           nn.BatchNorm2d(128),
#           nn.Upsample(scale_factor=2),
#           nn.Conv2d(128, 128, 3, stride=1, padding=1),
#           nn.BatchNorm2d(128, 0.8),
#           nn.LeakyReLU(0.2, inplace=True),
#           nn.Upsample(scale_factor=2),
#           nn.Conv2d(128, 64, 3, stride=1, padding=1),
#           nn.BatchNorm2d(64, 0.8),
#           nn.LeakyReLU(0.2, inplace=True),
#           nn.Conv2d(64, 1, 3, stride=1, padding=1),
#           nn.Tanh(),
#         )

#     def forward(self, z):
#         out = self.l1(z)
#         out = out.view(out.shape[0], 128, self.init_size, self.init_size)
#         img = self.conv_blocks(out)
#         return img

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = 32 // 4
        self.l1 = nn.Sequential(nn.Linear(100, 128 * self.init_size ** 2))

        self.conv_block1 = nn.Sequential(
          nn.BatchNorm2d(128),
          nn.Upsample(scale_factor=2),
          nn.Conv2d(128, 128, 3, stride=1, padding=1),
          nn.BatchNorm2d(128, 0.8),
          nn.LeakyReLU(0.2, inplace=True))
        self.conv_block2 = nn.Sequential(
          nn.Upsample(scale_factor=2),
          nn.Conv2d(128, 64, 3, stride=1, padding=1),
          nn.BatchNorm2d(64, 0.8),
          nn.LeakyReLU(0.2, inplace=True),
        )
        self.out = nn.Sequential(nn.Conv2d(64, 1, 3, stride=1, padding=1),
          nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        out = self.conv_block1(out)
        out = self.conv_block2(out)
        img = self.out(out)
        return img

In [4]:
device = torch.device('cuda:0')

In [5]:
gan1 = Generator().to(device)
ckpt = torch.load("/home/amildravid/CS496/weights/seed_0/DCGAN_EPOCH_199.pt", map_location = device)



In [6]:
sd = {}
sd['l1.0.weight']=ckpt['l1.0.weight']
sd['l1.0.bias']=ckpt['l1.0.bias']
sd['conv_block1.0.weight'] = ckpt['conv_blocks.0.weight']
sd['conv_block1.0.bias'] = ckpt['conv_blocks.0.bias']
sd['conv_block1.0.running_mean'] = ckpt['conv_blocks.0.running_mean']
sd['conv_block1.0.running_var'] = ckpt['conv_blocks.0.running_var']
sd['conv_block1.0.num_batches_tracked'] = ckpt['conv_blocks.0.num_batches_tracked']
sd['conv_block1.2.weight'] = ckpt['conv_blocks.2.weight']
sd['conv_block1.2.bias'] = ckpt['conv_blocks.2.bias']
sd['conv_block1.3.weight'] = ckpt['conv_blocks.3.weight']
sd['conv_block1.3.bias'] = ckpt['conv_blocks.3.bias']
sd['conv_block1.3.running_mean'] = ckpt['conv_blocks.3.running_mean']
sd['conv_block1.3.running_var'] = ckpt['conv_blocks.3.running_var']
sd['conv_block1.3.num_batches_tracked'] = ckpt['conv_blocks.3.num_batches_tracked']
sd['conv_block2.1.weight'] = ckpt['conv_blocks.6.weight']
sd['conv_block2.1.bias'] = ckpt['conv_blocks.6.bias']
sd['conv_block2.2.weight'] = ckpt['conv_blocks.7.weight']
sd['conv_block2.2.bias'] = ckpt['conv_blocks.7.bias']
sd['conv_block2.2.running_mean'] = ckpt['conv_blocks.7.running_mean']
sd['conv_block2.2.running_var'] = ckpt['conv_blocks.7.running_var']
sd['conv_block2.2.num_batches_tracked'] = ckpt['conv_blocks.7.num_batches_tracked']
sd['out.0.weight'] = ckpt['conv_blocks.9.weight']
sd['out.0.bias'] = ckpt['conv_blocks.9.bias']

In [7]:
gan1.load_state_dict(sd)
gan1_layers = ['conv_block1', 'conv_block2', 'out']
gan1.eval()

Generator(
  (l1): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
  )
  (conv_block1): Sequential(
    (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2.0, mode=nearest)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (conv_block2): Sequential(
    (0): Upsample(scale_factor=2.0, mode=nearest)
    (1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (out): Sequential(
    (0): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Tanh()
  )
)

In [8]:
sum(p.numel() for p in gan1.parameters() if p.requires_grad)

1049985

In [9]:
gan2 = Generator().to(device)
ckpt = torch.load("/home/amildravid/CS496/weights/seed_0/DCGAN_EPOCH_199.pt", map_location = device)

In [10]:
sd = {}
sd['l1.0.weight']=ckpt['l1.0.weight']
sd['l1.0.bias']=ckpt['l1.0.bias']
sd['conv_block1.0.weight'] = ckpt['conv_blocks.0.weight']
sd['conv_block1.0.bias'] = ckpt['conv_blocks.0.bias']
sd['conv_block1.0.running_mean'] = ckpt['conv_blocks.0.running_mean']
sd['conv_block1.0.running_var'] = ckpt['conv_blocks.0.running_var']
sd['conv_block1.0.num_batches_tracked'] = ckpt['conv_blocks.0.num_batches_tracked']
sd['conv_block1.2.weight'] = ckpt['conv_blocks.2.weight']
sd['conv_block1.2.bias'] = ckpt['conv_blocks.2.bias']
sd['conv_block1.3.weight'] = ckpt['conv_blocks.3.weight']
sd['conv_block1.3.bias'] = ckpt['conv_blocks.3.bias']
sd['conv_block1.3.running_mean'] = ckpt['conv_blocks.3.running_mean']
sd['conv_block1.3.running_var'] = ckpt['conv_blocks.3.running_var']
sd['conv_block1.3.num_batches_tracked'] = ckpt['conv_blocks.3.num_batches_tracked']
sd['conv_block2.1.weight'] = ckpt['conv_blocks.6.weight']
sd['conv_block2.1.bias'] = ckpt['conv_blocks.6.bias']
sd['conv_block2.2.weight'] = ckpt['conv_blocks.7.weight']
sd['conv_block2.2.bias'] = ckpt['conv_blocks.7.bias']
sd['conv_block2.2.running_mean'] = ckpt['conv_blocks.7.running_mean']
sd['conv_block2.2.running_var'] = ckpt['conv_blocks.7.running_var']
sd['conv_block2.2.num_batches_tracked'] = ckpt['conv_blocks.7.num_batches_tracked']
sd['out.0.weight'] = ckpt['conv_blocks.9.weight']
sd['out.0.bias'] = ckpt['conv_blocks.9.bias']

In [11]:
gan2.load_state_dict(sd)
gan2_layers = ['conv_block1', 'conv_block2', 'out']
gan2.eval()

Generator(
  (l1): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
  )
  (conv_block1): Sequential(
    (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2.0, mode=nearest)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (conv_block2): Sequential(
    (0): Upsample(scale_factor=2.0, mode=nearest)
    (1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (out): Sequential(
    (0): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Tanh()
  )
)

In [12]:
batch_size = 100
epochs = 100
save_path = "results_1" 

In [13]:
start = time.time()
matching.activ_match_gan(gan1, gan1_layers,
                         gan2, gan2_layers, 
                         epochs,
                         batch_size,
                         save_path,
                         device)
end = time.time()
print(end-start)

Collecting Dataset Statistics
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
Done Iteration for Stats
Done
Starting Activation Matching
Iteration: 0
GAN1 Layer:0
________________________
GAN2 Layer 0
GAN2 Layer 1
GAN2 Layer 2
GAN1 Layer:1
________________________
GAN2 Layer 0
GAN2 Layer 1
GAN2 Layer 2
GAN1 Layer:2
________________________
GAN2 Layer 0
GAN2 Layer 1
GAN2 Layer 2
Iteration: 1
GAN1 Layer:0
________________________
GAN2 Layer 0
GAN2 Layer 1
GAN2 Layer 2
GAN1 Layer:1
________________________
GAN2 Layer 0
GAN2 Layer 1
GAN2 Layer 2
GAN1 Layer:2
________________________
GAN2 Layer 0
GAN2 Layer 1
GAN2 Layer 2
Iteration: 2
GAN1 Layer:0
________________________
GAN2 Layer 0
GAN2 Layer 1
GAN2 Layer 2
GAN1 Layer:1
________________________
GAN2 