In [None]:
import numpy as np
import torch
import math
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
# from tensorflow.keras.applications.convnext import preprocess_input
np.random.seed(0)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def pytorch_switch(tensor_image):
    return tensor_image.permute(1, 2, 0)


def to_pytorch(tensor_image):
    return torch.from_numpy(tensor_image).permute(2, 0, 1)


class UnTargeted:
    def __init__(self, model, true, unormalize=False, to_pytorch=False):
        self.model = model
        self.true = true
        self.unormalize = unormalize
        self.to_pytorch = to_pytorch

    # def preprocess(img):
    #     if self.model.name == "convnext_base":
    #         from tensorflow.keras.applications.convnext import preprocess_input

    #     return preprocess_input(img)


    def get_label(self, img):
        if self.unormalize:
            img_ = img * 255.

        else:
            img_ = img

        if self.to_pytorch:
            img_ = to_pytorch(img_)
            img_ = img_[None, :]
            if device == 'cuda':
                img_ = img_.to('cuda')
            preds = self.model.predict(img_).flatten()
            y = int(torch.argmax(preds))
        else:
            temp_img = np.expand_dims(img_, axis=0)
            preds = self.model.predict(preprocess_input(temp_img)).flatten()
            y = np.argmax(preds)

        return y

    def __call__(self, img):

        if self.unormalize:
            img_ = img * 255.

        else:
            img_ = img

        if self.to_pytorch:
            img_ = to_pytorch(img_)
            img_ = img_[None, :]
            if device == 'cuda':
                img_ = img_.to('cuda')
            preds = self.model.predict(img_).flatten()
            y = int(torch.argmax(preds))
            preds = preds.tolist()
        else:
            temp_img = preprocess_input(np.expand_dims(img_, axis=0))
            preds = self.model.predict(temp_img).flatten()
            # print("made preds")
            y = np.argmax(preds)
            # print("find y, argmax of the preds")

        is_adversarial = True if y != self.true else False
        # print("found if it is adversarial")

        f_true = math.log(math.exp(preds[self.true]) + 1e-30)
        preds[self.true] = -math.inf
        # print("find loss amount")

        f_other = math.log(math.exp(max(preds)) + 1e-30)
        # print("found f_other")
        return [is_adversarial, float(f_true - f_other)]


class Targeted:
    def __init__(self, model, true, target, unormalize=False, to_pytorch=False):
        self.model = model
        self.true = true
        self.target = target
        self.unormalize = unormalize
        self.to_pytorch = to_pytorch

    # def preprocess(img):
    #     if self.model.name == "convnext_base":
    #         from tensorflow.keras.applications.convnext import preprocess_input

    #     return preprocess_input(img)

    def get_label(self, img):
        if self.unormalize:
            img_ = img * 255.

        else:
            img_ = img

        if self.to_pytorch:
            img_ = to_pytorch(img_)
            img_ = img_[None, :]
            if device == 'cuda':
                img_ = img_.to('cuda')
            preds = self.model.predict(img_).flatten()
            y = int(torch.argmax(preds))
        else:
            temp_img = np.expand_dims(img_, axis=0)
            preds = self.model.predict(preprocess_input(temp_img)).flatten()
            y = np.argmax(preds)
        return y

    def __call__(self, img):

        if self.unormalize:
            img_ = img * 255.

        else:
            img_ = img

        if self.to_pytorch:
            img_ = to_pytorch(img_)
            img_ = img_[None, :]
            if device == 'cuda':
                img_ = img_.to('cuda')
            preds = self.model.predict(img_).flatten()
            y = int(torch.argmax(preds))
            preds = preds.tolist()
        else:
            temp_img = np.expand_dims(img_, axis=0)
            preds = self.model.predict(preprocess_input(temp_img)).flatten()
            y = np.argmax(preds)

        is_adversarial = True if y == self.target else False
        #print("current label %d target label %d" % (y, self.target))
        f_target = preds[self.target]
        #preds[self.true] = -math.inf

        f_other = math.log(sum(math.exp(pi) for pi in preds))
        return [is_adversarial, f_other - f_target]