In [1]:
import sys
sys.path.append('/home/ubuntu/fast-autoaugment')
import torch
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel
from FastAutoAugment.networks import get_model
from theconf import Config as C
import random
import copy
from torchvision.transforms import transforms
import multiprocessing as mp
from tqdm import tqdm
import matplotlib.pyplot as plt

from FastAutoAugment.datasets import CIFAR10_mod 
from FastAutoAugment.metrics import accuracy
from FastAutoAugment.hardness_measures import AVH
from FastAutoAugment.augmentations import *

In [2]:
_ = C('/home/ubuntu/fast-autoaugment/confs/test.yaml')

In [3]:
model = get_model(C.get()['model'], 10, local_rank=-1)

In [4]:
## If trained model is needed. Specify the save_path of the required run.
save_path = '/efs-cotton/outputs/fast-autoaugment/confs/hardnessaware/wresnet28x10_rcifar-AA-rerun3/test.pth'
data = torch.load(save_path)
key = 'model' if 'model' in data else 'state_dict'

if 'epoch' not in data:
    model.load_state_dict(data)
else:
    if not isinstance(model, (DataParallel, DistributedDataParallel)):
        model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()})
    else:
        model.load_state_dict({k if 'module.' in k else 'module.'+k: v for k, v in data[key].items()})

In [5]:
data['epoch']

195

In [6]:
data['log']

{'train': {'loss': 0.43232580375671387,
  'top1': 0.96325,
  'top5': 0.99775,
  'lr': 0.0006155829702431171},
 'valid': {'loss': 0.0, 'top1': 0.0, 'top5': 0.0},
 'test': {'loss': 0.5973012451648713, 'top1': 0.8451, 'top5': 0.9913}}

In [7]:
policy = {"Rotate30+Solarize200": [("Rotate", 1, 30), ("Solarize", 1, 200)],
          "Rotate30+Solarize100": [("Rotate", 1, 30), ("Solarize", 1, 100)],
          "Rotate30+Solarize50": [("Rotate", 1, 30), ("Solarize", 1, 50)],
          "Rotate30+Contrast0.1": [("Rotate", 1, 30), ("Contrast", 1, 0.1)],
          "Contrast0.1": [("Contrast", 1, 0.1)]}

In [8]:
class Augmentation2(object):
    def __init__(self, policies):
        self.policies = policies

    def __call__(self, img, hardness_score=None):
        for name, pr, level in self.policies:
            if random.random() > pr:
                continue
#             if level == 'random':
#                 level = random.choice(range(10))*0.1
            img = apply_augment_direct(img, name, level)
        return img

In [9]:
_CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
dataroot = "data"

In [10]:
no_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
basic_dataset = CIFAR10_mod(root=dataroot, train=True, download=True, transform=no_transform)

Files already downloaded and verified


In [11]:
baseline_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])

In [12]:
policy_transforms = {}
for key, value in policy.items():
    transform = copy.deepcopy(baseline_transform)
    transform.transforms.insert(0, Augmentation2(policy[key]))
    policy_transforms[key] = transform

In [13]:
# policy_transforms['Rotatem30'].transforms[0].policies

In [14]:
policy_datasets = {}
for key, value in policy_transforms.items():
    policy_datasets[key] = CIFAR10_mod(root=dataroot, train=True, download=True, transform=policy_transforms[key])

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [15]:
basic_dataloader = torch.utils.data.DataLoader(basic_dataset, batch_size=16, shuffle=False, num_workers=0, drop_last=False)
policy_dataloaders = {}
for key, value in policy_datasets.items():
    policy_dataloaders[key] = torch.utils.data.DataLoader(
        policy_datasets[key], batch_size=16, shuffle=False, 
        num_workers=0, drop_last=False)

In [16]:
all_preds = []
all_embeddings = []
all_labels = []
all_indices = []
with torch.no_grad():
    model.eval()
    loader = tqdm(basic_dataloader, disable=False)
    for i, (data, label, index) in enumerate(loader):
        data, label = data.cuda(), label.cuda()
        preds, embeddings = model(data)

        all_preds.append(preds)
        all_embeddings.append(embeddings)
        all_labels.append(label)
        all_indices.append(index)
        top1, top5 = accuracy(preds, label, (1, 5))
        if i%1000==0:
            print("top1", top1, "top5", top5)
    del data, label, index, preds, embeddings

  0%|          | 4/3125 [00:01<55:53,  1.07s/it]  

top1 tensor(1., device='cuda:0') top5 tensor(1., device='cuda:0')


 32%|███▏      | 1004/3125 [00:52<01:49, 19.44it/s]

top1 tensor(0.8125, device='cuda:0') top5 tensor(1., device='cuda:0')


 64%|██████▍   | 2004/3125 [01:43<00:57, 19.44it/s]

top1 tensor(0.9375, device='cuda:0') top5 tensor(1., device='cuda:0')


 96%|█████████▌| 3004/3125 [02:34<00:06, 19.43it/s]

top1 tensor(0.8750, device='cuda:0') top5 tensor(1., device='cuda:0')


100%|██████████| 3125/3125 [02:40<00:00, 19.49it/s]


In [17]:
avh = AVH()

In [18]:
unaugmented_hardness = avh(model=model, embeddings=torch.cat(all_embeddings), 
                           targets=torch.cat(all_labels))

In [19]:
unaugmented_hardness = unaugmented_hardness.cpu().numpy()
all_indices = torch.cat(all_indices).cpu().numpy()

In [20]:
combined = list(zip(*sorted(zip(unaugmented_hardness, all_indices))))

In [21]:
sorted_indices = list(combined[1])

In [22]:
easy_indices = sorted_indices[0:200]
hard_indices = sorted_indices[-200:]

In [23]:
def process(model, dataloader):
    all_preds = []
    all_embeddings = []
    all_labels = []
    all_indices = []
    with torch.no_grad():
        model.eval()
        loader = tqdm(dataloader, disable=False)
        for i, (data, label, index) in enumerate(loader):
            data, label = data.cuda(), label.cuda()
            preds, embeddings = model(data)

            all_preds.append(preds)
            all_embeddings.append(embeddings)
            all_labels.append(label)
            all_indices.append(index)
            top1, top5 = accuracy(preds, label, (1, 5))
            if i%1000==0:
                print("top1", top1, "top5", top5)
        del data, label, index, preds, embeddings
    
    avh = AVH()
    hardness_scores = avh(model=model, embeddings=torch.cat(all_embeddings), 
                           targets=torch.cat(all_labels))
    return hardness_scores

In [24]:
policy_hardness_scores = {}
for key, value in policy_dataloaders.items():
    policy_hardness_scores[key] = process(model, policy_dataloaders[key])

  0%|          | 3/3125 [00:00<02:07, 24.49it/s]

top1 tensor(0.6250, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:50, 19.24it/s]

top1 tensor(0.5625, device='cuda:0') top5 tensor(0.8750, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:57, 19.33it/s]

top1 tensor(0.4375, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:32<00:06, 19.32it/s]

top1 tensor(0.5625, device='cuda:0') top5 tensor(0.8750, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.64it/s]
  0%|          | 4/3125 [00:00<02:48, 18.55it/s]

top1 tensor(0.5625, device='cuda:0') top5 tensor(1., device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.33it/s]

top1 tensor(0.6250, device='cuda:0') top5 tensor(0.8750, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.33it/s]

top1 tensor(0.4375, device='cuda:0') top5 tensor(0.7500, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:32<00:06, 19.32it/s]

top1 tensor(0.5625, device='cuda:0') top5 tensor(0.8750, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.63it/s]
  0%|          | 4/3125 [00:00<05:46,  9.00it/s]

top1 tensor(0.3125, device='cuda:0') top5 tensor(1., device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.31it/s]

top1 tensor(0.5000, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.30it/s]

top1 tensor(0.3750, device='cuda:0') top5 tensor(0.6875, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.31it/s]

top1 tensor(0.4375, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.62it/s]
  0%|          | 4/3125 [00:00<05:47,  8.98it/s]

top1 tensor(0.0625, device='cuda:0') top5 tensor(0.7500, device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:50, 19.26it/s]

top1 tensor(0.3750, device='cuda:0') top5 tensor(0.8125, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.25it/s]

top1 tensor(0.3750, device='cuda:0') top5 tensor(0.7500, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.24it/s]

top1 tensor(0.2500, device='cuda:0') top5 tensor(0.7500, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.62it/s]
  0%|          | 4/3125 [00:00<05:46,  9.02it/s]

top1 tensor(0.9375, device='cuda:0') top5 tensor(1., device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:50, 19.27it/s]

top1 tensor(0.6250, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.28it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.29it/s]

top1 tensor(0.6250, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.61it/s]


In [25]:
len(policy_hardness_scores)

5

In [26]:
easy_hardness_scores = {}
hard_hardness_scores = {}
easy_hardness_scores["unaugmented"] = unaugmented_hardness[easy_indices]
hard_hardness_scores["unaugmented"] = unaugmented_hardness[hard_indices]
for key, value in policy_hardness_scores.items():
    easy_hardness_scores[key] = value[easy_indices].cpu().numpy()
    hard_hardness_scores[key] = value[hard_indices].cpu().numpy()

In [27]:
hardness_save_dir = save_path.replace('test.pth', 'hardness_scores')

In [28]:
import os
if not os.path.exists(hardness_save_dir):
    os.makedirs(hardness_save_dir)

In [29]:
hardness_save_path = os.path.join(hardness_save_dir, "exp3.pt")

In [30]:
torch.save({'policy': policy, 'easy_hardness_scores': easy_hardness_scores, "hard_hardness_scores": hard_hardness_scores}, hardness_save_path)

In [31]:
hardness_save_path

'/efs-cotton/outputs/fast-autoaugment/confs/hardnessaware/wresnet28x10_rcifar-AA-rerun3/hardness_scores/exp3.pt'