Dans ce script, nous allons tester différents modèles de type convolution Nelle pour examiner leur performance. Nous allons ensuite tester s'ils sont robustes lorsqu'on applique une rotation à l'image d'entrée.

In [1]:
import retinoto_py as fovea

args = fovea.Params()
args

Running on MPS device (Apple Silicon/MacOS) - macos_version = 26.1
Random seed 1998 has been set.
Welcome on macOS-26.1-arm64-arm-64bit-Mach-O	User laurentperrinet Working on host obiwan.local with device mps, pytorch==2.9.1


Params(image_size=224, num_epochs=5, n_train_stop=0, seed=1998, batch_size=250, verbose=True)

In [2]:
args

Params(image_size=224, num_epochs=5, n_train_stop=0, seed=1998, batch_size=250, verbose=True)

# testing each network on the validation dataset

In [3]:
import torch
import torchvision.models as models


VAL_IMAGE_DIR = args.DATAROOT / 'Imagenet_full' / 'train'

# --- 3. Load the Pre-trained ResNet Model ---
# We'll use ResNet50, a powerful and common choice.
# `pretrained=True` downloads the model weights trained on ImageNet.
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

# Move the model to the selected device (GPU or CPU)
model = model.to(args.device)

# Set the model to evaluation mode.
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
extensions = ['*.jpeg', '*.jpg', '*.png', '*.JPEG', '*.JPG', '*.PNG']

image_files = []
for ext in extensions:
    # rglob returns a generator, so we extend the list with its results
    image_files.extend(VAL_IMAGE_DIR.rglob(ext))

print(f'In folder {VAL_IMAGE_DIR}, I found {len(image_files)} images')

In folder /Users/laurentperrinet/data/Imagenet_full/train, I found 1281167 images


In [5]:
import torchvision.transforms as transforms
from PIL import Image
import os
import requests
import json # Don't forget to import json

# --- 4. Download and Load the ImageNet Class Index (with caching) ---
LABELS_URL = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json'
LABELS_FILE = args.data_cache / 'imagenet_class_index.json' # Local cache file name

try:
    # Check if we already have the file
    if not os.path.exists(LABELS_FILE):
        print(f"Downloading labels to {LABELS_FILE}...")
        response = requests.get(LABELS_URL)
        response.raise_for_status()
        with open(LABELS_FILE, 'w') as f:
            json.dump(response.json(), f)
    else:
        print(f"Loading labels from local cache {LABELS_FILE}...")
        
    # In both cases, load from the local file
    with open(LABELS_FILE, 'r') as f:
        class_idx = json.load(f)

    # Create a simple mapping from index to class name for easy lookup
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]

except requests.exceptions.RequestException as e:
    print(f"Error downloading labels: {e}")
    exit()
except (IOError, json.JSONDecodeError) as e:
    print(f"Error handling local label file: {e}")
    exit()
print(f'Got a list with {len(idx2label)} labels in {LABELS_FILE} ')

Loading labels from local cache cached_data/imagenet_class_index.json...
Got a list with 1000 labels in cached_data/imagenet_class_index.json 


In [6]:
# --- 5. Define Image Pre-processing ---
# The images must be pre-processed in the exact same way the model was trained on.
# This includes resizing, cropping, and normalizing.
preprocess = transforms.Compose([
    transforms.Resize(256),                # Resize the shortest side to 256px
    transforms.CenterCrop(224),            # Crop the center 224x224 pixels
    transforms.ToTensor(),                 # Convert the image to a PyTorch Tensor
    transforms.Normalize(                  # Normalize with ImageNet mean and std
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [10]:

print(f"\nStarting inference on images in '{VAL_IMAGE_DIR}'...")
print("-" * 50)


for image_file in image_files:

    try:
        # Load the image
        img = Image.open(image_file).convert('RGB')

        # Pre-process the image and add a batch dimension
        # PyTorch models expect a batch of images, so we unsqueeze a dimension.
        img_t = preprocess(img)
        batch_t = torch.unsqueeze(img_t, 0).to(args.device)

        # Run inference
        # `torch.no_grad()` tells PyTorch not to compute gradients, saving memory and computation
        with torch.no_grad():
            output = model(batch_t)

        # --- Get the Prediction ---
        # The output is a tensor of logits. We find the index of the highest logit.
        # `torch.max` returns the max value and its index. We only need the index.
        _, index = torch.max(output, 1)
        
        # Move the result to CPU and convert it to a Python integer
        predicted_index = index[0].item()
        
        # Get the human-readable label from our mapping
        predicted_label = idx2label[predicted_index]

        # --- Print the result ---
        print(f"Image: {image_file.name:<30} | Predicted: {predicted_label}")

    except Exception as e:
        print(f"Could not process {image_file.name}: {e}")

print("-" * 50)
print("Inference complete.")


Starting inference on images in '/Users/laurentperrinet/data/Imagenet_full/train'...
--------------------------------------------------
Image: n04542943_2211.JPEG            | Predicted: waffle_iron
Image: n04542943_486.JPEG             | Predicted: waffle_iron
Image: n04542943_4795.JPEG            | Predicted: waffle_iron
Image: n04542943_6252.JPEG            | Predicted: waffle_iron
Image: n04542943_1597.JPEG            | Predicted: waffle_iron
Image: n04542943_1082.JPEG            | Predicted: waffle_iron
Image: n04542943_5868.JPEG            | Predicted: waffle_iron
Image: n04542943_12684.JPEG           | Predicted: waffle_iron
Image: n04542943_2704.JPEG            | Predicted: waffle_iron
Image: n04542943_10590.JPEG           | Predicted: waffle_iron
Image: n04542943_8535.JPEG            | Predicted: waffle_iron
Image: n04542943_5307.JPEG            | Predicted: waffle_iron
Image: n04542943_2991.JPEG            | Predicted: waffle_iron
Image: n04542943_7947.JPEG            | Pred

KeyboardInterrupt: 