In [None]:
import json

import os
import sys
import tempfile
from tqdm.auto import tqdm

import torch
from pytorch_slim_cnn.slimnet import SlimNet
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np

torch.cuda.set_device(0)

In [None]:
class UnsupervisedImageFolder(torchvision.datasets.ImageFolder):
    def __init__(self, root, transform=None, max_size=None, get_path=False):
        self.temp_dir = tempfile.TemporaryDirectory()
        os.symlink(root, os.path.join(self.temp_dir.name, 'dummy'))
        root = self.temp_dir.name
        super().__init__(root, transform=transform)
        self.get_path = get_path
        self.perm = None
        if max_size is not None:
            actual_size = super().__len__()
            if actual_size > max_size:
                self.perm = torch.randperm(actual_size)[:max_size].clone()
                logging.info(f"{root} has {actual_size} images, downsample to {max_size}")
            else:
                logging.info(f"{root} has {actual_size} images <= max_size={max_size}")

    def _find_classes(self, dir):
        return ['./dummy'], {'./dummy': 0}

    def __getitem__(self, key):
        if self.perm is not None:
            key = self.perm[key].item()
        sample = super().__getitem__(key)[0]
        if self.get_path:
            path, _ = self.samples[key]
            return sample, path
        else:
            return sample
            

    def __len__(self):
        if self.perm is not None:
            return self.perm.size(0)
        else:
            return super().__len__()

In [None]:
transform = transforms.Compose([
                              transforms.Resize((178, 218)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])


In [None]:
device = torch.device('cuda')

In [None]:
labels = np.array(['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
       'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
       'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
       'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
       'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
       'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
       'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
       'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
       'Wearing_Necklace', 'Wearing_Necktie', 'Young'])

In [None]:
model = SlimNet.load_pretrained('./pytorch_slim_cnn/models/celeba_20.pth').to(device).eval()


In [None]:
data = {
    'ffhq/mustache/poisson': '/data/vision/torralba/ganprojects/placesgan/tracer/baselines/pyflow/mustaches/poisson',
    'ffhq/mustache/laplace': '/data/vision/torralba/ganprojects/placesgan/tracer/baselines/pyflow/mustaches/laplace',
    'ffhq/mustache/naive': '/data/vision/torralba/ganprojects/placesgan/tracer/baselines/pyflow/mustaches/naive',
    'ffhq/mustache/ours': '/data/vision/torralba/ganprojects/placesgan/tracer/utils/samples/edited/',
    'ffhq/mustache/ours_stdcovariance': '/data/vision/torralba/distillation/gan_rewriting/results/ablations/stylegan-celebhq-mustache-11-1-10001-0.01-ours-100-stdcovariance/images',
    'ffhq/mustache/overfit': '/data/vision/torralba/distillation/gan_rewriting/results/ablations/stylegan-celebhq-mustache-11-1-2001-0.0001-overfit/images',
    'ffhq/mustache/multikey_ours': '/data/vision/torralba/distillation/gan_rewriting/results/ablations/stylegan-celebhq-multikey_mustache-11-1-2001-0.05-ours-10/images',
    'ffhq/mustache/multikey_overfit': '/data/vision/torralba/distillation/gan_rewriting/results/ablations/stylegan-celebhq-multikey_mustache-11-1-2001-0.0001-overfit/images',
    'ffhq/clean': '/data/vision/torralba/ganprojects/placesgan/tracer/utils/samples/clean/',
    'ffhq/smiling/ours': '/data/vision/torralba/ganprojects/placesgan/tracer/utils/samples/edited_smiles/',
    'ffhq/smiling/ours_stdcovariance': '/data/vision/torralba/distillation/gan_rewriting/results/ablations/stylegan-celebhq-smile-10-1-2001-0.05-ours-10-stdcovariance/images',
    'ffhq/smiling/ours_stdcovariance_FIXED': '/data/vision/torralba/distillation/gan_rewriting/results/ablations/stylegan-celebhq-smile-10-1-2001-0.05-ours-10-stdcovariance-sseed/images',
    'ffhq/smiling/poisson': '/data/vision/torralba/ganprojects/placesgan/tracer/baselines/pyflow/smiles/poisson',
    'ffhq/smiling/laplace': '/data/vision/torralba/ganprojects/placesgan/tracer/baselines/pyflow/smiles/laplace',
    'ffhq/smiling/naive': '/data/vision/torralba/ganprojects/placesgan/tracer/baselines/pyflow/smiles/naive',
    'ffhq/smiling/overfit': '/data/vision/torralba/distillation/gan_rewriting/results/ablations/stylegan-celebhq-smile-10-1-2001-0.0001-overfit'
    'ffhq/smiling/overfit_FIXED': '/data/vision/torralba/distillation/gan_rewriting/results/ablations/stylegan-celebhq-smile-10-1-2001-0.0001-overfit-sseed/images'
}

In [None]:
def get_info_path(PATH):
    dataset = UnsupervisedImageFolder(PATH, transform=transform, get_path=True)
    loader = torch.utils.data.DataLoader(dataset, num_workers=20, batch_size=512, pin_memory=True)  
    
    info = {}

    with torch.no_grad():
        for x, paths in tqdm(loader):
            logits = model(x.to(device))
            sigmoid_logits = torch.sigmoid(logits)
            predictions = (sigmoid_logits > 0.5).cpu().numpy().astype(bool)
            for path, p in zip(paths, predictions):
                k = os.path.splitext(os.path.basename(path))[0]
                info[k] = labels[p].tolist()
                
    return info

In [None]:
!ls ffhq

In [None]:
with open(f'ffhq/real_labeled.json', 'r') as f:
    realinfo = json.load(f)

In [None]:
m = []
b = []
mb = []

In [None]:
for k, v in tqdm(realinfo.items()):
    l = v['image']['labels']
    hasm = 'Mustache' in l
    hasb = 'No_Beard' not in l
    if hasm:
        m.append(k)
    if hasb:
        b.append(k)
    if hasm or hasb:
        mb.append(k)

In [None]:
with open(f'ffhq/mustache/real_mustache_labeled.json', 'w') as f:
    json.dump({k: realinfo[k] for k in m}, f)

In [None]:
with open(f'ffhq/mustache/real_beard_labeled.json', 'w') as f:
    json.dump({k: realinfo[k] for k in b}, f)

In [None]:
with open(f'ffhq/mustache/real_mustache_beard_labeled.json', 'w') as f:
    json.dump({k: realinfo[k] for k in mb}, f)

In [None]:
len(m)

In [None]:
len(b)

In [None]:
len(mb)

In [None]:
def perc_attr(n, *attrs):
    assert all(a in labels for a in attrs)
    with open(f'{n}.json', 'r') as f:
        info = json.load(f)
        return len([k for k, v in info.items() if all(a in v for a in attrs)]) / len(info)

In [None]:
def mus(n):
    c = 0
    cm = 0
    cb = 0
    with open(f'{n}.json', 'r') as f:
        info = json.load(f)
        for k, v in info.items():
            if 'Mustache' in v:
                cm += 1
            if 'No_Beard' not in v:
                cb += 1
            if 'Mustache' in v or 'No_Beard' not in v:
                c += 1
        return cm / float(len(info)), cb / float(len(info)), c / float(len(info))
        return len([k for k, v in info.items() if all(a in v for a in attrs)]) / len(info)

In [None]:
for k, v in data.items():
    if 'mustache' in k:
        print(k, *mus(k))
#         print(k, perc_attr(k, 'Mustache'), 1-perc_attr(k, 'No_Beard'))

In [None]:
mus('ffhq/clean')

In [None]:
perc_attr('ffhq/mustache/ours', 'Mustache')

In [None]:
perc_attr('ffhq/clean', 'Mustache'), 1-perc_attr('ffhq/clean', 'No_Beard')

In [None]:
perc_attr('ffhq/smiling/ours_stdcovariance_FIXED', 'Smiling')

In [None]:
perc_attr('ffhq/smiling/overfit_FIXED', 'Smiling')

In [None]:
perc_attr('ffhq/smiling/ours_stdcovariance', 'Smiling')

In [None]:
perc_attr('ffhq/smiling/ours', 'Smiling')

In [None]:
perc_attr('ffhq/smiling/overfit', 'Smiling')

In [None]:
perc_attr('ffhq/smiling/poisson', 'Smiling')

In [None]:
perc_attr('ffhq/smiling/naive', 'Smiling')

In [None]:
perc_attr('ffhq/smiling/laplace', 'Smiling')

In [None]:
for name, PATH in data.items():
    print(name)
    info = get_info_path(PATH)       
    save = name + '.json'
    os.makedirs(os.path.split(save)[0], exist_ok=True)
                
    with open(save, 'w') as f:
        json.dump(info, f)
        
    print(f'saved to {save}')

In [None]:
root = '/data/vision/torralba/ganprojects/placesgan/tracer/utils/samples/edited/'

In [None]:
def get_labels(img):
    with torch.no_grad():
        logits = model(transform(img)[None].to(device))
        sigmoid_logits = torch.sigmoid(logits)
        predictions = (sigmoid_logits > 0.5).squeeze().cpu().numpy().astype(bool)

    return labels[predictions].tolist()

In [None]:
root = '/data/vision/torralba/ganprojects/placesgan/tracer/utils/samples/clean/'
img = Image.open(os.path.join(root, 'clean_1855.png')).resize([256, 256])
img

In [None]:
get_labels(img)

In [None]:
root = '/data/vision/torralba/ganprojects/placesgan/tracer/utils/samples/edited_smiles/'
img = Image.open(os.path.join(root, 'edited_smiles_1855.png')).resize([256, 256])
img

In [None]:
get_labels(img)

In [None]:
len([k for k,v  in info.items() if 'Mustache' in v])

In [None]:
[k for k,v  in info.items() if 'Mustache' in v]