# Test

## Group 29

1. Aizhigit MUSSALI
2. Ralif Rakhmatullin
3. Zhanaidar MUKANOV

In [1]:
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.models import resnet50
import numpy as np

from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score

In [2]:
class DeepFakeDataset(Dataset):
    def __init__(self, imgs, transforms = None):
        super().__init__()
        self.imgs = imgs
        self.transforms = transforms
        
    def __getitem__(self, idx):
        image_name = self.imgs[idx]
        img = Image.open(image_name)
        img = img.resize((224, 224))
        if image_name.split(".")[1].split("/")[3] == "original": label = 1 # it is original
        else: label = 0 # it is fake
        label = torch.tensor(label, dtype = torch.float32)
        img = self.transforms(img)
        return img, label
            
        
    def __len__(self):
        return len(self.imgs)
    

In [3]:
imgs = glob.glob("./data/test/*")
dataset = DeepFakeDataset(imgs, T.Compose([
            T.ToTensor(),
            T.Normalize((0, 0, 0),(1, 1, 1))
        ]))
dataloader = DataLoader(dataset = dataset, batch_size = 32, shuffle = True)

In [4]:
def test(model, data_path):
    Accuracy, Recall, Precision, AUC = 0, 0, 0, 0
    
    model = resnet50(pretrained = True)
    model.fc = nn.Sequential(nn.Linear(2048, 1, bias = True), nn.Sigmoid())
    model.load_state_dict(torch.load("resnet50_best.pth", map_location=torch.device('cpu')))
    model.eval()
    
    preds_list = []
    labels_list = []

    for imgs, labels in dataloader:
#         images.to(device)
#         labels.to(device)
        labels.reshape((labels.shape[0], 1))
    
        preds = model(imgs)
        preds_list.extend(preds.tolist())
        labels_list.extend(labels.tolist())
    
    labels_list, preds_list = np.array(labels_list), np.array(preds_list)
    preds_list = preds_list > 0.5
    Accuracy = accuracy_score(labels_list, preds_list)
    Recall = recall_score(labels_list, preds_list)
    Precision = precision_score(labels_list, preds_list)
    AUC = roc_auc_score(labels_list, preds_list)
    return Accuracy, Recall, Precision, AUC

In [5]:
result = test("./resnet50_best.pth", "./data/test")
print(f"")

(0.9827127659574468, 0.98, 0.9683794466403162, 0.9820318725099602)
