In [1]:
from transformers import AutoImageProcessor, ResNetForImageClassification, DetrImageProcessor, DetrForObjectDetection
import torch
import os
from PIL import Image

In [2]:
import re

def extract_names_from_file_name(file_name):
    regex = r'([^_]+)_\d+\w\+([^_]+)_\d+\w'
    matches = re.findall(regex, file_name)

    if matches and len(matches) == 1:
        answer1, answer2 = matches[0]
        return [answer1, answer2]

    return []

def get_file_list(directory):
    answers_list = []
    file_list = []
    for file in os.listdir(directory):
        answers_list.append(extract_names_from_file_name(file))
        # replace is for windows
        file_list.append(os.path.join(directory, file).replace("\\","/"))
    return file_list, answers_list

In [3]:
import random

file_list, answers_list = get_file_list('../res/generated')

# randomize list
zipped_list = list(zip(file_list, answers_list))
random.shuffle(zipped_list)
file_list, answers_list = zip(*zipped_list)

# print(file_list[:10], answers_list[:10])

In [None]:
N = 10000 # number of images to open

images_list = []
for file_name in file_list[:N]:
    # Open the image file
    try:
        image = Image.open(file_name)
        images_list.append(image)
    except IOError:
        print(f"Failed to open {file_name}")

In [None]:
class ResNet:
    def __init__(self):
        self.processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
        self.model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

    def test(self, images_list, answers_list):
        print("=============================================== ResNet50 Tests =================================================================")
        tot_found = 0
        for i, image in enumerate(images_list):
            inputs = self.processor(image, return_tensors="pt")
            
            with torch.no_grad():
                logits = self.model(**inputs).logits
            # model predicts one of the 1000 ImageNet classes
            # predicted_label = logits.argmax(-1).item()
            predicted_label = torch.argsort(logits)
            classifications = []
            found_1 = False
            found_2 = False
            result = f"Answers = {answers_list[i]};\t\t\tFound: "

            # tries the 5 most probable inferences
            for j in range(0,5):
                labels = self.model.config.id2label[predicted_label[0][j].item()].split(", ")
                for label in labels:
                    classifications.append(label)
                if answers_list[i][0] in labels:
                    found_1 = True 
                    result += f"{answers_list[i][0]} as n.{j+1} with highest probability; "
                    tot_found += 1
                if answers_list[i][1] in labels:
                    found_2 = True
                    result += f"{answers_list[i][1]} as n.{j+1} with highest probability; "
                    tot_found += 1
            if not found_1 and not found_2:
                result += f"NONE (classified as {classifications[:3]});"
                
            print(result)           
        print("=================================================================================================================================")
        print(f"Images analysed: {len(images_list)}; subjects recognized: {tot_found}.")
        print("=================================================================================================================================\n\n")

class Detr_ResNet:
    def __init__(self):
        self.processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
        self.model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")


    def test(self, images_list, answers_list):
        one_found = 0
        both_found = 0
        print("============================================== DETR-ResNet50 Tests ==============================================================")
        for i, image in enumerate(images_list):
            inputs = self.processor(image, return_tensors="pt")
            
            outputs = self.model(**inputs)
            
            found_1 = False
            found_2 = False
            result = f"Answers = {answers_list[i]};\t\t\tFound: "

            # convert outputs (bounding boxes and class logits) to COCO API
            # let's only keep detections with score > 0.9
            target_sizes = torch.tensor([image.size[::-1]])
            output = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
            labels = []
            scores = []
            for score, label in zip(output["scores"], output["labels"]):
                output_label = self.model.config.id2label[label.item()]
                labels.append(output_label)
                scores.append(round(score.item(), 3))
                if answers_list[i][0] == output_label:
                    found_1 = True 
                    result += f"{answers_list[i][0]} (detected with confidence {round(score.item(), 3)}); "
                if answers_list[i][1] == output_label:
                    found_2 = True
                    result += f"{answers_list[i][1]} (detected with confidence {round(score.item(), 3)}); "
            if answers_list[i][0] in labels or answers_list[i][1] in labels:
                one_found += 1
            if answers_list[i][0] in labels and answers_list[i][1] in labels:
                both_found += 1
            

            if not found_1 and not found_2:
                result += f"NONE ({labels} detected with confidence {scores});"
            print(result)           
        print("=================================================================================================================================")
        print(f"Images analysed: {len(images_list)}; at least one subject recognized: {one_found}; both subjects recognized: {both_found}.")
        print(f"\t\t     at least one subject recognized: {one_found/len(images_list)*100}%; both subjects recognized: {both_found/len(images_list)*100}%.")
        print("=================================================================================================================================\n\n")

In [None]:
resnet = ResNet()
detr_resnet = Detr_ResNet()

In [None]:
# resnet.test(images_list, answers_list)
detr_resnet.test(images_list, answers_list)