In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
from dataloaders import create_dataloaders
from custom_built_nn import SimplePestCNN
import matplotlib.pyplot as plt
import seaborn as sns

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model = SimplePestCNN(num_classes=12).to(device)
model.load_state_dict(torch.load("/Users/cgp/Desktop/Portfolio/Crop_pest_identifier/pestven/models/SimplePestCNN_attention.pth"))
model.eval()

# Load data
_, _, test_loader, _ = create_dataloaders(
    "/Users/cgp/Desktop/Portfolio/Crop_pest_identifier/pestven/Pest_data",
    batch_size=32, num_workers=0, use_mps=True
)

class_names = ['Weevil', 'ants', 'bees', 'beetle', 'caterpillar', 
               'earthworms', 'earwig', 'grasshopper', 'moth', 
               'slug', 'snail', 'wasp']

# Get predictions
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Per-class accuracy
print("\nPer-Class Accuracy:")
for i in range(12):
    if cm[i].sum() > 0:
        acc = cm[i, i] / cm[i].sum() * 100
        print(f"{class_names[i]:15}: {acc:5.1f}%")

# Confusion matrix display
print("\nConfusion Matrix:")

# Create figure
plt.figure(figsize=(12, 10))

# Create heatmap
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, 
            yticklabels=class_names,
            cbar_kws={'label': 'Count'})

plt.title('Confusion Matrix - SimplePestCNN', fontsize=16)
plt.xlabel('Predicted Class', fontsize=12)
plt.ylabel('Actual Class', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)

# Add accuracy text
overall_acc = np.trace(cm) / np.sum(cm) * 100
plt.text(0.5, -0.15, f'Overall Accuracy: {overall_acc:.1f}%', 
         transform=plt.gca().transAxes, ha='center', fontsize=12)
plt.tight_layout()
plt.show()

RuntimeError: Error(s) in loading state_dict for SimplePestCNN:
	Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.4.weight", "features.4.bias", "features.5.weight", "features.5.bias", "features.5.running_mean", "features.5.running_var", "features.8.weight", "features.8.bias", "features.9.weight", "features.9.bias", "features.9.running_mean", "features.9.running_var", "features.12.weight", "features.12.bias", "features.13.weight", "features.13.bias", "features.13.running_mean", "features.13.running_var", "classifier.1.weight", "classifier.1.bias", "classifier.4.weight", "classifier.4.bias". 
	Unexpected key(s) in state_dict: "conv1.weight", "conv1.bias", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "bn1.num_batches_tracked", "se1.excitation.0.weight", "se1.excitation.2.weight", "conv2.weight", "conv2.bias", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "bn2.num_batches_tracked", "se2.excitation.0.weight", "se2.excitation.2.weight", "conv3.weight", "conv3.bias", "bn3.weight", "bn3.bias", "bn3.running_mean", "bn3.running_var", "bn3.num_batches_tracked", "se3.excitation.0.weight", "se3.excitation.2.weight", "conv4.weight", "conv4.bias", "bn4.weight", "bn4.bias", "bn4.running_mean", "bn4.running_var", "bn4.num_batches_tracked", "se4.excitation.0.weight", "se4.excitation.2.weight", "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias". 