In [None]:
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from PIL import Image

def visualize_feature_maps(model, image_path, layer_name="features.0"):
    # Load the image
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    
    image = Image.open(image_path).convert("RGB")
    input_image = transform(image).unsqueeze(0).to(device)

    # Load the model
    model.eval()
    model.to(device)
    
    # Register hook to capture the feature maps
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook

    # Attach the hook to the desired layer
    for name, layer in model.named_modules():
        if name == layer_name:
            layer.register_forward_hook(get_activation(name))
    
    # Forward pass
    _ = model(input_image)

    # Extract feature maps
    feature_maps = activation[layer_name].squeeze(0)
    num_feature_maps = feature_maps.size(0)

    # Plotting the feature maps
    plt.figure(figsize=(15, 15))
    for i in range(min(num_feature_maps, 16)):  # Show 16 feature maps max
        plt.subplot(4, 4, i + 1)
        plt.imshow(feature_maps[i].cpu().numpy(), cmap='viridis')
        plt.axis('off')
        plt.title(f"Feature Map {i+1}")
    plt.tight_layout()
    plt.show()

# Example usage
model = models.mobilenet_v2(pretrained=False)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)  # change 10 to your number of classes
model.load_state_dict(torch.load("best_model.pth", map_location=device))
visualize_feature_maps(model, "/kaggle/input/your_image.jpg", layer_name="features.0")