# MEMO

In [1]:
import torch
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision

In [2]:
from MEMO.MEMO import MEMO

In [3]:
imagenet_a_path = "imagenet-a"
imagenet_b_path = "imagenetv2-matched-frequency-format-val/"

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
import torchvision.transforms as T

augmentations = [
    T.RandomHorizontalFlip(p=1),
    T.RandomVerticalFlip(p=1),
    T.RandomRotation(degrees=30),
    T.RandomRotation(degrees=60),
    T.ColorJitter(brightness=0.2),
    T.ColorJitter(contrast=0.2),
    T.ColorJitter(saturation=0.2),
    T.ColorJitter(hue=0.2),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    T.RandomRotation(degrees=15),
    T.RandomAdjustSharpness(sharpness_factor=2, p=1),
    T.RandomGrayscale(p=1),
    T.RandomInvert(p=1),
    T.RandomAutocontrast(p=1),
    T.GaussianBlur(kernel_size=5),
]

augmix_augmentations = [
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0)
]

## Resnet50

In [6]:
exp_path_a = "/home/sagemaker-user/Domain-Shift-Computer-Vision/experiments/Resnet50_ImagenetA_SGD"

In [7]:
MEMO_resnet50 = MEMO(
    model = models.resnet50,
    optimizer = torch.optim.SGD,  
    exp_path = exp_path_a, 
    device = device
)

In [12]:
lr_setting = [{
    "classifier" : [["fc.weight", "fc.bias"], 0.00025]    
},
0]

In [11]:
MEMO_resnet50.test_MEMO(
     augmentations = augmix_augmentations, 
     num_augmentations = 8,
     seed_augmentations = 32,
     batch_size = 64, 
     img_root = imagenet_a_path,
     MEMO = True,
     lr_setting = lr_setting,
     top_augmentations = 0,
     weights_imagenet = models.ResNet50_Weights.IMAGENET1K_V1,
     stable_entropy = True
)

Batch 83/118, Accuracy: 1.06%

KeyboardInterrupt: 

## test pipeline

In [4]:
from utility.get_data import get_data

In [5]:
transform_loader = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor()
        ])

test_loader = get_data(batch_size=64, img_root=imagenet_a_path, transform = transform_loader, split_data=False)

In [78]:
def get_model(weights_imagenet):
    model = models.resnet50(weights="DEFAULT")
    model.eval()
    model.to(device)
    return model

In [7]:
import json
def get_imagenetA_masking():
    imagenetA_masking_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/MEMO/imagenetA_masking.json"
    with open(imagenetA_masking_path, 'r') as json_file:
        imagenetA_masking = json.load(json_file)
    indices_in_1k = [int(k) for k in imagenetA_masking if imagenetA_masking[k] != -1]
    return indices_in_1k

In [8]:
import tqdm

In [79]:
verbose = True
log_interval = 1
samples = 0.0
cumulative_accuracy = 0.0
imagenetA_masking = get_imagenetA_masking()

all_logits = torch.zeros(1,1000, device=device)
all_targets = torch.zeros(1, device = device)
model = get_model(models.ResNet50_Weights.IMAGENET1K_V2)
for batch_idx, (inputs, targets) in enumerate(test_loader):
    inputs, targets = inputs.to(device), targets.to(device)
    all_targets = torch.cat((all_targets, targets)) 
    with torch.no_grad():
        normalize_input = T.Compose([
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        inputs = normalize_input(inputs)
        logits = model(inputs)
        all_logits = torch.cat((all_logits, logits), dim=0)
        #logits = logits[:,imagenetA_masking]
        #predicted = torch.argmax(logits, dim=1)
        #cumulative_accuracy += (predicted == targets).sum().item()
                        
    #samples += inputs.shape[0]
    if verbose and batch_idx % log_interval == 0:
        current_accuracy = cumulative_accuracy / samples * 100
        print(f"Batch {batch_idx}/{len(test_loader)}", end='\r')

Batch 117/118

In [80]:
actual_logits = all_logits[1:,:]
actual_targets = all_targets[1:]

In [84]:
actual_logits_imagenetA = actual_logits[:,imagenetA_masking]

In [85]:
(actual_targets == actual_logits_imagenetA.argmax(dim=1)).sum() / 7500 * 100

tensor(12.6933, device='cuda:0')

imgenetA (models.ResNet50_Weights.IMAGENET1K_V1) : 2.4 % 
imgenetA (models.ResNet50_Weights.IMAGENET1K_V2 or DEFAULT) : 12.6933 %