Dans ce script, nous allons tester différents modèles de type convolution Nelle pour examiner leur performance. Nous allons ensuite tester s'ils sont robustes lorsqu'on applique une rotation à l'image d'entrée.

In [7]:
import retinoto_py as fovea

args = fovea.Params(batch_size=1)
args

Params(image_size=224, num_epochs=5, n_train_stop=0, seed=1998, batch_size=1, model_name='resnet50', do_scratch=False, verbose=True)

# testing each network on the validation dataset

In [8]:
import torch
VAL_DATA_DIR = args.DATAROOT / 'Imagenet_full' / 'val'

In [9]:
from retinoto_py import get_idx_to_label
idx_to_label = get_idx_to_label(args)
idx_to_label[0]

Loading labels from local cache cached_data/imagenet_class_index.json...


'tench'

In [None]:
from retinoto_py.utils import get_loader
val_loader, class_to_idx, idx_to_class = get_loader(args, VAL_DATA_DIR)

In [None]:
# Load the Pre-trained ResNet Model ---
import torchvision.models as models
# We'll use ResNet50, a powerful and common choice.
# `pretrained=True` downloads the model weights trained on ImageNet.
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

# Move the model to the selected device (GPU or CPU)
model = model.to(args.device)

# Set the model to evaluation mode.
model.eval()

# Make sure to load a model trained on the same number of classes
num_classes = len(val_loader.dataset.classes)
num_ftrs = model.fc.out_features
print(f'Model has {num_ftrs} output features to final FC layer for {num_classes} classes.')


Model has 1000 output features to final FC layer for 1000 classes.


In [6]:
from tqdm.auto import tqdm

# --- 4. Evaluation Loop ---
correct_predictions = 0
total_predictions = 0

print(f"Starting evaluation on {len(val_loader.dataset)} images...")
print("-" * 50)

# We use tqdm to create a nice progress bar
for images, true_labels in tqdm(val_loader, desc="Evaluating"):
    # Move data to the correct device
    images = images.to(args.device)
    true_labels = true_labels.to(args.device)

    # Get predictions (no need for gradients)
    with torch.no_grad():
        outputs = model(images)
        _, predicted_labels = torch.max(outputs, 1)

    # --- Check for Correctness ---
    # The batch size is 1, so we get the first element of the tensors
    predicted_label_idx = predicted_labels[0].item()
    true_label_idx = true_labels[0].item()

    # Get the human-readable names (folder names)
    predicted_class_name = idx_to_class[predicted_label_idx]
    true_class_name = idx_to_class[true_label_idx]

    # Check if the prediction was correct
    is_correct = (predicted_label_idx == true_label_idx)
    if is_correct:
        correct_predictions += 1
    total_predictions += 1

    # Get the image filename for printing
    # The DataLoader doesn't easily give filenames, so we can't print them without more work.
    # For this example, we'll just print the class names.
    status = "✅ Correct" if is_correct else "❌ Incorrect"
    if torch.rand(1) < .0001: print(f"True: {true_class_name:<20} ({idx_to_label[true_label_idx]}) | Predicted: {predicted_class_name:<20} (({idx_to_label[predicted_label_idx]})) | {status}")
        # # Get the human-readable label from our mapping
        # predicted_label = idx2label[predicted_index]

print("-" * 50)
accuracy = 100 * correct_predictions / total_predictions
print(f"Evaluation complete.")
print(f"Accuracy: {correct_predictions}/{total_predictions} ({accuracy:.2f}%)")

Starting evaluation on 50000 images...
--------------------------------------------------


Evaluating:   0%|          | 0/50000 [00:00<?, ?it/s]

True: n02113624            (toy_poodle) | Predicted: n02098413            ((Lhasa)) | ❌ Incorrect
True: n02690373            (airliner) | Predicted: n02690373            ((airliner)) | ✅ Correct
True: n03223299            (doormat) | Predicted: n02110185            ((Siberian_husky)) | ❌ Incorrect
True: n04525305            (vending_machine) | Predicted: n04525305            ((vending_machine)) | ✅ Correct
True: n13052670            (hen-of-the-woods) | Predicted: n07718747            ((artichoke)) | ❌ Incorrect
--------------------------------------------------
Evaluation complete.
Accuracy: 40173/50000 (80.35%)
