## 1. Setup Paths & Imports

In [None]:
import os
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from anomalib.models import Patchcore
from anomalib.data import Folder
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

# Paths
DATASET_ROOT = Path("dataset")
TRAIN_GOOD = DATASET_ROOT / "train" / "good"
CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)

print(f"\nDataset: {DATASET_ROOT.absolute()}")
print(f"Train folder: {TRAIN_GOOD.absolute()}")
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR.absolute()}")

## 2. Setup Data Module

In [None]:
# Configure Anomalib Folder datamodule
# This uses only the train/good folder for training
datamodule = Folder(
    root=str(DATASET_ROOT),
    normal_dir="train/good",  # Train on good images only
    abnormal_dir="test/defect",  # Use defects for validation/testing (optional)
    task="classification",  # or "segmentation" if pixel-level anomalies
    image_size=224,
    batch_size=32,
    num_workers=0,  # Set to 0 on Windows; increase on Linux/Mac
    train_val_split=0.2,  # 80% train, 20% val from good images
    seed=42
)

print("âœ“ Data module configured")
print(f"  - Image size: 224x224")
print(f"  - Batch size: 32")
print(f"  - Train/val split: 80/20")

## 3. Initialize Patchcore Model

In [None]:
# Initialize Patchcore model
# Backbone options: resnet18, resnet50, wide_resnet50_2 (recommended), vgg16, etc.
model = Patchcore(
    backbone="wide_resnet50_2",  # Strong backbone for feature extraction
    layers=["layer2", "layer3"],  # Use intermediate layers for local anomalies
    num_neighbors=9,  # Number of neighbors for patch correlation
    normalization_method="min_max",  # Normalize anomaly scores to [0, 1]
)

print("âœ“ Patchcore model initialized")
print(f"  - Backbone: wide_resnet50_2")
print(f"  - Layers: layer2, layer3")
print(f"  - Num neighbors: 9")

## 4. Setup PyTorch Lightning Trainer

In [None]:
# Configure logger
logger = TensorBoardLogger(
    save_dir="logs",
    name="patchcore_training",
    version="v1"
)

# Configure checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINT_DIR,
    filename="patchcore-{epoch:02d}",
    monitor="val_anomaly_map_auroc",  # Best checkpoint based on AUROC
    mode="max",
    save_last=True,
    verbose=True
)

# Create trainer
trainer = Trainer(
    max_epochs=1,  # Patchcore typically trains in 1 epoch
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    logger=logger,
    callbacks=[checkpoint_callback],
    enable_progress_bar=True,
    enable_model_summary=True
)

print("âœ“ Trainer configured")
print(f"  - Max epochs: 1")
print(f"  - Accelerator: {'GPU' if torch.cuda.is_available() else 'CPU'}")
print(f"  - Logger: TensorBoard")

## 5. Train Model

In [None]:
print("\n" + "="*50)
print("TRAINING PATCHCORE")
print("="*50 + "\n")

# Train the model
trainer.fit(model, datamodule=datamodule)

print("\n" + "="*50)
print("âœ“ TRAINING COMPLETE")
print("="*50)

## 6. Save Trained Model

In [None]:
# Save the trained model for inference
model_save_path = CHECKPOINT_DIR / "patchcore_trained.ckpt"
trainer.save_checkpoint(model_save_path)

print(f"âœ“ Model saved to: {model_save_path}")
print(f"  Checkpoint size: {os.path.getsize(model_save_path) / (1024**2):.2f} MB")

# Also save model weights only
model_weights_path = CHECKPOINT_DIR / "patchcore_weights.pth"
torch.save(model.state_dict(), model_weights_path)
print(f"âœ“ Weights saved to: {model_weights_path}")

## 7. Summary

In [None]:
print("\n" + "="*50)
print("TRAINING SUMMARY")
print("="*50)
print(f"\nâœ“ Model: Patchcore (wide_resnet50_2)")
print(f"âœ“ Training data: {TRAIN_GOOD} (good images only)")
print(f"âœ“ Training mode: Unsupervised (learns normal distribution)")
print(f"âœ“ Checkpoint: {model_save_path}")
print(f"âœ“ Logs: logs/patchcore_training/v1 (TensorBoard)")
print(f"\nðŸš€ Next: Run 03_anomalib_evaluate.ipynb to test on good/defect images")
print("="*50)