In [None]:
import torch
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import datasets
import models
import robustness_evaluate as re
from tqdm.notebook import tqdm
import timm

IMAGENET9_DIR = "path_to_imagenet9_dir"
IMAGENET9_BGCHALLENGE_DIR = os.path.join(IMAGENET9_DIR, 'bg_challenge')
TMETHODS = ['standard', 'standardbackground', 'rrr', 'ada', 'actdiff', 'actdiffbackground', 'gradmask', 'fgsm']
DATASETS = ['OxfordFlower', 'CUB', 'Cars']
device = 'cuda'

In [None]:
import torchvision.transforms.functional as TF
from torchvision import transforms
from PIL import Image

def transforms_test(image, img_size):
    if img_size == 448:
        sizs = [512, 448]
    elif img_size == 224:
        sizs = [256, 224]
    elif img_size == 128:
        sizs = [160, 128]
    elif img_size == 96:
        sizs = [128, 96]
    elif img_size == 32:
        sizs = [48, 32]

    resize = transforms.Resize(size=(sizs[0], sizs[0]))
    image = resize(image)

    #if random.random():
    # Random crop
    ccrop = transforms.CenterCrop(size=(sizs[1], sizs[1]))
    image = ccrop(image)

    # Transform to tensor
    image = TF.to_tensor(image)

    if image.shape[0] == 1:
        image = torch.cat([image, image, image], dim=0)

    image = TF.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    return image

class Imagenet9Challenge():
    
    def __init__(self, split='train', img_size=224):
        self.dir = IMAGENET9_BGCHALLENGE_DIR
        self.original = os.path.join(self.dir, split, 'val')
        self.split = split
        self.img_size = img_size
        self.imgs = []
        self.masks = []
        self.targets = []
        self.class_weights = None
        self.load()
    
    def compute_class_weights(self):
        self.class_weights = np.zeros(len(set(self.targets)))
        one_based = 1 * (np.array(self.targets).min() == 1)
        
        if one_based:
            for i in range(len(self.targets)):
                self.targets[i] = self.targets[i] - one_based

        for c in self.targets:
            self.class_weights[c] += 1

        total = self.class_weights.sum()
        self.class_weights = self.class_weights/total

    def load(self):
        path_images = self.original
        for idx, folder in enumerate(sorted(os.listdir(path_images))):
            class_dir = os.path.join(path_images, folder)
            for img in os.listdir(class_dir):
                path_img_orig = os.path.join(class_dir, img)
                if '.npy' in path_img_orig:
                    img2 = np.load(path_img_orig, allow_pickle=True)
                elif '.jpg' in path_img_orig:
                    img2 = Image.open(path_img_orig)
                    img2 = np.asarray(img2)
                elif '.JPEG' in path_img_orig:
                    img2 = Image.open(path_img_orig)
                    img2 = np.asarray(img2)
                    
                
                self.imgs.append(img2)
                self.targets.append(idx)
        
        self.num_of_categories = len(set(self.targets))
        self.compute_class_weights()
        
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, idx):
        imgs = []
        
        img = Image.fromarray(self.imgs[idx])
        y = self.targets[idx]
        
        
        img = transforms_test(img, self.img_size)

        if img.shape[0] == 1:
            img = torch.cat([img, img, img], 0)
            
        return img, 0, y

In [None]:
def get_best(dataset_name, tmethod, img_size=224):
    results_dir = f'results-{dataset_name}/'
    best_path = None
    best_eval = 0.0
    for file in [f for f in os.listdir(results_dir) if tmethod in f]:
        print(file)
        try:
            log_path = os.path.join(results_dir, file)
            log = torch.load(log_path)
            if str(img_size) in file and log['train_method'] == (tmethod+'-baseline'):
                if log['best_eval_acc'] > best_eval:
                    best_eval = log['best_eval_acc']
                    best_path = log_path
                    
        except Exception as e:
            print('Error:',e)
            continue
    print()
    print('best path'.upper(), best_path)
    print('accuracy:'.upper(), best_eval)

    return best_path, best_eval


def get_model_from_path(path, num_of_categories):
    summary = torch.load(path)
    keys = summary.keys()
    #print(keys)
    print('Train method:', summary['train_method'])
    print('Best eval acc:', summary['best_eval_acc'])
    print('regularizer_rate', summary['regularizer_rate'])
    
    if summary['regularizer_rate'] == 100.0:
        r = 'r3'
    elif summary['regularizer_rate'] == 10.0:
        r = 'r2'
    else:
        r = 'r1'
    device = 'cuda'

    def printnorm(self, input, output):
        self.avgoutput = output
        
    # Set model
    #model = models.get_resnet18(num_classes=num_of_categories, pretrained=False)
    #model.load_state_dict(summary['best_ckp'])
    #model.avgpool.register_forward_hook(printnorm)
    #model.to(device)
    
    model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_of_categories)
    model.load_state_dict(summary['best_ckp'])
    model = model.to(device)
    
    return model, r

In [None]:
import timm
dataset_name = 'Imagenet9'
name2models = {}
num_of_categories = 9

for m in ['standard', 'standardbackground', 'actdiff', 'gradmask', 'actdiffbackground', 'ada', 'rrr']:
    try:
        path, acc = get_best(dataset_name, m)
        name2models[m] = get_model_from_path(path, num_of_categories)[0]
    except Exception as e:
        print(e)

In [None]:
print('challenges:'.upper())
for challenge in os.listdir():
    print('\t', challenge)

In [None]:
from sklearn.metrics import confusion_matrix, accuracy_score
from tqdm.notebook import tqdm

def evaluate_std(model, dataloader, criterion):            
    model.eval()  

    dataset_size = 0
    running_loss = 0.0
    running_corrects = 0
    
    target_labels = []
    cat_preds = []

    bar = tqdm(dataloader)
    vec_logits = []
    
    for inputs, blob, labels in bar:
        dataset_size += inputs.shape[0]
        
        inputs = inputs.to(device)
        labels = labels.to(device)
            
        blob = torch.FloatTensor(blob.float()).to(device)
        
        outputs = model(inputs)
        vec_logits.append(outputs.cpu().detach().numpy())
        
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        loss.backward()

        # statistics
        running_loss += float(loss.detach().cpu().data) * int(inputs.size(0))
        running_corrects += float(torch.sum(preds.cpu().detach().data == labels.cpu().detach().data))
        
        cat_preds.extend(preds.cpu().detach().data.tolist())
        target_labels.extend(labels.cpu().detach().data.tolist())
    
    vec_logits = np.concatenate(vec_logits, 0)
    epoch_loss = running_loss / dataset_size
    epoch_acc = running_corrects / dataset_size

    summary = {
        #'epoch_loss':epoch_loss,
        #'epoch_acc':epoch_acc,
        'target_labels':target_labels,
        'predictions':cat_preds,
        'logits':vec_logits.tolist()
        }
    
    return summary

# Log with confusion matrix and logits

In [None]:
import torch.nn as nn
import utils_train
import pandas as pd

os.makedirs("challenge-logits/", exist_ok=True)

NUM_WORKERS = 1
BATCH_SIZE = 128
img_size = 224
#tmp_logs = []

for challenge in os.listdir(IMAGENET9_BGCHALLENGE_DIR):
    test_dataset = Imagenet9Challenge(challenge, img_size)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True, 
        num_workers=NUM_WORKERS*2
    )

    num_of_categories = test_dataset.num_of_categories
    print(challenge.upper())
    for name in name2models.keys():
        criterion = nn.CrossEntropyLoss()
        model = name2models[name]
        ans = evaluate_std(
                model=model, 
                dataloader=test_dataloader, 
                criterion=criterion
                )
        #tmp_logs.append(ans)
        print('\t', name, accuracy_score(ans['target_labels'], ans['predictions']))        
        tmp_path_log = f'challenge-logits/{challenge}-{name}.pkl'
        pd.DataFrame.from_dict(ans).to_pickle(tmp_path_log)


In [None]:
import torch.nn as nn
import utils_train

NUM_WORKERS = 1
BATCH_SIZE = 128
img_size = 224

summary = {
    'method':[]
}

for name in name2models.keys():
    summary['method'].append(name)
    
for challenge in os.listdir('/home/work/datafolder/imagenet9/bg_challenge'):
    summary[challenge] = []
    test_dataset = Imagenet9Challenge(challenge, img_size)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True, 
        num_workers=NUM_WORKERS*2
    )

    #x, mask, y = train_dataset[0]
    num_of_categories = test_dataset.num_of_categories
    print(challenge.upper())
    for name in name2models.keys():
        criterion = nn.CrossEntropyLoss()
        model = name2models[name]
        ans = utils_train.evaluate_std(
                model=model, 
                dataloader=test_dataloader, 
                criterion=criterion
                )
        print('\t', name, ans['epoch_acc'])
        summary[challenge].append(ans['epoch_acc'])        

In [None]:
df = pd.DataFrame.from_dict(summary)
df['BG-Gap'] = df['mixed_same'] - df['mixed_rand']
print(df)

In [None]:
df2 = pd.read_pickle("imagenet9-challenge-results.pkl")

In [None]:
#pd.concat([df, df2], axis=0).to_pickle("challenge-results.pkl")

In [None]:
resp = pd.concat([df, df2], axis=0)
resp

In [None]:
"""import random

path_challenges = IMAGENET9_BGCHALLENGE_DIR
path_original = os.path.join(IMAGENET9_BGCHALLENGE_DIR, 'original/val')
challenges = os.listdir(path_challenges)
print(challenges)

category = random.choice(os.listdir(path_original))
path_category = os.path.join(path_original, category)
img_name = random.choice(os.listdir(path_category))

print(category, img_name)

for cha in challenges:
    path_tmp = f'IMAGENET9_BGCHALLENGE_DIR/{cha}/val/{category}/{img_name}'
    print(path_tmp, os.path.exists(path_tmp))
    #img = Image.open(path_tmp)
    #plt.imshow(img)"""