In [18]:
import csv
import os
from collections import defaultdict
from itertools import islice

import torch
from torchvision import models, transforms, datasets

In [3]:
dataset_counts = defaultdict(int)
img_counts = defaultdict(lambda : defaultdict(int))
unique_imgs = defaultdict(set)
img_votes = defaultdict(lambda: defaultdict(lambda: [0,0])) #no,yes tuples by dataset and image

with open('../all_gans_inf.csv') as f:
    reader = csv.DictReader(f)
    for row in islice(reader, None):
        img = row['img']        
#         dataset_name = img[:img.index('/')]
        split_name = img.split('/')
        assert(len(split_name) == 2)
        dataset_name, img_name = split_name
        
        vote_index = 1 if row['correctness'] == row['realness'] else 0
        img_votes[dataset_name][img_name][vote_index] += 1
        
        dataset_counts[dataset_name] += 1
        img_counts[dataset_name][img_name] += 1
unique_imgs = {dataset: len(img_counts[dataset]) for dataset in img_counts}

In [39]:
began_labeled = img_votes['began5000'].keys()
len(began_labeled)

2397

In [36]:
celeba_labeled = img_votes['styleganceleba5000'].keys()
celeba_all = os.listdir('../../64celeba_stylegan/')
print(len(celeba_labeled), len(celeba_all))
set(celeba_labeled).issubset(set(celeba_all))

3103 5000


True

In [38]:
began_all = os.listdir('imgs_by_label/began5000/')
len(began_all)

5000

In [37]:
set(began_labeled).difference(set(began_all))

set()

In [40]:
for f in celeba_labeled:
    os.system('cp ../../64celeba_stylegan/{} imgs_by_label/stylegan_labeled/stylegan_labeled_imgs/{}'.format(f,f))

In [40]:
began_include = set(began_labeled).intersection(set(began_all))
len(began_include)

2397

In [48]:
for f in began_include:
    os.system('cp ../../began_files/{} imgs_by_label/began_labeled/began_labeled_imgs/{}'.format(f,f))

In [47]:
%ls imgs_by_label/began_labeled/began_labeled_imgs

In [9]:
wgan_labeled = set(img_votes['wgangp5000'].keys())
len(wgan_labeled)

4251

In [10]:
wgan_all = set(os.listdir('../wgangp5000/'))
len(wgan_all)

6637

In [12]:
len(wgan_labeled.intersection(wgan_all))

4251

In [13]:
for f in wgan_labeled:
    os.system('cp ../wgangp5000/{} imgs_by_label/wgan_labeled/wgan_labeled_imgs/{}'.format(f,f))

In [14]:
wgan_labeled

{'wgan_gp_25849.png',
 'wgan_gp_23941.png',
 'wgan_gp_25400.png',
 'wgan_gp_4454.png',
 'wgan_gp_27329.png',
 'wgan_gp_37675.png',
 'wgan_gp_600.png',
 'wgan_gp_30287.png',
 'wgan_gp_16550.png',
 'wgan_gp_41388.png',
 'wgan_gp_4814.png',
 'wgan_gp_47476.png',
 'wgan_gp_2679.png',
 'wgan_gp_49641.png',
 'wgan_gp_26022.png',
 'wgan_gp_7651.png',
 'wgan_gp_31423.png',
 'wgan_gp_3880.png',
 'wgan_gp_10634.png',
 'wgan_gp_27864.png',
 'wgan_gp_16525.png',
 'wgan_gp_8221.png',
 'wgan_gp_3622.png',
 'wgan_gp_26871.png',
 'wgan_gp_7574.png',
 'wgan_gp_38541.png',
 'wgan_gp_26693.png',
 'wgan_gp_40330.png',
 'wgan_gp_38353.png',
 'wgan_gp_47815.png',
 'wgan_gp_16917.png',
 'wgan_gp_3790.png',
 'wgan_gp_10142.png',
 'wgan_gp_24527.png',
 'wgan_gp_46355.png',
 'wgan_gp_23459.png',
 'wgan_gp_48661.png',
 'wgan_gp_12759.png',
 'wgan_gp_44085.png',
 'wgan_gp_29323.png',
 'wgan_gp_8213.png',
 'wgan_gp_25872.png',
 'wgan_gp_1475.png',
 'wgan_gp_2895.png',
 'wgan_gp_39638.png',
 'wgan_gp_35313.png',
 '

In [19]:
labeled_wgan = datasets.ImageFolder('imgs_by_label/wgan_labeled/')
print(labeled_wgan)
labeled_wgan_loader = torch.utils.data.DataLoader(
        labeled_wgan, batch_size=1, shuffle=False, num_workers=1)

Dataset ImageFolder
    Number of datapoints: 4251
    Root location: imgs_by_label/wgan_labeled/


In [21]:
wgan_filenames = []
for i in range(len(labeled_wgan)):
    filename, _ = labeled_wgan_loader.dataset.samples[i]
    wgan_filenames.append(filename)

In [23]:
len(wgan_filenames)

4251

In [24]:
#split into train, val, test

#Let's split such that train set is 0 through 7 mod 10, val is 8 mod 10, test is 9 mod 10
N = len(wgan_filenames)
train_indices = [i for i in range(N) if i % 10 in range(8)]
val_indices =   [i for i in range(N) if i % 10 == 8]
test_indices =  [i for i in range(N) if i % 10 == 9]

ltrain = len(train_indices)
lval = len(val_indices)
ltest = len(test_indices)
print(ltrain, lval, ltest, ltrain+lval+ltest, N)

train_files = [wgan_filenames[i] for i in train_indices]
val_files = [wgan_filenames[i] for i in val_indices]
test_files = [wgan_filenames[i] for i in test_indices]


3401 425 425 4251 4251


In [26]:
len(train_files), len(val_files), len(test_files)

(3401, 425, 425)

In [27]:
with open('wgan_train_set.txt', 'w') as f:
    for x in train_files:
        f.write(x + '\n')
        
with open('wgan_val_set.txt', 'w') as f:
    for x in val_files:
        f.write(x + '\n')
        
with open('wgan_test_set.txt', 'w') as f:
    for x in test_files:
        f.write(x + '\n')

In [48]:
with open('began_train_set.txt') as f:
    began_files = [os.path.split(x.strip())[-1] for x in f.readlines()]
        
with open('began_val_set.txt') as f:
    began_files += [os.path.split(x.strip())[-1] for x in f.readlines()]
    
with open('began_test_set.txt') as f:
    began_files += [os.path.split(x.strip())[-1] for x in f.readlines()]

len(began_files)

1966

In [49]:
for f in began_files:
    os.system('cp {} {}'.format('./imgs_by_label/began5000/' + f, 'imgs_by_label/began_labeled/began_labeled_imgs/'))