In [29]:
from itertools import islice
import matplotlib.pyplot as plt

from tqdm import tqdm
import torch
from torchvision import models, transforms, datasets

In [2]:
inception_transforms = transforms.Compose([
            transforms.Resize(299),
            #transforms.CenterCrop(constants.INPUT_SIZE),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

In [3]:
unlabeled_celeba = datasets.ImageFolder('imgs_by_label/celeba_unlabeled/', inception_transforms)
unlabeled_celeba

Dataset ImageFolder
    Number of datapoints: 734
    Root location: imgs_by_label/celeba_unlabeled/

In [4]:
unlabeled_celeba_loader = torch.utils.data.DataLoader(
        unlabeled_celeba, batch_size=1, shuffle=True, num_workers=1)

In [30]:
def get_inception_features(img_iter):
    inception_net = models.inception_v3(pretrained=True, transform_input=True)
    
    layers_to_grab = [inception_net.Conv2d_1a_3x3, inception_net.Conv2d_2b_3x3,
                 inception_net.Conv2d_3b_1x1, inception_net.Mixed_5d, inception_net.Mixed_6e,
                 inception_net.Mixed_7c, inception_net.fc]
    
    layer_features = [None for i in range(len(layers_to_grab))]
    
    
    def hook_fn(self, inp, out, container, layer_index):
        #print(layer_index, inp[0].shape, out.shape)

        num_channels = out.shape[1]
        if len(out.shape) > 2:
            #Warning: this will break for batch sizes > 1
            cur_features = out.squeeze().permute(1,2,0).reshape(-1, num_channels)
        else:
            cur_features = out

        if container[layer_index] is None:
            container[layer_index] = cur_features
        else:
            container[layer_index] = torch.cat((container[layer_index], cur_features))

    def hook_fn_i(container, i):
        return lambda self, inp, out: hook_fn(self, inp, out, container, i)

    for i, layer in enumerate(layers_to_grab):
        layer.register_forward_hook(hook_fn_i(layer_features, i))
        
    inception_net.eval()

    for x,y in tqdm(img_iter):
        print(x.shape, y)
        #plt.imshow((x).squeeze().permute(1, 2, 0))
        #plt.show()
        out = inception_net(x)
        #print(out.sum())
        
    return layer_features

In [None]:
unlabeled_celeba_features = get_inception_features(unlabeled_celeba_loader)

  0%|          | 0/734 [00:00<?, ?it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  0%|          | 1/734 [00:00<06:33,  1.86it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  0%|          | 2/734 [00:01<06:33,  1.86it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  0%|          | 3/734 [00:01<06:21,  1.91it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 4/734 [00:02<06:21,  1.91it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 5/734 [00:02<06:25,  1.89it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 6/734 [00:03<06:25,  1.89it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 7/734 [00:03<06:26,  1.88it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 8/734 [00:04<06:34,  1.84it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 9/734 [00:04<06:40,  1.81it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|▏         | 10/734 [00:05<07:41,  1.57it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|▏         | 11/734 [00:06<07:58,  1.51it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 12/734 [00:07<08:02,  1.50it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 13/734 [00:07<08:33,  1.40it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 14/734 [00:08<08:26,  1.42it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 15/734 [00:09<08:12,  1.46it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 16/734 [00:10<09:02,  1.32it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 17/734 [00:11<10:20,  1.15it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 18/734 [00:12<10:23,  1.15it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 19/734 [00:12<10:13,  1.17it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 20/734 [00:13<10:11,  1.17it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 21/734 [00:14<09:58,  1.19it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 22/734 [00:15<10:47,  1.10it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 23/734 [00:16<11:07,  1.07it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 24/734 [00:17<11:31,  1.03it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 25/734 [00:18<11:46,  1.00it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▎         | 26/734 [00:19<11:35,  1.02it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▎         | 27/734 [00:20<11:48,  1.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 28/734 [00:21<12:13,  1.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 29/734 [00:23<12:58,  1.10s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 30/734 [00:24<12:56,  1.10s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 31/734 [00:25<12:48,  1.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 32/734 [00:26<12:56,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 33/734 [00:27<13:02,  1.12s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▍         | 34/734 [00:28<13:50,  1.19s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▍         | 35/734 [00:29<13:11,  1.13s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▍         | 36/734 [00:31<13:27,  1.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▌         | 37/734 [00:32<14:37,  1.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▌         | 38/734 [00:33<14:42,  1.27s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▌         | 39/734 [00:35<14:51,  1.28s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▌         | 40/734 [00:36<14:09,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 41/734 [00:37<12:51,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 42/734 [00:38<13:28,  1.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 43/734 [00:39<13:57,  1.21s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 44/734 [00:41<14:19,  1.25s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 45/734 [00:42<14:38,  1.27s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▋         | 46/734 [00:43<13:43,  1.20s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▋         | 47/734 [00:45<14:44,  1.29s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 48/734 [00:46<13:36,  1.19s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 49/734 [00:46<12:29,  1.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 50/734 [00:47<12:26,  1.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 51/734 [00:49<12:56,  1.14s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 52/734 [00:50<12:33,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 53/734 [00:51<12:36,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 54/734 [00:52<13:08,  1.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 55/734 [00:53<13:17,  1.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 56/734 [00:55<13:40,  1.21s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 57/734 [00:56<13:28,  1.19s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 58/734 [00:57<13:43,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 59/734 [00:58<13:54,  1.24s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 60/734 [01:00<13:40,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 61/734 [01:01<14:09,  1.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 62/734 [01:02<13:41,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▊         | 63/734 [01:03<13:29,  1.21s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▊         | 64/734 [01:04<13:34,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 65/734 [01:06<14:28,  1.30s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 66/734 [01:07<13:36,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 67/734 [01:08<12:56,  1.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 68/734 [01:09<12:20,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 69/734 [01:11<14:10,  1.28s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|▉         | 70/734 [01:13<18:40,  1.69s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|▉         | 71/734 [01:16<22:01,  1.99s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|▉         | 72/734 [01:18<22:06,  2.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|▉         | 73/734 [01:20<20:23,  1.85s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|█         | 74/734 [01:21<18:00,  1.64s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|█         | 75/734 [01:22<16:15,  1.48s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|█         | 76/734 [01:23<15:41,  1.43s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|█         | 77/734 [01:24<14:46,  1.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 78/734 [01:26<16:35,  1.52s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 79/734 [01:28<18:34,  1.70s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 80/734 [01:31<20:23,  1.87s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 81/734 [01:33<20:59,  1.93s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 82/734 [01:35<20:50,  1.92s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█▏        | 83/734 [01:36<20:19,  1.87s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█▏        | 84/734 [01:38<20:46,  1.92s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 85/734 [01:40<20:50,  1.93s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 86/734 [01:42<20:27,  1.89s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 87/734 [01:44<20:18,  1.88s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 88/734 [01:46<19:20,  1.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 89/734 [01:47<17:25,  1.62s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 90/734 [01:48<17:16,  1.61s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 91/734 [01:51<19:24,  1.81s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 92/734 [01:52<18:24,  1.72s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 93/734 [01:55<21:02,  1.97s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 94/734 [01:57<21:43,  2.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 95/734 [01:59<23:03,  2.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 96/734 [02:02<23:36,  2.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 97/734 [02:03<22:00,  2.07s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 98/734 [02:05<21:34,  2.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 99/734 [02:07<21:42,  2.05s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 14%|█▎        | 100/734 [02:10<22:03,  2.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


In [24]:
[lf.shape for lf in unlabeled_celeba_features]

[torch.Size([66603, 32]),
 torch.Size([64827, 64]),
 torch.Size([15987, 80]),
 torch.Size([3675, 288]),
 torch.Size([867, 768]),
 torch.Size([192, 2048]),
 torch.Size([3, 1000])]