In [None]:
# Import 

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import numpy as np
import pickle
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
import cv2 # Import OpenCV for contour detection

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# --- Load the Universal Feature Extractor ---
backbone_arch = models.wide_resnet50_2(weights=None)
feature_extractor = nn.Sequential(*list(backbone_arch.children())[:-2]).to(device)
feature_extractor.load_state_dict(torch.load("universal_feature_extractor.pth"))
feature_extractor.eval()

In [None]:
# --- Load the Product-Specific Coreset ---
product_name = "ENTER_PRODUCT_NAME"
with open(f"/kaggle/working/{product_name}_coreset.pkl", "rb") as f:
    coreset = pickle.load(f)
print(f"âœ… Loaded feature extractor and '{product_name}' coreset.")

# --- Prepare for efficient nearest-neighbor search ---
print("Fitting nearest-neighbor search algorithm on the coreset...")
nn_searcher = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(coreset)
print("âœ… Search algorithm ready.")

# --- Register hooks to capture intermediate features ---
features = {}
def get_features_hook(name):
    def hook(model, input, output):
        features[name] = output
    return hook
feature_extractor[5].register_forward_hook(get_features_hook('layer2'))
feature_extractor[6].register_forward_hook(get_features_hook('layer3'))

In [None]:
# --- 2. Load and Preprocess a New Test Image ---
test_image_path = "PATH_TO_YOUR_TEST_IMAGE.jpg" # Replace with your test image path
image_pil = Image.open(test_image_path).convert("RGB")
image_np_orig = np.array(image_pil)

inference_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_tensor = inference_transform(image_pil).unsqueeze(0).to(device)

print("\nPerforming inference on the test image...")
with torch.no_grad():
    # --- Feature Extraction ---
    _ = feature_extractor(test_tensor)
    layer2_features = features['layer2']
    layer3_features = features['layer3']
    upsampled_layer3 = torch.nn.functional.interpolate(layer3_features, size=layer2_features.shape[2:], mode='bilinear', align_corners=False)
    combined_features = torch.cat((layer2_features, upsampled_layer3), dim=1)
    patch_embeddings = combined_features.permute(0, 2, 3, 1).flatten(0, 2).cpu().numpy()

    # --- Nearest-Neighbor Search ---
    distances, _ = nn_searcher.kneighbors(patch_embeddings)
    patch_scores = distances.flatten()

    # --- Anomaly Map Generation ---
    feature_map_size = combined_features.shape[2:]
    anomaly_map_low_res = patch_scores.reshape(feature_map_size)
    
    # Upsample to original image size
    anomaly_map_high_res = torch.nn.functional.interpolate(
        torch.tensor(anomaly_map_low_res).unsqueeze(0).unsqueeze(0),
        size=image_pil.size[::-1],
        mode='bilinear',
        align_corners=False
    ).squeeze().cpu().numpy()
    
    # --- Image-Level Scoring and Decision ---
    image_level_score = np.max(patch_scores)
    
    IMAGE_THRESHOLD = 3.5 
    decision = "ANOMALOUS ðŸ”´" if image_level_score > IMAGE_THRESHOLD else "NORMAL ðŸŸ¢"
    print(f"Inference complete. Image score: {image_level_score:.4f}. Decision: {decision}")

norm_anomaly_map = (
    255 * (anomaly_map_high_res - np.min(anomaly_map_high_res)) / 
    (np.max(anomaly_map_high_res) - np.min(anomaly_map_high_res))
).astype(np.uint8)

PIXEL_THRESHOLD = 200
_, binary_mask = cv2.threshold(norm_anomaly_map, PIXEL_THRESHOLD, 255, cv2.THRESH_BINARY)

kernel_open = np.ones((3, 3), np.uint8)
kernel_close = np.ones((9, 9), np.uint8)

cleaned_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel_open, iterations=1)
cleaned_mask = cv2.morphologyEx(cleaned_mask, cv2.MORPH_CLOSE, kernel_close, iterations=1)
contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
image_with_contours = image_np_orig.copy()
cv2.drawContours(image_with_contours, contours, -1, (0, 255, 0), 2)

In [None]:
# --- 5. Visualize the Results ---
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

axes[0].imshow(image_pil)
axes[0].set_title('Original Test Image')
axes[0].axis('off')

im = axes[1].imshow(anomaly_map_high_res, cmap='jet')
axes[1].set_title('Anomaly Heatmap')
axes[1].axis('off')
fig.colorbar(im, ax=axes[1], orientation='vertical', fraction=0.046, pad=0.04)

axes[2].imshow(image_pil)
axes[2].imshow(anomaly_map_high_res, cmap='jet', alpha=0.5)
axes[2].set_title(f'Overlay Heatmap - {decision}')
axes[2].axis('off')

axes[3].imshow(image_with_contours)
axes[3].set_title(f'Overlay Contours - {decision}')
axes[3].axis('off')

plt.suptitle(f'Anomaly Detection for Product: {product_name.capitalize()}', fontsize=16)
plt.tight_layout()
plt.show()