In [96]:
from adversarial_dataset import AdversarialDataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
from torchvision import models
import numpy as np

In [8]:
def show_image(tensor):
    permuted = torch.permute(tensor.cpu(), (1, 2, 0))
    plt.imshow(permuted)

In [79]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

noise_transform = transforms.Compose([
    transforms.ToTensor()
])

model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
model.eval()
dataset = AdversarialDataset("images.csv",
                             "categories.csv",
                             "images",
                             None,
                             img_transform=img_transform,
                             noise_transform=noise_transform)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [80]:
dataset[0][0].shape

torch.Size([3, 299, 299])

In [98]:
def find_kp_point(image, model):
    n = image.shape[1]
    output = model(image[None])
    _, c_original = torch.max(output, dim=1)

    for k in reversed(range(n)):
        output = model(image[None, :, :k, :])
        _, c = torch.max(output, dim=1)

        if c != c_original:
            probabilities = torch.nn.functional.softmax(output, dim=1)
            return k, probabilities[0, c].item()

In [99]:
find_kp_point(dataset[0][0], model)

(122, 0.6097415685653687)

In [100]:
kp_set = np.zeros((2*len(dataset), 3))

for i in range(2*len(dataset)):
    row = dataset[i]

    img = row[0]
    noise = row[1]

    is_adv = i - len(dataset) >= 0
    if is_adv:
        img = img+noise

    point = find_kp_point(img, model)
    print(f"Point: {point}. Adv: {is_adv}")

    kp_set[i] = [point[0], point[1], is_adv]

Point: (122, 0.6097415685653687). Adv: False
Point: (218, 0.6544612646102905). Adv: False
Point: (170, 0.7488264441490173). Adv: False
Point: (146, 0.6222823858261108). Adv: False
Point: (274, 0.42696118354797363). Adv: False


KeyboardInterrupt: 