# Wafer Defect Classification using Vision Transformer (ViT)

This notebook implements wafer defect classification using pretrained Vision Transformer models.
We'll use the `timm` library for access to state-of-the-art pretrained ViT models.

In [1]:
# Install required packages
!pip install timm torchvision transformers
!pip install torchsummary scikit-learn

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [2]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split, SubsetRandomSampler
import torchvision.transforms as transforms

# Transformer and model imports
import timm
from torchsummary import summary

# ML imports
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set style
sns.set_theme(style="whitegrid")
%matplotlib inline

print(f"PyTorch version: {torch.__version__}")
print(f"TIMM version: {timm.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

PyTorch version: 2.0.0+cpu
TIMM version: 1.0.21
CUDA available: False


## Data Loading and Preprocessing
Using the same data preprocessing pipeline as the original notebook but adapted for Vision Transformers

In [3]:
# Load dataset
df = pd.read_pickle("MIR-WM811K/Python/WM811K.pkl")
print(f"Dataset shape: {df.shape}")
df.info()

Dataset shape: (811457, 6)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 811457 entries, 0 to 811456
Data columns (total 6 columns):
 #   Column          Non-Null Count   Dtype  
---  ------          --------------   -----  
 0   dieSize         811457 non-null  float64
 1   failureType     811457 non-null  object 
 2   lotName         811457 non-null  object 
 3   trainTestLabel  811457 non-null  object 
 4   waferIndex      811457 non-null  float64
 5   waferMap        811457 non-null  object 
dtypes: float64(2), object(4)
memory usage: 37.1+ MB


In [4]:
# Data preprocessing - same as original but optimized
def preprocess_data(df):
    # Drop waferIndex column
    df = df.drop(['waferIndex'], axis=1)
    
    # Add waferMapDim column
    def find_dim(x):
        dim0 = np.size(x, axis=0)
        dim1 = np.size(x, axis=1)
        return dim0, dim1
    
    df['waferMapDim'] = df.waferMap.apply(find_dim)
    
    # Clean failure types
    df['failureType'] = df['failureType'].astype(str).str.replace(r"[\[\]']", "", regex=True)
    
    # Mapping failure types to numbers
    mapping_type = {
        'Center': 0, 'Donut': 1, 'Edge-Loc': 2, 'Edge-Ring': 3,
        'Loc': 4, 'Random': 5, 'Scratch': 6, 'Near-full': 7, 'none': 8
    }
    df['failureNum'] = df['failureType'].map(mapping_type)
    
    # Filter labeled data
    df_withlabel = df[df['failureType'] != 0].reset_index(drop=True)
    
    return df_withlabel

df_processed = preprocess_data(df)
print(f"Processed dataset shape: {df_processed.shape}")
print("\nFailure type distribution:")
print(df_processed['failureType'].value_counts())

Processed dataset shape: (811457, 7)

Failure type distribution:
0 0          638507
none         147431
Edge-Ring      9680
Edge-Loc       5189
Center         4294
Loc            3593
Scratch        1193
Random          866
Donut           555
Near-full       149
Name: failureType, dtype: int64


In [None]:
# Extract and prepare wafer maps for ViT
def prepare_wafer_data_for_vit(df_withlabel, target_size=224):
    """
    Prepare wafer map data for Vision Transformer
    ViT typically works with 224x224 RGB images
    """
    wafer_maps = []
    labels = []
    
    print("Processing wafer maps...")
    for idx, row in df_withlabel.iterrows():
        if idx % 10000 == 0:
            print(f"Processed {idx}/{len(df_withlabel)} samples")
            
        wafer_map = row['waferMap']
        failure_type = row['failureType']
        
        # Convert to RGB (0: non-wafer -> R, 1: normal -> G, 2: defect -> B)
        h, w = wafer_map.shape
        rgb_map = np.zeros((h, w, 3), dtype=np.uint8)
        
        for i in range(h):
            for j in range(w):
                pixel_val = int(wafer_map[i, j])
                if pixel_val < 3:  # Ensure valid pixel values
                    rgb_map[i, j, pixel_val] = 255
        
        # Resize to target size for ViT
        pil_image = Image.fromarray(rgb_map)
        resized_image = pil_image.resize((target_size, target_size), Image.LANCZOS)
        resized_array = np.array(resized_image)
        
        wafer_maps.append(resized_array)
        labels.append(failure_type)
    
    return np.array(wafer_maps), np.array(labels)

# Prepare data
wafer_images, wafer_labels = prepare_wafer_data_for_vit(df_processed)
print(f"\nWafer images shape: {wafer_images.shape}")
print(f"Wafer labels shape: {wafer_labels.shape}")

Processing wafer maps...
Processed 0/811457 samples
Processed 10000/811457 samples
Processed 20000/811457 samples
Processed 30000/811457 samples
Processed 40000/811457 samples
Processed 50000/811457 samples
Processed 60000/811457 samples
Processed 70000/811457 samples
Processed 80000/811457 samples
Processed 90000/811457 samples
Processed 100000/811457 samples
Processed 110000/811457 samples
Processed 120000/811457 samples
Processed 130000/811457 samples
Processed 140000/811457 samples
Processed 150000/811457 samples
Processed 160000/811457 samples
Processed 170000/811457 samples
Processed 180000/811457 samples
Processed 190000/811457 samples
Processed 200000/811457 samples
Processed 210000/811457 samples


In [None]:
# Visualize sample wafer maps
plt.figure(figsize=(15, 10))
failure_types = np.unique(wafer_labels)
for i, failure_type in enumerate(failure_types[:9]):
    idx = np.where(wafer_labels == failure_type)[0][0]
    plt.subplot(3, 3, i+1)
    plt.imshow(wafer_images[idx])
    plt.title(f'Failure Type: {failure_type}')
    plt.axis('off')
plt.tight_layout()
plt.show()

## Vision Transformer Data Transforms
Define proper data augmentation and normalization for ViT models

In [None]:
# Define transforms for ViT
class ViTDataTransforms:
    def __init__(self, img_size=224):
        # ImageNet normalization (standard for pretrained models)
        self.train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.val_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

# Custom Dataset class
class WaferDataset(Dataset):
    def __init__(self, images, labels, transform=None, label_encoder=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        self.label_encoder = label_encoder
        
        # Encode labels to integers
        if label_encoder is None:
            unique_labels = np.unique(labels)
            self.label_encoder = {label: idx for idx, label in enumerate(unique_labels)}
        else:
            self.label_encoder = label_encoder
            
        self.encoded_labels = [self.label_encoder[label] for label in labels]
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.encoded_labels[idx]
        
        if self.transform:
            image = self.transform(image)
        else:
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            
        return image, torch.tensor(label, dtype=torch.long)

transforms_vit = ViTDataTransforms()
print("ViT transforms created successfully!")

## Vision Transformer Model Definition
Using pretrained ViT models from `timm` library

In [None]:
class WaferViTClassifier(nn.Module):
    def __init__(self, model_name='vit_base_patch16_224', num_classes=9, pretrained=True):
        super(WaferViTClassifier, self).__init__()
        
        # Load pretrained ViT model
        self.backbone = timm.create_model(model_name, pretrained=pretrained)
        
        # Get the number of features from the classifier
        if hasattr(self.backbone, 'head'):
            num_features = self.backbone.head.in_features
            # Replace the head with our custom classifier
            self.backbone.head = nn.Identity()
        elif hasattr(self.backbone, 'classifier'):
            num_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        else:
            # Fallback - assume 768 for base ViT
            num_features = 768
        
        # Custom classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        # Get features from backbone
        features = self.backbone(x)
        # Classify
        output = self.classifier(features)
        return output

# Available ViT models to try
available_models = [
    'vit_tiny_patch16_224',
    'vit_small_patch16_224', 
    'vit_base_patch16_224',
    'vit_base_patch16_384',
    'vit_large_patch16_224'
]

print("Available ViT models:")
for i, model in enumerate(available_models):
    print(f"{i+1}. {model}")

In [None]:
# Setup device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model - starting with base ViT
model_name = 'vit_base_patch16_224'
num_classes = len(np.unique(wafer_labels))

model = WaferViTClassifier(model_name=model_name, num_classes=num_classes)
model = model.to(device)

print(f"\nCreated {model_name} with {num_classes} output classes")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## Data Preparation and Training Setup

In [None]:
# Create datasets
dataset = WaferDataset(wafer_images, wafer_labels, transform=transforms_vit.train_transform)
print(f"Dataset created with {len(dataset)} samples")
print(f"Label encoder: {dataset.label_encoder}")

# Training configuration
config = {
    'batch_size': 32,
    'learning_rate': 3e-5,  # Lower LR for fine-tuning
    'num_epochs': 15,
    'weight_decay': 1e-4,
    'num_folds': 5
}

print(f"Training configuration: {config}")

In [None]:
# Training and validation functions
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    for batch_idx, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()
        
        if batch_idx % 50 == 0:
            print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc, all_predictions, all_labels

print("Training functions defined successfully!")

## K-Fold Cross Validation Training

In [None]:
# K-Fold Cross Validation
kfold = KFold(n_splits=config['num_folds'], shuffle=True, random_state=42)
fold_results = {}
best_models = {}

for fold, (train_idx, val_idx) in enumerate(kfold.split(range(len(dataset)))):
    print(f"\n{'='*50}")
    print(f"FOLD {fold + 1}/{config['num_folds']}")
    print(f"{'='*50}")
    
    # Create data loaders for this fold
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)
    
    train_loader = DataLoader(dataset, batch_size=config['batch_size'], 
                             sampler=train_sampler, num_workers=2)
    val_loader = DataLoader(dataset, batch_size=config['batch_size'], 
                           sampler=val_sampler, num_workers=2)
    
    # Create fresh model for this fold
    fold_model = WaferViTClassifier(model_name=model_name, num_classes=num_classes)
    fold_model = fold_model.to(device)
    
    # Setup optimizer and criterion
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(fold_model.parameters(), 
                           lr=config['learning_rate'], 
                           weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, config['num_epochs'])
    
    # Training history for this fold
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    best_val_acc = 0.0
    
    # Training loop
    for epoch in range(config['num_epochs']):
        print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
        print("-" * 30)
        
        # Train
        train_loss, train_acc = train_epoch(fold_model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc, val_predictions, val_labels = validate_epoch(
            fold_model, val_loader, criterion, device)
        
        # Update learning rate
        scheduler.step()
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_models[fold] = {
                'model_state': fold_model.state_dict().copy(),
                'val_acc': val_acc,
                'predictions': val_predictions,
                'labels': val_labels
            }
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Store fold results
    fold_results[fold] = history
    print(f"\nFold {fold+1} Best Validation Accuracy: {best_val_acc:.4f}")

print("\n" + "="*50)
print("CROSS VALIDATION COMPLETED")
print("="*50)

## Results Analysis and Visualization

In [None]:
# Calculate overall performance metrics
fold_train_accs = []
fold_val_accs = []
fold_train_losses = []
fold_val_losses = []

for fold in range(config['num_folds']):
    history = fold_results[fold]
    fold_train_accs.append(max(history['train_acc']))
    fold_val_accs.append(max(history['val_acc']))
    fold_train_losses.append(min(history['train_loss']))
    fold_val_losses.append(min(history['val_loss']))

# Print summary statistics
print("Vision Transformer (ViT) Performance Summary")
print("=" * 50)
print(f"Average Training Accuracy: {np.mean(fold_train_accs):.4f} ± {np.std(fold_train_accs):.4f}")
print(f"Average Validation Accuracy: {np.mean(fold_val_accs):.4f} ± {np.std(fold_val_accs):.4f}")
print(f"Average Training Loss: {np.mean(fold_train_losses):.4f} ± {np.std(fold_train_losses):.4f}")
print(f"Average Validation Loss: {np.mean(fold_val_losses):.4f} ± {np.std(fold_val_losses):.4f}")
print(f"Best Validation Accuracy: {max(fold_val_accs):.4f}")

# Store results for comparison
vit_results = {
    'model_name': 'Vision Transformer (ViT)',
    'avg_train_acc': np.mean(fold_train_accs),
    'avg_val_acc': np.mean(fold_val_accs),
    'std_train_acc': np.std(fold_train_accs),
    'std_val_acc': np.std(fold_val_accs),
    'avg_train_loss': np.mean(fold_train_losses),
    'avg_val_loss': np.mean(fold_val_losses),
    'best_val_acc': max(fold_val_accs),
    'fold_results': fold_results,
    'config': config
}

In [None]:
# Visualization of training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot training curves for each fold
for fold in range(config['num_folds']):
    history = fold_results[fold]
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Training and validation loss
    axes[0, 0].plot(epochs, history['train_loss'], label=f'Fold {fold+1}')
    axes[0, 1].plot(epochs, history['val_loss'], label=f'Fold {fold+1}')
    
    # Training and validation accuracy
    axes[1, 0].plot(epochs, history['train_acc'], label=f'Fold {fold+1}')
    axes[1, 1].plot(epochs, history['val_acc'], label=f'Fold {fold+1}')

axes[0, 0].set_title('Training Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

axes[0, 1].set_title('Validation Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True)

axes[1, 0].set_title('Training Accuracy')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Accuracy')
axes[1, 0].legend()
axes[1, 0].grid(True)

axes[1, 1].set_title('Validation Accuracy')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Confusion matrix for best performing fold
best_fold = max(best_models.keys(), key=lambda k: best_models[k]['val_acc'])
best_predictions = best_models[best_fold]['predictions']
best_labels = best_models[best_fold]['labels']

# Create confusion matrix
cm = confusion_matrix(best_labels, best_predictions)
label_names = list(dataset.label_encoder.keys())

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=label_names, yticklabels=label_names)
plt.title(f'Confusion Matrix - ViT (Best Fold: {best_fold+1})')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# Classification report
print(f"\nClassification Report - Best Fold ({best_fold+1}):")
print(classification_report(best_labels, best_predictions, target_names=label_names))

In [None]:
# Save the results and best model
import pickle

# Save ViT results
with open('vit_wafer_classification_results.pkl', 'wb') as f:
    pickle.dump(vit_results, f)

# Save best model
best_model_path = 'best_vit_wafer_model.pth'
torch.save({
    'model_state_dict': best_models[best_fold]['model_state'],
    'model_name': model_name,
    'num_classes': num_classes,
    'label_encoder': dataset.label_encoder,
    'config': config,
    'val_accuracy': best_models[best_fold]['val_acc']
}, best_model_path)

print(f"Results saved to: vit_wafer_classification_results.pkl")
print(f"Best model saved to: {best_model_path}")
print(f"Best validation accuracy: {best_models[best_fold]['val_acc']:.4f}")

## Model Interpretation and Analysis

In [None]:
# Load best model for analysis
analysis_model = WaferViTClassifier(model_name=model_name, num_classes=num_classes)
analysis_model.load_state_dict(best_models[best_fold]['model_state'])
analysis_model = analysis_model.to(device)
analysis_model.eval()

# Feature analysis function
def analyze_predictions(model, dataset, device, num_samples=10):
    model.eval()
    
    # Get some random samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, true_label = dataset[idx]
            image_batch = image.unsqueeze(0).to(device)
            
            # Get prediction
            outputs = model(image_batch)
            probabilities = F.softmax(outputs, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0, predicted_class].item()
            
            # Convert back to label names
            label_names = list(dataset.label_encoder.keys())
            true_label_name = label_names[true_label]
            pred_label_name = label_names[predicted_class]
            
            # Display original image (denormalize for visualization)
            img_display = image.clone()
            # Denormalize
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            img_display = img_display * std + mean
            img_display = torch.clamp(img_display, 0, 1)
            
            axes[i].imshow(img_display.permute(1, 2, 0))
            axes[i].set_title(f'True: {true_label_name}\nPred: {pred_label_name}\nConf: {confidence:.3f}',
                             color='green' if true_label == predicted_class else 'red')
            axes[i].axis('off')
    
    plt.suptitle('ViT Prediction Analysis', fontsize=16)
    plt.tight_layout()
    plt.show()

analyze_predictions(analysis_model, dataset, device)
print("Prediction analysis completed!")

## Summary and Next Steps

This notebook implemented Vision Transformer (ViT) for wafer defect classification with the following key features:

1. **Pretrained ViT Model**: Used `timm` library for state-of-the-art pretrained transformers
2. **Proper Data Preprocessing**: Adapted wafer maps to RGB format suitable for ViT
3. **Data Augmentation**: Applied appropriate transforms for better generalization
4. **K-Fold Cross Validation**: Robust evaluation methodology
5. **Comprehensive Analysis**: Training curves, confusion matrices, and prediction analysis

### Key Advantages of ViT:
- **Global Context**: Attention mechanism captures long-range dependencies
- **Transfer Learning**: Benefits from large-scale pretraining
- **Scalability**: Performs better with larger datasets
- **Interpretability**: Attention maps can provide insights

### Next Steps:
1. Create Swin Transformer implementation
2. Develop comparative analysis framework
3. Implement wafer life expectancy prediction
4. Optimize hyperparameters and model architectures