In [1]:
from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import make_grid, save_image

from tensorboardX import SummaryWriter
from tqdm import tqdm
from copy import deepcopy

from utils import *
from models import *
from fid_score import *
from inception_score import *

!mkdir checkpoint
!mkdir generated_imgs
!pip install tensorboardX
!mkdir fid_stat
%cd fid_stat
%cd ..

mkdir: cannot create directory ‘checkpoint’: File exists
mkdir: cannot create directory ‘generated_imgs’: File exists
mkdir: cannot create directory ‘fid_stat’: File exists
/home/jishnu/Projects/FB/transgan/fid_stat
/home/jishnu/Projects/FB/transgan


In [2]:
# training hyperparameters given by code author

lr_gen = 0.0001 #Learning rate for generator
lr_dis = 0.0001 #Learning rate for discriminator
latent_dim = 1024 #Latent dimension
gener_batch_size = 32 #Batch size for generator
dis_batch_size = 32 #Batch size for discriminator
epoch = 10 #Number of epoch
weight_decay = 1e-3 #Weight decay
drop_rate = 0.5 #dropout
n_critic = 5 #
max_iter = 500000
img_name = "img_name"
lr_decay = True

# architecture details by authors
image_size = 32 #H,W size of image for discriminator
initial_size = 8 #Initial size for generator
patch_size = 4 #Patch size for generated image
num_classes = 1 #Number of classes for discriminator 
output_dir = 'checkpoint' #saved model path
dim = 384 #Embedding dimension 
optimizer = 'Adam' #Optimizer
loss = "wgangp_eps" #Loss function
phi = 1 #
beta1 = 0 #
beta2 = 0.99 #
diff_aug = "translation,cutout,color" #data augmentation

In [3]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"

device = torch.device(dev)

with torch.no_grad():
    generator= Generator(depth1=5, depth2=4, depth3=2, initial_size=8, dim=384, heads=4, mlp_ratio=4, drop_rate=0.5)#,device = device)
    generator.to(device)

    generator.load_state_dict(torch.load('../transgan_models/cifar10_mixed/0/checkpoint.pth')['generator_state_dict'])

In [4]:
from net.densenet import densenet121
from net.resnet import resnet50, resnet110
from net.wide_resnet import wrn_28_10
from net.inception import inceptionv3
from net.vgg import vgg16
from utils.ensemble_utils import ensemble_forward_pass

ensemble = []

with torch.no_grad():
    densenet_model = densenet121().cuda()
    resnet50_model = resnet50().cuda()
    resnet110_model = resnet110().cuda()
    wide_resnet_model = wrn_28_10().cuda()
    inception_v3_model = inceptionv3().cuda()
    vgg16_model = vgg16().cuda()

    densenet_model.load_state_dict(torch.load('../ood_ensemble/densenet121/densenet121_1_350.model'))
    resnet50_model.load_state_dict(torch.load('../ood_ensemble/resnet50/resnet50_1_350.model'))
    resnet110_model.load_state_dict(torch.load('../ood_ensemble/resnet110/resnet110_1_350.model'))
    wide_resnet_model.load_state_dict(torch.load('../ood_ensemble/wide_resnet/wide_resnet_1_350.model'))
    inception_v3_model.load_state_dict(torch.load('../ood_ensemble/inception_v3/inception_v3_1_350.model'))
    vgg16_model.load_state_dict(torch.load('../ood_ensemble/vgg16/vgg16_1_350.model'))

    ensemble = [densenet_model,
                resnet50_model,
                resnet110_model,
                wide_resnet_model,
                inception_v3_model,
                vgg16_model]

In [5]:
sel_images = []
generator.eval()
total_images = 0
with torch.no_grad():
    while(True):
        noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (50, latent_dim)))#noise(img, latent_dim)#= args.latent_dim)
        gen_images = generator(noise)
        _, _, mut_info = ensemble_forward_pass(ensemble, gen_images)
        sel_gen_images = gen_images[mut_info < 0.1]
        total_images += sel_gen_images.shape[0]
        sel_images.append(sel_gen_images)
        print (total_images)
        if (total_images >= 5000):
            break

10
15
23
32
37
43
47
53
58
62
72
81
86
91
101
109
113
123
128
134
140
146
153
157
164
167
177
182
192
196
204
215
221
227
232
238
246
256
266
273
276
282
287
292
296
303
310
320
325
331
337
340
348
353
357
362
367
373
376
383
390
395
400
404
407
414
417
424
428
431
436
442
445
454
460
465
469
476
484
488
491
496
499
507
514
521
527
533
540
546
557
566
571
579
588
596
598
603
607
611
617
622
625
636
641
647
651
658
666
673
679
687
693
700
706
715
721
728
734
742
747
757
764
770
774
781
784
788
799
805
815
821
827
836
844
849
851
858
864
873
878
883
892
899
906
911
915
922
926
931
938
944
950
953
963
967
971
977
979
983
987
988
994
1000
1004
1013
1021
1029
1035
1040
1047
1052
1060
1067
1076
1081
1087
1090
1098
1100
1108
1113
1116
1126
1130
1137
1144
1156
1160
1169
1176
1184
1192
1198
1200
1208
1218
1224
1231
1235
1244
1249
1259
1267
1273
1283
1294
1299
1308
1317
1324
1330
1333
1336
1345
1349
1354
1359
1364
1367
1373
1383
1387
1392
1397
1399
1403
1408
1413
1423
1429
1436
1441
1444
1452
14

In [6]:
sel_images = torch.cat(sel_images, dim=0)
print (sel_images.shape)

torch.Size([5004, 3, 32, 32])


In [7]:
torch.save(sel_images.cpu(), './sel_images_1.pt')

In [8]:
print (sel_images.min())
print (sel_images.max())

tensor(-2.1381, device='cuda:0')
tensor(2.2282, device='cuda:0')


In [9]:
sel_images_vis = sel_images[:200]
save_image(sel_images_vis, f'generated_images/sel_gen_images_0_1.jpg', nrow=10, normalize=True, scale_each=True)

In [6]:
with torch.no_grad():
    gen_images = generator(noise)
    print (gen_images.shape)

torch.Size([100, 3, 32, 32])


In [None]:
with torch.no_grad():
    mean_pred, pred_entropy, mut_info = ensemble_forward_pass(ensemble, gen_images)

In [7]:
print (gen_images.min())
print (gen_images.max())

tensor(-1.3904, device='cuda:0')
tensor(1.4602, device='cuda:0')


In [8]:
save_image(gen_images, f'generated_images/generated_img_1.jpg', nrow=5, normalize=True, scale_each=True)

In [8]:
from net.densenet import densenet121
from net.resnet import resnet50, resnet110
from net.wide_resnet import wrn_28_10
from net.inception import inceptionv3
from net.vgg import vgg16

In [9]:
ensemble = []

with torch.no_grad():
    densenet_model = densenet121().cuda()
    resnet50_model = resnet50().cuda()
    resnet110_model = resnet110().cuda()
    wide_resnet_model = wrn_28_10().cuda()
    inception_v3_model = inceptionv3().cuda()
    vgg16_model = vgg16().cuda()

    densenet_model.load_state_dict(torch.load('../ood_ensemble/densenet121/densenet121_1_350.model'))
    resnet50_model.load_state_dict(torch.load('../ood_ensemble/resnet50/resnet50_1_350.model'))
    resnet110_model.load_state_dict(torch.load('../ood_ensemble/resnet110/resnet110_1_350.model'))
    wide_resnet_model.load_state_dict(torch.load('../ood_ensemble/wide_resnet/wide_resnet_1_350.model'))
    inception_v3_model.load_state_dict(torch.load('../ood_ensemble/inception_v3/inception_v3_1_350.model'))
    vgg16_model.load_state_dict(torch.load('../ood_ensemble/vgg16/vgg16_1_350.model'))

    ensemble = [densenet_model,
                resnet50_model,
                resnet110_model,
                wide_resnet_model,
                inception_v3_model,
                vgg16_model]

In [10]:
from utils.ensemble_utils import ensemble_forward_pass

with torch.no_grad():
    mean_pred, pred_entropy, mut_info = ensemble_forward_pass(ensemble, gen_images)

In [14]:
print (mut_info > 0.9)

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True,  True, False, False, False, False, False,
        False, False,  True, False, False], device='cuda:0')
