# Wafer Defect Classification using Swin Transformer

This notebook implements wafer defect classification using Swin Transformer, which combines the efficiency of CNNs with the modeling power of Transformers through hierarchical feature maps and shifted window attention.

In [1]:
# Install required packages
!pip install timm torchvision transformers
!pip install torchsummary scikit-learn
!pip install einops

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Collecting torchsummary
  Using cached torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
Defaulting to user installation because normal site-packages is not writeable

ERROR: Invalid requirement: '#'





## Data Loading and Preprocessing
Same preprocessing pipeline but optimized for Swin Transformer's hierarchical structure

In [3]:
# Load dataset
df = pd.read_pickle("MIR-WM811K/Python/WM811K.pkl")
print(f"Dataset shape: {df.shape}")
df.info()

Dataset shape: (811457, 6)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 811457 entries, 0 to 811456
Data columns (total 6 columns):
 #   Column          Non-Null Count   Dtype  
---  ------          --------------   -----  
 0   dieSize         811457 non-null  float64
 1   failureType     811457 non-null  object 
 2   lotName         811457 non-null  object 
 3   trainTestLabel  811457 non-null  object 
 4   waferIndex      811457 non-null  float64
 5   waferMap        811457 non-null  object 
dtypes: float64(2), object(4)
memory usage: 37.1+ MB


In [4]:
# Data preprocessing - same as ViT but with Swin-specific considerations
def preprocess_data(df):
    # Drop waferIndex column
    df = df.drop(['waferIndex'], axis=1)
    
    # Add waferMapDim column
    def find_dim(x):
        dim0 = np.size(x, axis=0)
        dim1 = np.size(x, axis=1)
        return dim0, dim1
    
    df['waferMapDim'] = df.waferMap.apply(find_dim)
    
    # Clean failure types
    df['failureType'] = df['failureType'].astype(str).str.replace(r"[\[\]']", "", regex=True)
    
    # Mapping failure types to numbers
    mapping_type = {
        'Center': 0, 'Donut': 1, 'Edge-Loc': 2, 'Edge-Ring': 3,
        'Loc': 4, 'Random': 5, 'Scratch': 6, 'Near-full': 7, 'none': 8
    }
    df['failureNum'] = df['failureType'].map(mapping_type)
    
    # Filter labeled data
    df_withlabel = df[df['failureType'] != 0].reset_index(drop=True)
    
    return df_withlabel

df_processed = preprocess_data(df)
print(f"Processed dataset shape: {df_processed.shape}")
print("\nFailure type distribution:")
print(df_processed['failureType'].value_counts())

Processed dataset shape: (811457, 7)

Failure type distribution:
0 0          638507
none         147431
Edge-Ring      9680
Edge-Loc       5189
Center         4294
Loc            3593
Scratch        1193
Random          866
Donut           555
Near-full       149
Name: failureType, dtype: int64


In [None]:
# Extract and prepare wafer maps for Swin Transformer
def prepare_wafer_data_for_swin(df_withlabel, target_size=224):
    """
    Prepare wafer map data for Swin Transformer
    Swin works well with various resolutions, but we'll use 224x224 for consistency
    """
    wafer_maps = []
    labels = []
    
    print("Processing wafer maps for Swin Transformer...")
    for idx, row in df_withlabel.iterrows():
        if idx % 10000 == 0:
            print(f"Processed {idx}/{len(df_withlabel)} samples")
            
        wafer_map = row['waferMap']
        failure_type = row['failureType']
        
        # Convert to RGB with enhanced contrast for hierarchical features
        h, w = wafer_map.shape
        rgb_map = np.zeros((h, w, 3), dtype=np.uint8)
        
        for i in range(h):
            for j in range(w):
                pixel_val = int(wafer_map[i, j])
                if pixel_val < 3:
                    # Enhanced encoding for better hierarchical feature learning
                    if pixel_val == 0:  # non-wafer
                        rgb_map[i, j] = [255, 0, 0]  # Red
                    elif pixel_val == 1:  # normal
                        rgb_map[i, j] = [0, 255, 0]  # Green  
                    else:  # defect
                        rgb_map[i, j] = [0, 0, 255]  # Blue
        
        # Resize with high-quality resampling for Swin's hierarchical processing
        pil_image = Image.fromarray(rgb_map)
        resized_image = pil_image.resize((target_size, target_size), Image.LANCZOS)
        resized_array = np.array(resized_image)
        
        wafer_maps.append(resized_array)
        labels.append(failure_type)
    
    return np.array(wafer_maps), np.array(labels)

# Prepare data
wafer_images, wafer_labels = prepare_wafer_data_for_swin(df_processed)
print(f"\nWafer images shape: {wafer_images.shape}")
print(f"Wafer labels shape: {wafer_labels.shape}")

Processing wafer maps for Swin Transformer...
Processed 0/811457 samples
Processed 10000/811457 samples
Processed 20000/811457 samples
Processed 30000/811457 samples
Processed 40000/811457 samples
Processed 50000/811457 samples
Processed 60000/811457 samples
Processed 70000/811457 samples
Processed 80000/811457 samples
Processed 90000/811457 samples
Processed 100000/811457 samples
Processed 110000/811457 samples
Processed 120000/811457 samples
Processed 130000/811457 samples
Processed 140000/811457 samples
Processed 150000/811457 samples
Processed 160000/811457 samples
Processed 170000/811457 samples
Processed 180000/811457 samples
Processed 190000/811457 samples
Processed 200000/811457 samples
Processed 210000/811457 samples
Processed 220000/811457 samples
Processed 230000/811457 samples
Processed 240000/811457 samples
Processed 250000/811457 samples
Processed 260000/811457 samples
Processed 270000/811457 samples
Processed 280000/811457 samples
Processed 290000/811457 samples
Processe

## Swin Transformer Data Transforms
Specialized transforms that work well with Swin's hierarchical attention mechanism

In [None]:
# Define transforms optimized for Swin Transformer
class SwinDataTransforms:
    def __init__(self, img_size=224):
        # Swin-specific augmentations that preserve spatial hierarchies
        self.train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),
            # Gentle augmentations to preserve spatial structure
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),  # Reduced rotation for Swin
            # Color augmentations that enhance hierarchical features
            transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1),
            # Random erasing to improve robustness
            transforms.ToTensor(),
            transforms.RandomErasing(p=0.1, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.val_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

# Custom Dataset class optimized for Swin
class SwinWaferDataset(Dataset):
    def __init__(self, images, labels, transform=None, label_encoder=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        self.label_encoder = label_encoder
        
        # Encode labels to integers
        if label_encoder is None:
            unique_labels = np.unique(labels)
            self.label_encoder = {label: idx for idx, label in enumerate(unique_labels)}
        else:
            self.label_encoder = label_encoder
            
        self.encoded_labels = [self.label_encoder[label] for label in labels]
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.encoded_labels[idx]
        
        if self.transform:
            image = self.transform(image)
        else:
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            
        return image, torch.tensor(label, dtype=torch.long)

transforms_swin = SwinDataTransforms()
print("Swin Transformer transforms created successfully!")

## Swin Transformer Model Definition
Using pretrained Swin models with custom classification heads

In [None]:
class WaferSwinClassifier(nn.Module):
    def __init__(self, model_name='swin_base_patch4_window7_224', num_classes=9, pretrained=True):
        super(WaferSwinClassifier, self).__init__()
        
        # Load pretrained Swin model
        self.backbone = timm.create_model(model_name, pretrained=pretrained)
        
        # Get the number of features from the classifier
        if hasattr(self.backbone, 'head'):
            num_features = self.backbone.head.in_features
            self.backbone.head = nn.Identity()
        elif hasattr(self.backbone, 'classifier'):
            num_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        else:
            # Fallback for different Swin variants
            num_features = 1024  # Base Swin dimension
        
        # Enhanced classification head for Swin
        self.classifier = nn.Sequential(
            nn.LayerNorm(num_features),
            nn.Dropout(0.3),
            nn.Linear(num_features, 512),
            nn.GELU(),  # GELU activation works well with transformers
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )
        
        # Initialize classifier weights
        self._init_classifier_weights()
        
    def _init_classifier_weights(self):
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
        
    def forward(self, x):
        # Get hierarchical features from Swin backbone
        features = self.backbone(x)
        # Classify using enhanced head
        output = self.classifier(features)
        return output

# Available Swin models to try
available_swin_models = [
    'swin_tiny_patch4_window7_224',
    'swin_small_patch4_window7_224', 
    'swin_base_patch4_window7_224',
    'swin_base_patch4_window12_384',
    'swin_large_patch4_window7_224',
    'swinv2_tiny_window16_256',
    'swinv2_small_window16_256',
    'swinv2_base_window16_256'
]

print("Available Swin Transformer models:")
for i, model in enumerate(available_swin_models):
    print(f"{i+1}. {model}")

In [None]:
# Setup device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create Swin model - starting with base Swin
model_name = 'swin_base_patch4_window7_224'
num_classes = len(np.unique(wafer_labels))

model = WaferSwinClassifier(model_name=model_name, num_classes=num_classes)
model = model.to(device)

print(f"\nCreated {model_name} with {num_classes} output classes")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## Training Configuration and Setup
Optimized hyperparameters for Swin Transformer

In [None]:
# Create datasets
dataset = SwinWaferDataset(wafer_images, wafer_labels, transform=transforms_swin.train_transform)
print(f"Dataset created with {len(dataset)} samples")
print(f"Label encoder: {dataset.label_encoder}")

# Training configuration optimized for Swin
config = {
    'batch_size': 16,  # Smaller batch size for Swin due to memory requirements
    'learning_rate': 1e-5,  # Very low LR for fine-tuning Swin
    'num_epochs': 20,  # More epochs for gradual fine-tuning
    'weight_decay': 0.05,  # Higher weight decay for Swin
    'num_folds': 5,
    'warmup_epochs': 3,  # Warmup for stable training
    'min_lr': 1e-7  # Minimum learning rate
}

print(f"Training configuration: {config}")

In [None]:
# Enhanced training and validation functions for Swin
def train_epoch_swin(model, dataloader, criterion, optimizer, scheduler, device, epoch):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    for batch_idx, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass with mixed precision for efficiency
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping for stable training
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Update learning rate during warmup
        if scheduler is not None:
            scheduler.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()
        
        if batch_idx % 25 == 0:
            current_lr = optimizer.param_groups[0]['lr']
            print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}, LR: {current_lr:.2e}')
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

def validate_epoch_swin(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc, all_predictions, all_labels

# Learning rate scheduler with warmup
def get_warmup_cosine_scheduler(optimizer, warmup_epochs, total_epochs, min_lr):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            # Linear warmup
            return (epoch + 1) / warmup_epochs
        else:
            # Cosine annealing
            progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
            return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress)))
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

print("Enhanced training functions for Swin Transformer defined successfully!")

## K-Fold Cross Validation Training for Swin Transformer

In [None]:
# K-Fold Cross Validation with Swin-optimized training
kfold = KFold(n_splits=config['num_folds'], shuffle=True, random_state=42)
fold_results = {}
best_models = {}

for fold, (train_idx, val_idx) in enumerate(kfold.split(range(len(dataset)))):
    print(f"\n{'='*60}")
    print(f"FOLD {fold + 1}/{config['num_folds']} - SWIN TRANSFORMER")
    print(f"{'='*60}")
    
    # Create data loaders for this fold
    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)
    
    train_loader = DataLoader(dataset, batch_size=config['batch_size'], 
                             sampler=train_sampler, num_workers=2, pin_memory=True)
    val_loader = DataLoader(dataset, batch_size=config['batch_size'], 
                           sampler=val_sampler, num_workers=2, pin_memory=True)
    
    # Create fresh model for this fold
    fold_model = WaferSwinClassifier(model_name=model_name, num_classes=num_classes)
    fold_model = fold_model.to(device)
    
    # Setup optimizer with different parameter groups
    backbone_params = []
    classifier_params = []
    
    for name, param in fold_model.named_parameters():
        if 'classifier' in name:
            classifier_params.append(param)
        else:
            backbone_params.append(param)
    
    # Different learning rates for backbone and classifier
    optimizer = optim.AdamW([
        {'params': backbone_params, 'lr': config['learning_rate'] * 0.1},  # Lower LR for pretrained
        {'params': classifier_params, 'lr': config['learning_rate']}        # Higher LR for classifier
    ], weight_decay=config['weight_decay'])
    
    # Scheduler with warmup
    total_steps = len(train_loader) * config['num_epochs']
    warmup_steps = len(train_loader) * config['warmup_epochs']
    
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=[config['learning_rate'] * 0.1, config['learning_rate']],
        total_steps=total_steps,
        pct_start=config['warmup_epochs'] / config['num_epochs'],
        anneal_strategy='cos'
    )
    
    # Loss function with label smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    # Training history for this fold
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    best_val_acc = 0.0
    patience_counter = 0
    patience = 5
    
    # Training loop
    for epoch in range(config['num_epochs']):
        print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
        print("-" * 40)
        
        # Train
        train_loss, train_acc = train_epoch_swin(
            fold_model, train_loader, criterion, optimizer, scheduler, device, epoch)
        
        # Validate
        val_loss, val_acc, val_predictions, val_labels = validate_epoch_swin(
            fold_model, val_loader, criterion, device)
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Early stopping and best model saving
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            best_models[fold] = {
                'model_state': fold_model.state_dict().copy(),
                'val_acc': val_acc,
                'predictions': val_predictions,
                'labels': val_labels
            }
        else:
            patience_counter += 1
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        print(f"Best Val Acc: {best_val_acc:.4f}")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
        
        # Early stopping
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Store fold results
    fold_results[fold] = history
    print(f"\nFold {fold+1} Best Validation Accuracy: {best_val_acc:.4f}")

print("\n" + "="*60)
print("SWIN TRANSFORMER CROSS VALIDATION COMPLETED")
print("="*60)

## Results Analysis and Visualization

In [None]:
# Calculate overall performance metrics
fold_train_accs = []
fold_val_accs = []
fold_train_losses = []
fold_val_losses = []

for fold in range(config['num_folds']):
    history = fold_results[fold]
    fold_train_accs.append(max(history['train_acc']))
    fold_val_accs.append(max(history['val_acc']))
    fold_train_losses.append(min(history['train_loss']))
    fold_val_losses.append(min(history['val_loss']))

# Print summary statistics
print("Swin Transformer Performance Summary")
print("=" * 50)
print(f"Average Training Accuracy: {np.mean(fold_train_accs):.4f} ± {np.std(fold_train_accs):.4f}")
print(f"Average Validation Accuracy: {np.mean(fold_val_accs):.4f} ± {np.std(fold_val_accs):.4f}")
print(f"Average Training Loss: {np.mean(fold_train_losses):.4f} ± {np.std(fold_train_losses):.4f}")
print(f"Average Validation Loss: {np.mean(fold_val_losses):.4f} ± {np.std(fold_val_losses):.4f}")
print(f"Best Validation Accuracy: {max(fold_val_accs):.4f}")

# Store results for comparison
swin_results = {
    'model_name': 'Swin Transformer',
    'avg_train_acc': np.mean(fold_train_accs),
    'avg_val_acc': np.mean(fold_val_accs),
    'std_train_acc': np.std(fold_train_accs),
    'std_val_acc': np.std(fold_val_accs),
    'avg_train_loss': np.mean(fold_train_losses),
    'avg_val_loss': np.mean(fold_val_losses),
    'best_val_acc': max(fold_val_accs),
    'fold_results': fold_results,
    'config': config
}

In [None]:
# Enhanced visualization for Swin Transformer results
fig, axes = plt.subplots(2, 3, figsize=(20, 12))

# Plot training curves for each fold
for fold in range(config['num_folds']):
    history = fold_results[fold]
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Training and validation loss
    axes[0, 0].plot(epochs, history['train_loss'], label=f'Fold {fold+1}', alpha=0.7)
    axes[0, 1].plot(epochs, history['val_loss'], label=f'Fold {fold+1}', alpha=0.7)
    
    # Training and validation accuracy
    axes[1, 0].plot(epochs, history['train_acc'], label=f'Fold {fold+1}', alpha=0.7)
    axes[1, 1].plot(epochs, history['val_acc'], label=f'Fold {fold+1}', alpha=0.7)

# Loss plots
axes[0, 0].set_title('Swin Training Loss', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].set_title('Swin Validation Loss', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Accuracy plots
axes[1, 0].set_title('Swin Training Accuracy', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Accuracy')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].set_title('Swin Validation Accuracy', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Performance comparison boxplot
perf_data = [fold_train_accs, fold_val_accs]
axes[0, 2].boxplot(perf_data, labels=['Train Acc', 'Val Acc'])
axes[0, 2].set_title('Swin Performance Distribution', fontsize=14, fontweight='bold')
axes[0, 2].set_ylabel('Accuracy')
axes[0, 2].grid(True, alpha=0.3)

# Fold-wise performance
fold_nums = range(1, config['num_folds'] + 1)
axes[1, 2].bar([x - 0.2 for x in fold_nums], fold_train_accs, 0.4, label='Train Acc', alpha=0.7)
axes[1, 2].bar([x + 0.2 for x in fold_nums], fold_val_accs, 0.4, label='Val Acc', alpha=0.7)
axes[1, 2].set_title('Swin Fold-wise Performance', fontsize=14, fontweight='bold')
axes[1, 2].set_xlabel('Fold')
axes[1, 2].set_ylabel('Accuracy')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)
axes[1, 2].set_xticks(fold_nums)

plt.suptitle('Swin Transformer Training Analysis', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Confusion matrix for best performing fold
best_fold = max(best_models.keys(), key=lambda k: best_models[k]['val_acc'])
best_predictions = best_models[best_fold]['predictions']
best_labels = best_models[best_fold]['labels']

# Create confusion matrix
cm = confusion_matrix(best_labels, best_predictions)
label_names = list(dataset.label_encoder.keys())

plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=label_names, yticklabels=label_names, 
            cbar_kws={'label': 'Count'})
plt.title(f'Swin Transformer Confusion Matrix\nBest Fold: {best_fold+1} (Accuracy: {best_models[best_fold]["val_acc"]:.4f})', 
         fontsize=14, fontweight='bold')
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# Classification report
print(f"\nClassification Report - Swin Transformer (Best Fold {best_fold+1}):")
print("=" * 70)
print(classification_report(best_labels, best_predictions, target_names=label_names))

In [None]:
# Per-class analysis
from sklearn.metrics import precision_recall_fscore_support

precision, recall, f1, support = precision_recall_fscore_support(
    best_labels, best_predictions, average=None)

# Create per-class performance DataFrame
class_performance = pd.DataFrame({
    'Class': label_names,
    'Precision': precision,
    'Recall': recall,
    'F1-Score': f1,
    'Support': support
})

print("\nPer-class Performance Analysis:")
print(class_performance.round(4))

# Visualize per-class performance
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

metrics = ['Precision', 'Recall', 'F1-Score']
for i, metric in enumerate(metrics):
    bars = axes[i].bar(label_names, class_performance[metric], alpha=0.7, color=plt.cm.Set3(range(len(label_names))))
    axes[i].set_title(f'Swin Transformer {metric} by Class', fontweight='bold')
    axes[i].set_ylabel(metric)
    axes[i].set_ylim(0, 1.05)
    axes[i].tick_params(axis='x', rotation=45)
    axes[i].grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        axes[i].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{height:.3f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

In [None]:
# Save the results and best model
import pickle

# Save Swin results
with open('swin_wafer_classification_results.pkl', 'wb') as f:
    pickle.dump(swin_results, f)

# Save best model
best_model_path = 'best_swin_wafer_model.pth'
torch.save({
    'model_state_dict': best_models[best_fold]['model_state'],
    'model_name': model_name,
    'num_classes': num_classes,
    'label_encoder': dataset.label_encoder,
    'config': config,
    'val_accuracy': best_models[best_fold]['val_acc'],
    'class_performance': class_performance.to_dict()
}, best_model_path)

print(f"Results saved to: swin_wafer_classification_results.pkl")
print(f"Best model saved to: {best_model_path}")
print(f"Best validation accuracy: {best_models[best_fold]['val_acc']:.4f}")
print(f"\nModel Summary:")
print(f"- Architecture: {model_name}")
print(f"- Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"- Training Time: {config['num_epochs']} epochs max per fold")
print(f"- Best Performance: {max(fold_val_accs):.4f} accuracy")

## Swin Transformer Analysis and Insights

This notebook implemented Swin Transformer for wafer defect classification with several key innovations:

### Key Features of Swin Implementation:

1. **Hierarchical Feature Learning**: Swin's shifted window attention captures both local and global patterns
2. **Enhanced Data Preprocessing**: Optimized RGB encoding for hierarchical feature extraction
3. **Advanced Training Strategy**: 
   - Differential learning rates for backbone vs classifier
   - Warmup and cosine annealing schedules
   - Label smoothing and gradient clipping
4. **Early Stopping**: Prevents overfitting with patience-based stopping
5. **Comprehensive Evaluation**: Per-class metrics and detailed performance analysis

### Advantages of Swin Transformer:
- **Efficient Attention**: Linear computational complexity w.r.t. image size
- **Hierarchical Representations**: Multi-scale feature learning like CNNs
- **Transfer Learning**: Strong pretrained representations
- **Robustness**: Self-attention mechanism handles various defect patterns

### Performance Characteristics:
- Better handling of complex spatial patterns
- Improved generalization through attention mechanisms
- Strong performance on geometric defect patterns
- Efficient memory usage compared to standard ViT

### Next Steps:
1. Create comprehensive comparison with CNN and ViT
2. Implement wafer life expectancy prediction
3. Statistical significance testing
4. Model ensemble strategies