# DO NOT USE, DEBUG ONLY

In [2]:
import os
import time
import torch
import numpy as np
import onnxruntime
from torch.utils.data import DataLoader
from data.dataset import WeedDataset
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

In [3]:
# Path to optimized model
MODEL_PATH = 'optimized_models/model_quantized.onnx'  # Use quantized ONNX model
DATA_PATH = 'data'  # Path to data directory
BATCH_SIZE = 32
    
if not os.path.exists(MODEL_PATH):
    print(f"Model file {MODEL_PATH} not found. Please run convert_model.py first.")

In [4]:
# Create validation dataset and dataloader
print("Loading validation dataset...")
val_dataset = WeedDataset(DATA_PATH, split='val')
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Validation dataset size: {len(val_dataset)} images")
class_dist = val_dataset.get_class_distribution()
print(f"Class distribution: {class_dist}")

Loading validation dataset...
Validation dataset size: 169 images
Class distribution: {0: 102, 1: 67}


In [None]:
# class names mapping
class_names = {0: 'Broadleaf', 1: 'Grass'}

# init ONNX Runtime session
print(f"Loading ONNX model from {MODEL_PATH}...")
session = onnxruntime.InferenceSession(MODEL_PATH)
input_name = session.get_inputs()[0].name

Loading ONNX model from optimized_models/model_quantized.onnx...


NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for ConvInteger(10) node with name '/backbone/stem/net/net.0/Conv_quant'

In [None]:
# Testing the optimized ONNX model on validation dataset

import os
import time
import torch
import numpy as np
import onnxruntime
from torch.utils.data import DataLoader
from data.dataset import WeedDataset
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Path to optimized model
MODEL_PATH = 'optimized_models/model_quantized.onnx'  # Use quantized ONNX model
DATA_PATH = 'data'  # Path to data directory
BATCH_SIZE = 32
    
if not os.path.exists(MODEL_PATH):
    print(f"Model file {MODEL_PATH} not found. Please run convert_model.py first.")

# Create validation dataset and dataloader
print("Loading validation dataset...")
val_dataset = WeedDataset(DATA_PATH, split='val')
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Validation dataset size: {len(val_dataset)} images")
class_dist = val_dataset.get_class_distribution()
print(f"Class distribution: {class_dist}")

# Class names mapping
class_names = {0: 'Broadleaf', 1: 'Grass'}

# Initialize ONNX Runtime session
print(f"Loading ONNX model from {MODEL_PATH}...")
session = onnxruntime.InferenceSession(MODEL_PATH)
input_name = session.get_inputs()[0].name

# Function to run inference
def onnx_inference(session, input_data):
    ort_inputs = {input_name: input_data}
    ort_outputs = session.run(None, ort_inputs)
    return ort_outputs[0]

# Evaluate the model
def evaluate_model():
    all_predictions = []
    all_labels = []
    total_correct = 0
    total_samples = 0
    inference_times = []
    
    print("Evaluating model on validation set...")
    for images, labels in tqdm(val_loader):
        # Convert to numpy for ONNX Runtime
        input_data = images.numpy()
        
        # Run inference and measure time
        start_time = time.time()
        outputs = onnx_inference(session, input_data)
        inference_time = (time.time() - start_time) * 1000  # ms
        inference_times.append(inference_time)
        
        # Get predictions
        predictions = np.argmax(outputs, axis=1)
        
        # Calculate accuracy
        correct = (predictions == labels.numpy()).sum()
        total_correct += correct
        total_samples += labels.size(0)
        
        # Store predictions and labels for confusion matrix
        all_predictions.extend(predictions)
        all_labels.extend(labels.numpy())
    
    # Calculate accuracy
    accuracy = total_correct / total_samples
    avg_inference_time = np.mean(inference_times)
    
    return accuracy, avg_inference_time, all_predictions, all_labels

# Run evaluation
accuracy, avg_inference_time, predictions, labels = evaluate_model()

# Print results
print("\n===== Model Evaluation Results =====")
print(f"Validation Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Average Inference Time: {avg_inference_time:.2f} ms per batch")
print(f"Average Inference Time: {avg_inference_time/BATCH_SIZE:.2f} ms per image")
print(f"Throughput: {1000/(avg_inference_time/BATCH_SIZE):.1f} images/second")

# Plot confusion matrix
cm = confusion_matrix(labels, predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=[class_names[i] for i in range(len(class_names))],
            yticklabels=[class_names[i] for i in range(len(class_names))])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.show()

# Print classification report
print("\nClassification Report:")
report = classification_report(
    labels, predictions, 
    target_names=[class_names[i] for i in range(len(class_names))],
    digits=3
)
print(report)

# Compare with original PyTorch model accuracy (if available)
try:
    from models.tinyresvit import TinyResViT
    from utils import accuracy as acc_metric
    
    print("\nComparing with original PyTorch model...")
    original_model = TinyResViT(num_classes=2)
    checkpoint = torch.load('full_training/run_20250503_045116/best_model.pth', map_location='cpu')
    
    if 'model_state_dict' in checkpoint:
        original_model.load_state_dict(checkpoint['model_state_dict'])
    else:
        original_model.load_state_dict(checkpoint)
        
    original_model.eval()
    
    # Evaluate original PyTorch model
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            outputs = original_model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    pytorch_acc = correct / total
    print(f"Original PyTorch Model Accuracy: {pytorch_acc:.4f} ({pytorch_acc*100:.2f}%)")
    print(f"Accuracy difference: {(pytorch_acc - accuracy)*100:.2f}%")
    
except Exception as e:
    print(f"Could not load original PyTorch model for comparison: {e}")

# Generate per-class metrics
print("\nPer-class Performance:")
for class_idx, class_name in class_names.items():
    class_mask = np.array(labels) == class_idx
    class_correct = np.sum((np.array(predictions) == class_idx) & class_mask)
    class_total = np.sum(class_mask)
    class_acc = class_correct / class_total if class_total > 0 else 0
    print(f"{class_name}: {class_acc:.4f} ({class_correct}/{class_total})")
    
# Display some example predictions
def show_predictions(num_samples=5):
    # Reset the dataloader
    val_dataset_vis = WeedDataset(DATA_PATH, split='val')
    val_loader_vis = DataLoader(val_dataset_vis, batch_size=1, shuffle=True)
    
    plt.figure(figsize=(15, 3*num_samples))
    
    for i, (image, label) in enumerate(val_loader_vis):
        if i >= num_samples:
            break
            
        # Get prediction
        input_data = image.numpy()
        output = onnx_inference(session, input_data)
        pred = np.argmax(output, axis=1)[0]
        
        # Convert image for display
        img = image.squeeze(0)
        # Denormalize
        mean = torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).reshape(3, 1, 1)
        img = img * std + mean
        img = img.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        
        # Plot
        plt.subplot(num_samples, 1, i+1)
        plt.imshow(img)
        true_label = class_names[label.item()]
        pred_label = class_names[pred]
        color = 'green' if pred == label.item() else 'red'
        plt.title(f"True: {true_label}, Predicted: {pred_label}", color=color)
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Show some example predictions
show_predictions(5)