In [None]:
import os
import torch
import torch.nn as nn
import random
import argparse
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os.path as osp
import glob
import yaml

from PIL import Image
from tqdm import tqdm
from torch.utils import data
from torchvision import transforms
import torchvision

import sys
sys.path.insert(0,'/home/argusm/lang/')
from LDVCE.data.imagenet_classnames import name_map


class CFDataset():
    def __init__(self, path, idx_to_tgt="/home/argusm/lang/LDVCE/data/image_idx_to_tgt.yaml"):
        self.images = []
        self.path = path
        for bucket_folder in sorted(glob.glob(self.path + "/bucket*")):
            for original, counterfactual in zip(sorted(glob.glob(bucket_folder + "/original/*.png")), sorted(glob.glob(bucket_folder + "/counterfactual/*.png"))):
                self.images.append((original, counterfactual, os.path.join(bucket_folder, os.path.basename(original).replace("png", "pth"))))
        imagenet_mean = (0.485, 0.456, 0.406)
        iamgenet_std = (0.229, 0.224, 0.225)
        self.transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize((256, 256)),
                torchvision.transforms.CenterCrop((224, 224)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=imagenet_mean, std=iamgenet_std),
            ]
        )
        with open(idx_to_tgt, 'r') as file:
            idx_to_tgt_cls = yaml.safe_load(file)
            if isinstance(idx_to_tgt_cls, dict):
                idx_to_tgt_cls = [idx_to_tgt_cls[i]
        for i in range(len(idx_to_tgt_cls))]
            self.idx_to_tgt_cls = idx_to_tgt_cls
        idx_to_tgt_cls = []
        for idx in range(50000//1000):
            idx_to_tgt_cls.extend(self.idx_to_tgt_cls[idx::50000//1000])
        self.idx_to_tgt_cls = idx_to_tgt_cls
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        original_path, counterfactual_path, pth_file = self.images[idx]
        original = self.load_img(original_path)
        counterfactual = self.load_img(counterfactual_path)
        # data = torch.load(pth_file, map_location="cpu")
        # counterfactual = data["gen_image"]
        return original, idx%1000, counterfactual, self.idx_to_tgt_cls[idx]
    
    def load_img(self, path):
        img = Image.open(path).convert("RGB")
        return self.transform(img)
    

In [None]:
method, path = "LDCE", "/misc/lmbraid21/faridk/LDCE_w382_cc23"
#method, path = "SVCE", "/misc/lmbraid21/faridk/ImageNetSVCEs_robustOnly"
#method, path = "DVCE", "/misc/lmbraid21/faridk/ImageNetDVCEs_"

dataset = CFDataset(path)
print(len(dataset))
batch_size = 12
num_workers = 8

In [None]:
import torchvision
import timm
device = torch.device('cuda')

models = ["pytorch10_resnet50","pytorch10_resnet18","pytorch10_resnet101",
          "pytorch10_inception_v3", "pytorch10_resnext50_32x4d",
          "timm_resnet50",  "timm_resnet101"] #"timm_resnet18"

def get_model(name="pytorch10_resnet50"):
    if name == "pytorch10_resnet50":
        return torchvision.models.resnet50(pretrained=True)
    if name == "pytorch10_resnet18":
        return torchvision.models.resnet18(pretrained=True)
    if name == "pytorch10_resnet101":
        return torchvision.models.resnet101(pretrained=True)
    if name == "pytorch10_inception_v3":
        return torchvision.models.inception_v3(pretrained=True)
    if name == "pytorch10_resnext50_32x4d":
        return torchvision.models.resnext50_32x4d(pretrained=True)
    if name == "timm_resnet50":
        return timm.create_model('resnet50', pretrained=True)
    if name == "timm_resnet18":
        return timm.create_model('resnet18', pretrained=True)
    if name == "timm_resnet101":
        return timm.create_model('resnet101', pretrained=True)
    raise ValueError

In [None]:
def eval_model(name):
    model = get_model(name)
    model.to(device)
    model.eval()

    loader = data.DataLoader(dataset, batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers, pin_memory=True)
    orig_list = []  # original predictions
    label_list = []  # original gt labels
    counter_list = []  # counterfactual predictions
    target_list = []  # target labels

    with torch.inference_mode():
        for orig, label, counter, target in tqdm(loader):
            orig = orig.to(device, dtype=torch.float)
            counter = counter.to(device, dtype=torch.float)
            label_list.append(label)
            target_list.append(target)
            orig_list.append(model(orig).argmax(1).cpu().numpy())
            counter_list.append(model(counter).argmax(1).cpu().numpy())

    orig_list = np.concatenate(orig_list)
    label_list = np.concatenate(label_list)
    counter_list = np.concatenate(counter_list)
    target_list = np.concatenate(target_list)
    np.savez(f"{method}_{name}.npz", orig_pred=orig_list, orig_label=label_list,
             counter_pred=counter_list,counter_label=target_list)
    del model

In [None]:
# this takes long
for model_name in models:
    print("evaluatin", model_name)
    eval_model(model_name)
    
#eval_model("pytorch10_resnet50")

In [None]:
tmp = np.load(f"{method}_pytorch10_resnet50.npz")
orig_pred = tmp["orig_pred"]
orig_label = tmp["orig_label"]
counter_pred = tmp["counter_pred"]
counter_label = tmp["counter_label"]

np.mean(counter_pred == counter_label)

In [None]:
import os
def eval_classes(class_file):
    tmp = np.load(class_file)
    orig_pred = tmp["orig_pred"]
    orig_label = tmp["orig_label"]
    counter_pred = tmp["counter_pred"]
    counter_label = tmp["counter_label"]
    return orig_pred, orig_label, counter_pred, counter_label

def get_runs():
    for x in sorted(os.listdir()):
        if (".npz" not in x) or (x == "class_labels.npz"):
            continue
        yield x

tmp = {}
for method in ("LDCE","SVCE", "DVCE"):
    tmp[method] = {}
    base_orig_pred, base_orig_label, base_counter_pred, base_counter_label = eval_classes(f"{method}_pytorch10_resnet50.npz")
    for model in models:
        x = f"{method}_{model}.npz"
        orig_pred, orig_label, counter_pred, counter_label = eval_classes(x)
        fr = np.mean(counter_pred == counter_label)
        cr = np.mean(counter_pred != orig_pred)
        cf = np.mean(base_counter_pred == counter_pred)
        #rows.append((method, model, fr, cr))
        tmp[method][model] = dict(fr=fr,cr=cr, cf=cf)

In [None]:
def short(x):
    return x.replace("pytorch10_","").replace("resnet","rn").replace("inception","in").replace("resnext50_32x4d","rNeXt").replace("timm_","t_")
print("    ", "\t".join([short(x) for x in models]))
for method in ("SVCE", "DVCE","LDCE"):
    print(method, "\t".join(str(tmp[method][model]["cf"]) for model in models))

for x in get_runs():
    print(x)
    m = x.split("_")[0]
    
    orig_pred, orig_label, counter_pred, counter_label = eval_classes(x)
    print("Change Rate:", np.mean(counter_pred != orig_pred).round(12))
    print("Flip Rate*: ", np.mean(counter_pred == counter_label).round(12))
    #print("Same Rate: ", np.mean(counter_pred == orig_label).round(3))
    #print("Start Rate: ", np.mean(orig_pred == orig_label).round(3))
    
    #selection = orig_pred == orig_label
    #print("Change Rate$:", np.mean(counter_pred[selection] != orig_pred[selection]).round(3))
    #print("Flip Rate$: ", np.mean(counter_pred[selection] == counter_label[selection]).round(3))
    print()

base_orig_pred, base_orig_label, base_counter_pred, base_counter_label = eval_classes("pytorch10_resnet50.npz")

for x in get_runs():
    print(x)
    orig_pred, orig_label, counter_pred, counter_label = eval_classes(x)
    print("corr pred orig", np.mean(base_orig_pred == orig_pred))
    print("corr pred cf  ", np.mean(base_counter_pred == counter_pred))
    select = base_orig_pred == orig_pred
    print("corr pred cf* ", np.mean(base_counter_pred[select] == counter_pred[select]).round(4))
    print()

### Check if pytorch and timm resnet18 are the same (yes)

In [None]:
orig_pred, _, counter_pred, _ = eval_classes("pytorch10_resnet18.npz")
orig_pred2, _, counter_pred2, _ = eval_classes("timm_resnet18.npz")
print(np.mean(orig_pred==orig_pred2))
print(np.mean(counter_pred==counter_pred2))

pytorch_18 = get_model(name="pytorch10_resnet18")
timm_18 = get_model(name="pytorch10_resnet18")

def compareModelWeights(model_a, model_b):
    module_a = model_a._modules
    module_b = model_b._modules
    if len(list(module_a.keys())) != len(list(module_b.keys())):
        return False
    a_modules_names = list(module_a.keys())
    b_modules_names = list(module_b.keys())
    for i in range(len(a_modules_names)):
        layer_name_a = a_modules_names[i]
        layer_name_b = b_modules_names[i]
        if layer_name_a != layer_name_b:
            return False
        layer_a = module_a[layer_name_a]
        layer_b = module_b[layer_name_b]
        if (
            (type(layer_a) == nn.Module) or (type(layer_b) == nn.Module) or
            (type(layer_a) == nn.Sequential) or (type(layer_b) == nn.Sequential)
            ):
            if not compareModelWeights(layer_a, layer_b):
                return False
        if hasattr(layer_a, 'weight') and hasattr(layer_b, 'weight'):
            if not torch.equal(layer_a.weight.data, layer_b.weight.data):
                return False
    return True

print("same:",compareModelWeights(pytorch_18,timm_18))
del pytorch_18
del timm_18