# Supervised Training on Wafer Map Data

In this notebook, we train a slightly modified ResNet-18 on our training splits in a fully-supervised, end-to-end fashion. This serves as a baseline to compare performance against fine-tuned models that are pretrained in a self-supervised fashion.

## Imports

In [4]:
import lightly
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from lightly.data import LightlyDataset
from lightly.transforms.rotation import RandomRotate
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassAUROC,
    MulticlassF1Score,
)
from torchvision.transforms.functional import InterpolationMode

## Loading in Wafer Data

Below we load in the cleaned data splits for each of the training sets (corresponding to 1\%, 10\%, and 20\% of the total labels), as well as the validation set and test set.

In [5]:
root = "../data/cleaned_splits"
train_1_split = pd.read_pickle(f"{root}/train_1_split.pkl")
train_10_split = pd.read_pickle(f"{root}train_10_split.pkl")
train_20_split = pd.read_pickle(f"{root}/train_20_split.pkl")
val_data = pd.read_pickle(f"{root}/val_data.pkl")
test_data = pd.read_pickle(f"{root}/test_data.pkl")

In [6]:
class WaferMapDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.data = pd.concat([X, y], axis="columns")
        # All resizing is done in augmentations, so we have tensors/arraays of different sizes
        # Because of this, just create a list of tensors
        self.X_list = [torch.tensor(ndarray) for ndarray in X]
        self.y_list = [torch.tensor(ndarray) for ndarray in y]
        self.transform = transform

    def __getitem__(self, index):
        x = self.X_list[index]
        y = self.y_list[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.X_list)

## Transforms for Datasets

Here we create the sets of transforms for training and holdout datasets. We use the same augmentations for the training data as we used for self-supervised pretraining. The reason for doing so is to investigate whether a fully supervised baseline trained end-to-end can learn visual representations that are invariant to these augmentations.

In [7]:
class DieNoise:
    def __call__(self, sample: torch.Tensor) -> torch.Tensor:
        def flip(item, p=0.03):
            """
            Given a wafermap die, flips pass to fail and vice-versa with probability p.
            Does nothing to non-die area (0's if 128's and 255's are passes/fails respectively).
            """
            prob = np.random.choice([False, True], p=[1 - p, p])
            if prob:
                if item == 128:
                    return 255
                elif item == 255:
                    return 128
                else:
                    return item
            return item

        vflip = np.vectorize(flip)
        out = vflip(sample)
        return torch.from_numpy(out)

train_transforms = T.Compose([
    # Add die noise before anything else
    DieNoise(),
    # Convert to PIL Image, then perform all torchvision transforms
    T.ToPILImage(),
    T.Resize([128, 128], interpolation=InterpolationMode.NEAREST),
    RandomRotate(),
    T.RandomVerticalFlip(),
    T.RandomHorizontalFlip(),
    T.RandomApply(
        torch.nn.ModuleList(
            [T.RandomRotation(90, interpolation=InterpolationMode.NEAREST)]
        ),
        0.25,
    ),
    # Finally, create a 3-channel image since we use ResNet, and convert to tensor
    T.Grayscale(num_output_channels=3),  # R == G == B
    T.ToTensor(),
])

# Test transforms for val and test datasets don't contain any augmentations
# All that's here is to ensure data quality: correct size, correct channels, tensor
test_transforms = T.Compose([
    T.ToPILImage(),
    T.Resize([128, 128], interpolation=InterpolationMode.NEAREST),
    T.Grayscale(num_output_channels=3),
    T.ToTensor(),
])

Our fully supervised network is defined below. It's more or less just a modified ResNet-18 from lightly. The differneces between their implementation and the torchvision versions are explained [here](https://docs.lightly.ai/self-supervised-learning/lightly.models.html):

> *Note that the architecture we present here differs from the one used in torchvision. We replace the first $7 \times 7$ convolution by a $3 \times 3$ convolution to make the model faster and run better on smaller input image resolutions*

Several classification metrics from [torchmetrics](https://torchmetrics.readthedocs.io/en/stable/) are also logged to monitor performance on the training, validation, and test set.

In [8]:
class WaferMapModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        resnet = lightly.models.ResNetGenerator("resnet-18")
        # lightly's resnet doesn't have the AdaptiveAvgPool, so add it back in
        self.backbone = nn.Sequential(
            *list(resnet.children())[:-1],  # everything up to but not including Linear layer at end
            nn.AdaptiveAvgPool2d(1),  # add in the AdaptiveAvgPool
        )
        self.fc = nn.Linear(512, 9)  # add in the Linear layer at the end

        # Note that MulticlassAccuracy and MulticlassAUROC use macro averaging by default
        self.val_accuracy = MulticlassAccuracy(num_classes=9)
        self.val_auroc = MulticlassAUROC(num_classes=9)
        self.val_f1 = MulticlassF1Score(num_classes=9, average="macro")
        

        self.test_accuracy = MulticlassAccuracy(num_classes=9)
        self.test_auroc = MulticlassAUROC(num_classes=9)
        self.test_f1 = MulticlassF1Score(num_classes=9, average="macro")

    def forward(self, x):
        f = self.backbone(x).flatten(start_dim=1)
        p = self.fc(f)
        return F.log_softmax(p, dim=1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        # preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(logits, y)
        self.val_auroc.update(logits, y)
        self.val_f1.update(logits, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)
        self.log("val_auroc", self.val_auroc, prog_bar=True)
        self.log("val_f1", self.val_f1, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(logits, y)
        self.test_auroc.update(logits, y)
        self.test_f1.update(logits, y)

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)
        self.log("test_auroc", self.test_auroc, prog_bar=True)
        self.log("test_f1", self.test_f1, prog_bar=True)


    def configure_optimizers(self):
        optim = torch.optim.AdamW(self.parameters())
        return optim

## Training, Validating, and Testing

The `train_val_test` function creates datasets and dataloaders from the train, validation, and test data that is passed in. It then trains a model on the training data for a maximum of 50 epochs with early stopping based on the validation loss (training will end early if the validation loss does not improve for a certain number of epochs). After training is complete, the Accuracy, AUROC, and F1 scores are displayed for the test set.

In [10]:
def train_val_test(train_data, val_data, test_data, batch_size=32):
    train_dataset = WaferMapDataset(
        X=train_data.waferMap, y=train_data.failureCode, transform=train_transforms
    )
    val_dataset = WaferMapDataset(
        X=val_data.waferMap, y=val_data.failureCode, transform=test_transforms
    )
    test_dataset = WaferMapDataset(
        X=test_data.waferMap, y=test_data.failureCode, transform=test_transforms
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, drop_last=False
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, drop_last=False
    )

    model = WaferMapModule()

    accelerator = "gpu" if torch.cuda.is_available() else "cpu"

    trainer = pl.Trainer(
        max_epochs=50,
        accelerator=accelerator,
        devices=-1,
        callbacks=[
            EarlyStopping(
                monitor="val_loss",
                mode="min",
                patience=5 if len(train_data) < 1000 else 3,
            )
        ],
        precision=16,
        enable_progress_bar=False
    )
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    trainer.test(model, test_loader)

    # delete model and trainer + free up cuda memory
    del model
    del trainer
    # torch.cuda.empty_cache()

In [11]:
import warnings

# suppress annoying torchmetrics and lightning warnings
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*meaningless.*")
warnings.filterwarnings("ignore", ".*log_every_n_steps.*")

# Train and evaluate on 1%, then 10%, then 20% of the labeled data
for train_split in (train_1_split, train_10_split, train_20_split):
    train_val_test(train_split, val_data, test_data)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | Sequential         | 11.2 M
1 | fc            | Linear             | 4.6 K 
2 | val_accuracy  | MulticlassAccuracy | 0     
3 | val_auroc     | MulticlassAUROC    | 0     
4 | val_f1        | MulticlassF1Score  | 0     
5 | test_accuracy | MulticlassAccuracy | 0     
6 | test_auroc    | MulticlassAUROC    | 0     
7 | test_f1       | MulticlassF1Score  | 0     
-----------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
22.347    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

NaN or Inf found in input tensor.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.4580616354942322
       test_auroc           0.9077513217926025
         test_f1            0.4450047016143799
        test_loss           0.6908338665962219
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | Sequential         | 11.2 M
1 | fc            | Linear             | 4.6 K 
2 | val_accuracy  | MulticlassAccuracy | 0     
3 | val_auroc     | MulticlassAUROC    | 0     
4 | val_f1        | MulticlassF1Score  | 0     
5 | test_accuracy | MulticlassAccuracy | 0     
6 | test_auroc    | MulticlassAUROC    | 0     
7 | test_f1       | MulticlassF1Score  | 0     
-----------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
22.347    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc             0.605180561542511
       test_auroc           0.9592916965484619
         test_f1            0.5469847917556763
        test_loss            0.470514178276062
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | Sequential         | 11.2 M
1 | fc            | Linear             | 4.6 K 
2 | val_accuracy  | MulticlassAccuracy | 0     
3 | val_auroc     | MulticlassAUROC    | 0     
4 | val_f1        | MulticlassF1Score  | 0     
5 | test_accuracy | MulticlassAccuracy | 0     
6 | test_auroc    | MulticlassAUROC    | 0     
7 | test_f1       | MulticlassF1Score  | 0     
-----------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
22.347    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc             0.763986349105835
       test_auroc            0.989201009273529
         test_f1            0.7402098774909973
        test_loss           0.24100734293460846
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
