In [37]:
from itertools import islice
import matplotlib.pyplot as plt
import sys

from tqdm import tqdm_notebook
import torch
from torchvision import models, transforms, datasets

In [26]:
device = torch.device(2)
device.type

'cuda'

In [2]:
inception_transforms = transforms.Compose([
            transforms.Resize(299),
            #transforms.CenterCrop(constants.INPUT_SIZE),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

In [15]:
unlabeled_celeba = datasets.ImageFolder('imgs_by_label/celeba_unlabeled/', inception_transforms)
print(unlabeled_celeba)
unlabeled_celeba_loader = torch.utils.data.DataLoader(
        unlabeled_celeba, batch_size=1, shuffle=True, num_workers=1)

labeled_celeba = datasets.ImageFolder('imgs_by_label/celeba_labeled/', inception_transforms)
print(labeled_celeba)
labeled_celeba_loader = torch.utils.data.DataLoader(
        labeled_celeba, batch_size=1, shuffle=True, num_workers=1)

labeled_progan = datasets.ImageFolder('imgs_by_label/progan_labeled/', inception_transforms)
print(labeled_progan)
labeled_progan_loader = torch.utils.data.DataLoader(
        labeled_progan, batch_size=1, shuffle=True, num_workers=1)

Dataset ImageFolder
    Number of datapoints: 734
    Root Location: imgs_by_label/celeba_unlabeled/
    Transforms (if any): Compose(
                             Resize(size=299, interpolation=PIL.Image.BILINEAR)
                             ToTensor()
                         )
    Target Transforms (if any): None
Dataset ImageFolder
    Number of datapoints: 66
    Root Location: imgs_by_label/celeba_labeled/
    Transforms (if any): Compose(
                             Resize(size=299, interpolation=PIL.Image.BILINEAR)
                             ToTensor()
                         )
    Target Transforms (if any): None
Dataset ImageFolder
    Number of datapoints: 2233
    Root Location: imgs_by_label/progan_labeled/
    Transforms (if any): Compose(
                             Resize(size=299, interpolation=PIL.Image.BILINEAR)
                             ToTensor()
                         )
    Target Transforms (if any): None


In [31]:
def get_inception_features(img_iter, device=None):
    inception_net = models.inception_v3(pretrained=True, transform_input=True)
    
    layers_to_grab = [inception_net.Conv2d_1a_3x3, inception_net.Conv2d_2b_3x3,
                 inception_net.Conv2d_3b_1x1, inception_net.Mixed_5d, inception_net.Mixed_6e,
                 inception_net.Mixed_7c, inception_net.fc]
    
    layer_features = [None for i in range(len(layers_to_grab))]
    
    
    def hook_fn(self, inp, out, container, layer_index):
        #print(layer_index, inp[0].shape, out.shape)

        num_channels = out.shape[1]
        if len(out.shape) > 2:
            #Warning: this will break for batch sizes > 1
            cur_features = out.squeeze().permute(1,2,0).reshape(-1, num_channels)
        else:
            cur_features = out

        if container[layer_index] is None:
            container[layer_index] = [cur_features]
        else:
            #container[layer_index] = torch.cat((container[layer_index], cur_features))
            container[layer_index].append(cur_features)

    def hook_fn_i(container, i):
        return lambda self, inp, out: hook_fn(self, inp, out, container, i)

    for i, layer in enumerate(layers_to_grab):
        layer.register_forward_hook(hook_fn_i(layer_features, i))
        
    inception_net.eval()

    for x,y in tqdm_notebook(img_iter):
        #print(x.shape, y)
        #plt.imshow((x).squeeze().permute(1, 2, 0))
        #plt.show()
        out = inception_net(x.to(device))
        del(out)
        #print(out.sum())
        
    return layer_features

In [32]:
#unlabeled_celeba_features = get_inception_features(unlabeled_celeba_loader)

#flat_unlabeled_celeba_features = [torch.cat(lf, dim=0) for lf in unlabeled_celeba_features]

#print([(len(lf), lf[0].shape) for lf in unlabeled_celeba_features])
#print([lf.shape for lf in flat_unlabeled_celeba_features])

#torch.save(flat_unlabeled_celeba_features, 'flat_unlabeled_celeba_features.pt')
# The features from these 734 reference images are 8.9 gigs on disk, yikes!

HBox(children=(IntProgress(value=0, max=734), HTML(value='')))

In [None]:
#TODO: modify the get_inception_features function to have a "don't flatten" mode for these?
#Or just feed them in one at a time so we don't care about the flattening. (So only mess with
#it if we need batch size > 1)
#celeba_features = get_inception_features(labeled_celeba_loader)

In [None]:
#progan_features_50 = get_inception_features(labeled_progan_loader)

In [None]:
# Features (for single image): #layers x (H*W for that layer) x (C for that layer)
# Reference set (for N comparison images): # layers x (N*H*W for that layer) x (C for that layer)
def layerwise_nn_features(features, reference_set):
    assert(len(features) == len(reference_set))
    L = len(features)
    mean_layer_closest_dists = torch.zeros(L)
    
    for l in range(L):
        lf = features[l] #layer features
        rlf = reference_set[l] #reference layer features
        
        #layer is HxWxC
        #rlf[i] is NxC
        H,W,C = lf.shape
        N,C2 = rlf.shape
        assert(C == C2)

        x = lf.reshape(H*W, 1, C)
        cur_refs = rlf.reshape(1, N, C)

        diffs = x - cur_refs
        assert(diffs.shape == (H*W, N, C))

        sqr_dists = torch.sum(diffs**2, dim=2)
        assert(sqr_dists.shape == (H*W, N))

        min_sqr_dists = torch.min(sqr_dists, dim=1)
        assert(min_dists.shape == (H*W))
        
        min_dists = torch.sqrt(min_sqr_dists)
        assert(min_dists.shape == (H*W))
        
        mean_layer_closest_dists[l] = torch.mean(min_dists) 
    
    return mean_layer_closest_dists

In [None]:
labeled_celeba_x = []
labeled_celeba_y = []
# TODO: Pull % rated as real for each image!
for x,y in tqdm(labeled_celeba_loader):
    cur_features = layerwise_nn_features(x, unlabeled_celeba_features)
    labeled_celeba_x.append(cur_features)
    
    # Now pull the % label
    labeled_celeba_y.append(pct_real_votes)
    

In [None]:
labeled_progan_x = []
labeled_progan_y = []
# TODO: Pull % rated as real for each image!
for x,y in tqdm(labeled_progan_loader):
    cur_features = layerwise_nn_features(x, unlabeled_celeba_features)
    labeled_progan_x.append(cur_features)
    
    # Now pull the % label
    labeled_progan_y.append(pct_real_votes)
    

In [None]:
#TODO: Break the features/labels into train/val/test and train a logistic regression model;
#see how well it does out of sample!