In [1]:
from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import make_grid, save_image

from tensorboardX import SummaryWriter
from tqdm import tqdm
from copy import deepcopy

from utils import *
from models import *
from fid_score import *
from inception_score import *

!mkdir checkpoint
!mkdir generated_imgs
!pip install tensorboardX
!mkdir fid_stat
%cd fid_stat
%cd ..

mkdir: cannot create directory ‘checkpoint’: File exists
mkdir: cannot create directory ‘generated_imgs’: File exists
mkdir: cannot create directory ‘fid_stat’: File exists
/home/jishnu/Projects/FB/transgan/fid_stat
/home/jishnu/Projects/FB/transgan


In [2]:
# training hyperparameters given by code author

lr_gen = 0.0001 #Learning rate for generator
lr_dis = 0.0001 #Learning rate for discriminator
latent_dim = 1024 #Latent dimension
gener_batch_size = 32 #Batch size for generator
dis_batch_size = 32 #Batch size for discriminator
epoch = 10 #Number of epoch
weight_decay = 1e-3 #Weight decay
drop_rate = 0.5 #dropout
n_critic = 5 #
max_iter = 500000
img_name = "img_name"
lr_decay = True

# architecture details by authors
image_size = 32 #H,W size of image for discriminator
initial_size = 8 #Initial size for generator
patch_size = 4 #Patch size for generated image
num_classes = 1 #Number of classes for discriminator 
output_dir = 'checkpoint' #saved model path
dim = 384 #Embedding dimension 
optimizer = 'Adam' #Optimizer
loss = "wgangp_eps" #Loss function
phi = 1 #
beta1 = 0 #
beta2 = 0.99 #
diff_aug = "translation,cutout,color" #data augmentation

In [3]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"

device = torch.device(dev)

with torch.no_grad():
    generator= Generator(depth1=5, depth2=4, depth3=2, initial_size=8, dim=384, heads=4, mlp_ratio=4, drop_rate=0.5)#,device = device)
    generator.to(device)

    generator.load_state_dict(torch.load('../transgan_models/cifar10_mixed/0/checkpoint.pth')['generator_state_dict'])

In [4]:
from net.densenet import densenet121
from net.resnet import resnet50, resnet110
from net.wide_resnet import wrn_28_10
from net.inception import inceptionv3
from net.vgg import vgg16
from utils.ensemble_utils import ensemble_forward_pass

ensemble = []

with torch.no_grad():
    densenet_model = densenet121().cuda()
    resnet50_model = resnet50().cuda()
    resnet110_model = resnet110().cuda()
    wide_resnet_model = wrn_28_10().cuda()
    inception_v3_model = inceptionv3().cuda()
    vgg16_model = vgg16().cuda()

    densenet_model.load_state_dict(torch.load('../ood_ensemble/densenet121/densenet121_1_350.model'))
    resnet50_model.load_state_dict(torch.load('../ood_ensemble/resnet50/resnet50_1_350.model'))
    resnet110_model.load_state_dict(torch.load('../ood_ensemble/resnet110/resnet110_1_350.model'))
    wide_resnet_model.load_state_dict(torch.load('../ood_ensemble/wide_resnet/wide_resnet_1_350.model'))
    inception_v3_model.load_state_dict(torch.load('../ood_ensemble/inception_v3/inception_v3_1_350.model'))
    vgg16_model.load_state_dict(torch.load('../ood_ensemble/vgg16/vgg16_1_350.model'))

    ensemble = [densenet_model,
                resnet50_model,
                resnet110_model,
                wide_resnet_model,
                inception_v3_model,
                vgg16_model]

In [10]:
sel_images = []
generator.eval()
total_images = 0
with torch.no_grad():
    while(True):
        noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (50, latent_dim)))#noise(img, latent_dim)#= args.latent_dim)
        gen_images = generator(noise)
        _, _, mut_info = ensemble_forward_pass(ensemble, gen_images)
        sel_gen_images = gen_images[mut_info < 0.1]
        total_images += sel_gen_images.shape[0]
        sel_images.append(sel_gen_images)
        print (total_images)
        if (total_images >= 5000):
            break

3
10
12
18
26
30
34
40
46
48
57
65
72
79
83
91
96
98
106
109
112
119
122
128
136
141
145
149
155
158
163
167
169
174
180
185
192
196
201
212
218
220
229
237
246
254
257
266
269
278
283
288
292
296
304
309
315
321
323
328
336
339
344
352
356
360
362
372
379
385
390
399
405
411
419
426
429
431
437
447
454
459
462
467
472
481
487
490
497
501
506
513
526
531
533
540
544
551
554
560
565
572
575
582
588
592
597
603
609
610
615
622
631
634
641
648
655
659
666
672
677
683
687
694
699
702
708
717
727
735
741
748
751
756
759
765
771
779
788
796
803
806
810
815
823
825
832
837
843
845
852
859
862
866
877
881
884
888
896
899
908
912
920
929
939
948
958
964
971
976
984
989
995
999
1006
1014
1020
1023
1031
1040
1044
1051
1055
1058
1065
1067
1070
1078
1088
1094
1101
1109
1116
1121
1128
1135
1142
1145
1151
1161
1167
1169
1175
1182
1190
1192
1199
1208
1213
1222
1225
1229
1234
1238
1245
1253
1258
1262
1265
1271
1279
1289
1293
1305
1312
1318
1322
1328
1330
1335
1342
1347
1354
1360
1365
1375
1378
1384
138

In [11]:
sel_images = torch.cat(sel_images, dim=0)
print (sel_images.shape)

torch.Size([5003, 3, 32, 32])


In [12]:
torch.save(sel_images.cpu(), './sel_images_2.pt')

In [13]:
print (sel_images.min())
print (sel_images.max())

tensor(-1.9734, device='cuda:0')
tensor(2.2485, device='cuda:0')


In [14]:
sel_images_vis = sel_images[:200]
save_image(sel_images_vis, f'generated_images/sel_gen_images_0_2.jpg', nrow=10, normalize=True, scale_each=True)

In [6]:
with torch.no_grad():
    gen_images = generator(noise)
    print (gen_images.shape)

torch.Size([100, 3, 32, 32])


In [None]:
with torch.no_grad():
    mean_pred, pred_entropy, mut_info = ensemble_forward_pass(ensemble, gen_images)

In [7]:
print (gen_images.min())
print (gen_images.max())

tensor(-1.3904, device='cuda:0')
tensor(1.4602, device='cuda:0')


In [8]:
save_image(gen_images, f'generated_images/generated_img_1.jpg', nrow=5, normalize=True, scale_each=True)

In [8]:
from net.densenet import densenet121
from net.resnet import resnet50, resnet110
from net.wide_resnet import wrn_28_10
from net.inception import inceptionv3
from net.vgg import vgg16

In [9]:
ensemble = []

with torch.no_grad():
    densenet_model = densenet121().cuda()
    resnet50_model = resnet50().cuda()
    resnet110_model = resnet110().cuda()
    wide_resnet_model = wrn_28_10().cuda()
    inception_v3_model = inceptionv3().cuda()
    vgg16_model = vgg16().cuda()

    densenet_model.load_state_dict(torch.load('../ood_ensemble/densenet121/densenet121_1_350.model'))
    resnet50_model.load_state_dict(torch.load('../ood_ensemble/resnet50/resnet50_1_350.model'))
    resnet110_model.load_state_dict(torch.load('../ood_ensemble/resnet110/resnet110_1_350.model'))
    wide_resnet_model.load_state_dict(torch.load('../ood_ensemble/wide_resnet/wide_resnet_1_350.model'))
    inception_v3_model.load_state_dict(torch.load('../ood_ensemble/inception_v3/inception_v3_1_350.model'))
    vgg16_model.load_state_dict(torch.load('../ood_ensemble/vgg16/vgg16_1_350.model'))

    ensemble = [densenet_model,
                resnet50_model,
                resnet110_model,
                wide_resnet_model,
                inception_v3_model,
                vgg16_model]

In [10]:
from utils.ensemble_utils import ensemble_forward_pass

with torch.no_grad():
    mean_pred, pred_entropy, mut_info = ensemble_forward_pass(ensemble, gen_images)

In [14]:
print (mut_info > 0.9)

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True,  True, False, False, False, False, False,
        False, False,  True, False, False], device='cuda:0')
