In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

In [2]:
carpet_root = 'D:/MvTec/mvtec_anomaly_detection/capsule'  

test_dir = os.path.join(carpet_root, 'test')
train_dir = os.path.join(carpet_root, 'train')
ground_truth_dir = os.path.join(carpet_root, 'ground_truth')


In [8]:

# transformation images and convert to tensors
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

In [5]:

# Load the test dataset
test_dataset = ImageFolder(test_dir, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load the train dataset
train_dataset = ImageFolder(train_dir, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# Load the ground truth dataset
ground_truth_dataset = ImageFolder(ground_truth_dir, transform=transform)
ground_truth_loader = torch.utils.data.DataLoader(ground_truth_dataset, batch_size=32, shuffle=False)

In [6]:
# Define the defect labels
defect_labels = ['crack', 'faulty_imprint', 'poke', 'scratch', 'squeeze']


In [41]:
defect_attributes = {
    'crack': ['visible', 'thin', 'narrow'],
    'faulty_imprint': ['misaligned', 'distorted', 'uneven'],
    'poke': ['deep', 'sharp', 'small'],
    'scratch': ['visible', 'linear', 'surface'],
    'squeeze': ['deformed', 'compressed', 'misshapen']
}


In [7]:
# Initialize variables for counting correct predictions
correct_predictions = 0
total_samples = 0
correct_predictions_by_defect = {label: 0 for label in defect_labels}
total_samples_by_defect = {label: 0 for label in defect_labels}


In [43]:
def attribute_classifier(attributes):
    # Define the attribute-based classification rules
    if 'visible' in attributes:
        return 'good'
    elif 'thin' in attributes and 'long' in attributes and 'narrow' in attributes:
        return 'crack'
    elif 'irregular' in attributes and 'misaligned' in attributes and 'distorted' in attributes:
        return 'faulty_imprint'
    elif 'deep' in attributes and 'sharp'in attributes and 'small' in attributes:
        return 'poke'
    elif 'visible' in attributes and 'linear' in attributes and 'surface' in attributes:
        return 'scratch'
    elif 'deformed' in attributes and  'compressed' in attributes and  'misshapen' in attributes:
        return 'squeeze'
    else:
        return 'unknown'

In [39]:
class ZeroShotDefectClassifier(nn.Module):
    def __init__(self, num_attributes, num_classes):
        super(ZeroShotDefectClassifier, self).__init__()
        self.features = nn.Sequential(
            # feature extraction layers
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))  # Adjust output size
        self.classifier = nn.Linear(256 * 4 * 4, num_attributes)  # Adjust input size
        self.fc = nn.Linear(num_attributes, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        x = self.fc(x)
        return x


# Model Parameters
num_attributes = 10  
num_classes = 6  
model = ZeroShotDefectClassifier(num_attributes, num_classes)

# Save Model
torch.save(model.state_dict(), "model_weights.pth")

# Loss Function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        
        # Compute the loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        
        # Update the weights
        optimizer.step()
        
        # Accumulate the loss
        running_loss += loss.item()
    
    # Print the average loss for the epoch
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")


Epoch [1/10], Loss: 0.4574
Epoch [2/10], Loss: 0.0000
Epoch [3/10], Loss: 0.0000
Epoch [4/10], Loss: 0.0000
Epoch [5/10], Loss: 0.0000
Epoch [6/10], Loss: 0.0000
Epoch [7/10], Loss: 0.0000
Epoch [8/10], Loss: 0.0000
Epoch [9/10], Loss: 0.0000
Epoch [10/10], Loss: 0.0000


In [55]:
# Convert predicted labels to numeric indices
defect_label_map = {'good': 0, 'crack': 1, 'faulty_imprint': 2, 'poke': 3, 'scratch': 4, 'squeeze': 5, 'unknown': 6}

# Set the model to evaluation mode
model.eval()

# Initialize variables for evaluation
correct_predictions = 0
total_samples = 0

# Iterate over the test data and labels
for images, labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)

    # Forward pass
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

    # Perform attribute-based classification
    predicted_labels = []
    for attributes in defect_attributes.values():
        predicted_label = attribute_classifier(attributes)
        predicted_labels.append(predicted_label)

    predicted_labels_indices = [defect_label_map[label] for label in predicted_labels]
    predicted_labels_tensor = torch.tensor(predicted_labels_indices).to(device)

    # one-hot encoded 
    predicted_labels_onehot = torch.nn.functional.one_hot(predicted_labels_tensor, num_classes=len(defect_label_map)).float()

    # Update the evaluation metrics
    correct_predictions += torch.sum(torch.argmax(predicted_labels_onehot, dim=1) == labels.unsqueeze(1)).item()
    total_samples += labels.size(0)

accuracy = correct_predictions / total_samples

print(f"Accuracy: {accuracy:.4f}")


Accuracy: 0.6325
