In [None]:
import torch
import gc
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models
from torch import nn
from torch import optim
import json
import PIL
from PIL import Image
import io
import cv2
import torchvision.transforms.functional as TF
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage
from matplotlib import pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as data
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import time
import os
import copy
import pickle
import urllib.request
import requests
from matplotlib.pyplot import imshow
import random
from utils import *
from resnet50_ft_dims_2048 import *

In [None]:
model_img_size = 224
model_transform = transforms.Compose([transforms.Resize((model_img_size,model_img_size)),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[91.4953, 103.8827, 131.0912],
                                                              std=[1, 1, 1])
                                     ])

In [None]:
class ResNet50_Classifier(nn.Module):
    def __init__(self):
        super(ResNet50_Classifier, self).__init__()
        self.fc = nn.Linear(2048, 500)

    def forward(self, feats):
        # Get the flattened vector from the backbone of resnet50
        return self.fc(feats)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_class = ResNet50_Classifier()
state_dict = torch.load("vgg2_classifier_500.pt")
model_class.load_state_dict(state_dict)

model = resnet50_ft("vgg_face_testimages/resnet50_ft_dims_2048.pth")
model.to(device)
model.eval()

model_class.to(device)
model_class.eval()

In [None]:
mean_pixel = torch.DoubleTensor([131.0912, 103.8827, 91.4953])
def bounding_crop(img, bounding_box):
    im_shape = np.array(img.size)
    x,y,w,h = bounding_box
    half_extension = 0.15
    area = (max(0,x-half_extension*w), max(0,y-half_extension*h),
            min(im_shape[0], x+w*(1+half_extension*2)), min(im_shape[1], y+h*(1+half_extension*2)))
    img = img.crop(area)
    return img

def load_data(img, shape=None, bounding_box=None):
        # in the format of (width, height, *)
    im_shape = np.array(img.size)

    short_size = 224.0
    crop_size = shape
    img = img.convert('RGB')
    
    ratio = float(short_size) / np.min(im_shape)
    img = img.resize(size=(int(np.ceil(im_shape[0] * ratio)),   # width
                           int(np.ceil(im_shape[1] * ratio))),  # height
                     resample=PIL.Image.BILINEAR)

    x = np.array(img)  # image has been transposed into (height, width)
    newshape = x.shape[:2]
    h_start = (newshape[0] - crop_size[0])//2
    w_start = (newshape[1] - crop_size[1])//2
    x = x[h_start:h_start+crop_size[0], w_start:w_start+crop_size[1]]
    return x-mean
def fetch_images(paths):
    images = []
    for im in paths:
        images.append(Image.open(im))
    return images
def prepare_input(img_list):
    im_array = np.array([load_data(img=x, shape=(224, 224, 3)) for x in img_list])
    im_tensor = torch.Tensor(im_array.transpose(0, 3, 1, 2))
    return im_tensor

def prepare_input2(img_list):
    t_list = []
    for y in img_list:
        #y = model_transform(y)
        y[0] = y[0] - mean_pixel[0]/255
        y[1] = y[1] - mean_pixel[1]/255
        y[2] = y[2] - mean_pixel[2]/255
        t_list.append(y*255)
    return torch.stack(t_list)

def prepare_input_adv(img_list):
    t_list = []
    for y in img_list:
        y[0] = y[0] - mean_pixel[0]/255
        y[1] = y[1] - mean_pixel[1]/255
        y[2] = y[2] - mean_pixel[2]/255
        t_list.append(y*255)
    return torch.stack(t_list)

def prepare_input_inv(img_list):
    t_list = []
    for y in img_list:
        y[0] = y[0]/255 + mean_pixel[0]/255
        y[1] = y[1]/255 + mean_pixel[1]/255
        y[2] = y[2]/255 + mean_pixel[2]/255
        t_list.append(y)
    return torch.stack(t_list)

def multi_acc(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
    
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)

    acc = acc * 100
    
    return acc

def get_prediction(y_pred, classid=None):
    lay2 = torch.nn.Softmax(dim=1)
    y_pred_softmax = lay2(y_pred)
    if classid:
        conf, y_pred_tags = torch.mean(y_pred_softmax[:,classid]), classid
    #else:
        #conf, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
    return y_pred_tags, conf
'''
im_array = np.array([load_data(img=i, shape=(224, 224, 3), bounding_box = bounding_dict[image_path_list[ind].split("test/")[1].split(".")[0]]) for ind, i in enumerate(images)])
im_tensor = torch.Tensor(im_array.transpose(0, 3, 1, 2))
labels = [self.class_dict[int(path.split("/")[-2].split("n00")[-1])] for path in image_path_list]

labels = torch.tensor(labels,dtype=torch.long)
labels = labels.squeeze(0)
'''

In [None]:
import csv

vgg_bounding = "/nobackup/vgg2face/bb_landmark/"
bounding_dict = {}

with open(vgg_bounding+'loose_bb_test.csv') as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    line_count = 0
    for row in csv_reader:
        if line_count == 0:
            print(f'Column names are {", ".join(row)}')
            line_count += 1
        else:
            bounding_dict[row[0]] = [int(row[1]), int(row[2]), int(row[3]), int(row[4])]
            line_count += 1
            
class_dict = pickle.load(open("vgg2_testset_classdict.pk","rb"))

In [None]:
for k,v in class_dict.items():
    if v == 272:
        print(k)

In [None]:
img_file = ["n004064/"+x for x in os.listdir(test_dir+"n004064")[:35]]
src_img = fetch_images([test_dir+x for x in img_file])
src_img = [bounding_crop(x, bounding_dict[y.split(".")[0]]) for x,y in zip(src_img,img_file)]

In [None]:
test_dir = "/nobackup/vgg2face/test/"
img_file = "n004891/"+os.listdir(test_dir+"n004891")[29]
tar_file = "n005148/"+os.listdir(test_dir+"n005148")[8]
src_img = fetch_images([test_dir+img_file])[0]
trg_exm = fetch_images([test_dir+tar_file])[0]
src_img = bounding_crop(src_img, bounding_dict[img_file.split(".")[0]])
display(src_img)
trg_exm = bounding_crop(trg_exm, bounding_dict[tar_file.split(".")[0]])
display(trg_exm)

mean = (131.0912, 103.8827, 91.4953)
model_input = model_transform(src_img[3])
model_input = prepare_input2([model_input])

print(display(transToPIL(model_input[0])))
model_input = model_input.to(device)

y_pred = model(model_input)[1]

y_pred = y_pred.squeeze(-1)
y_pred = y_pred.squeeze(-1)

y_pred = model_class(y_pred)
print(get_prediction(y_pred))
print(class_dict[3009])
print(class_dict[5427])

In [None]:
transToPIL = transforms.ToPILImage()
transToTensor = transforms.ToTensor()
model_transform = transforms.Compose([transforms.Resize(256),
                                      transforms.RandomCrop((224,224)),
                                         transforms.ToTensor()])
reverse_normalize = transforms.Normalize(mean=[-131.0912, -103.8827, -91.4953],
                                                              std=[1,1,1])
forward_normalize = transforms.Normalize(mean=[131.0912, 103.8827, 91.4953],
                                                              std=[1,1,1])
#display(src_img)
#x = np.array(src_img)
#x = x - mean
#x = x.transpose(2,0,1)
#xx = torch.Tensor(x)
#display(transToPIL(xx))

In [None]:
input_size = (224,224)
mask = torch.ones(input_size,dtype=torch.float, device=device)
targidx = 363
classidx = 209
im_height = 224
print(im_height)

In [None]:
sz = 120   #Length of input signal
c = .1     #Ambient light ratio
c_limits = [0.3,0.7]
batch = 16
channels = 1
# change of variable term to optimise on
w = torch.rand([channels,sz,1], requires_grad=True, dtype=torch.float, device=device)
rescale_factor = 1
#Create the mask to only illuminate the object
#mask = torch.tensor(get_object_mask(img_input.cpu()), dtype=torch.float, device=device)
#mask = mask / torch.max(mask)


#Target and original class labels
target = torch.tensor([targidx], dtype=torch.long, device=device)
orig = torch.tensor([classidx], dtype=torch.long, device=device)

#Model parameters
lr = 1e-1
n_epochs = 5
optimizer = optim.SGD([w], lr=lr)
loss_fn = nn.CrossEntropyLoss()

In [None]:
src_img = Image.new('RGB', (800,1280), (255, 255, 255))
src_img = [src_img]*25
targidx = 272

In [None]:
#Track the loss to target and original class
targloss = []
origloss = []
n_epochs = 3000
#obj_dict = {}

#Optimisation loop. initially untargeted
for epoch in tqdm(range(n_epochs)):
    
    #Switch to targeted at halfway point
    half = False#epoch < n_epochs//6
    if epoch == -1:#n_epochs//6:
        tops = out.topk(2).indices[0]
        targidx = tops[0].item() if tops[0].item() != classidx else tops[1].item()
        target = torch.tensor([targidx], dtype=torch.long, device=device)
        print("Switching from untarget to target {}".format(targidx))
    
    if channels==1:
        n_w = w.repeat(3,1,1)
    else:
        n_w = w
    if rescale_factor!=1:
        dim = len(torch.flatten(n_w))//3
        n_dim = rescale_factor*dim
        n_w = n_w.unsqueeze(0)
        t = nn.Upsample(size=(n_dim,1), mode='bilinear')
        n_w = t(n_w)[0]
    
    gy, new_w = fttogy(n_w, batch, mask, c_limits, im_height+4)
    
    model_input = model_transform(src_img[epoch%25])
    input_img2 = model_input.to(device)
    inp = gy*input_img2   
    inp2 = prepare_input_adv(inp)
    
    out = model(inp2)[1]
    out = out.squeeze(-1)
    out = out.squeeze(-1)
    out = model_class(out)
    
    #Calculate Loss depended on if targeted or untargeted
    if not half: targLoss = loss_fn(out, target.repeat(batch))
    origLoss = loss_fn(out, orig.repeat(batch))
    loss = -origLoss if half else targLoss
    if epoch%100 == 0:
        targloss.append(0 if half else targLoss)
        origloss.append(origLoss)
        if not half: print(targLoss, origLoss) 
    loss.backward()   
    
    optimizer.step()
    
    optimizer.zero_grad()
    
    del loss
    if epoch!=n_epochs-1:
        del inp
        del new_w
    #else:
        #saving w to be used for prediction
        #torch.save(n_w,'w_0.5_764.pt')

    #Code to check gpu allocation    
    '''
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                #print(type(obj), obj.size())
                if type(obj) not in obj_dict:
                    obj_dict[type(obj)] = 1
                else:
                    obj_dict[type(obj)] += 1
        except: pass
    print(obj_dict)
    obj_dict.clear()
    '''
    torch.cuda.empty_cache()
    
#View original loss and target loss
plt.plot(targloss, label="target")
plt.plot(origloss, label="original")
plt.legend()
plt.show()

In [None]:
get_results(n_w, classidx, targidx, [src_img[0]])

In [None]:
print(inp[0].size())
print(display(transToPIL((gy.cpu()*model_transform(src_img[0]))[0])))

In [None]:
def applyMask(w, batch_limits, mask, c_limits):
    
    sz = w.shape[1]
    
    #stack the signal to fit the input size
    oot = stack(w,228)             
    batch = batch_limits[1]-batch_limits[0]
    # EOT sampling for ambient light and shift
    c = torch.rand([batch,1,1,1], device=device) * (c_limits[1] - c_limits[0]) + c_limits[0]
    shift = torch.tensor(range(batch_limits[0],batch_limits[1]), dtype=torch.int)
    #shift = torch.from_numpy(np.array(range(0,batch,1)))
    #Shift the signal
    ootn = shift_operation(oot.unsqueeze(0).repeat(batch,1,1,1).view(-1, 228, 1), shift).view(batch,3,228,1)
    #Fit w into the range [0,1]. new_w is the same as ft
    new_w = .5 * (torch.tanh(ootn) + 1)
    
    #Convolution of ft and the shutter
    gy = lay(new_w.transpose(0,3).transpose(0,1)).transpose(0,1).transpose(0,3)
    #Mask the signal to only affect the object
    gy_mask = gy * mask
    #return 0.5*c, new_w, c
    return (c+ (1-c)*gy_mask), new_w, c

In [None]:
result_list = [x for x in os.listdir("face_rec_results/") if x.startswith("n00")]
test_dir = "/nobackup/vgg2face/test/"

In [None]:
def get_results(n_ww, classidx, targidx, src_img):
    n_ww = n_ww.to(device)
    classid_over = []
    targid_over = []
    top_over = []
    batch_size = 1
    for s in range(0,len(src_img),batch_size):
        #print(s)
        src_img_batch = src_img[s:s+batch_size]
        #sr = src_img[s]
        #display(src_img[s])
        model_input = [model_transform(x) for x in src_img_batch]
        model_input = torch.stack(model_input)
        #print(display(transToPIL(model_input)))
        batch_limits = [0,120]
        gy, signal, c = applyMask(n_ww,batch_limits, mask, [0.3,0.7])
        #print(input_img2)

        classid = []
        targid = []
        topid = []
        for i in range(len(gy)):
            #print(i)
            input_img2 = model_input.to(device)
            #print(input_img2.size())
            multi = gy[i]*input_img2
            #print(multi.size())
            if i==0:
                print(display(transToPIL(multi[0].cpu())))
            multi2 = prepare_input_adv(multi)
            #print(multi2.size())
            #print(display(transToPIL(multi2[0].cpu())))
            #print(multi2[0].cpu())
            y_pred = model(multi2)[1]

            y_pred = y_pred.squeeze(-1)
            y_pred = y_pred.squeeze(-1)

            y_pred = model_class(y_pred)

            classid.append(get_prediction(y_pred, classidx)[1].cpu().detach().item())
            targid.append(get_prediction(y_pred, targidx)[1].cpu().detach().item())
            #topid.append(get_prediction(y_pred)[1].cpu().detach().item())
            #print(get_prediction(y_pred, classidx))
            #print(get_prediction(y_pred, targidx))
        #print("source", classidx, np.array(classid).mean())
        #print("target", targidx, np.array(targid).mean())
        classid_over.append(np.array(classid).mean())
        targid_over.append(np.array(targid).mean())
    return np.array(classid_over), np.array(targid_over)
#print("source", classidx, np.array(classid_over).mean())
#print("target", targidx, np.array(targid_over).mean())

In [None]:
results = []
for r in tqdm(range(len(result_list))):
    res = result_list[r]
    if res.split('_')[1]!='212': continue
    n_ww = torch.load("face_rec_results/"+res)

    img_file = [res.split("_")[0]+"/"+x for x in os.listdir(test_dir+res.split("_")[0])]
    src_img = fetch_images([test_dir+x for x in img_file[:1]])
    src_img = [bounding_crop(x, bounding_dict[y.split(".")[0]]) for x,y in zip(src_img,img_file)]
    
    classidx = int(res.split("_")[1])
    targidx = int(res.split("_")[2].split(".")[0])
    
    classid_over, targid_over = get_results(n_ww, classidx, targidx, src_img)
    for co in classid_over:
        results.append((classidx, 'source', co))
    for co in targid_over:
        results.append((classidx, 'target', co))
    #print("source", classidx, np.array(classid_over).mean())
    #print("target", targidx, np.array(targid_over).mean())

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set_context('paper')

# load dataset
df = pd.DataFrame(results, columns = ['Class_id', 'Type', 'Conf'])
df['aggr_conf'] = df.groupby(['Class_id', 'Type'])['Conf'].transform('mean')
classes = df.Class_id.unique()
#df = df[df['Class_id'].isin(classes[160:])]
filtered = df#df[(df['aggr_conf']<0.75)&(df['Type']=='target')|(df['aggr_conf']>0.05)&(df['Type']=='source')]
sns.barplot(x = 'Class_id', y = 'Conf', hue = 'Type', data = filtered,
            palette = 'Blues', edgecolor = 'w')
plt.show()

In [None]:
print(df.groupby(['Type']).max())

In [None]:
def get_results_across_transformation(n_ww, classidx, targidx, src_img, shift_range, light_range):
    n_ww = n_ww.to(device)
    results = []
    batch_size = 1
    for s in range(0,len(src_img),batch_size):
        result = {'shift':[],'light':[]}
        src_img_batch = src_img[s:s+batch_size]
        model_input = [model_transform(x) for x in src_img_batch]
        model_input = torch.stack(model_input)
        batch_limits = shift_range
        gy, signal, c = applyMask(n_ww,batch_limits, mask, [float(light_range[0])/10,float(light_range[1])/10])
        
        for i in range(len(gy)):
            input_img2 = model_input.to(device)
            multi = gy[i]*input_img2
            multi2 = prepare_input_adv(multi)
            y_pred = model(multi2)[1]

            y_pred = y_pred.squeeze(-1)
            y_pred = y_pred.squeeze(-1)

            y_pred = model_class(y_pred)

            class_acc = get_prediction(y_pred, classidx)[1].cpu().detach().item()
            targ_acc = get_prediction(y_pred, targidx)[1].cpu().detach().item()
            
            result['shift'].append((class_acc,targ_acc))
        
        
        for i in range(light_range[0],light_range[1]+1):
            gy, signal, c = applyMask(n_ww,batch_limits, mask, [float(i)/10,float(i)/10])
            class_ac = []
            targ_ac = []
            for j in range(len(gy)):
                input_img2 = model_input.to(device)
                multi = gy[j]*input_img2
                multi2 = prepare_input_adv(multi)
                y_pred = model(multi2)[1]

                y_pred = y_pred.squeeze(-1)
                y_pred = y_pred.squeeze(-1)

                y_pred = model_class(y_pred)

                class_acc = get_prediction(y_pred, classidx)[1].cpu().detach().item()
                targ_acc = get_prediction(y_pred, targidx)[1].cpu().detach().item()

                class_ac.append(class_acc)
                targ_ac.append(targ_acc)
            result['light'].append((np.array(class_ac).mean(), np.array(targ_ac).mean()))
        results.append(result)
    return results

In [None]:
results_shift = []
results_light = []
for r in tqdm(range(len(result_list[30:50]))):
    res = result_list[r]
    #if res.split('_')[1]!='211': continue
    n_ww = torch.load("face_rec_results/"+res)

    img_file = [res.split("_")[0]+"/"+x for x in os.listdir(test_dir+res.split("_")[0])]
    src_img = fetch_images([test_dir+x for x in img_file[30:31]])
    src_img = [bounding_crop(x, bounding_dict[y.split(".")[0]]) for x,y in zip(src_img,img_file)]
    
    classidx = int(res.split("_")[1])
    targidx = int(res.split("_")[2].split(".")[0])
    
    result = get_results_across_transformation(n_ww, classidx, targidx, src_img, [0,120], [3,7])
    for j in range(len(result)):
        for i,res in enumerate(result[j]['shift']):
            results_shift.append((str(classidx)+str(j), 'source', i, res[0]))
            results_shift.append((str(classidx)+str(j), 'target', i, res[1]))
    for j in range(len(result)):
        for i,res in enumerate(result[j]['light']):
            results_light.append((str(classidx)+str(j), 'source', 3+i, res[0]))
            results_light.append((str(classidx)+str(j), 'target', 3+i, res[1]))
    #print("source", classidx, np.array(classid_over).mean())
    #print("target", targidx, np.array(targid_over).mean())

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set_context('paper')

# load dataset
df = pd.DataFrame(results_shift, columns = ['Class_id', 'Type', 'shift','Conf'])
sns.barplot(x = 'Class_id', y = 'Conf', hue = 'Type', data = df,
            palette = 'Blues', edgecolor = 'w')
plt.show()

df = pd.DataFrame(results_light, columns = ['Class_id', 'Type', 'light','Conf'])
sns.barplot(x = 'Class_id', y = 'Conf', hue = 'Type', data = df,
            palette = 'Blues', edgecolor = 'w')
plt.show()

In [None]:
files_done=os.listdir("face_rec_results/")
files_all=os.listdir("/nobackup/vgg2face/test/")
files_left = [x for x in files_all if x not in [x.split("_")[0] for x in files_done]]
random.shuffle(files_left)
print(len(files_left))



In [None]:
with open("script_gpu_2.sh","w") as ff:
    for f in files_left[:30]:
        ff.write("python face_recog_generate.py "+f+" 2\n")
    ff.close()
with open("script_gpu_1.sh","w") as ff:
    for f in files_left[30:60]:
        ff.write("python face_recog_generate.py "+f+" 1\n")
    ff.close()
with open("script_gpu_3.sh","w") as ff:
    for f in files_left[60:90]:
        ff.write("python face_recog_generate.py "+f+" 3\n")
    ff.close()