In [1]:
import os
import glob
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from IPython.display import display

from datasets.wafer import WM811K
from datasets.wafer import get_dataloader
from datasets.transforms import WM811KTransform
from models.config import RESNET_BACKBONE_CONFIGS
from models.resnet import ResNetBackbone
from models.head import GAPClassifier

In [2]:
test_transform = WM811KTransform(size=(96, 96), mode='test')
test_set = WM811K('../data/wm811k/labeled/test', transform=test_transform)

In [3]:
labels = test_set.idx2label[:-1]
print(*[f"{i}:{l}" for i, l in enumerate(labels)], sep='\n')

0:center
1:donut
2:edge-loc
3:edge-ring
4:loc
5:random
6:scratch
7:near-full
8:none


In [4]:
backbone = ResNetBackbone(
    layer_config=RESNET_BACKBONE_CONFIGS['18.original'],
    in_channels=2,
)

# print(backbone)

In [5]:
head = GAPClassifier(
    in_channels=backbone.out_channels,
    num_classes=test_set.num_classes,
    dropout=0.
)

# print(head)

In [15]:
def inference(dataset: torch.utils.data.Dataset, model: nn.Module, device: str):
    
    data_loader = get_dataloader(
        dataset=dataset,
        batch_size=1024,
        shuffle=False,
        pin_memory=True,
    )
    
    model = model.to(device)
    
    num_classes = dataset.num_classes
    logits = torch.empty(len(dataset), num_classes, dtype=torch.float, device=device)
    targets = torch.empty(len(dataset), dtype=torch.long, device=device)
    
    for i, batch in enumerate(data_loader):
        
        i = batch['idx']
        x = batch['x'].to(device)
        y_pred = model(x)
        
        logits[i] = y_pred
        targets[i] = batch['y'].to(device)
        print('.', end='')
    
    model = model.to('cpu')
    torch.cuda.empty_cache()

    return logits.cpu(), targets.cpu()

In [16]:
from pytorch_lightning.metrics.functional import confusion_matrix

def plot_conf_mat(logits: torch.Tensor,
                  targets: torch.Tensor,
                  labels: list,
                  normalize=False):

    pred = logits.argmax(dim=1)
    
    cm = confusion_matrix(pred, targets, normalize=normalize)
    if not normalize:
        cm = cm.long()
    cm = pd.DataFrame(cm.numpy(), index=labels, columns=labels)
    
    with pd.option_context('precision', 3):
        display(cm.style.background_gradient(cmap=plt.cm.Greens, axis=0))

In [17]:
from pytorch_lightning.metrics.functional import stat_scores_multiple_classes

def plot_stat_scores(logits: torch.Tensor,
                     targets: torch.Tensor,
                     labels: list):
    
    probs = nn.functional.softmax(logits, dim=1)
    
    tp, fp, tn, fn, _ = stat_scores_multiple_classes(probs, targets, len(labels))
    precision = tp / (fp + tp)
    recall = tp / (fn + tp)
    f1 = 2 * ((precision * recall) / (precision + recall))
    
    stats = torch.stack([precision, recall, f1], dim=0)
    stats = pd.DataFrame(stats.numpy(), index=['precision', 'recall', 'f1'], columns=labels)
    stats['Average'] = stats.mean(axis=1)

    with pd.option_context("precision", 3):
        display(stats.style.background_gradient(cmap=plt.cm.Blues, axis=1))

In [18]:
def evaluate_model(ckpt: str, backbone: nn.Module, head: nn.Module, device: str):
    
    assert os.path.exists(ckpt), "File to model checkpoint."
    
    backbone.load_weights_from_checkpoint(path=ckpt, key='backbone')
    head.load_weights_from_checkpoint(path=ckpt, key='classifier')
    model = nn.Sequential(*[backbone, head])
    model.eval()
    
    with torch.no_grad():
        logits, targets = inference(dataset=test_set, model=model, device=device)
        
    plot_conf_mat(logits, targets, labels, normalize=False)
    plot_conf_mat(logits, targets, labels, normalize=True)
    plot_stat_scores(logits, targets, labels)

In [20]:
# WaPIRL (ResNet-18, Label=1%)
ckpt = '../checkpoints.balanced/wm811k/classification_pirl/resnet.18.original/SEED_003/LABEL_0.010/'
ckpt = glob.glob(os.path.join(ckpt, '**/last_model.pt'), recursive=True)[0]

evaluate_model(ckpt, backbone, head, device='cuda:0')

.................

Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,396,0,1,0,4,0,3,0,25
donut,4,35,1,0,9,0,6,0,1
edge-loc,4,0,392,19,13,13,6,1,71
edge-ring,2,0,22,931,0,0,0,0,13
loc,25,6,47,0,207,1,17,0,56
random,3,0,0,1,4,76,0,0,3
scratch,3,0,4,1,28,0,20,0,63
near-full,0,0,0,0,0,6,0,9,0
none,43,0,82,62,75,2,62,0,14417


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,0.923,0.0,0.002,0.0,0.011,0.0,0.025,0.0,0.002
donut,0.009,0.625,0.002,0.0,0.025,0.0,0.05,0.0,0.0
edge-loc,0.009,0.0,0.755,0.02,0.036,0.149,0.05,0.067,0.005
edge-ring,0.005,0.0,0.042,0.962,0.0,0.0,0.0,0.0,0.001
loc,0.058,0.107,0.091,0.0,0.577,0.011,0.143,0.0,0.004
random,0.007,0.0,0.0,0.001,0.011,0.874,0.0,0.0,0.0
scratch,0.007,0.0,0.008,0.001,0.078,0.0,0.168,0.0,0.004
near-full,0.0,0.0,0.0,0.0,0.0,0.069,0.0,0.6,0.0
none,0.1,0.0,0.158,0.064,0.209,0.023,0.521,0.0,0.978


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none,Average
precision,0.825,0.854,0.714,0.918,0.609,0.776,0.175,0.9,0.984,0.751
recall,0.923,0.625,0.755,0.962,0.577,0.874,0.168,0.6,0.978,0.718
f1,0.871,0.722,0.734,0.939,0.592,0.822,0.172,0.72,0.981,0.728


In [13]:
# Scratch (ResNet-18, Label=100%)
ckpt = '../checkpoints.balanced/wm811k/classification_scratch/resnet.18.original/SEED_001/LABEL_0.010/'
ckpt = glob.glob(os.path.join(ckpt, '**/last_model.pt'), recursive=True)[0]

evaluate_model(ckpt, backbone, head, device='cuda:0')

.................

Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,341,28,1,0,5,4,47,1,2
donut,0,42,1,0,7,0,6,0,0
edge-loc,4,2,255,86,25,6,42,0,99
edge-ring,1,0,46,908,1,0,8,0,4
loc,20,15,54,1,129,3,74,0,63
random,2,5,3,1,0,75,0,1,0
scratch,0,1,5,2,26,2,52,0,31
near-full,0,0,0,0,0,5,0,10,0
none,88,3,451,164,619,19,3242,0,10157


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,0.795,0.5,0.002,0.0,0.014,0.046,0.395,0.067,0.0
donut,0.0,0.75,0.002,0.0,0.019,0.0,0.05,0.0,0.0
edge-loc,0.009,0.036,0.491,0.089,0.07,0.069,0.353,0.0,0.007
edge-ring,0.002,0.0,0.089,0.938,0.003,0.0,0.067,0.0,0.0
loc,0.047,0.268,0.104,0.001,0.359,0.034,0.622,0.0,0.004
random,0.005,0.089,0.006,0.001,0.0,0.862,0.0,0.067,0.0
scratch,0.0,0.018,0.01,0.002,0.072,0.023,0.437,0.0,0.002
near-full,0.0,0.0,0.0,0.0,0.0,0.057,0.0,0.667,0.0
none,0.205,0.054,0.869,0.169,1.724,0.218,27.244,0.0,0.689


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none,Average
precision,0.748,0.438,0.312,0.781,0.159,0.658,0.015,0.833,0.981,0.547
recall,0.795,0.75,0.491,0.938,0.359,0.862,0.437,0.667,0.689,0.665
f1,0.771,0.553,0.382,0.853,0.22,0.746,0.029,0.741,0.809,0.567


In [14]:
torch.cuda.empty_cache()

In [21]:
# scratch model

ckpt = '../checkpoints/wm811k/classification_scratch/resnet.18.original/'
ckpt += 'SEED_001/LABEL_1.000/**/best_model.pt'
ckpt = glob.glob(ckpt, recursive=True)[-1]

evaluate_model(ckpt, backbone, head, device='cuda:1')

....................................................................

Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,412,1,0,0,2,0,0,0,14
donut,3,49,0,0,4,0,0,0,0
edge-loc,1,1,436,19,11,0,0,0,51
edge-ring,0,0,6,959,0,0,0,0,3
loc,6,3,13,0,297,0,4,0,36
random,2,0,0,0,4,81,0,0,0
scratch,0,1,1,0,4,0,94,0,19
near-full,0,0,0,0,0,0,0,15,0
none,12,1,30,12,18,2,18,0,14650


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,0.96,0.018,0.0,0.0,0.006,0.0,0.0,0.0,0.001
donut,0.007,0.875,0.0,0.0,0.011,0.0,0.0,0.0,0.0
edge-loc,0.002,0.018,0.84,0.02,0.031,0.0,0.0,0.0,0.003
edge-ring,0.0,0.0,0.012,0.991,0.0,0.0,0.0,0.0,0.0
loc,0.014,0.054,0.025,0.0,0.827,0.0,0.034,0.0,0.002
random,0.005,0.0,0.0,0.0,0.011,0.931,0.0,0.0,0.0
scratch,0.0,0.018,0.002,0.0,0.011,0.0,0.79,0.0,0.001
near-full,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
none,0.028,0.018,0.058,0.012,0.05,0.023,0.151,0.0,0.994


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none,Average
precision,0.945,0.875,0.897,0.969,0.874,0.976,0.81,1.0,0.992,0.926
recall,0.96,0.875,0.84,0.991,0.827,0.931,0.79,1.0,0.994,0.912
f1,0.953,0.875,0.868,0.98,0.85,0.953,0.8,1.0,0.993,0.919


In [22]:
# scratch model

ckpt = '../checkpoints.balanced/wm811k/classification_scratch/resnet.18.original/'
ckpt += 'SEED_001/LABEL_1.000/**/best_model.pt'
ckpt = glob.glob(ckpt, recursive=True)[-1]

evaluate_model(ckpt, backbone, head, device='cuda:1')

....................................................................

Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,422,2,0,0,2,0,0,0,3
donut,2,50,0,0,4,0,0,0,0
edge-loc,0,0,478,10,11,0,2,0,18
edge-ring,0,0,10,958,0,0,0,0,0
loc,8,2,20,0,299,1,12,0,17
random,2,0,1,0,3,80,0,1,0
scratch,0,0,2,0,5,0,107,0,5
near-full,0,0,0,0,0,0,0,15,0
none,36,1,105,19,57,6,71,0,14448


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,0.984,0.036,0.0,0.0,0.006,0.0,0.0,0.0,0.0
donut,0.005,0.893,0.0,0.0,0.011,0.0,0.0,0.0,0.0
edge-loc,0.0,0.0,0.921,0.01,0.031,0.0,0.017,0.0,0.001
edge-ring,0.0,0.0,0.019,0.99,0.0,0.0,0.0,0.0,0.0
loc,0.019,0.036,0.039,0.0,0.833,0.011,0.101,0.0,0.001
random,0.005,0.0,0.002,0.0,0.008,0.92,0.0,0.067,0.0
scratch,0.0,0.0,0.004,0.0,0.014,0.0,0.899,0.0,0.0
near-full,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
none,0.084,0.018,0.202,0.02,0.159,0.069,0.597,0.0,0.98


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none,Average
precision,0.898,0.909,0.776,0.971,0.785,0.92,0.557,0.938,0.997,0.861
recall,0.984,0.893,0.921,0.99,0.833,0.92,0.899,1.0,0.98,0.935
f1,0.939,0.901,0.842,0.98,0.808,0.92,0.688,0.968,0.988,0.893


In [114]:
# Pretrained model

ckpt = '../checkpoints/wm811k/classification_pirl/resnet.18.original/'
ckpt += 'SEED_001/LABEL_0.010/**/best_model.pt'
ckpt = glob.glob(ckpt, recursive=True)
if not len(ckpt) == 1:
    raise ValueError
ckpt = ckpt[0]

evaluate_model(ckpt, backbone, head, device='cuda:3')

....................................................................

Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,371,10,3,0,8,2,0,0,35
donut,1,43,1,0,5,4,0,0,2
edge-loc,0,0,341,19,34,7,5,0,113
edge-ring,0,0,65,864,0,1,0,0,38
loc,5,10,19,0,221,4,19,0,81
random,3,5,1,0,1,69,0,4,4
scratch,0,1,6,0,18,1,18,0,75
near-full,0,0,0,0,0,2,0,13,0
none,21,1,49,3,20,0,38,0,14611


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,0.865,0.179,0.006,0.0,0.022,0.023,0.0,0.0,0.002
donut,0.002,0.768,0.002,0.0,0.014,0.046,0.0,0.0,0.0
edge-loc,0.0,0.0,0.657,0.02,0.095,0.08,0.042,0.0,0.008
edge-ring,0.0,0.0,0.125,0.893,0.0,0.011,0.0,0.0,0.003
loc,0.012,0.179,0.037,0.0,0.616,0.046,0.16,0.0,0.005
random,0.007,0.089,0.002,0.0,0.003,0.793,0.0,0.267,0.0
scratch,0.0,0.018,0.012,0.0,0.05,0.011,0.151,0.0,0.005
near-full,0.0,0.0,0.0,0.0,0.0,0.023,0.0,0.867,0.0
none,0.049,0.018,0.094,0.003,0.056,0.0,0.319,0.0,0.991


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none,Average
precision,0.925,0.614,0.703,0.975,0.72,0.767,0.225,0.765,0.977,0.741
recall,0.865,0.768,0.657,0.893,0.616,0.793,0.151,0.867,0.991,0.733
f1,0.894,0.683,0.679,0.932,0.664,0.78,0.181,0.812,0.984,0.734


In [115]:
# scratch model

ckpt = '../checkpoints/wm811k/classification_scratch/resnet.18.original/'
ckpt += 'SEED_001/LABEL_0.010/**/best_model.pt'
ckpt = glob.glob(ckpt, recursive=True)
if not len(ckpt) == 1:
    raise ValueError
ckpt = ckpt[0]

evaluate_model(ckpt, backbone, head, device='cuda:3')

....................................................................

Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,369,13,0,0,2,2,0,0,43
donut,2,39,1,0,5,8,0,0,1
edge-loc,0,0,295,67,17,5,0,0,135
edge-ring,0,0,49,886,0,0,0,0,33
loc,14,10,55,1,178,6,2,0,93
random,2,3,3,0,0,76,0,2,1
scratch,0,0,6,1,17,3,3,0,89
near-full,0,0,0,0,0,4,0,11,0
none,24,0,51,32,32,9,4,0,14591


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,0.86,0.232,0.0,0.0,0.006,0.023,0.0,0.0,0.003
donut,0.005,0.696,0.002,0.0,0.014,0.092,0.0,0.0,0.0
edge-loc,0.0,0.0,0.568,0.069,0.047,0.057,0.0,0.0,0.009
edge-ring,0.0,0.0,0.094,0.915,0.0,0.0,0.0,0.0,0.002
loc,0.033,0.179,0.106,0.001,0.496,0.069,0.017,0.0,0.006
random,0.005,0.054,0.006,0.0,0.0,0.874,0.0,0.133,0.0
scratch,0.0,0.0,0.012,0.001,0.047,0.034,0.025,0.0,0.006
near-full,0.0,0.0,0.0,0.0,0.0,0.046,0.0,0.733,0.0
none,0.056,0.0,0.098,0.033,0.089,0.103,0.034,0.0,0.99


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none,Average
precision,0.898,0.6,0.641,0.898,0.709,0.673,0.333,0.846,0.974,0.73
recall,0.86,0.696,0.568,0.915,0.496,0.874,0.025,0.733,0.99,0.684
f1,0.879,0.645,0.603,0.906,0.584,0.76,0.047,0.786,0.982,0.688


In [116]:
# Pretrained model

ckpt = '../checkpoints/wm811k/classification_pirl/resnet.18.original/'
ckpt += 'SEED_001/LABEL_0.100/**/best_model.pt'
ckpt = glob.glob(ckpt, recursive=True)
if not len(ckpt) == 1:
    raise ValueError
ckpt = ckpt[0]

evaluate_model(ckpt, backbone, head, device='cuda:3')

....................................................................

Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,407,2,0,0,1,0,0,0,19
donut,3,48,0,0,3,1,0,0,1
edge-loc,2,1,419,13,16,1,0,0,67
edge-ring,0,0,13,929,0,1,0,0,25
loc,14,8,23,0,262,0,17,0,35
random,5,2,7,0,1,68,1,2,1
scratch,2,1,3,0,18,2,34,0,59
near-full,0,0,2,0,0,1,0,12,0
none,19,1,41,8,28,2,19,0,14625


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,0.949,0.036,0.0,0.0,0.003,0.0,0.0,0.0,0.001
donut,0.007,0.857,0.0,0.0,0.008,0.011,0.0,0.0,0.0
edge-loc,0.005,0.018,0.807,0.013,0.045,0.011,0.0,0.0,0.005
edge-ring,0.0,0.0,0.025,0.96,0.0,0.011,0.0,0.0,0.002
loc,0.033,0.143,0.044,0.0,0.73,0.0,0.143,0.0,0.002
random,0.012,0.036,0.013,0.0,0.003,0.782,0.008,0.133,0.0
scratch,0.005,0.018,0.006,0.0,0.05,0.023,0.286,0.0,0.004
near-full,0.0,0.0,0.004,0.0,0.0,0.011,0.0,0.8,0.0
none,0.044,0.018,0.079,0.008,0.078,0.023,0.16,0.0,0.992


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none,Average
precision,0.9,0.762,0.825,0.978,0.796,0.895,0.479,0.857,0.986,0.831
recall,0.949,0.857,0.807,0.96,0.73,0.782,0.286,0.8,0.992,0.796
f1,0.924,0.807,0.816,0.969,0.762,0.834,0.358,0.828,0.989,0.81


In [117]:
# scratch model

ckpt = '../checkpoints/wm811k/classification_scratch/resnet.18.original/'
ckpt += 'SEED_001/LABEL_0.100/**/best_model.pt'
ckpt = glob.glob(ckpt, recursive=True)
if not len(ckpt) == 1:
    raise ValueError
ckpt = ckpt[0]

evaluate_model(ckpt, backbone, head, device='cuda:3')

....................................................................

Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,403,2,0,0,0,0,0,0,24
donut,2,44,0,0,9,0,0,0,1
edge-loc,1,0,400,19,19,1,0,0,79
edge-ring,0,0,10,939,0,1,0,0,18
loc,12,8,25,0,249,0,17,0,48
random,6,3,1,0,0,74,0,1,2
scratch,0,1,4,1,18,1,27,0,67
near-full,0,0,0,0,0,3,0,12,0
none,19,1,36,11,13,1,12,0,14650


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none
center,0.939,0.036,0.0,0.0,0.0,0.0,0.0,0.0,0.002
donut,0.005,0.786,0.0,0.0,0.025,0.0,0.0,0.0,0.0
edge-loc,0.002,0.0,0.771,0.02,0.053,0.011,0.0,0.0,0.005
edge-ring,0.0,0.0,0.019,0.97,0.0,0.011,0.0,0.0,0.001
loc,0.028,0.143,0.048,0.0,0.694,0.0,0.143,0.0,0.003
random,0.014,0.054,0.002,0.0,0.0,0.851,0.0,0.067,0.0
scratch,0.0,0.018,0.008,0.001,0.05,0.011,0.227,0.0,0.005
near-full,0.0,0.0,0.0,0.0,0.0,0.034,0.0,0.8,0.0
none,0.044,0.018,0.069,0.011,0.036,0.011,0.101,0.0,0.994


Unnamed: 0,center,donut,edge-loc,edge-ring,loc,random,scratch,near-full,none,Average
precision,0.91,0.746,0.84,0.968,0.808,0.914,0.482,0.923,0.984,0.842
recall,0.939,0.786,0.771,0.97,0.694,0.851,0.227,0.8,0.994,0.781
f1,0.924,0.765,0.804,0.969,0.747,0.881,0.309,0.857,0.989,0.805
