## 1. Setup & Load Model

In [None]:
import os
import torch
import cv2
import numpy as np
from pathlib import Path
import json
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from anomalib.models import Patchcore
import warnings
warnings.filterwarnings('ignore')

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

# Paths
CHECKPOINT_DIR = Path("checkpoints")
INFERENCE_INPUT = Path("inference_images")  # Folder for new images
INFERENCE_OUTPUT = Path("inference_results")
INFERENCE_OUTPUT.mkdir(exist_ok=True)
INFERENCE_INPUT.mkdir(exist_ok=True)

# Load threshold from evaluation
EVAL_RESULTS = Path("evaluation_results") / "evaluation_results.json"
if EVAL_RESULTS.exists():
    with open(EVAL_RESULTS) as f:
        eval_data = json.load(f)
    THRESHOLD = eval_data['metrics']['optimal_threshold']
else:
    THRESHOLD = 0.5  # Default if evaluation not run

print(f"Using threshold: {THRESHOLD:.4f}")

# Load model
try:
    model = Patchcore.load_from_checkpoint(CHECKPOINT_DIR / "patchcore_trained.ckpt")
    model = model.to(DEVICE)
    model.eval()
    print("‚úì Model loaded successfully")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Make sure to run 02_anomalib_train_patchcore.ipynb first")

## 2. Define Inference Function

In [None]:
# Prepare transform (same as training/evaluation)
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

def infer_single_image(model, image_path, device, threshold=THRESHOLD):
    """
    Run inference on a single image.
    
    Returns:
        - image: Original PIL image
        - anomaly_score: Anomaly score (0-1)
        - is_anomaly: Boolean (True if score > threshold)
        - confidence: Confidence of prediction
    """
    try:
        # Load image
        img_pil = Image.open(image_path).convert('RGB')
        img_tensor = inference_transform(img_pil).unsqueeze(0).to(device)
        
        # Inference
        with torch.no_grad():
            output = model.predict(img_tensor)
        
        # Extract anomaly score
        if isinstance(output, dict):
            anomaly_score = float(output.get('anomaly_score', output.get('score', 0.0)))
        else:
            anomaly_score = float(output.item()) if isinstance(output, torch.Tensor) else float(output)
        
        # Normalize to [0, 1]
        anomaly_score = np.clip(anomaly_score, 0, 1)
        
        # Determine if anomaly
        is_anomaly = anomaly_score > threshold
        
        # Confidence (distance from threshold)
        confidence = abs(anomaly_score - threshold)
        
        return {
            'image': img_pil,
            'anomaly_score': anomaly_score,
            'is_anomaly': is_anomaly,
            'confidence': confidence,
            'status': 'ANOMALY' if is_anomaly else 'NORMAL'
        }
    
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None

print("‚úì Inference function defined")

## 3. Run Inference on Folder

In [None]:
# Get all images from inference folder
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp']
image_files = []

for ext in image_extensions:
    image_files.extend(INFERENCE_INPUT.glob(ext))

if not image_files:
    print(f"‚ö† No images found in {INFERENCE_INPUT}")
    print(f"Please add images to this folder and rerun this cell.")
else:
    print(f"Found {len(image_files)} images to process")
    
    # Run inference
    results = []
    
    for idx, img_path in enumerate(sorted(image_files)):
        print(f"  [{idx+1}/{len(image_files)}] Processing {img_path.name}...")
        result = infer_single_image(model, img_path, DEVICE, THRESHOLD)
        
        if result:
            result['filename'] = img_path.name
            results.append(result)
    
    print(f"\n‚úì Inference complete: {len(results)} images processed")
    
    # Summary
    normal_count = sum(1 for r in results if not r['is_anomaly'])
    anomaly_count = sum(1 for r in results if r['is_anomaly'])
    
    print(f"\nüìä Results:")
    print(f"  Normal: {normal_count}")
    print(f"  Anomalous: {anomaly_count}")

## 4. Save Detailed Results

In [None]:
# Save results as JSON
results_json = []

for result in results:
    results_json.append({
        'filename': result['filename'],
        'anomaly_score': float(result['anomaly_score']),
        'status': result['status'],
        'confidence': float(result['confidence']),
        'threshold': THRESHOLD
    })

results_json_path = INFERENCE_OUTPUT / "inference_results.json"
with open(results_json_path, 'w') as f:
    json.dump(results_json, f, indent=2)

print(f"‚úì Results saved to: {results_json_path}")

## 5. Visualize Results

In [None]:
# Visualize top 12 results
num_vis = min(12, len(results))
n_cols = 4
n_rows = (num_vis + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 4*n_rows))
fig.suptitle('Anomalib Inference Results', fontsize=16, fontweight='bold')

if num_vis == 1:
    axes = np.array([axes])
else:
    axes = axes.flatten()

for idx, result in enumerate(results[:num_vis]):
    ax = axes[idx]
    
    # Display image
    ax.imshow(result['image'])
    
    # Color by status
    color = 'red' if result['is_anomaly'] else 'green'
    status_text = f"{result['status']}\nScore: {result['anomaly_score']:.4f}\nConf: {result['confidence']:.4f}"
    
    ax.set_title(status_text, color=color, fontweight='bold', fontsize=11)
    ax.set_xlabel(result['filename'], fontsize=9)
    ax.axis('off')

# Hide unused subplots
for idx in range(num_vis, len(axes)):
    axes[idx].axis('off')

plt.tight_layout()

# Save
vis_path = INFERENCE_OUTPUT / "inference_visualizations.png"
plt.savefig(vis_path, dpi=150, bbox_inches='tight')
print(f"‚úì Visualizations saved to: {vis_path}")
plt.show()

## 6. Score Distribution

In [None]:
# Plot score distribution
scores = [r['anomaly_score'] for r in results]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
ax1.hist(scores, bins=20, color='steelblue', edgecolor='black', alpha=0.7)
ax1.axvline(THRESHOLD, color='red', linestyle='--', linewidth=2, label=f'Threshold: {THRESHOLD:.4f}')
ax1.set_xlabel('Anomaly Score')
ax1.set_ylabel('Frequency')
ax1.set_title('Anomaly Score Distribution')
ax1.legend()
ax1.grid(alpha=0.3)

# Bar plot (per image)
ax2.barh(range(len(results)), scores, color=['red' if r['is_anomaly'] else 'green' for r in results])
ax2.axvline(THRESHOLD, color='blue', linestyle='--', linewidth=2, label='Threshold')
ax2.set_xlabel('Anomaly Score')
ax2.set_ylabel('Image Index')
ax2.set_title('Anomaly Scores by Image')
ax2.legend()
ax2.grid(alpha=0.3, axis='x')

plt.tight_layout()

# Save
dist_path = INFERENCE_OUTPUT / "score_distribution.png"
plt.savefig(dist_path, dpi=150)
print(f"‚úì Distribution plot saved to: {dist_path}")
plt.show()

# Statistics
print(f"\nüìà Score Statistics:")
print(f"  Mean: {np.mean(scores):.4f}")
print(f"  Median: {np.median(scores):.4f}")
print(f"  Min: {np.min(scores):.4f}")
print(f"  Max: {np.max(scores):.4f}")
print(f"  Std: {np.std(scores):.4f}")

## 7. Summary Report

In [None]:
print("\n" + "="*60)
print("INFERENCE SUMMARY")
print("="*60)
print(f"\nModel: Patchcore (wide_resnet50_2)")
print(f"Threshold: {THRESHOLD:.4f}")
print(f"\nüìä Results:")
print(f"  Total images: {len(results)}")
print(f"  Normal: {sum(1 for r in results if not r['is_anomaly'])}")
print(f"  Anomalous: {sum(1 for r in results if r['is_anomaly'])}")
print(f"\nüìÅ Output files:")
print(f"  - Results JSON: {results_json_path}")
print(f"  - Visualizations: {vis_path}")
print(f"  - Score distribution: {dist_path}")
print(f"\n‚úì Inference complete!")
print("="*60)