In [81]:
from transformers import AutoImageProcessor, ResNetForImageClassification
import torch
import os
from PIL import Image

In [82]:
import re

def extract_names_from_file_name(file_name):
    regex = r'([^_]+)_\d+\w\+([^_]+)_\d+\w'
    matches = re.findall(regex, file_name)

    if matches and len(matches) == 1:
        answer1, answer2 = matches[0]
        return [answer1, answer2]

    return []

def get_file_list(directory):
    answers_list = []
    file_list = []
    for file in os.listdir(directory):
        answers_list.append(extract_names_from_file_name(file))
        # replace is for windows
        file_list.append(os.path.join(directory, file).replace("\\","/"))
    return file_list, answers_list
        
    

In [83]:
file_list, answers_list = get_file_list('../res/generated')
# print(file_list[:10], answers_list[:10])

['../res/generated/airplane_01b+banana_01b.jpg', '../res/generated/airplane_01b+banana_02s.jpg', '../res/generated/airplane_01b+banana_03s.jpg', '../res/generated/airplane_01b+banana_04s.jpg', '../res/generated/airplane_01b+banana_05s.jpg', '../res/generated/airplane_01b+banana_06s.jpg', '../res/generated/airplane_01b+banana_07s.jpg', '../res/generated/airplane_01b+banana_08s.jpg', '../res/generated/airplane_01b+banana_09s.jpg', '../res/generated/airplane_01b+banana_10s.jpg'] [['airplane', 'banana'], ['airplane', 'banana'], ['airplane', 'banana'], ['airplane', 'banana'], ['airplane', 'banana'], ['airplane', 'banana'], ['airplane', 'banana'], ['airplane', 'banana'], ['airplane', 'banana'], ['airplane', 'banana']]


In [85]:
N = 10 # number of images to open

image_list = []
for file_name in file_list[:N]:
    # Open the image file
    try:
        image = Image.open(file_name)
        image_list.append(image)
    except IOError:
        print(f"Failed to open {file_name}")

In [101]:
class ResNet:
    def __init__(self):
        self.processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
        self.model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

    def test(self, image_list, answers_list):
        print("========================= ResNet50 Test ==============================")
        for i, image in enumerate(image_list):
            inputs = self.processor(image, return_tensors="pt")
            
            with torch.no_grad():
                logits = self.model(**inputs).logits
            # model predicts one of the 1000 ImageNet classes
            # predicted_label = logits.argmax(-1).item()
            predicted_label = torch.argsort(logits)
            classifications = []
            found_1 = False
            found_2 = False
            result = "Found: "

            # tries the 5 most probable inferences
            for j in range(0,5):
                labels = self.model.config.id2label[predicted_label[0][j].item()].split(", ")
                for label in labels:
                    classifications.append(label)
                if answers_list[i][0] in labels:
                    found_1 = True 
                    found += f"{answer_list[i][0]} as n.{j+1} with highest probability; "
                if answers_list[i][1] in labels:
                    found_2 = True
                    found += f"{answer_list[i][1]} as n.{j+1} with highest probability; "
            if not found_1 and not found_2:
                result += " None."
            print(result)           
        print("=====================================================================")

In [102]:
resnet = ResNet()
resnet.test(image_list, answers_list)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


Found:  None.
Found:  None.
Found:  None.
Found:  None.
Found:  None.
Found:  None.
Found:  None.
Found:  None.
Found:  None.
Found:  None.
