In [1]:
import sys
sys.path.append('/home/ubuntu/fast-autoaugment')
import torch
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel
from FastAutoAugment.networks import get_model
from theconf import Config as C
import random
import copy
from torchvision.transforms import transforms
import multiprocessing as mp
from tqdm import tqdm
import matplotlib.pyplot as plt

from FastAutoAugment.datasets import CIFAR10_mod 
from FastAutoAugment.metrics import accuracy
from FastAutoAugment.hardness_measures import AVH
from FastAutoAugment.augmentations import *

In [2]:
_ = C('/home/ubuntu/fast-autoaugment/confs/test.yaml')

In [3]:
model = get_model(C.get()['model'], 10, local_rank=-1)

In [5]:
save_path = '/efs-cotton/outputs/fast-autoaugment/confs/hardnessaware/wresnet28x10_rcifar-AA-rerun6/test.pth'
data = torch.load(save_path)
key = 'model' if 'model' in data else 'state_dict'

if 'epoch' not in data:
    model.load_state_dict(data)
else:
    if not isinstance(model, (DataParallel, DistributedDataParallel)):
        model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()})
    else:
        model.load_state_dict({k if 'module.' in k else 'module.'+k: v for k, v in data[key].items()})

In [7]:
data['epoch']

50

In [8]:
augmentations = ["ShearX", "ShearY", "Rotate", "AutoContrast", "Invert", 
                 "Equalize", "Solarize", "Contrast", "Color", "Brightness", 
                 "Sharpness", "Posterize2", "TranslateXAbs", "TranslateYAbs"]

In [9]:
policy = []
for i in range(10):
    policy.append([(aug, 0.5, 'random') for aug in random.sample(augmentations, 5)])

In [10]:
policy = [[('Contrast', 0.5, 'random'),
  ('Rotate', 0.5, 'random'),
  ('TranslateXAbs', 0.5, 'random'),
  ('Solarize', 0.5, 'random'),
  ('Invert', 0.5, 'random')],
 [('Rotate', 0.5, 'random'),
  ('ShearX', 0.5, 'random'),
  ('TranslateXAbs', 0.5, 'random'),
  ('TranslateYAbs', 0.5, 'random'),
  ('ShearY', 0.5, 'random')],
 [('Solarize', 0.5, 'random'),
  ('Sharpness', 0.5, 'random'),
  ('Rotate', 0.5, 'random'),
  ('TranslateYAbs', 0.5, 'random'),
  ('Posterize2', 0.5, 'random')],
 [('TranslateYAbs', 0.5, 'random'),
  ('TranslateXAbs', 0.5, 'random'),
  ('AutoContrast', 0.5, 'random'),
  ('Posterize2', 0.5, 'random'),
  ('Solarize', 0.5, 'random')],
 [('TranslateYAbs', 0.5, 'random'),
  ('ShearY', 0.5, 'random'),
  ('Invert', 0.5, 'random'),
  ('Contrast', 0.5, 'random'),
  ('TranslateXAbs', 0.5, 'random')],
 [('ShearX', 0.5, 'random'),
  ('Rotate', 0.5, 'random'),
  ('Invert', 0.5, 'random'),
  ('TranslateYAbs', 0.5, 'random'),
  ('ShearY', 0.5, 'random')],
 [('Solarize', 0.5, 'random'),
  ('AutoContrast', 0.5, 'random'),
  ('Color', 0.5, 'random'),
  ('Posterize2', 0.5, 'random'),
  ('Brightness', 0.5, 'random')],
 [('ShearY', 0.5, 'random'),
  ('Contrast', 0.5, 'random'),
  ('ShearX', 0.5, 'random'),
  ('TranslateYAbs', 0.5, 'random'),
  ('TranslateXAbs', 0.5, 'random')],
 [('Solarize', 0.5, 'random'),
  ('Color', 0.5, 'random'),
  ('Brightness', 0.5, 'random'),
  ('TranslateXAbs', 0.5, 'random'),
  ('Equalize', 0.5, 'random')],
 [('Brightness', 0.5, 'random'),
  ('Solarize', 0.5, 'random'),
  ('ShearX', 0.5, 'random'),
  ('Equalize', 0.5, 'random'),
  ('Invert', 0.5, 'random')]]

In [11]:
class Augmentation2(object):
    def __init__(self, policies):
        self.policies = policies

    def __call__(self, img, hardness_score=None):
        for name, pr, level in self.policies:
            if random.random() > pr:
                continue
            if level == 'random':
                level = random.choice(range(10))*0.1
            img = apply_augment(img, name, level)
        return img

In [12]:
_CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
dataroot = "data"

In [13]:
basic_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
basic_dataset = CIFAR10_mod(root=dataroot, train=True, download=True, transform=basic_transform)

Files already downloaded and verified


In [14]:
policy_transforms = []
for i in range(10):
    transform = copy.deepcopy(basic_transform)
    transform.transforms.insert(0, Augmentation2(policy[i]))
    policy_transforms.append(transform)

In [15]:
policy_transforms[0].transforms

[<__main__.Augmentation2 at 0x7f92abced790>,
 RandomCrop(size=(32, 32), padding=4),
 RandomHorizontalFlip(p=0.5),
 ToTensor(),
 Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))]

In [16]:
policy_transforms[1].transforms

[<__main__.Augmentation2 at 0x7f92abced290>,
 RandomCrop(size=(32, 32), padding=4),
 RandomHorizontalFlip(p=0.5),
 ToTensor(),
 Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))]

In [17]:
policy_datasets = []
for i in range(10):
    policy_datasets.append(CIFAR10_mod(root=dataroot, train=True, download=True, transform=policy_transforms[i]))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [18]:
basic_dataloader = torch.utils.data.DataLoader(basic_dataset, batch_size=16, shuffle=False, num_workers=0, drop_last=False)
policy_dataloaders = []
for i in range(10):
    policy_dataloaders.append(torch.utils.data.DataLoader(
        policy_datasets[i], batch_size=16, shuffle=False, 
        num_workers=0, drop_last=False))

In [19]:
all_preds = []
all_embeddings = []
all_labels = []
all_indices = []
with torch.no_grad():
    model.eval()
    loader = tqdm(basic_dataloader, disable=False)
    for i, (data, label, index) in enumerate(loader):
        data, label = data.cuda(), label.cuda()
        preds, embeddings = model(data)

        all_preds.append(preds)
        all_embeddings.append(embeddings)
        all_labels.append(label)
        all_indices.append(index)
        top1, top5 = accuracy(preds, label, (1, 5))
        if i%100==0:
            print("top1", top1, "top5", top5)
    del data, label, index, preds, embeddings

  0%|          | 3/3125 [00:01<33:38,  1.55it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(1., device='cuda:0')


  3%|▎         | 101/3125 [00:10<05:42,  8.84it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


  6%|▋         | 201/3125 [00:20<05:33,  8.77it/s]

top1 tensor(0.8125, device='cuda:0') top5 tensor(1., device='cuda:0')


 10%|▉         | 301/3125 [00:30<05:21,  8.77it/s]

top1 tensor(0.8750, device='cuda:0') top5 tensor(1., device='cuda:0')


 13%|█▎        | 401/3125 [00:40<05:12,  8.70it/s]

top1 tensor(0.5625, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 16%|█▌        | 501/3125 [00:50<05:00,  8.73it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(1., device='cuda:0')


 19%|█▉        | 601/3125 [00:59<04:49,  8.72it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(1., device='cuda:0')


 22%|██▏       | 701/3125 [01:09<04:38,  8.71it/s]

top1 tensor(0.8125, device='cuda:0') top5 tensor(1., device='cuda:0')


 26%|██▌       | 801/3125 [01:19<04:24,  8.80it/s]

top1 tensor(0.8750, device='cuda:0') top5 tensor(1., device='cuda:0')


 29%|██▉       | 901/3125 [01:29<04:13,  8.78it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(1., device='cuda:0')


 32%|███▏      | 1001/3125 [01:39<04:00,  8.82it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(1., device='cuda:0')


 35%|███▌      | 1101/3125 [01:49<03:52,  8.72it/s]

top1 tensor(0.8750, device='cuda:0') top5 tensor(1., device='cuda:0')


 38%|███▊      | 1201/3125 [01:58<03:39,  8.78it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(1., device='cuda:0')


 42%|████▏     | 1301/3125 [02:08<03:27,  8.81it/s]

top1 tensor(0.5625, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 45%|████▍     | 1401/3125 [02:18<03:16,  8.75it/s]

top1 tensor(0.8125, device='cuda:0') top5 tensor(1., device='cuda:0')


 48%|████▊     | 1501/3125 [02:28<03:06,  8.73it/s]

top1 tensor(0.8750, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 51%|█████     | 1601/3125 [02:38<02:54,  8.76it/s]

top1 tensor(0.8125, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 54%|█████▍    | 1701/3125 [02:48<02:40,  8.89it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(1., device='cuda:0')


 58%|█████▊    | 1801/3125 [02:57<02:30,  8.79it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(1., device='cuda:0')


 61%|██████    | 1901/3125 [03:07<02:20,  8.71it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 64%|██████▍   | 2001/3125 [03:17<02:09,  8.70it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 67%|██████▋   | 2101/3125 [03:27<01:55,  8.85it/s]

top1 tensor(0.5625, device='cuda:0') top5 tensor(1., device='cuda:0')


 70%|███████   | 2201/3125 [03:37<01:44,  8.80it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 74%|███████▎  | 2301/3125 [03:47<01:34,  8.76it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(1., device='cuda:0')


 77%|███████▋  | 2401/3125 [03:57<01:22,  8.79it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(1., device='cuda:0')


 80%|████████  | 2501/3125 [04:06<01:11,  8.77it/s]

top1 tensor(0.8125, device='cuda:0') top5 tensor(1., device='cuda:0')


 83%|████████▎ | 2601/3125 [04:16<00:59,  8.76it/s]

top1 tensor(0.8750, device='cuda:0') top5 tensor(1., device='cuda:0')


 86%|████████▋ | 2701/3125 [04:26<00:48,  8.80it/s]

top1 tensor(0.8125, device='cuda:0') top5 tensor(1., device='cuda:0')


 90%|████████▉ | 2801/3125 [04:36<00:36,  8.89it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 93%|█████████▎| 2901/3125 [04:46<00:25,  8.87it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(1., device='cuda:0')


 96%|█████████▌| 3001/3125 [04:55<00:14,  8.84it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 99%|█████████▉| 3101/3125 [05:05<00:02,  8.82it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


100%|██████████| 3125/3125 [05:08<00:00, 10.15it/s]


In [20]:
avh = AVH()

In [21]:
unaugmented_hardness = avh(model=model, embeddings=torch.cat(all_embeddings), 
                           targets=torch.cat(all_labels))

RuntimeError: CUDA out of memory. Tried to allocate 124.00 MiB (GPU 0; 7.44 GiB total capacity; 667.95 MiB already allocated; 49.00 MiB free; 730.00 MiB reserved in total by PyTorch)

In [None]:
unaugmented_hardness = unaugmented_hardness.cpu().numpy()
all_indices = torch.cat(all_indices).cpu().numpy()

In [None]:
combined = list(zip(*sorted(zip(unaugmented_hardness, all_indices))))

In [None]:
sorted_indices = list(combined[1])

In [None]:
easy_indices = sorted_indices[0:50]
hard_indices = sorted_indices[-50:]

In [None]:
def process(model, dataloader):
    all_preds = []
    all_embeddings = []
    all_labels = []
    all_indices = []
    with torch.no_grad():
        model.eval()
        loader = tqdm(dataloader, disable=False)
        for i, (data, label, index) in enumerate(loader):
            data, label = data.cuda(), label.cuda()
            preds, embeddings = model(data)

            all_preds.append(preds)
            all_embeddings.append(embeddings)
            all_labels.append(label)
            all_indices.append(index)
            top1, top5 = accuracy(preds, label, (1, 5))
            if i%1000==0:
                print("top1", top1, "top5", top5)
        del data, label, index, preds, embeddings
    
    avh = AVH()
    hardness_scores = avh(model=model, embeddings=torch.cat(all_embeddings), 
                           targets=torch.cat(all_labels))
    return hardness_scores

In [None]:
policy_hardness_scores = []
for i in range(10):
    policy_hardness_scores.append(process(model, policy_dataloaders[i]))

In [None]:
len(policy_hardness_scores)

In [None]:
easy_hardness_scores = []
hard_hardness_scores = []
unaugmented_easy_hardness_scores = unaugmented_hardness[easy_indices]
unaugmented_hard_hardness_scores = unaugmented_hardness[hard_indices]
for hardness_scores in policy_hardness_scores:
    easy_hardness_scores.append(hardness_scores[easy_indices].cpu().numpy())
    hard_hardness_scores.append(hardness_scores[hard_indices].cpu().numpy())

In [None]:
easy_hardness_scores = np.array(easy_hardness_scores)
hard_hardness_scores = np.array(hard_hardness_scores)

In [None]:
easy_hardness_scores.mean(axis=1)

In [None]:
easy_hardness_scores.std(axis=1)

In [None]:
hard_hardness_scores.mean(axis=1)

In [None]:
hard_hardness_scores.std(axis=1)

In [None]:
policy_mean_easy = easy_hardness_scores.mean(axis=0)

In [None]:
policy_std_easy = easy_hardness_scores.std(axis=0)

In [None]:
policy_mean_hard = hard_hardness_scores.mean(axis=0)

In [None]:
policy_std_hard = hard_hardness_scores.std(axis=0)

In [None]:
plt.figure(figsize=(15,7))
plt.plot(np.arange(50), unaugmented_easy_hardness_scores, color='green')
plt.plot(np.arange(50), unaugmented_hard_hardness_scores, color='black')

plt.plot(np.arange(50), policy_mean_easy, color='r')
plt.plot(np.arange(50), policy_mean_easy-policy_std_easy, color='b')
plt.plot(np.arange(50), policy_mean_easy+policy_std_easy, color='b')
plt.fill_between(np.arange(50), policy_mean_easy-policy_std_easy, policy_mean_easy+policy_std_easy, color='blue', alpha=0.1)

plt.plot(np.arange(50), policy_mean_hard, color='r')
plt.plot(np.arange(50), policy_mean_hard-policy_std_hard, color='b')
plt.plot(np.arange(50), policy_mean_hard+policy_std_hard, color='b')
plt.fill_between(np.arange(50), policy_mean_hard-policy_std_hard, policy_mean_hard+policy_std_hard, color='blue', alpha=0.1)
plt.ylim(0, 0.2)
plt.xlabel('images')
plt.ylabel('avh_scores')
plt.legend(["unaugmented_easy", "unaugmented_hard", "easy", "", "", "hard", "", ""])
plt.show();

In [None]:
range = (0, unaugmented_hard_hardness_scores.max()+0.05)

In [None]:
plt.hist(unaugmented_easy_hardness_scores, bins=50, align='mid', range=range)

In [None]:
plt.hist(unaugmented_hard_hardness_scores, bins=50, align='mid', range=range)

In [None]:
plt.hist(policy_mean_easy, bins=50, align='mid', range=range)

In [None]:
plt.hist(policy_mean_hard, bins=50, align='mid', range=range)

In [None]:
plt.hist(unaugmented_easy_hardness_scores, bins=50, align='mid')
plt.hist(unaugmented_hard_hardness_scores, bins=50, align='mid')
plt.hist(policy_mean_easy, bins=50, align='mid')
plt.hist(policy_mean_hard, bins=50, align='mid')
plt.hist(policy_std_easy, bins=50, align='mid')
plt.hist(policy_std_hard, bins=50, align='mid')
plt.ylabel('number of images')
plt.xlabel('avh')
plt.legend(["unaugmented_easy", "unaugmented_hard", "augmented_easy", "augmented_hard", "std_easy", "std_hard"])

In [None]:
difference_easy = policy_mean_easy - unaugmented_easy_hardness_scores
difference_hard = policy_mean_hard - unaugmented_hard_hardness_scores

In [None]:
plt.plot(np.arange(50), difference_easy)
plt.plot(np.arange(50), difference_hard)
plt.legend(["difference_easy", "difference_hard"])
plt.xlabel('images')
plt.ylabel('(augmented_avh - unaugmented_avh)')
plt.show();

In [None]:
plt.hist(difference_easy, bins=50)
plt.hist(difference_hard, bins=50)
plt.legend(["difference_easy", "difference_hard"])
plt.ylabel('number of images')
plt.xlabel('(augmented_avh - unaugmented_avh)')
plt.show();

In [None]:
print("Differnce_easy", "mean", difference_easy.mean(), "std", difference_easy.std())

In [None]:
print("Differnce_hard", "mean", difference_hard.mean(), "std", difference_hard.std())

In [140]:
all_preds = []
all_embeddings = []
all_labels = []
all_indices = []
with torch.no_grad():
    model.eval()
    loader = tqdm(basic_dataloader, disable=False)
    for i, (data, label, index) in enumerate(loader):
        data, label = data.cuda(), label.cuda()
        preds, embeddings = model(data)

        all_preds.append(preds)
        all_embeddings.append(embeddings)
        all_labels.append(label)
        all_indices.append(index)
        top1, top5 = accuracy(preds, label, (1, 5))
        if i%100==0:
            print("top1", top1, "top5", top5)
    del data, label, index, preds, embeddings

  0%|          | 3/3125 [00:00<02:09, 24.17it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.6875, device='cuda:0')


  3%|▎         | 104/3125 [00:05<02:36, 19.31it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.5625, device='cuda:0')


  7%|▋         | 204/3125 [00:10<02:31, 19.33it/s]

top1 tensor(0., device='cuda:0') top5 tensor(0.2500, device='cuda:0')


 10%|▉         | 304/3125 [00:15<02:26, 19.32it/s]

top1 tensor(0.1875, device='cuda:0') top5 tensor(0.4375, device='cuda:0')


 13%|█▎        | 404/3125 [00:20<02:20, 19.31it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.5625, device='cuda:0')


 16%|█▌        | 504/3125 [00:25<02:15, 19.31it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.1875, device='cuda:0')


 19%|█▉        | 604/3125 [00:30<02:10, 19.31it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 23%|██▎       | 704/3125 [00:35<02:05, 19.31it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.3125, device='cuda:0')


 26%|██▌       | 804/3125 [00:40<02:00, 19.30it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.4375, device='cuda:0')


 29%|██▉       | 904/3125 [00:46<01:55, 19.30it/s]

top1 tensor(0.1875, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.31it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.5000, device='cuda:0')


 35%|███▌      | 1104/3125 [00:56<01:44, 19.31it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.5000, device='cuda:0')


 39%|███▊      | 1204/3125 [01:01<01:39, 19.31it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.8125, device='cuda:0')


 42%|████▏     | 1304/3125 [01:06<01:34, 19.31it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.5625, device='cuda:0')


 45%|████▍     | 1404/3125 [01:11<01:29, 19.31it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.3750, device='cuda:0')


 48%|████▊     | 1504/3125 [01:16<01:23, 19.30it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 51%|█████▏    | 1604/3125 [01:21<01:18, 19.31it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 55%|█████▍    | 1704/3125 [01:26<01:13, 19.29it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.5625, device='cuda:0')


 58%|█████▊    | 1804/3125 [01:31<01:08, 19.31it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.5625, device='cuda:0')


 61%|██████    | 1904/3125 [01:36<01:03, 19.31it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.5000, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.30it/s]

top1 tensor(0.2500, device='cuda:0') top5 tensor(0.6875, device='cuda:0')


 67%|██████▋   | 2104/3125 [01:47<00:52, 19.30it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.5625, device='cuda:0')


 71%|███████   | 2204/3125 [01:52<00:47, 19.31it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.3750, device='cuda:0')


 74%|███████▎  | 2304/3125 [01:57<00:42, 19.30it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.3750, device='cuda:0')


 77%|███████▋  | 2404/3125 [02:02<00:37, 19.30it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.5000, device='cuda:0')


 80%|████████  | 2504/3125 [02:07<00:32, 19.30it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 83%|████████▎ | 2604/3125 [02:12<00:26, 19.30it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 87%|████████▋ | 2704/3125 [02:17<00:21, 19.30it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 90%|████████▉ | 2804/3125 [02:22<00:16, 19.30it/s]

top1 tensor(0.1875, device='cuda:0') top5 tensor(0.3750, device='cuda:0')


 93%|█████████▎| 2904/3125 [02:27<00:11, 19.31it/s]

top1 tensor(0., device='cuda:0') top5 tensor(0.5000, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.29it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 99%|█████████▉| 3104/3125 [02:38<00:01, 19.29it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.63it/s]


In [143]:
augmented_hardness = avh(model=model, embeddings=torch.cat(all_embeddings), 
                           targets=torch.cat(all_labels))

In [144]:
augmented_hardness = unaugmented_hardness.cpu().numpy()
all_indices = torch.cat(all_indices).cpu().numpy()