In [None]:
from architectures import *
from dataloader import iqa_dataset
from config import *
from torch.utils.data import DataLoader
from torchvision import transforms
import copy

In [None]:
db = "../Databases/tid2013/"

In [None]:
scale = int(input()) #Scale of ground truth quality score used in the dataset

In [None]:
if db.split("/")[-2] == "tid2013":
    scale2 = 10
elif db.split("/")[-2] == "koniq":
    scale2 = 5
elif db.split("/")[-2] == "livew":
    scale2 = 100
else:
    print("SCALE WASNT DEFINED")
if scale != scale2:
        print("Check scale value!")

In [None]:
batch_size = 1

train_data=iqa_dataset(part='train',labels_path=db+'scores.pickle' ,db_path=db+'Images/',
                 ids_path=db+'/IDs.pickle',
                       transform=transforms.Compose([transforms.ToTensor()]))

val_data = iqa_dataset(part='test',labels_path=db+'scores.pickle' ,db_path=db+'Images/',
                 ids_path=db+'IDs.pickle',
                       transform=transforms.Compose([transforms.ToTensor()]))

print(f"Length of Train Data : {len(train_data)}")
print(f"Length of Validation Data : {len(val_data)}")

train_dl = DataLoader(train_data, batch_size, shuffle = True)
val_dl = DataLoader(val_data,1,shuffle=False)


In [None]:
# Initialize the model for this run
model = initialize_model('inception',False,True)

In [None]:
model=FC(model,2,1024,0.25,'inception')

In [None]:
weights_path = '../pretrained/iqaModel_tid_inception.pth'
model.fc.load_state_dict(torch.load(weights_path, map_location = device))
model = model.eval()

In [None]:
model = model.to(device)

In [None]:
from attacks.fgm import fast_gradient_method 
from attacks.pgd import projected_gradient_descent 

from tqdm import tqdm
import numpy as np

from eval_funct import *

In [None]:
SRCC,KRCC,PLCC,RMSE,fr,mos = evaluate(val_dl,scale,model,normalize_imagenet, None ,None)
print("------------------- performance of the NR IQA metric: -------------------")
print("SRCC: ",SRCC)
print("PLCC: ",PLCC)
print("KRCC: ",KRCC)
print("RMSE: ",RMSE)

In [None]:
#Define parametters and attacks
iterations = [10]
epsilons = [0.001]#,0.01,0.1]
attacks = ["pgd"]#["bim","pgd","fgm"]
losses = ["mse(y_tielda,y)"]#,'mse','sqr(max-y)']

In [None]:
#Initialization
perf = dict()
epsilon_dict = dict()
iter_dict = dict()
attack_dict = dict()
loss_dict = dict()
results = dict()

In [None]:
perf['srcc'] = str(SRCC)
perf['plcc'] = str(PLCC)
perf['krcc'] = str(KRCC)
perf['rmse'] = str(RMSE)
results['original_performance'] = copy.deepcopy(perf)

In [None]:
import pyiqa
fr_metric = pyiqa.create_metric('lpips', device=device,as_loss=True)

In [None]:
loss_fct = torch.nn.MSELoss()
y_target = torch.tensor(1).float().to(device)
y_target = torch.unsqueeze(y_target,0)
y_target = torch.unsqueeze(y_target,0)



for loss in losses:
    if loss in ['mse','mse(y_tielda,y)']:
        targeted = False
    else:
        targeted = True
    for attack in attacks:
        fgm_passage = False
        for it in iterations:
            if attack == 'fgm' and fgm_passage:
                continue
            else:
                fgm_passage = True


            epsilon_dict = dict()


            for epsilon in epsilons: 
                #noise_list = []
                
                y_adv_list = []
                y = []
                fr = []


                for i ,[im, label] in enumerate(tqdm(val_dl)):
                    im = im.to(device)

                    if targeted == False:
                      if loss == 'mse':
                        #if mos are available
                        y_target = torch.unsqueeze(label.float().to(device)/scale,0)
                      if loss == 'mse(y_tielda,y)':
                        #estimate mos
                        y_pred = float(model(normalize_imagenet(im)).detach().cpu())*scale
                        s = 0
                        for counter in range(10):        
                            s += np.random.normal(y_pred,3*float(results['original_performance']['rmse']),1)
                        s /= 10
                        s =torch.tensor(s)
                        y_target = torch.unsqueeze(s.float().to(device)/scale,0)
                        
                    y.append(float(label.detach().cpu()))
                    if attack == "fgm":
                        img_adv = fast_gradient_method(model,im,epsilon,np.inf,
                                                       preprocess=normalize_imagenet,y=y_target,
                                                       loss_fn=loss_fct,targeted=targeted)
                        
                        #x_adv_list.append(img_adv)
                    if attack == "pgd":
                        img_adv = projected_gradient_descent(model,im,epsilon,eps_iter=0.001,preprocess=normalize_imagenet,
                                                             nb_iter=it,norm=np.inf,y=y_target,loss_fn=loss_fct,
                                                             targeted=targeted,rand_init=True)
                        #x_adv_list.append(img_adv)
                    if attack == "bim":
                        x0 = torch.clone(im)
                        for j in range(it):
                            x0 = fast_gradient_method(model,x0,epsilon,np.inf,y=y_target,loss_fn=loss_fct,targeted=targeted,preprocess=normalize_imagenet)
                        img_adv = x0
                   

                    y_adv_list.append(float(model(normalize_imagenet(img_adv)).detach().cpu()*scale))
                    fr.append(fr_metric(im,img_adv).cpu().detach())
                SRCC,KRCC,PLCC,RMSE = compute_metrics(y,y_adv_list)


                perf['srcc'] = str(SRCC)
                perf['krcc'] = str(KRCC)
                perf['plcc'] = str(PLCC)
                perf['rmse'] = str(RMSE)
                perf['lpips'] = str(np.mean(fr))
                epsilon_dict[str(epsilon)] = copy.deepcopy(perf)
                print(f'{loss}\n {attack}\n  {it}\n   {epsilon_dict}')
               
            iter_dict[str(it)] = copy.deepcopy(epsilon_dict)
        attack_dict[attack] = copy.deepcopy(iter_dict)
    loss_dict[loss] = copy.deepcopy(attack_dict)
    results['results'] = loss_dict
    



In [None]:
#Save stats and performance measures
import yaml
f = 'tid_resnet.yaml'
with open(f, 'w') as outfile:
    yaml.dump(results, outfile, default_flow_style=False)