In [29]:
from itertools import islice
import matplotlib.pyplot as plt

from tqdm import tqdm
import torch
from torchvision import models, transforms, datasets

In [2]:
inception_transforms = transforms.Compose([
            transforms.Resize(299),
            #transforms.CenterCrop(constants.INPUT_SIZE),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

In [3]:
unlabeled_celeba = datasets.ImageFolder('imgs_by_label/celeba_unlabeled/', inception_transforms)
print(unlabeled_celeba)
unlabeled_celeba_loader = torch.utils.data.DataLoader(
        unlabeled_celeba, batch_size=1, shuffle=True, num_workers=1)

labeled_celeba = datasets.ImageFolder('imgs_by_label/celeba_labeled/', inception_transforms)
print(labeled_celeba)
labeled_celeba_loader = torch.utils.data.DataLoader(
        labeled_celeba, batch_size=1, shuffle=True, num_workers=1)

labeled_progan = datasets.ImageFolder('imgs_by_label/progan_labeled/', inception_transforms)
print(labeled_progan)
labeled_progan_loader = torch.utils.data.DataLoader(
        labeled_progan, batch_size=1, shuffle=True, num_workers=1)

Dataset ImageFolder
    Number of datapoints: 734
    Root location: imgs_by_label/celeba_unlabeled/

In [30]:
def get_inception_features(img_iter):
    inception_net = models.inception_v3(pretrained=True, transform_input=True)
    
    layers_to_grab = [inception_net.Conv2d_1a_3x3, inception_net.Conv2d_2b_3x3,
                 inception_net.Conv2d_3b_1x1, inception_net.Mixed_5d, inception_net.Mixed_6e,
                 inception_net.Mixed_7c, inception_net.fc]
    
    layer_features = [None for i in range(len(layers_to_grab))]
    
    
    def hook_fn(self, inp, out, container, layer_index):
        #print(layer_index, inp[0].shape, out.shape)

        num_channels = out.shape[1]
        if len(out.shape) > 2:
            #Warning: this will break for batch sizes > 1
            cur_features = out.squeeze().permute(1,2,0).reshape(-1, num_channels)
        else:
            cur_features = out

        if container[layer_index] is None:
            container[layer_index] = cur_features
        else:
            container[layer_index] = torch.cat((container[layer_index], cur_features))

    def hook_fn_i(container, i):
        return lambda self, inp, out: hook_fn(self, inp, out, container, i)

    for i, layer in enumerate(layers_to_grab):
        layer.register_forward_hook(hook_fn_i(layer_features, i))
        
    inception_net.eval()

    for x,y in tqdm(img_iter):
        print(x.shape, y)
        #plt.imshow((x).squeeze().permute(1, 2, 0))
        #plt.show()
        out = inception_net(x)
        #print(out.sum())
        
    return layer_features

In [None]:
unlabeled_celeba_features = get_inception_features(unlabeled_celeba_loader)

  0%|          | 0/734 [00:00<?, ?it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  0%|          | 1/734 [00:00<06:33,  1.86it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  0%|          | 2/734 [00:01<06:33,  1.86it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  0%|          | 3/734 [00:01<06:21,  1.91it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 4/734 [00:02<06:21,  1.91it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 5/734 [00:02<06:25,  1.89it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 6/734 [00:03<06:25,  1.89it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 7/734 [00:03<06:26,  1.88it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 8/734 [00:04<06:34,  1.84it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|          | 9/734 [00:04<06:40,  1.81it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|▏         | 10/734 [00:05<07:41,  1.57it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  1%|▏         | 11/734 [00:06<07:58,  1.51it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 12/734 [00:07<08:02,  1.50it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 13/734 [00:07<08:33,  1.40it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 14/734 [00:08<08:26,  1.42it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 15/734 [00:09<08:12,  1.46it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 16/734 [00:10<09:02,  1.32it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 17/734 [00:11<10:20,  1.15it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  2%|▏         | 18/734 [00:12<10:23,  1.15it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 19/734 [00:12<10:13,  1.17it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 20/734 [00:13<10:11,  1.17it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 21/734 [00:14<09:58,  1.19it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 22/734 [00:15<10:47,  1.10it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 23/734 [00:16<11:07,  1.07it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 24/734 [00:17<11:31,  1.03it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  3%|▎         | 25/734 [00:18<11:46,  1.00it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▎         | 26/734 [00:19<11:35,  1.02it/s]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▎         | 27/734 [00:20<11:48,  1.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 28/734 [00:21<12:13,  1.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 29/734 [00:23<12:58,  1.10s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 30/734 [00:24<12:56,  1.10s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 31/734 [00:25<12:48,  1.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 32/734 [00:26<12:56,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  4%|▍         | 33/734 [00:27<13:02,  1.12s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▍         | 34/734 [00:28<13:50,  1.19s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▍         | 35/734 [00:29<13:11,  1.13s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▍         | 36/734 [00:31<13:27,  1.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▌         | 37/734 [00:32<14:37,  1.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▌         | 38/734 [00:33<14:42,  1.27s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▌         | 39/734 [00:35<14:51,  1.28s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  5%|▌         | 40/734 [00:36<14:09,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 41/734 [00:37<12:51,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 42/734 [00:38<13:28,  1.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 43/734 [00:39<13:57,  1.21s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 44/734 [00:41<14:19,  1.25s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▌         | 45/734 [00:42<14:38,  1.27s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▋         | 46/734 [00:43<13:43,  1.20s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  6%|▋         | 47/734 [00:45<14:44,  1.29s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 48/734 [00:46<13:36,  1.19s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 49/734 [00:46<12:29,  1.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 50/734 [00:47<12:26,  1.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 51/734 [00:49<12:56,  1.14s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 52/734 [00:50<12:33,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 53/734 [00:51<12:36,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 54/734 [00:52<13:08,  1.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  7%|▋         | 55/734 [00:53<13:17,  1.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 56/734 [00:55<13:40,  1.21s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 57/734 [00:56<13:28,  1.19s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 58/734 [00:57<13:43,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 59/734 [00:58<13:54,  1.24s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 60/734 [01:00<13:40,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 61/734 [01:01<14:09,  1.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  8%|▊         | 62/734 [01:02<13:41,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▊         | 63/734 [01:03<13:29,  1.21s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▊         | 64/734 [01:04<13:34,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 65/734 [01:06<14:28,  1.30s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 66/734 [01:07<13:36,  1.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 67/734 [01:08<12:56,  1.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 68/734 [01:09<12:20,  1.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


  9%|▉         | 69/734 [01:11<14:10,  1.28s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|▉         | 70/734 [01:13<18:40,  1.69s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|▉         | 71/734 [01:16<22:01,  1.99s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|▉         | 72/734 [01:18<22:06,  2.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|▉         | 73/734 [01:20<20:23,  1.85s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|█         | 74/734 [01:21<18:00,  1.64s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|█         | 75/734 [01:22<16:15,  1.48s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|█         | 76/734 [01:23<15:41,  1.43s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 10%|█         | 77/734 [01:24<14:46,  1.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 78/734 [01:26<16:35,  1.52s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 79/734 [01:28<18:34,  1.70s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 80/734 [01:31<20:23,  1.87s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 81/734 [01:33<20:59,  1.93s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█         | 82/734 [01:35<20:50,  1.92s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█▏        | 83/734 [01:36<20:19,  1.87s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 11%|█▏        | 84/734 [01:38<20:46,  1.92s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 85/734 [01:40<20:50,  1.93s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 86/734 [01:42<20:27,  1.89s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 87/734 [01:44<20:18,  1.88s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 88/734 [01:46<19:20,  1.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 89/734 [01:47<17:25,  1.62s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 90/734 [01:48<17:16,  1.61s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 12%|█▏        | 91/734 [01:51<19:24,  1.81s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 92/734 [01:52<18:24,  1.72s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 93/734 [01:55<21:02,  1.97s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 94/734 [01:57<21:43,  2.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 95/734 [01:59<23:03,  2.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 96/734 [02:02<23:36,  2.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 97/734 [02:03<22:00,  2.07s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 98/734 [02:05<21:34,  2.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 13%|█▎        | 99/734 [02:07<21:42,  2.05s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 14%|█▎        | 100/734 [02:10<22:03,  2.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 14%|█▍        | 101/734 [02:12<22:47,  2.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 14%|█▍        | 102/734 [02:14<22:02,  2.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 14%|█▍        | 103/734 [02:15<19:52,  1.89s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 14%|█▍        | 104/734 [02:17<20:15,  1.93s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 14%|█▍        | 105/734 [02:19<19:06,  1.82s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 14%|█▍        | 106/734 [02:21<18:54,  1.81s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 15%|█▍        | 107/734 [02:22<18:37,  1.78s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 15%|█▍        | 108/734 [02:24<18:37,  1.79s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 15%|█▍        | 109/734 [02:27<22:14,  2.14s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 15%|█▍        | 110/734 [02:30<25:46,  2.48s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 15%|█▌        | 111/734 [02:34<29:09,  2.81s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 15%|█▌        | 112/734 [02:35<24:53,  2.40s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 15%|█▌        | 113/734 [02:37<22:09,  2.14s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 16%|█▌        | 114/734 [02:39<21:19,  2.06s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 16%|█▌        | 115/734 [02:40<19:58,  1.94s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 16%|█▌        | 116/734 [02:42<19:56,  1.94s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 16%|█▌        | 117/734 [02:44<19:00,  1.85s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 16%|█▌        | 118/734 [02:46<18:42,  1.82s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 16%|█▌        | 119/734 [02:48<19:22,  1.89s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 16%|█▋        | 120/734 [02:50<19:26,  1.90s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 16%|█▋        | 121/734 [02:52<21:38,  2.12s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 17%|█▋        | 122/734 [02:54<20:07,  1.97s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 17%|█▋        | 123/734 [02:56<18:42,  1.84s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 17%|█▋        | 124/734 [02:58<20:00,  1.97s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 17%|█▋        | 125/734 [03:00<19:34,  1.93s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 17%|█▋        | 126/734 [03:01<18:37,  1.84s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 17%|█▋        | 127/734 [03:03<19:31,  1.93s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 17%|█▋        | 128/734 [03:05<18:18,  1.81s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 18%|█▊        | 129/734 [03:07<18:22,  1.82s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 18%|█▊        | 130/734 [03:09<20:43,  2.06s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 18%|█▊        | 131/734 [03:11<20:31,  2.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 18%|█▊        | 132/734 [03:14<21:55,  2.18s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 18%|█▊        | 133/734 [03:16<22:40,  2.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 18%|█▊        | 134/734 [03:21<29:08,  2.91s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 18%|█▊        | 135/734 [03:25<31:36,  3.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 19%|█▊        | 136/734 [03:32<44:50,  4.50s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 19%|█▊        | 137/734 [03:38<48:35,  4.88s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 19%|█▉        | 138/734 [03:43<48:07,  4.84s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 19%|█▉        | 139/734 [03:46<44:27,  4.48s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 19%|█▉        | 140/734 [03:49<40:17,  4.07s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 19%|█▉        | 141/734 [03:52<34:19,  3.47s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 19%|█▉        | 142/734 [03:54<31:12,  3.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 19%|█▉        | 143/734 [03:58<34:42,  3.52s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 20%|█▉        | 144/734 [04:05<43:48,  4.46s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 20%|█▉        | 145/734 [04:07<37:44,  3.85s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 20%|█▉        | 146/734 [04:09<32:10,  3.28s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 20%|██        | 147/734 [04:11<28:34,  2.92s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 20%|██        | 148/734 [04:14<26:59,  2.76s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 20%|██        | 149/734 [04:17<27:03,  2.77s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 20%|██        | 150/734 [04:19<26:21,  2.71s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 21%|██        | 151/734 [04:22<26:42,  2.75s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 21%|██        | 152/734 [04:25<27:12,  2.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 21%|██        | 153/734 [04:28<27:58,  2.89s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 21%|██        | 154/734 [04:30<26:15,  2.72s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 21%|██        | 155/734 [04:34<28:46,  2.98s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 21%|██▏       | 156/734 [04:36<25:54,  2.69s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 21%|██▏       | 157/734 [04:41<32:41,  3.40s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 22%|██▏       | 158/734 [04:46<35:42,  3.72s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 22%|██▏       | 159/734 [04:49<34:28,  3.60s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 22%|██▏       | 160/734 [04:53<35:59,  3.76s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 22%|██▏       | 161/734 [04:57<35:38,  3.73s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 22%|██▏       | 162/734 [05:02<41:29,  4.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 22%|██▏       | 163/734 [05:07<41:50,  4.40s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 22%|██▏       | 164/734 [05:10<37:14,  3.92s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 22%|██▏       | 165/734 [05:13<34:23,  3.63s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 23%|██▎       | 166/734 [05:15<31:31,  3.33s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 23%|██▎       | 167/734 [05:19<31:01,  3.28s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 23%|██▎       | 168/734 [05:21<27:32,  2.92s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 23%|██▎       | 169/734 [05:24<28:13,  3.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 23%|██▎       | 170/734 [05:28<31:31,  3.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 23%|██▎       | 171/734 [05:33<37:28,  3.99s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 23%|██▎       | 172/734 [05:38<38:06,  4.07s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 24%|██▎       | 173/734 [05:41<35:32,  3.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 24%|██▎       | 174/734 [05:45<37:46,  4.05s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 24%|██▍       | 175/734 [05:50<38:18,  4.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 24%|██▍       | 176/734 [05:54<39:34,  4.25s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 24%|██▍       | 177/734 [05:57<35:17,  3.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 24%|██▍       | 178/734 [06:00<31:43,  3.42s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 24%|██▍       | 179/734 [06:03<31:52,  3.45s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 25%|██▍       | 180/734 [06:06<29:17,  3.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 25%|██▍       | 181/734 [06:08<27:13,  2.95s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 25%|██▍       | 182/734 [06:12<28:58,  3.15s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 25%|██▍       | 183/734 [06:15<29:44,  3.24s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 25%|██▌       | 184/734 [06:20<34:32,  3.77s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 25%|██▌       | 185/734 [06:24<35:17,  3.86s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 25%|██▌       | 186/734 [06:28<35:48,  3.92s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 25%|██▌       | 187/734 [06:31<33:05,  3.63s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 26%|██▌       | 188/734 [06:34<31:26,  3.45s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 26%|██▌       | 189/734 [06:37<29:10,  3.21s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 26%|██▌       | 190/734 [06:41<30:58,  3.42s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 26%|██▌       | 191/734 [06:44<31:21,  3.47s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 26%|██▌       | 192/734 [06:47<30:03,  3.33s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 26%|██▋       | 193/734 [06:50<27:47,  3.08s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 26%|██▋       | 194/734 [06:52<25:40,  2.85s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 27%|██▋       | 195/734 [06:55<24:29,  2.73s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 27%|██▋       | 196/734 [06:57<24:12,  2.70s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 27%|██▋       | 197/734 [07:00<24:17,  2.71s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 27%|██▋       | 198/734 [07:03<24:29,  2.74s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 27%|██▋       | 199/734 [07:06<24:56,  2.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 27%|██▋       | 200/734 [07:09<25:06,  2.82s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 27%|██▋       | 201/734 [07:12<26:44,  3.01s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 28%|██▊       | 202/734 [07:16<27:52,  3.14s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 28%|██▊       | 203/734 [07:19<28:50,  3.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 28%|██▊       | 204/734 [07:22<28:22,  3.21s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 28%|██▊       | 205/734 [07:25<28:30,  3.23s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 28%|██▊       | 206/734 [07:28<27:28,  3.12s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 28%|██▊       | 207/734 [07:31<26:29,  3.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 28%|██▊       | 208/734 [07:34<26:28,  3.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 28%|██▊       | 209/734 [07:37<26:29,  3.03s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 29%|██▊       | 210/734 [07:40<26:32,  3.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 29%|██▊       | 211/734 [07:43<26:18,  3.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 29%|██▉       | 212/734 [07:46<26:15,  3.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 29%|██▉       | 213/734 [07:49<24:48,  2.86s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 29%|██▉       | 214/734 [07:51<23:50,  2.75s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 29%|██▉       | 215/734 [07:54<22:45,  2.63s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 29%|██▉       | 216/734 [07:56<21:46,  2.52s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 30%|██▉       | 217/734 [07:58<21:34,  2.50s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 30%|██▉       | 218/734 [08:01<21:35,  2.51s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 30%|██▉       | 219/734 [08:04<22:37,  2.64s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 30%|██▉       | 220/734 [08:07<22:55,  2.68s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 30%|███       | 221/734 [08:09<22:31,  2.64s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 30%|███       | 222/734 [08:12<22:30,  2.64s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 30%|███       | 223/734 [08:14<22:25,  2.63s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 31%|███       | 224/734 [08:17<22:23,  2.63s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 31%|███       | 225/734 [08:20<22:40,  2.67s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 31%|███       | 226/734 [08:22<22:06,  2.61s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 31%|███       | 227/734 [08:25<22:30,  2.66s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 31%|███       | 228/734 [08:28<23:00,  2.73s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 31%|███       | 229/734 [08:31<24:31,  2.91s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 31%|███▏      | 230/734 [08:35<25:32,  3.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 31%|███▏      | 231/734 [08:38<26:00,  3.10s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 32%|███▏      | 232/734 [08:42<28:02,  3.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 32%|███▏      | 233/734 [08:46<29:29,  3.53s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 32%|███▏      | 234/734 [08:49<28:52,  3.46s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 32%|███▏      | 235/734 [08:53<30:56,  3.72s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 32%|███▏      | 236/734 [08:57<31:27,  3.79s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 32%|███▏      | 237/734 [09:01<30:36,  3.70s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 32%|███▏      | 238/734 [09:04<29:11,  3.53s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 33%|███▎      | 239/734 [09:09<31:58,  3.88s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 33%|███▎      | 240/734 [09:12<29:54,  3.63s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 33%|███▎      | 241/734 [09:15<28:31,  3.47s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 33%|███▎      | 242/734 [09:18<26:50,  3.27s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 33%|███▎      | 243/734 [09:22<29:16,  3.58s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 33%|███▎      | 244/734 [09:25<27:41,  3.39s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 33%|███▎      | 245/734 [09:28<26:53,  3.30s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 34%|███▎      | 246/734 [09:32<29:28,  3.62s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 34%|███▎      | 247/734 [09:37<31:11,  3.84s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 34%|███▍      | 248/734 [09:41<32:12,  3.98s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 34%|███▍      | 249/734 [09:43<28:37,  3.54s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 34%|███▍      | 250/734 [09:46<26:06,  3.24s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 34%|███▍      | 251/734 [09:50<27:38,  3.43s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 34%|███▍      | 252/734 [09:52<25:37,  3.19s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 34%|███▍      | 253/734 [09:55<24:31,  3.06s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 35%|███▍      | 254/734 [09:58<24:03,  3.01s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 35%|███▍      | 255/734 [10:03<27:31,  3.45s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 35%|███▍      | 256/734 [10:06<27:06,  3.40s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 35%|███▌      | 257/734 [10:11<31:46,  4.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 35%|███▌      | 258/734 [10:14<28:36,  3.61s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 35%|███▌      | 259/734 [10:17<26:47,  3.38s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 35%|███▌      | 260/734 [10:21<29:50,  3.78s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 36%|███▌      | 261/734 [10:24<27:45,  3.52s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 36%|███▌      | 262/734 [10:29<29:10,  3.71s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 36%|███▌      | 263/734 [10:33<29:46,  3.79s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 36%|███▌      | 264/734 [10:36<28:13,  3.60s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 36%|███▌      | 265/734 [10:39<27:19,  3.50s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 36%|███▌      | 266/734 [10:44<31:22,  4.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 36%|███▋      | 267/734 [10:49<33:58,  4.36s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 37%|███▋      | 268/734 [10:54<34:00,  4.38s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 37%|███▋      | 269/734 [10:57<30:12,  3.90s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 37%|███▋      | 270/734 [10:59<27:20,  3.54s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 37%|███▋      | 271/734 [11:02<25:52,  3.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 37%|███▋      | 272/734 [11:06<26:00,  3.38s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 37%|███▋      | 273/734 [11:09<27:05,  3.53s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 37%|███▋      | 274/734 [11:21<46:08,  6.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 37%|███▋      | 275/734 [11:28<47:36,  6.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 38%|███▊      | 276/734 [11:35<50:14,  6.58s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 38%|███▊      | 277/734 [11:42<50:53,  6.68s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 38%|███▊      | 278/734 [11:45<42:24,  5.58s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 38%|███▊      | 279/734 [11:48<36:43,  4.84s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 38%|███▊      | 280/734 [11:52<34:09,  4.51s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 38%|███▊      | 281/734 [11:55<30:33,  4.05s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 38%|███▊      | 282/734 [11:58<28:23,  3.77s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 39%|███▊      | 283/734 [12:03<30:10,  4.01s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 39%|███▊      | 284/734 [12:06<28:30,  3.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 39%|███▉      | 285/734 [12:11<29:45,  3.98s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 39%|███▉      | 286/734 [12:16<32:27,  4.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 39%|███▉      | 287/734 [12:19<29:24,  3.95s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 39%|███▉      | 288/734 [12:22<27:35,  3.71s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 39%|███▉      | 289/734 [12:25<26:33,  3.58s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 40%|███▉      | 290/734 [12:31<30:13,  4.08s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 40%|███▉      | 291/734 [12:34<28:52,  3.91s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 40%|███▉      | 292/734 [12:38<28:45,  3.90s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 40%|███▉      | 293/734 [12:42<28:30,  3.88s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 40%|████      | 294/734 [12:45<27:24,  3.74s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 40%|████      | 295/734 [12:50<30:51,  4.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 40%|████      | 296/734 [12:57<35:31,  4.87s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 40%|████      | 297/734 [13:00<32:04,  4.40s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 41%|████      | 298/734 [13:03<29:17,  4.03s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 41%|████      | 299/734 [13:07<29:00,  4.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 41%|████      | 300/734 [13:12<31:34,  4.37s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 41%|████      | 301/734 [13:15<28:31,  3.95s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 41%|████      | 302/734 [13:19<26:33,  3.69s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 41%|████▏     | 303/734 [13:23<28:28,  3.96s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 41%|████▏     | 304/734 [13:27<27:45,  3.87s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 42%|████▏     | 305/734 [13:31<28:16,  3.96s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 42%|████▏     | 306/734 [13:35<27:57,  3.92s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 42%|████▏     | 307/734 [13:39<29:16,  4.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 42%|████▏     | 308/734 [13:43<27:09,  3.82s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 42%|████▏     | 309/734 [13:48<30:11,  4.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 42%|████▏     | 310/734 [13:51<28:36,  4.05s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 42%|████▏     | 311/734 [13:55<28:22,  4.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 43%|████▎     | 312/734 [13:58<26:16,  3.73s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 43%|████▎     | 313/734 [14:02<25:44,  3.67s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 43%|████▎     | 314/734 [14:06<27:22,  3.91s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 43%|████▎     | 315/734 [14:10<26:45,  3.83s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 43%|████▎     | 316/734 [14:14<27:54,  4.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 43%|████▎     | 317/734 [14:22<34:59,  5.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 43%|████▎     | 318/734 [14:28<36:27,  5.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 43%|████▎     | 319/734 [14:33<35:47,  5.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 44%|████▎     | 320/734 [14:37<33:29,  4.86s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 44%|████▎     | 321/734 [14:42<34:01,  4.94s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 44%|████▍     | 322/734 [14:48<36:28,  5.31s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 44%|████▍     | 323/734 [14:55<39:36,  5.78s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 44%|████▍     | 324/734 [14:59<36:21,  5.32s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 44%|████▍     | 325/734 [15:03<32:42,  4.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 44%|████▍     | 326/734 [15:09<34:36,  5.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 45%|████▍     | 327/734 [15:13<32:27,  4.78s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 45%|████▍     | 328/734 [15:17<30:50,  4.56s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 45%|████▍     | 329/734 [15:20<29:01,  4.30s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 45%|████▍     | 330/734 [15:24<28:27,  4.23s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 45%|████▌     | 331/734 [15:29<28:53,  4.30s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 45%|████▌     | 332/734 [15:33<28:09,  4.20s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 45%|████▌     | 333/734 [15:36<25:54,  3.88s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 46%|████▌     | 334/734 [15:40<27:02,  4.06s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 46%|████▌     | 335/734 [15:45<28:49,  4.33s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 46%|████▌     | 336/734 [15:49<28:15,  4.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 46%|████▌     | 337/734 [15:55<29:49,  4.51s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 46%|████▌     | 338/734 [16:00<32:03,  4.86s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 46%|████▌     | 339/734 [16:05<31:20,  4.76s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 46%|████▋     | 340/734 [16:11<34:18,  5.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 46%|████▋     | 341/734 [16:16<33:20,  5.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 47%|████▋     | 342/734 [16:20<30:38,  4.69s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 47%|████▋     | 343/734 [16:23<27:46,  4.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 47%|████▋     | 344/734 [16:27<26:37,  4.10s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 47%|████▋     | 345/734 [16:31<27:08,  4.19s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 47%|████▋     | 346/734 [16:38<32:51,  5.08s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 47%|████▋     | 347/734 [16:46<38:01,  5.89s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 47%|████▋     | 348/734 [16:49<33:19,  5.18s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 48%|████▊     | 349/734 [16:54<32:14,  5.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 48%|████▊     | 350/734 [17:02<38:30,  6.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 48%|████▊     | 351/734 [17:12<45:45,  7.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 48%|████▊     | 352/734 [17:18<42:53,  6.74s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 48%|████▊     | 353/734 [17:22<37:15,  5.87s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 48%|████▊     | 354/734 [17:27<36:34,  5.77s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 48%|████▊     | 355/734 [17:39<47:40,  7.55s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 49%|████▊     | 356/734 [17:45<44:54,  7.13s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 49%|████▊     | 357/734 [17:50<40:54,  6.51s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 49%|████▉     | 358/734 [17:54<36:09,  5.77s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 49%|████▉     | 359/734 [17:59<34:32,  5.53s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 49%|████▉     | 360/734 [18:06<36:03,  5.78s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 49%|████▉     | 361/734 [18:09<32:09,  5.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 49%|████▉     | 362/734 [18:13<28:46,  4.64s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 49%|████▉     | 363/734 [18:16<26:27,  4.28s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 50%|████▉     | 364/734 [18:20<24:57,  4.05s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 50%|████▉     | 365/734 [18:25<26:10,  4.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 50%|████▉     | 366/734 [18:28<25:19,  4.13s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 50%|█████     | 367/734 [18:32<24:17,  3.97s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 50%|█████     | 368/734 [18:37<26:33,  4.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 50%|█████     | 369/734 [18:41<25:38,  4.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 50%|█████     | 370/734 [18:45<24:07,  3.98s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 51%|█████     | 371/734 [18:51<27:56,  4.62s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 51%|█████     | 372/734 [19:02<39:08,  6.49s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 51%|█████     | 373/734 [19:10<42:12,  7.01s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 51%|█████     | 374/734 [19:15<39:25,  6.57s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 51%|█████     | 375/734 [19:22<39:48,  6.65s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 51%|█████     | 376/734 [19:27<36:31,  6.12s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 51%|█████▏    | 377/734 [19:31<33:19,  5.60s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 51%|█████▏    | 378/734 [19:36<31:42,  5.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 52%|█████▏    | 379/734 [19:40<29:14,  4.94s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 52%|█████▏    | 380/734 [19:45<28:43,  4.87s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 52%|█████▏    | 381/734 [19:48<26:12,  4.46s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 52%|█████▏    | 382/734 [19:53<26:10,  4.46s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 52%|█████▏    | 383/734 [19:59<28:57,  4.95s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 52%|█████▏    | 384/734 [20:02<26:23,  4.52s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 52%|█████▏    | 385/734 [20:08<27:31,  4.73s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 53%|█████▎    | 386/734 [20:13<29:08,  5.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 53%|█████▎    | 387/734 [20:18<28:44,  4.97s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 53%|█████▎    | 388/734 [20:25<31:59,  5.55s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 53%|█████▎    | 389/734 [20:30<30:44,  5.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 53%|█████▎    | 390/734 [20:34<28:00,  4.89s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 53%|█████▎    | 391/734 [20:38<26:03,  4.56s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 53%|█████▎    | 392/734 [20:44<28:43,  5.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 54%|█████▎    | 393/734 [20:49<29:10,  5.13s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 54%|█████▎    | 394/734 [20:55<30:42,  5.42s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 54%|█████▍    | 395/734 [20:59<28:40,  5.07s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 54%|█████▍    | 396/734 [21:05<29:08,  5.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 54%|█████▍    | 397/734 [21:10<28:33,  5.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 54%|█████▍    | 398/734 [21:17<31:45,  5.67s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 54%|█████▍    | 399/734 [21:23<32:05,  5.75s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 54%|█████▍    | 400/734 [21:28<31:06,  5.59s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 55%|█████▍    | 401/734 [21:35<32:56,  5.94s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 55%|█████▍    | 402/734 [21:42<35:07,  6.35s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 55%|█████▍    | 403/734 [21:47<32:48,  5.95s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 55%|█████▌    | 404/734 [21:51<29:20,  5.33s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 55%|█████▌    | 405/734 [21:55<26:36,  4.85s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 55%|█████▌    | 406/734 [21:59<25:58,  4.75s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 55%|█████▌    | 407/734 [22:03<24:58,  4.58s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 56%|█████▌    | 408/734 [22:12<31:18,  5.76s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 56%|█████▌    | 409/734 [22:19<33:26,  6.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 56%|█████▌    | 410/734 [22:27<35:42,  6.61s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 56%|█████▌    | 411/734 [22:31<31:12,  5.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 56%|█████▌    | 412/734 [22:36<30:57,  5.77s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 56%|█████▋    | 413/734 [22:43<31:48,  5.95s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 56%|█████▋    | 414/734 [22:51<35:13,  6.60s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 57%|█████▋    | 415/734 [22:57<33:52,  6.37s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 57%|█████▋    | 416/734 [23:03<34:31,  6.52s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 57%|█████▋    | 417/734 [23:12<38:03,  7.20s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 57%|█████▋    | 418/734 [23:21<39:55,  7.58s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 57%|█████▋    | 419/734 [23:29<41:47,  7.96s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 57%|█████▋    | 420/734 [23:38<42:41,  8.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 57%|█████▋    | 421/734 [23:46<42:00,  8.05s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 57%|█████▋    | 422/734 [23:51<36:30,  7.02s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 58%|█████▊    | 423/734 [23:55<32:26,  6.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 58%|█████▊    | 424/734 [23:59<28:58,  5.61s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 58%|█████▊    | 425/734 [24:05<29:05,  5.65s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 58%|█████▊    | 426/734 [24:09<26:34,  5.18s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 58%|█████▊    | 427/734 [24:14<26:37,  5.20s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 58%|█████▊    | 428/734 [24:22<30:08,  5.91s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 58%|█████▊    | 429/734 [24:29<32:41,  6.43s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 59%|█████▊    | 430/734 [24:38<35:46,  7.06s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 59%|█████▊    | 431/734 [24:43<32:01,  6.34s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 59%|█████▉    | 432/734 [24:47<29:03,  5.77s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 59%|█████▉    | 433/734 [24:52<28:15,  5.63s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 59%|█████▉    | 434/734 [24:59<29:32,  5.91s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 59%|█████▉    | 435/734 [25:06<31:44,  6.37s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 59%|█████▉    | 436/734 [25:13<31:43,  6.39s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 60%|█████▉    | 437/734 [25:20<32:59,  6.66s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 60%|█████▉    | 438/734 [25:28<34:43,  7.04s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 60%|█████▉    | 439/734 [25:35<34:22,  6.99s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 60%|█████▉    | 440/734 [25:42<34:45,  7.09s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 60%|██████    | 441/734 [25:49<34:33,  7.08s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 60%|██████    | 442/734 [25:56<33:19,  6.85s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 60%|██████    | 443/734 [26:03<34:15,  7.06s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 60%|██████    | 444/734 [26:12<36:09,  7.48s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 61%|██████    | 445/734 [26:20<36:51,  7.65s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 61%|██████    | 446/734 [26:27<36:28,  7.60s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 61%|██████    | 447/734 [26:34<35:18,  7.38s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 61%|██████    | 448/734 [26:40<33:56,  7.12s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 61%|██████    | 449/734 [26:49<35:09,  7.40s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 61%|██████▏   | 450/734 [26:55<34:23,  7.26s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 61%|██████▏   | 451/734 [27:01<31:42,  6.72s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 62%|██████▏   | 452/734 [27:08<32:07,  6.84s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 62%|██████▏   | 453/734 [27:16<33:06,  7.07s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 62%|██████▏   | 454/734 [27:22<31:33,  6.76s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 62%|██████▏   | 455/734 [27:32<36:45,  7.91s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 62%|██████▏   | 456/734 [27:47<46:08,  9.96s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 62%|██████▏   | 457/734 [27:54<41:33,  9.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 62%|██████▏   | 458/734 [28:00<38:11,  8.30s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 63%|██████▎   | 459/734 [28:10<39:07,  8.54s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 63%|██████▎   | 460/734 [28:16<36:47,  8.06s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 63%|██████▎   | 461/734 [28:23<34:49,  7.65s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 63%|██████▎   | 462/734 [28:32<35:56,  7.93s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 63%|██████▎   | 463/734 [28:39<34:53,  7.72s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 63%|██████▎   | 464/734 [28:47<35:12,  7.82s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 63%|██████▎   | 465/734 [28:54<34:22,  7.67s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 63%|██████▎   | 466/734 [29:02<34:14,  7.67s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 64%|██████▎   | 467/734 [29:10<34:58,  7.86s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 64%|██████▍   | 468/734 [29:19<35:57,  8.11s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 64%|██████▍   | 469/734 [29:27<36:06,  8.17s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 64%|██████▍   | 470/734 [29:40<42:23,  9.64s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 64%|██████▍   | 471/734 [29:49<41:08,  9.38s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 64%|██████▍   | 472/734 [29:56<37:03,  8.49s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 64%|██████▍   | 473/734 [30:03<35:48,  8.23s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 65%|██████▍   | 474/734 [30:13<37:37,  8.68s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 65%|██████▍   | 475/734 [30:19<34:21,  7.96s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 65%|██████▍   | 476/734 [30:24<30:39,  7.13s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 65%|██████▍   | 477/734 [30:31<29:56,  6.99s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 65%|██████▌   | 478/734 [30:39<30:33,  7.16s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 65%|██████▌   | 479/734 [30:48<33:15,  7.83s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 65%|██████▌   | 480/734 [30:56<33:02,  7.80s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 66%|██████▌   | 481/734 [31:03<32:04,  7.61s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 66%|██████▌   | 482/734 [31:13<34:31,  8.22s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 66%|██████▌   | 483/734 [31:27<41:46,  9.99s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 66%|██████▌   | 484/734 [31:37<41:29,  9.96s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 66%|██████▌   | 485/734 [31:45<38:58,  9.39s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 66%|██████▌   | 486/734 [31:56<41:06,  9.94s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 66%|██████▋   | 487/734 [32:05<39:17,  9.55s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 66%|██████▋   | 488/734 [32:12<36:31,  8.91s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 67%|██████▋   | 489/734 [32:21<35:57,  8.81s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 67%|██████▋   | 490/734 [32:29<35:28,  8.72s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 67%|██████▋   | 491/734 [32:44<42:27, 10.48s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 67%|██████▋   | 492/734 [32:58<47:31, 11.78s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 67%|██████▋   | 493/734 [33:07<43:55, 10.93s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 67%|██████▋   | 494/734 [33:16<41:13, 10.31s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 67%|██████▋   | 495/734 [33:25<39:38,  9.95s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 68%|██████▊   | 496/734 [33:33<37:18,  9.41s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 68%|██████▊   | 497/734 [33:42<35:50,  9.08s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 68%|██████▊   | 498/734 [33:52<36:37,  9.31s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 68%|██████▊   | 499/734 [34:00<35:14,  9.00s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 68%|██████▊   | 500/734 [34:13<39:58, 10.25s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 68%|██████▊   | 501/734 [34:30<48:06, 12.39s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


 68%|██████▊   | 502/734 [34:44<49:21, 12.77s/it]

torch.Size([1, 3, 299, 299]) tensor([0])


In [24]:
[lf.shape for lf in unlabeled_celeba_features]

[torch.Size([66603, 32]),
 torch.Size([64827, 64]),
 torch.Size([15987, 80]),
 torch.Size([3675, 288]),
 torch.Size([867, 768]),
 torch.Size([192, 2048]),
 torch.Size([3, 1000])]

In [None]:
#TODO: modify the get_inception_features function to have a "don't flatten" mode for these?
#Or just feed them in one at a time so we don't care about the flattening. (So only mess with
#it if we need batch size > 1)
#celeba_features = get_inception_features(labeled_celeba_loader)

In [None]:
#progan_features_50 = get_inception_features(labeled_progan_loader)

In [None]:
# Features (for single image): #layers x (H*W for that layer) x (C for that layer)
# Reference set (for N comparison images): # layers x (N*H*W for that layer) x (C for that layer)
def layerwise_nn_features(features, reference_set):
    assert(len(features) == len(reference_set))
    L = len(features)
    mean_layer_closest_dists = torch.zeros(L)
    
    for l in range(L):
        lf = features[l] #layer features
        rlf = reference_set[l] #reference layer features
        
        #layer is HxWxC
        #rlf[i] is NxC
        H,W,C = lf.shape
        N,C2 = rlf.shape
        assert(C == C2)

        x = lf.reshape(H*W, 1, C)
        cur_refs = rlf.reshape(1, N, C)

        diffs = x - cur_refs
        assert(diffs.shape == (H*W, N, C))

        sqr_dists = torch.sum(diffs**2, dim=2)
        assert(sqr_dists.shape == (H*W, N))

        min_sqr_dists = torch.min(sqr_dists, dim=1)
        assert(min_dists.shape == (H*W))
        
        min_dists = torch.sqrt(min_sqr_dists)
        assert(min_dists.shape == (H*W))
        
        mean_layer_closest_dists[l] = torch.mean(min_dists) 
    
    return mean_layer_closest_dists

In [None]:
labeled_celeba_x = []
labeled_celeba_y = []
# TODO: Pull % rated as real for each image!
for x,y in tqdm(labeled_celeba_loader):
    cur_features = layerwise_nn_features(x, unlabeled_celeba_features)
    labeled_celeba_x.append(cur_features)
    
    # Now pull the % label
    labeled_celeba_y.append(pct_real_votes)
    

In [None]:
labeled_progan_x = []
labeled_progan_y = []
# TODO: Pull % rated as real for each image!
for x,y in tqdm(labeled_progan_loader):
    cur_features = layerwise_nn_features(x, unlabeled_celeba_features)
    labeled_progan_x.append(cur_features)
    
    # Now pull the % label
    labeled_progan_y.append(pct_real_votes)
    

In [None]:
#TODO: Break the features/labels into train/val/test and train a logistic regression model;
#see how well it does out of sample!