In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        os.path.join(dirname, filename)

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# **INSTALLATION & IMPORTS**

In [None]:
print("Installing required packages...")
!pip install git+https://github.com/openai/CLIP.git -q
!pip install imageio[ffmpeg] -q

import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image, ImageDraw, ImageFont
import imageio
import warnings
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from torch.utils.data import Dataset, DataLoader
import clip

warnings.filterwarnings("ignore")

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# **CONFIGURATION**

In [None]:
class Config:
    """Centralized configuration for easy tuning and reproducibility."""
    
    # Dataset path (Kaggle-specific)
    DATA_PATH = "/kaggle/input/mvtec-ad"
    
    # CLIP model selection
    # ViT-B/32: faster, lower memory
    # ViT-L/14 or ViT-L/14@336px: higher accuracy (use if GPU memory allows)
    CLIP_MODEL = "ViT-B/32"
    
    # Training/Evaluation settings
    BATCH_SIZE = 32
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Visualization settings
    NUM_GIF_FRAMES = 40
    GIF_OUTPUT_PATH = "/kaggle/working/anomaly_detection_comparison.gif"
    
    # MVTec AD categories
    CATEGORIES = [
        "bottle", "cable", "capsule", "carpet", "grid",
        "hazelnut", "leather", "metal_nut", "pill", "screw",
        "tile", "toothbrush", "transistor", "wood", "zipper"
    ]
    
    # Method names for reporting
    METHODS = ["AnomalyCLIP", "WinCLIP", "PA-CLIP", "AA/AF-CLIP"]
    
    # Shared prompt templates
    NORMAL_TEMPLATES = ["a photo of a {}", "a good photo of a {}", "a cropped photo of a {}"]
    ANOMALY_TEMPLATES = ["a photo of a {}", "a bad photo of a {}", "a defective {}"]

config = Config()

print(f"Configuration loaded.")
print(f"Device: {config.DEVICE}")
print(f"CLIP Model: {config.CLIP_MODEL}")
print(f"Categories: {len(config.CATEGORIES)}")

# **DATASET**

In [None]:
class MVTecADDataset(Dataset):
    """
    PyTorch Dataset for MVTec Anomaly Detection.
    Loads test images (good + all defect types) with binary labels.
    """
    
    def __init__(self, root_path: str, category: str, split: str = "test", transform=None):
        self.root_path = root_path
        self.category = category
        self.split = split
        self.transform = transform
        
        self.image_paths = []
        self.labels = []  # 0: normal (good), 1: anomalous
        
        base_path = os.path.join(root_path, category, split)
        
        if not os.path.exists(base_path):
            raise ValueError(f"Path does not exist: {base_path}")
        
        # Load normal ('good') images
        good_path = os.path.join(base_path, "good")
        if os.path.exists(good_path):
            for img_name in sorted(os.listdir(good_path)):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(good_path, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(0)
        
        # Load all anomalous defect types
        for defect_type in os.listdir(base_path):
            if defect_type != "good" and os.path.isdir(os.path.join(base_path, defect_type)):
                defect_path = os.path.join(base_path, defect_type)
                for img_name in sorted(os.listdir(defect_path)):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        img_path = os.path.join(defect_path, img_name)
                        self.image_paths.append(img_path)
                        self.labels.append(1)
        
        print(f"{category} | {split} | Loaded {len(self.image_paths)} images "
              f"({sum(self.labels)} anomalous, {len(self.labels) - sum(self.labels)} normal)")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label, img_path

# **UTILITIES**

In [None]:
def get_text_features(model, texts: list) -> torch.Tensor:
    """Encode list of text prompts and return normalized features."""
    tokens = clip.tokenize(texts).to(config.DEVICE)
    with torch.no_grad():
        text_features = model.encode_text(tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features

def compute_anomaly_score(image_features: torch.Tensor,
                          normal_features: torch.Tensor,
                          anomaly_features: torch.Tensor) -> np.ndarray:
    """
    Compute anomaly score as average of:
        (1 - similarity to normal) + similarity to anomaly
    Higher score → more anomalous
    """
    sim_normal = (image_features @ normal_features.T).squeeze(1)
    sim_anomaly = (image_features @ anomaly_features.T).squeeze(1)
    scores = (1 - sim_normal + sim_anomaly) / 2.0
    return scores.cpu().numpy()

# **ANOMALY DETECTION METHODS**

In [None]:
class BaseAnomalyDetector:
    """Base class for all CLIP-based zero-shot detectors."""
    
    def __init__(self, clip_model, preprocess):
        self.clip_model = clip_model
        self.preprocess = preprocess
    
    def detect(self, images: torch.Tensor, category: str) -> np.ndarray:
        raise NotImplementedError

class AnomalyCLIPDetector(BaseAnomalyDetector):
    """Simple AnomalyCLIP-style prompting."""
    
    def detect(self, images: torch.Tensor, category: str) -> np.ndarray:
        normal_texts = [t.format(f"good {category}") for t in config.NORMAL_TEMPLATES]
        anomaly_texts = [t.format(f"defective {category}") for t in config.ANOMALY_TEMPLATES]
        
        normal_feat = get_text_features(self.clip_model, normal_texts).mean(0, keepdim=True)
        anomaly_feat = get_text_features(self.clip_model, anomaly_texts).mean(0, keepdim=True)
        
        with torch.no_grad():
            img_feat = self.clip_model.encode_image(images)
            img_feat /= img_feat.norm(dim=-1, keepdim=True)
        
        return compute_anomaly_score(img_feat, normal_feat, anomaly_feat)

class WinCLIPDetector(BaseAnomalyDetector):
    """WinCLIP-style: rich multi-state prompts."""
    
    def detect(self, images: torch.Tensor, category: str) -> np.ndarray:
        normal_states = ["good", "clean", "perfect", "flawless", "undamaged"]
        anomaly_states = ["broken", "damaged", "defective", "scratched", "cracked"]
        
        normal_texts = [t.format(f"{state} {category}")
                       for t in config.NORMAL_TEMPLATES for state in normal_states]
        anomaly_texts = [t.format(f"{state} {category}")
                        for t in config.ANOMALY_TEMPLATES for state in anomaly_states]
        
        normal_feat = get_text_features(self.clip_model, normal_texts).mean(0, keepdim=True)
        anomaly_feat = get_text_features(self.clip_model, anomaly_texts).mean(0, keepdim=True)
        
        with torch.no_grad():
            img_feat = self.clip_model.encode_image(images)
            img_feat /= img_feat.norm(dim=-1, keepdim=True)
        
        return compute_anomaly_score(img_feat, normal_feat, anomaly_feat)

class PACLIPDetector(BaseAnomalyDetector):
    """Pseudo-Anomaly Aware CLIP: enhances normal representation with pseudo-anomalies."""
    
    def detect(self, images: torch.Tensor, category: str) -> np.ndarray:
        normal_base = ["a photo of a clean " + category, "a flawless " + category]
        pseudo_additions = ["with noise", "with shadow", "slightly deformed", "with background clutter"]
        
        pseudo_texts = [f"{nt} {pa}" for nt in normal_base for pa in pseudo_additions]
        anomaly_texts = [
            f"a photo of a defective {category}",
            f"a broken {category}",
            f"a scratched {category}"
        ]
        
        normal_feat = get_text_features(self.clip_model, normal_base).mean(0, keepdim=True)
        pseudo_feat = get_text_features(self.clip_model, pseudo_texts).mean(0, keepdim=True)
        anomaly_feat = get_text_features(self.clip_model, anomaly_texts).mean(0, keepdim=True)
        
        # Enhanced normal representation
        enhanced_normal = (normal_feat + pseudo_feat) / 2
        
        with torch.no_grad():
            img_feat = self.clip_model.encode_image(images)
            img_feat /= img_feat.norm(dim=-1, keepdim=True)
        
        return compute_anomaly_score(img_feat, enhanced_normal, anomaly_feat)

class AAFCLIPDetector(BaseAnomalyDetector):
    """Anomaly-Aware / Anomaly-Focused CLIP with high-quality templates."""
    
    def detect(self, images: torch.Tensor, category: str) -> np.ndarray:
        templates = [
            "a high quality photo of a normal {}",
            "a detailed photo of an undamaged {}",
            "a clear photo of a perfect {}",
            "a close-up of a flawless {}"
        ]
        anomaly_states = ["with defect", "with scratch", "with crack",
                         "with contamination", "abnormal", "deformed"]
        
        normal_texts = [t.format(category) for t in templates]
        anomaly_texts = [t.format(category) + " " + s for t in templates for s in anomaly_states]
        
        normal_feat = get_text_features(self.clip_model, normal_texts).mean(0, keepdim=True)
        anomaly_feat = get_text_features(self.clip_model, anomaly_texts).mean(0, keepdim=True)
        
        with torch.no_grad():
            img_feat = self.clip_model.encode_image(images)
            img_feat /= img_feat.norm(dim=-1, keepdim=True)
        
        return compute_anomaly_score(img_feat, normal_feat, anomaly_feat)

# **EVALUATION**

In [None]:
def evaluate_category(clip_model, preprocess, category: str, detectors: dict) -> dict:
    """Evaluate all methods on one MVTec AD category and return AUROC scores."""
    
    print(f"\nEvaluating category: {category.upper()}")
    
    dataset = MVTecADDataset(config.DATA_PATH, category, "test", transform=preprocess)
    dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)
    
    results = {name: [] for name in detectors}
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _ in tqdm(dataloader, desc="Processing", leave=False):
            images = images.to(config.DEVICE)
            all_labels.extend(labels.numpy())
            
            for name, detector in detectors.items():
                scores = detector.detect(images, category)
                results[name].extend(scores)
    
    # Compute AUROC
    auroc_scores = {}
    for name, scores in results.items():
        auroc = roc_auc_score(all_labels, scores)
        auroc_scores[name] = round(auroc, 4)
        print(f"{name:15} AUROC = {auroc:.4f}")
    
    return auroc_scores

# **VISUALIZATION**

In [None]:
def plot_results(df: pd.DataFrame):
    """Generate professional plots: bar, box, heatmap."""
    
    sns.set(style="whitegrid", font_scale=1.2)
    
    # Bar plot per category
    plt.figure(figsize=(16, 8))
    df_melted = df.melt(id_vars="Category", var_name="Method", value_name="AUROC")
    sns.barplot(x="Category", y="AUROC", hue="Method", data=df_melted, palette="viridis")
    plt.title("Zero-Shot Anomaly Detection AUROC per Category", fontsize=16, fontweight="bold")
    plt.ylabel("AUROC")
    plt.ylim(0, 1.05)
    plt.xticks(rotation=45, ha="right")
    plt.legend(title="Method", bbox_to_anchor=(1.02, 1), loc="upper left")
    plt.tight_layout()
    plt.savefig("/kaggle/working/bar_plot.png", dpi=300, bbox_inches="tight")
    plt.show()
    
    # Box plot
    plt.figure(figsize=(10, 6))
    sns.boxplot(data=df.drop(columns="Category"), palette="Set2")
    plt.title("AUROC Distribution Across Categories", fontsize=16, fontweight="bold")
    plt.ylabel("AUROC")
    plt.ylim(0, 1.05)
    plt.tight_layout()
    plt.savefig("/kaggle/working/box_plot.png", dpi=300, bbox_inches="tight")
    plt.show()
    
    # Heatmap
    plt.figure(figsize=(10, 12))
    sns.heatmap(df.set_index("Category"), annot=True, cmap="YlOrRd", fmt=".3f",
                linewidths=.5, cbar_kws={"shrink": 0.8})
    plt.title("AUROC Heatmap", fontsize=16, fontweight="bold")
    plt.tight_layout()
    plt.savefig("/kaggle/working/heatmap.png", dpi=300, bbox_inches="tight")
    plt.show()

def create_comparison_gif(clip_model, preprocess):
    """Generate animated GIF comparing all methods on random test samples."""
    
    print("\nGenerating comparison GIF...")
    
    detectors = {
        "AnomalyCLIP": AnomalyCLIPDetector(clip_model, preprocess),
        "WinCLIP": WinCLIPDetector(clip_model, preprocess),
        "PA-CLIP": PACLIPDetector(clip_model, preprocess),
        "AA/AF-CLIP": AAFCLIPDetector(clip_model, preprocess)
    }
    
    # Collect random samples
    samples = []
    for cat in config.CATEGORIES:
        dataset = MVTecADDataset(config.DATA_PATH, cat, "test", transform=None)
        for i in range(min(3, len(dataset))):
            samples.append((dataset.image_paths[i], cat, dataset.labels[i]))
    
    np.random.shuffle(samples)
    selected = samples[:config.NUM_GIF_FRAMES]
    
    frames = []
    for img_path, cat, label in tqdm(selected, desc="Creating frames"):
        pil_img = Image.open(img_path).convert("RGB").resize((224, 224))
        tensor_img = preprocess(Image.open(img_path).convert("RGB")).unsqueeze(0).to(config.DEVICE)
        
        scores = {name: detector.detect(tensor_img, cat)[0] for name, detector in detectors.items()}
        
        frame = create_gif_frame(pil_img, scores, cat, label)
        frames.append(np.array(frame))
    
    imageio.mimsave(config.GIF_OUTPUT_PATH, frames, fps=2, loop=0)
    print(f"GIF saved: {config.GIF_OUTPUT_PATH}")

def create_gif_frame(pil_img: Image.Image, scores: dict, category: str, label: int) -> Image.Image:
    """Create one frame for the comparison GIF."""
    
    w, h = pil_img.size
    frame_w = w * 4 + 100
    frame_h = h + 200
    frame = Image.new("RGB", (frame_w, frame_h), "white")
    draw = ImageDraw.Draw(frame)
    
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24)
        title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 32)
    except:
        font = ImageFont.load_default()
        title_font = ImageFont.load_default()
    
    title = f"Category: {category.upper()} | Ground Truth: {'ANOMALY' if label else 'NORMAL'}"
    draw.text((20, 20), title, fill="black", font=title_font)
    draw.text((20, 70), "Zero-Shot CLIP-based Anomaly Detection Comparison", fill="gray", font=font)
    
    for i, (method, score) in enumerate(scores.items()):
        img_copy = pil_img.copy()
        border_color = "red" if score > 0.5 else "green"
        bordered = Image.new("RGB", (w + 12, h + 12), border_color)
        bordered.paste(img_copy, (6, 6))
        
        d = ImageDraw.Draw(bordered)
        text = f"{method}\nScore: {score:.3f}\n{'ANOMALY' if score > 0.5 else 'NORMAL'}"
        y = 10
        for line in text.split("\n"):
            d.text((10, y), line, fill="white", font=font, stroke_fill="black", stroke_width=2)
            y += 30
        
        frame.paste(bordered, (20 + i * (w + 20), 120))
    
    # Legend
    y_legend = frame_h - 50
    draw.rectangle([20, y_legend, 40, y_legend+20], fill="green")
    draw.text((50, y_legend), "Predicted Normal", fill="black", font=font)
    draw.rectangle([220, y_legend, 240, y_legend+20], fill="red")
    draw.text((250, y_legend), "Predicted Anomaly", fill="black", font=font)
    
    return frame

# **MAIN EXECUTION**

In [None]:
def main():
    print("\n" + "="*80)
    print("ZERO-SHOT ANOMALY DETECTION ON MVTEC AD - PROFESSIONAL IMPLEMENTATION")
    print("="*80)
    
    # Load CLIP
    print("\nLoading CLIP model...")
    clip_model, preprocess = clip.load(config.CLIP_MODEL, device=config.DEVICE)
    clip_model.eval()
    print(f"CLIP {config.CLIP_MODEL} loaded on {config.DEVICE}")
    
    # Initialize detectors
    detectors = {
        "AnomalyCLIP": AnomalyCLIPDetector(clip_model, preprocess),
        "WinCLIP": WinCLIPDetector(clip_model, preprocess),
        "PA-CLIP": PACLIPDetector(clip_model, preprocess),
        "AA/AF-CLIP": AAFCLIPDetector(clip_model, preprocess)
    }
    
    # Evaluate all categories
    all_results = []
    for category in config.CATEGORIES:
        result = evaluate_category(clip_model, preprocess, category, detectors)
        result["Category"] = category
        all_results.append(result)
    
    # Create DataFrame
    results_df = pd.DataFrame(all_results)
    mean_row = pd.DataFrame([{
        "Category": "MEAN",
        **{col: results_df[col].mean() for col in config.METHODS}
    }])
    results_df = pd.concat([results_df, mean_row], ignore_index=True)
    
    # Save and display results
    csv_path = "/kaggle/working/mvtec_ad_zero_shot_results.csv"
    results_df.to_csv(csv_path, index=False)
    
    print("\n" + "="*80)
    print("FINAL RESULTS (Image-level AUROC)")
    print("="*80)
    print(results_df.round(4).to_string(index=False))
    print(f"\nResults saved to: {csv_path}")
    
    # Visualizations
    plot_results(results_df.iloc[:-1])  # Exclude mean row from plots
    
    # GIF
    create_comparison_gif(clip_model, preprocess)
    
    print("\n" + "="*80)
    print("EXPERIMENT COMPLETED SUCCESSFULLY")
    print("="*80)
    print("Outputs:")
    print(f"  • CSV Results: {csv_path}")
    print(f"  • GIF Comparison: {config.GIF_OUTPUT_PATH}")
    print("  • Plots: bar_plot.png, box_plot.png, heatmap.png")
    print("="*80)

# **RUN**

In [None]:
if __name__ == "__main__":
    if not os.path.exists(config.DATA_PATH):
        print(f"ERROR: Dataset not found at {config.DATA_PATH}")
        print("Please add the 'MVTec AD' dataset to your Kaggle notebook:")
        print("   → Add Data → Search 'mvtec ad' → Add")
    else:
        main()