In [1]:
import sys
sys.path.append('/home/ubuntu/fast-autoaugment')
import torch
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel
from FastAutoAugment.networks import get_model
from theconf import Config as C
import random
import copy
from torchvision.transforms import transforms
import multiprocessing as mp
from tqdm import tqdm
import matplotlib.pyplot as plt

from FastAutoAugment.datasets import CIFAR10_mod 
from FastAutoAugment.metrics import accuracy
from FastAutoAugment.hardness_measures import AVH
from FastAutoAugment.augmentations import *

In [2]:
_ = C('/home/ubuntu/fast-autoaugment/confs/test.yaml')

In [3]:
model = get_model(C.get()['model'], 10, local_rank=-1)

In [52]:
torch.cuda.is_available()

True

In [4]:
## If trained model is needed. Specify the save_path of the required run.
save_path = '/efs-cotton/outputs/fast-autoaugment/confs/hardnessaware/wresnet28x10_rcifar-AA-rerun3/test.pth'
data = torch.load(save_path)
key = 'model' if 'model' in data else 'state_dict'

if 'epoch' not in data:
    model.load_state_dict(data)
else:
    if not isinstance(model, (DataParallel, DistributedDataParallel)):
        model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()})
    else:
        model.load_state_dict({k if 'module.' in k else 'module.'+k: v for k, v in data[key].items()})

In [50]:
data.keys()

dict_keys(['epoch', 'log', 'optimizer', 'model', 'ema'])

In [6]:
data['log']

{'train': {'loss': 0.43232580375671387,
  'top1': 0.96325,
  'top5': 0.99775,
  'lr': 0.0006155829702431171},
 'valid': {'loss': 0.0, 'top1': 0.0, 'top5': 0.0},
 'test': {'loss': 0.5973012451648713, 'top1': 0.8451, 'top5': 0.9913}}

In [7]:
policy = {"Rotate30": [("Rotate", 1, 30)],
          "Rotate60": [("Rotate", 1, 60)],
          "Rotate90": [("Rotate", 1, 90)],
          "Solarize0": [("Solarize", 1, 0)],
          "Posterize0": [("Posterize", 1, 0)],
          "Cutout20": [("CutoutAbs", 1, 20)],
          "Contrast0": [("Contrast", 1, 0)],
          "Brightness0": [("Brightness", 1, 0)],
#           "Flip": [("Flip", 1, 1)],
          }

In [8]:
class Augmentation2(object):
    def __init__(self, policies):
        self.policies = policies

    def __call__(self, img, hardness_score=None):
        for name, pr, level in self.policies:
            if random.random() > pr:
                continue
#             if level == 'random':
#                 level = random.choice(range(10))*0.1
            img = apply_augment_direct(img, name, level)
        return img

In [9]:
_CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
dataroot = "data"

In [10]:
no_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
basic_dataset = CIFAR10_mod(root=dataroot, train=True, download=True, transform=no_transform)

Files already downloaded and verified


In [11]:
baseline_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])

In [14]:
dataset1 = CIFAR10_mod(root=dataroot, train=True, download=True, transform=no_transform)
dataset2 = CIFAR10_mod(root=dataroot, train=True, download=True, transform=baseline_transform)

Files already downloaded and verified
Files already downloaded and verified


In [47]:
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
dataloader1 = torch.utils.data.DataLoader(dataset1, batch_size=16, shuffle=True, 
                                          num_workers=0, drop_last=False)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)

dataloader2 = torch.utils.data.DataLoader(dataset2, batch_size=16, shuffle=True, 
                                          num_workers=0, drop_last=False)

In [49]:
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)


iterator2 = iter(dataloader2)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)


for data1, label1, index1 in iter(dataloader1):
    data2, label2, index2 = next(iterator2)
    print(index1, index2)
    index1.sort()
    

tensor([14933, 16284,  2523, 19485, 42663, 32433,  1879, 10687, 38029,  4091,
        48716, 32603,  5342, 43725,  9284,  1303]) tensor([ 3694, 22570, 44315,  2295, 29628, 24748,  6750, 35446, 21881, 27637,
         1947, 43187, 28544, 24281, 46584, 15590])
tensor([45962,  3815,  7738, 14167, 45391,  2462,  6798, 24229, 16397, 27344,
         6305, 38426, 27173, 42602, 39503, 16572]) tensor([  490, 13200,  7095, 32602, 11899, 49360, 33493,  9698, 11011, 25759,
        21634, 37415, 18783,  5913, 32477, 39078])
tensor([15171, 15127, 32048, 24274, 16079, 13740, 30604, 18928,  8016,  8544,
         5121,  7921, 49947, 41955,  2602,  5838]) tensor([42042, 38298, 48471, 16757, 15367, 25833, 46901, 10492, 36697, 45279,
        24308, 45630, 36064, 25825, 26684, 14179])
tensor([30896,  8174, 14023,  7049, 46123, 36682, 44753, 20646,  6436, 19880,
        27315, 36801,  3583,  5593, 33159,   138]) tensor([ 3278, 47475, 45118, 34408, 47733, 25580,  2442,  3794, 26863, 42550,
        29249, 1244

In [49]:
policy_transforms = {}
for key, value in policy.items():
    transform = copy.deepcopy(baseline_transform)
    transform.transforms.insert(0, Augmentation2(policy[key]))
    policy_transforms[key] = transform

In [50]:
# policy_transforms['Rotatem30'].transforms[0].policies

In [51]:
policy_datasets = {}
for key, value in policy_transforms.items():
    policy_datasets[key] = CIFAR10_mod(root=dataroot, train=True, download=True, transform=policy_transforms[key])

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [52]:
basic_dataloader = torch.utils.data.DataLoader(basic_dataset, batch_size=16, shuffle=False, num_workers=0, drop_last=False)
policy_dataloaders = {}
for key, value in policy_datasets.items():
    policy_dataloaders[key] = torch.utils.data.DataLoader(
        policy_datasets[key], batch_size=16, shuffle=False, 
        num_workers=0, drop_last=False)

In [53]:
all_preds = []
all_embeddings = []
all_labels = []
all_indices = []
with torch.no_grad():
    model.eval()
    loader = tqdm(basic_dataloader, disable=False)
    for i, (data, label, index) in enumerate(loader):
        data, label = data.cuda(), label.cuda()
        preds, embeddings = model(data)

        all_preds.append(preds)
        all_embeddings.append(embeddings)
        all_labels.append(label)
        all_indices.append(index)
        top1, top5 = accuracy(preds, label, (1, 5))
        if i%1000==0:
            print("top1", top1, "top5", top5)
    del data, label, index, preds, embeddings

  0%|          | 3/3125 [00:00<02:56, 17.72it/s]

top1 tensor(1., device='cuda:0') top5 tensor(1., device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.30it/s]

top1 tensor(0.8125, device='cuda:0') top5 tensor(1., device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.31it/s]

top1 tensor(0.9375, device='cuda:0') top5 tensor(1., device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.31it/s]

top1 tensor(0.8750, device='cuda:0') top5 tensor(1., device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.61it/s]


In [54]:
avh = AVH()

In [55]:
unaugmented_hardness = avh(model=model, embeddings=torch.cat(all_embeddings), 
                           targets=torch.cat(all_labels))

In [56]:
unaugmented_hardness = unaugmented_hardness.cpu().numpy()
all_indices = torch.cat(all_indices).cpu().numpy()

In [57]:
combined = list(zip(*sorted(zip(unaugmented_hardness, all_indices))))

In [58]:
sorted_indices = list(combined[1])

In [59]:
easy_indices = sorted_indices[0:200]
hard_indices = sorted_indices[-200:]
mid_indices = sorted_indices[24000:26000]

In [60]:
def process(model, dataloader):
    all_preds = []
    all_embeddings = []
    all_labels = []
    all_indices = []
    with torch.no_grad():
        model.eval()
        loader = tqdm(dataloader, disable=False)
        for i, (data, label, index) in enumerate(loader):
            data, label = data.cuda(), label.cuda()
            preds, embeddings = model(data)

            all_preds.append(preds)
            all_embeddings.append(embeddings)
            all_labels.append(label)
            all_indices.append(index)
            top1, top5 = accuracy(preds, label, (1, 5))
            if i%1000==0:
                print("top1", top1, "top5", top5)
        del data, label, index, preds, embeddings
    
    avh = AVH()
    hardness_scores = avh(model=model, embeddings=torch.cat(all_embeddings), 
                           targets=torch.cat(all_labels))
    return hardness_scores

In [61]:
policy_hardness_scores = {}
for key, value in policy_dataloaders.items():
    policy_hardness_scores[key] = process(model, policy_dataloaders[key])

  0%|          | 3/3125 [00:00<02:07, 24.52it/s]

top1 tensor(0.6250, device='cuda:0') top5 tensor(1., device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:50, 19.23it/s]

top1 tensor(0.6875, device='cuda:0') top5 tensor(0.8750, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.30it/s]

top1 tensor(0.5625, device='cuda:0') top5 tensor(0.8750, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.32it/s]

top1 tensor(0.6250, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.62it/s]
  0%|          | 4/3125 [00:00<05:45,  9.03it/s]

top1 tensor(0.2500, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.30it/s]

top1 tensor(0.2500, device='cuda:0') top5 tensor(1., device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.32it/s]

top1 tensor(0.3750, device='cuda:0') top5 tensor(0.8750, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.31it/s]

top1 tensor(0.1875, device='cuda:0') top5 tensor(0.8750, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.60it/s]
  0%|          | 4/3125 [00:00<05:45,  9.04it/s]

top1 tensor(0.2500, device='cuda:0') top5 tensor(0.8125, device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.34it/s]

top1 tensor(0.6250, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.33it/s]

top1 tensor(0.1875, device='cuda:0') top5 tensor(0.6875, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.33it/s]

top1 tensor(0.4375, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.61it/s]
  0%|          | 4/3125 [00:00<05:47,  8.98it/s]

top1 tensor(0.8750, device='cuda:0') top5 tensor(1., device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.33it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.33it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(1., device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.32it/s]

top1 tensor(0.8125, device='cuda:0') top5 tensor(1., device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.61it/s]
  0%|          | 4/3125 [00:00<05:48,  8.96it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.34it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.3750, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:57, 19.36it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.5625, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:32<00:06, 19.35it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.3750, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.64it/s]
  0%|          | 4/3125 [00:00<05:46,  9.00it/s]

top1 tensor(1., device='cuda:0') top5 tensor(1., device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.34it/s]

top1 tensor(0.7500, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:57, 19.35it/s]

top1 tensor(0.6250, device='cuda:0') top5 tensor(1., device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.33it/s]

top1 tensor(0.8750, device='cuda:0') top5 tensor(0.9375, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.62it/s]
  0%|          | 4/3125 [00:00<05:46,  9.00it/s]

top1 tensor(0., device='cuda:0') top5 tensor(0.3750, device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.30it/s]

top1 tensor(0.3125, device='cuda:0') top5 tensor(0.5000, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:58, 19.30it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.3750, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:33<00:06, 19.28it/s]

top1 tensor(0.1875, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.63it/s]
  0%|          | 4/3125 [00:00<05:44,  9.07it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.6250, device='cuda:0')


 32%|███▏      | 1004/3125 [00:51<01:49, 19.36it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.3750, device='cuda:0')


 64%|██████▍   | 2004/3125 [01:42<00:57, 19.36it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.5625, device='cuda:0')


 96%|█████████▌| 3004/3125 [02:32<00:06, 19.36it/s]

top1 tensor(0.1250, device='cuda:0') top5 tensor(0.3750, device='cuda:0')


100%|██████████| 3125/3125 [02:39<00:00, 19.64it/s]


In [62]:
policy_hardness_scores.keys()

dict_keys(['Rotate30', 'Rotate60', 'Rotate90', 'Solarize0', 'Posterize0', 'Cutout20', 'Contrast0', 'Brightness0'])

In [63]:
len(policy_hardness_scores)

8

In [64]:
hardness_scores = {}
hardness_scores["unaugmented"] = unaugmented_hardness
for key, value in policy_hardness_scores.items():
    hardness_scores[key] = value.cpu().numpy()

In [65]:
hardness_save_dir = save_path.replace('test.pth', 'hardness_scores')

In [66]:
import os
if not os.path.exists(hardness_save_dir):
    os.makedirs(hardness_save_dir)

In [67]:
import os
hardness_save_path = os.path.join(hardness_save_dir, "exp7.pt")

In [68]:
torch.save({'policy': policy, 'hardness_scores': hardness_scores, 'easy_indices': easy_indices, 'mid_indices': mid_indices, "hard_indices": hard_indices}, hardness_save_path)

In [69]:
hardness_save_path

'/efs-cotton/outputs/fast-autoaugment/confs/hardnessaware/wresnet28x10_rcifar-AA-rerun3/hardness_scores/exp7.pt'

In [70]:
policy_hardness_scores.keys()

dict_keys(['Rotate30', 'Rotate60', 'Rotate90', 'Solarize0', 'Posterize0', 'Cutout20', 'Contrast0', 'Brightness0'])