In [21]:
import torch
from PIL import Image
from torchvision.models import vgg

import torchvision.models as models
import torchvision.transforms as transforms

# Load the pre-trained VGG-16 model
vgg16 = models.vgg16(pretrained=True)
vgg16.eval()  # Set the model to evaluation mode

# Define the image transformation pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to 224x224
    transforms.ToTensor(),  # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
])

def infer(image_path, K=5):
    image = Image.open(image_path).convert('RGB')  # Ensure the image is in RGB format
    input_tensor = transform(image).unsqueeze(0)  # Add a batch dimension

    # Perform inference
    with torch.no_grad():
        outputs = vgg16(input_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        top_k_probs, top_k_indices = torch.topk(probabilities, K)

    # Load the class labels
    class_labels = vgg.VGG16_Weights.IMAGENET1K_V1.meta["categories"]
    
    # Return just the top K labels
    return [class_labels[idx] for idx in top_k_indices[0]]

original_image = "image.png"
colormaps = ['viridis', 'Spectral', 'plasma', 'gray']

# Baseline inference on the original image
baseline_results = infer(original_image)
print("Baseline results:", baseline_results)

# Compare each depth_est_[colormap] to the baseline
accuracy_comparison = {}
for colormap in colormaps:
    depth_image_path = f"depth_est_{colormap}.png"
    depth_results = infer(depth_image_path)
    
    print(f"{colormap} results:", depth_results)
    
    # Calculate accuracy as the number of correct predictions for the top K labels
    accuracy = sum(1 for label in depth_results if label in baseline_results) / len(depth_results) * 100
    
    accuracy_comparison[colormap] = accuracy

print("Accuracy comparison:", accuracy_comparison)

Baseline results: ['picket fence', 'worm fence', 'patio', 'mobile home', 'pot']
viridis results: ['stage', 'spotlight', 'picket fence', 'scuba diver', 'screen']
Spectral results: ['picket fence', 'shower curtain', 'stage', 'theater curtain', 'digital clock']
plasma results: ['stage', 'spotlight', 'theater curtain', 'screen', 'traffic light']
gray results: ['picket fence', 'window screen', 'bannister', 'window shade', 'prison']
Accuracy comparison: {'viridis': 20.0, 'Spectral': 20.0, 'plasma': 0.0, 'gray': 20.0}
