# CRBL Anomaly Detection - PyTorch Inference

이 노트북은 PyTorch로 이식된 CRBL 이상 탐지 모델을 사용한 추론 코드입니다.


In [None]:
import os 
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd 
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import cv2 
from pathlib import Path

# Import our custom modules
from model_pytorch import create_model
from loadDataset_generator_pytorch import crop_image, origin_image, TestDataset

# Configuration
INPUT_DIM = 128
INPUT_SHAPE = (3, INPUT_DIM, INPUT_DIM)
CROPPED_THRESHOLD = 0.01 
FULL_THRESHOLD = 0.5 

weight_path = "./weights/CRBL_250328_pytorch.pth"
test_csv_path = "./data/csv/valid.csv"
image_path = "./data/images"

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
class InferenceModel:
    """Inference wrapper for CRBL model"""
    def __init__(self, model_path, input_shape, device):
        self.device = device
        self.input_shape = input_shape
        
        # Load model
        self.model = create_model(input_shape=input_shape, num_classes=1)
        self.model.load_state_dict(torch.load(model_path, map_location=device))
        self.model.to(device)
        self.model.eval()
        
        # Define transforms
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
        ])
        
    def preprocess_image(self, image):
        """Preprocess single image for inference"""
        if isinstance(image, str):
            image = cv2.imread(image)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Convert to tensor
        if isinstance(image, np.ndarray):
            image = self.transform(image)
        
        # Add batch dimension
        image = image.unsqueeze(0)
        return image.to(self.device)
    
    def predict_single(self, image):
        """Predict anomaly for single image"""
        with torch.no_grad():
            image_tensor = self.preprocess_image(image)
            output = self.model(image_tensor)
            probability = output.squeeze().cpu().item()
            return probability
    
    def predict_batch(self, images):
        """Predict anomaly for batch of images"""
        with torch.no_grad():
            if isinstance(images, list):
                # Process list of images
                batch_tensors = []
                for img in images:
                    if isinstance(img, str):
                        img = cv2.imread(img)
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    batch_tensors.append(self.transform(img))
                batch_tensor = torch.stack(batch_tensors).to(self.device)
            else:
                # Process numpy array batch
                batch_tensor = torch.from_numpy(images.transpose(0, 3, 1, 2)).float().to(self.device)
            
            outputs = self.model(batch_tensor)
            probabilities = outputs.squeeze().cpu().numpy()
            return probabilities


In [None]:
# Load inference model
print("Loading inference model...")
inference_model = InferenceModel(weight_path, INPUT_SHAPE, device)

# Load test data
print("Loading test data...")
X_cropped, y_cropped = crop_image(test_csv_path, image_path, INPUT_DIM)
X_original, y_original = origin_image(test_csv_path, image_path, INPUT_DIM)

print(f"Test samples: {len(X_cropped)}")


In [None]:
# Predict on test data
print("Predicting on cropped images...")
predictions_cropped = inference_model.predict_batch(X_cropped)

print("Predicting on original images...")
predictions_original = inference_model.predict_batch(X_original)

# Calculate accuracy
pred_binary_cropped = (predictions_cropped > CROPPED_THRESHOLD).astype(int)
accuracy_cropped = np.mean(pred_binary_cropped == y_cropped)

pred_binary_original = (predictions_original > FULL_THRESHOLD).astype(int)
accuracy_original = np.mean(pred_binary_original == y_original)

print(f"\nResults:")
print(f"Cropped images - Accuracy: {accuracy_cropped:.4f}")
print(f"Original images - Accuracy: {accuracy_original:.4f}")


In [None]:
# Visualize predictions
fig, axes = plt.subplots(2, min(10, len(X_cropped)), figsize=(20, 6))

for i in range(min(10, len(X_cropped))):
    # Cropped images
    axes[0, i].imshow(X_cropped[i])
    axes[0, i].set_title(f'Cropped\nPred: {predictions_cropped[i]:.3f}\nTrue: {y_cropped[i]}')
    axes[0, i].axis('off')
    
    # Original images
    axes[1, i].imshow(X_original[i])
    axes[1, i].set_title(f'Original\nPred: {predictions_original[i]:.3f}\nTrue: {y_original[i]}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()


In [None]:
# Detailed results
print("Detailed Results:")
print("Cropped images:")
for i, (pred, true) in enumerate(zip(predictions_cropped, y_cropped)):
    print(f"  Sample {i}: Prediction={pred:.4f}, True={true}, Correct={pred_binary_cropped[i]==true}")

print("\nOriginal images:")
for i, (pred, true) in enumerate(zip(predictions_original, y_original)):
    print(f"  Sample {i}: Prediction={pred:.4f}, True={true}, Correct={pred_binary_original[i]==true}")

print("\nInference completed successfully!")
