In [3]:
import torch
from PIL import Image
from torchvision import transforms
import os
import difflib
import torch.nn as nn

In [4]:
# Define AlexNet model
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Create an instance of the model
model = AlexNet()

In [5]:
# Download and load the pretrained weights
weights_url = "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth"
weights_path = "alexnet.pth"
torch.hub.download_url_to_file(weights_url, weights_path)
model.load_state_dict(torch.load(weights_path))
model.eval()

  0%|          | 0.00/233M [00:00<?, ?B/s]

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, 

In [6]:
# Define the preprocessing function
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [7]:
# Download ImageNet labels
!wget -q https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

# Read the categories
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

In [8]:
# Define the function to classify multiple images
def classify_images(image_folder, image_names):
    results = []
    
    for image_name in image_names:
        image_path = os.path.join(image_folder, image_name)
        try:
            input_image = Image.open(image_path)
            input_tensor = preprocess(input_image)
            input_batch = input_tensor.unsqueeze(0)  # create a mini-batch as expected by the model

            # Move the input and model to GPU for speed if available
            if torch.cuda.is_available():
                input_batch = input_batch.to('cuda')
                model.to('cuda')
            elif torch.backends.mps.is_available():
                input_batch = input_batch.to('mps')
                model.to('mps')

            with torch.no_grad():
                output = model(input_batch)

            # Get probabilities
            probabilities = torch.nn.functional.softmax(output[0], dim=0)

            # Get top 1 category
            top1_prob, top1_catid = torch.topk(probabilities, 1)
            
            # Append the recognized class and its probability to the results list
            recognized_class = categories[top1_catid[0]]
            probability = top1_prob[0].item()
            results.append((image_name, recognized_class, probability))
        except Exception as e:
            print(f"Error processing image {image_name}: {e}")

    return results

In [9]:
# Define the function to check similarity
def is_similar(name1, name2, threshold=0.5):
    similarity_ratio = difflib.SequenceMatcher(None, name1, name2).ratio()
    return similarity_ratio >= threshold

In [10]:
# Example usage
image_folder = "/Users/baonguyen/Documents/GitHub/240616 AlexNet/Test Dataset"

# Get a list of image names
image_names = [img for img in os.listdir(image_folder) if img.lower().endswith(('.jpg', '.jpeg', '.png'))]

# Ensure we have up to 100 images
image_names = image_names[:100]

In [11]:
# Check if there are images to process
if not image_names:
    print("No images found in the specified directory.")
else:
    results = classify_images(image_folder, image_names)
    
    total = 0
    correct = 0

    # Print the results and compare file name with recognized class
    for image_name, recognized_class, probability in results:
        # Remove file extension from the image name for comparison
        base_image_name = os.path.splitext(image_name)[0]
        total += 1
        
        if is_similar(base_image_name, recognized_class, 0.5):
            result = ""
            correct += 1
        else:
            result = "INCORRECT"
        
        print(f"Image: {image_name}, Recognized class: {recognized_class}, Probability: {probability*100:.2f}%, Result: {result}")


[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.


Image: n07615774_ice_lolly.JPEG, Recognized class: conch, Probability: 50.07%, Result: INCORRECT
Image: n04005630_prison.JPEG, Recognized class: prison, Probability: 80.27%, Result: 
Image: n04147183_schooner.JPEG, Recognized class: flagpole, Probability: 47.12%, Result: INCORRECT
Image: n07613480_trifle.JPEG, Recognized class: ice cream, Probability: 74.12%, Result: INCORRECT
Image: n03777568_Model_T.JPEG, Recognized class: Model T, Probability: 99.28%, Result: 
Image: n03100240_convertible.JPEG, Recognized class: convertible, Probability: 78.50%, Result: 
Image: n02112350_keeshond.JPEG, Recognized class: keeshond, Probability: 96.63%, Result: 
Image: n03110669_cornet.JPEG, Recognized class: cornet, Probability: 99.22%, Result: 
Image: n01776313_tick.JPEG, Recognized class: tick, Probability: 90.36%, Result: INCORRECT
Image: n03903868_pedestal.JPEG, Recognized class: pedestal, Probability: 84.70%, Result: 
Image: n02100877_Irish_setter.JPEG, Recognized class: Irish setter, Probability

In [12]:
print(f"Percentage correct: {correct*100/total:.2f}%")

Percentage correct: 72.00%
