# Model

Convolutional AutoEncoder (CAE)

## Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (
    Dataset,
    DataLoader,
)
import torchvision.transforms as T

from sklearn.metrics import (
    f1_score,
    mean_squared_error,
    confusion_matrix,
    ConfusionMatrixDisplay,
)

In [None]:
from torchsummary import summary
from torchview import draw_graph

In [None]:
import pandas as pd
import numpy as np

from PIL import Image
import matplotlib.pyplot as plt

from glob import glob
from typing import Optional
from tqdm.notebook import tqdm

## GPU

In [None]:
def assign_gpu(gpu_id: int=0) -> None:
    device = torch.device(f'cuda:{gpu_id}')
    torch.cuda.set_device(device)
    curr_gpu = torch.cuda.current_device()
    print(f'Assigned [device:GPU:{curr_gpu}]')
    return device

In [None]:
gpu_id = 2
device = assign_gpu(gpu_id=gpu_id)

## Dataset

In [None]:
class ThresholdTransform:
    def __init__(self, thr_255):
        self.thr = thr_255 / 255.

    def __call__(self, x):
        return (x > self.thr).to(x.dtype)

In [None]:
class CustomTransform:
    
    def __init__(self, target_size: int, crop_size: int, bw_thresh: int) -> None:
        self.target_size = target_size
        self.crop_size = crop_size
        self.bw_thresh = bw_thresh
        self.transform_output = self._transform()
    
    def _transform(self):
        # Without any data augmentation
        transform = {
            'train': T.Compose([
                T.Resize((self.target_size, self.target_size)),
                T.CenterCrop(self.crop_size),
                T.Grayscale(),
                T.ToTensor(),
                ThresholdTransform(thr_255=self.bw_thresh),
            ]),
            'test': T.Compose([
                T.Resize((self.target_size, self.target_size)),
                T.CenterCrop(self.crop_size),
                T.Grayscale(),
                T.ToTensor(),
                ThresholdTransform(thr_255=self.bw_thresh),
            ]),
        }
        return transform
    
    def __call__(self, img: Image, split: str):
        return self.transform_output[split](img)

In [None]:
class CustomDataset(Dataset):
    
    label = ['normal', 'abnormal']
    
    def __init__(self, path, split: str='train',
                 extension: str='png', transform=None) -> None:
        self.path = path
        self.split = split
        self.extension = extension
        self.transform = transform
        self.image_path = []
        self.label_list = []
        
        self._get_image_path()
    
    def _get_image_path(self) -> None:
        if self.split == 'test':
            for class_ in CustomDataset.label:
                self._get_split_dir(class_)
        else:
            class_ = 'normal'
            self._get_split_dir(class_)
    
    def _get_split_dir(self, class_: str) -> None:
        label = 0 if class_ == 'normal' else 1
        img_dir_path = f'{self.path}/{class_}/*.{self.extension}'
        full_image_path = glob(img_dir_path)
        label_list = [label] * len(full_image_path)
        self.image_path += full_image_path
        self.label_list += label_list
        
    def __len__(self) -> int:
        return len(self.image_path)
    
    def __getitem__(self ,idx) -> tuple[Image, str]:
        file_path = self.image_path[idx]
        label = self.label_list[idx]
        img = Image.open(file_path)
        if self.transform is not None:
            img = self.transform(img, self.split)
        return img, label

### - Dataset Object

In [None]:
resize = 224
crop_size = 32
bw_thresh = 90

In [None]:
transform = CustomTransform(resize, crop_size, bw_thresh)
train_set = CustomDataset(path='dataset/train', split='train', transform=transform)
test_set = CustomDataset(path='dataset/test', split='test', transform=transform)

### - DataLoader Object

In [None]:
batch_size = 128

In [None]:
train_dataloader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
)
test_dataloader = DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False,
)

## Stage 1

- 국소패턴 예측

In [None]:
class Pooling(nn.Module):
    
    def __init__(self, type_: str, *args, **kwargs) -> None:
        super(Pooling, self).__init__()
        self.type_ = type_.upper()
        self.max_pooling = nn.MaxPool2d(*args, **kwargs)
        self.avg_pooling = nn.AvgPool2d(*args, **kwargs)
        
    def forward(self, x) -> torch.Tensor:
        if self.type_.startswith('MAX'):
            x = self.max_pooling(x)
        else:
            x = self.avg_pooling(x)
        return x

In [None]:
class ConvBlock(nn.Module):
    
    def __init__(self, c_in: int, c_out: int, *args, **kwargs) -> None:
        super(ConvBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(c_in, c_out, *args, **kwargs),
            nn.ReLU(),
            Pooling('max', kernel_size=2, stride=2),
        )
        
    def forward(self, x) -> torch.Tensor:
        return self.block(x)

In [None]:
class TransConvBlock(nn.Module):
    
    def __init__(self, c_in: int, c_out: int, last_layer: bool=False, *args, **kwargs) -> None:
        super(TransConvBlock, self).__init__()
        self.last_layer = last_layer
        self.block = nn.Sequential(
            nn.ConvTranspose2d(c_in, c_out, *args, **kwargs),
        )
        self.output = nn.ReLU()
    
    def forward(self, x) -> torch.Tensor:
        if self.last_layer:
            return self.block(x)
        else:
            x = self.block(x)
            x = self.output(x)
            return x

In [None]:
class CAE(nn.Module):
    
    def __init__(self, enc_channel_list: List[int]) -> None:
        super(CAE, self).__init__()
        n_layers = len(enc_channel_list)
        dec_channel_list = enc_channel_list[::-1]
        enc_layers = []
        dec_layers = []
        
        for channel_idx in range(n_layers):
            input_ch = enc_channel_list[channel_idx]
            try:
                next_layer_ch_idx = channel_idx + 1
                ouput_ch = enc_channel_list[next_layer_ch_idx]
            except IndexError:
                break
            conv_block = ConvBlock(
                input_ch,
                ouput_ch,
                kernel_size=3,
                stride=1,
                padding=1,
            )
            enc_layers.append(conv_block)
        
        # TODO: Refactor input, output channel size setting
        bottleneck_layers = [
            nn.Flatten(),
            nn.Linear(4 * 4 * 32, 128),
            nn.Sequential(
                nn.Linear(128, 4 * 4 * 32),
                nn.ReLU(),
            ),
            nn.Unflatten(dim=1, unflattened_size=(32, 4, 4)),
        ]
            
        for channel_idx in range(n_layers):
            input_ch = dec_channel_list[channel_idx]
            try:
                next_layer_ch_idx = channel_idx + 1
                ouput_ch = dec_channel_list[next_layer_ch_idx]
            except IndexError:
                break
            if channel_idx == next_layer_ch_idx - 1:
                trans_conv_block = TransConvBlock(
                    input_ch,
                    ouput_ch,
                    last_layer=True,
                    kernel_size=2,
                    stride=2,
                )
            else:
                trans_conv_block = TransConvBlock(
                    input_ch,
                    ouput_ch,
                    kernel_size=2,
                    stride=2,
                )
            dec_layers.append(trans_conv_block)

        layers = enc_layers + bottleneck_layers + dec_layers
        self.net = nn.Sequential(*layers)
        self.output = nn.Sigmoid()
        
    def forward(self, x) -> torch.Tensor:
        x = self.net(x)
        x = self.output(x)
        return x

Parameter setting

In [None]:
input_channel = 1
enc_channel_list = [input_channel, 8, 16, 32]
lr = 1e-3
n_epoch = 30

Display model

In [None]:
summary(CAE(enc_channel_list).to(device), (input_channel, crop_size, crop_size))

In [None]:
model_graph = draw_graph(
    CAE(enc_channel_list),
    input_size=(batch_size, input_channel, crop_size, crop_size),
)
model_graph.visual_graph

### - Train model

In [None]:
%%time
verbose = True
model = CAE(enc_channel_list).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) 
loss_fn = nn.MSELoss()
loss_hist = []

model.train()
for epoch in range(1, n_epoch + 1):
    train_loss = 0
    for x, _ in train_dataloader:
        x = x.to(device)
        output = model(x)
        loss = loss_fn(output, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    loss_hist.append(train_loss)
    if verbose:
        print(f'[Epoch: {epoch}/{n_epoch}] Training Loss: {train_loss:.6f}')

### - Sub pattern 예측

In [None]:
reconstruction_error = []
true_labels = []

model.eval()
for X_test, y_test in tqdm(test_set):
    X_test = X_test.to(device)
    X_test = X_test.unsqueeze(dim=0)  # [1, 32, 32] -> [1, 1, 32, 32]
    reconstruction = model(X_test)
    X_test = X_test.to('cpu').detach().numpy()
    reconstruction = reconstruction.to('cpu').detach().numpy()
    mse = mean_squared_error(reconstruction[0][0], X_test[0][0])
    
    reconstruction_error.append(mse)
    true_labels.append(y_test)

Sub pattern과 SEM Image No 매칭

In [None]:
filename_list = [x.split('/')[-1] for x in test_set.image_path]
sem_no_list = [x.split('_')[0] for x in filename_list]
container = {
    're': reconstruction_error,
    'label': true_labels,
    'image_path': test_set.image_path,
    'filename': filename_list,
    'sem_no': sem_no_list,
}
result_table = pd.DataFrame(container)
normal = result_table.query('label == 0')
abnormal = result_table.query('label == 1')

In [None]:
plt.figure(figsize=(10, 5))
_, bins,_ = plt.hist(normal['re'], bins=100, color='grey', label='Normal')
plt.hist(abnormal['re'], bins=100, color='red', alpha=0.5, label='Abnormal')
plt.legend()
threshold = max(bins)
plt.axvline(threshold, color='grey', linestyle=':');

In [None]:
prediction = (result_table['re'] > threshold).astype(int)
result_table['prediction'] = prediction

### - Reconstruction

In [None]:
random_test = np.random.choice(np.arange(0, len(test_set), 1))
filename = (test_set
            .image_path[random_test]
            .split('/')[-1])

model.eval()
X_test, y_test = test_set[random_test]
X_test = X_test.view((1, input_channel, crop_size, crop_size)).to(device)
reconstruction = model(X_test)

mse = loss_fn(reconstruction, X_test).item()
prediction = result_table.query(f'filename == "{filename}"')['prediction'].iloc[0]

print(f'{filename}')
print(f'Reconstruction Error: {mse}')
print('True Class:', y_test)
print('Predicted Class:', prediction)

transpose_axes = (1, 2, 0)
transformed_input_img = np.transpose(
    X_test.reshape(-1, crop_size, crop_size).to('cpu'),
    transpose_axes,
)
recon_img = np.transpose(
    reconstruction.reshape(-1, crop_size, crop_size).to('cpu').detach().numpy(),
    transpose_axes,
)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 3))
ax1.imshow(transformed_input_img)
ax1.set_title('Transformed Original Image')
ax2.imshow(recon_img)
ax2.set_title('Reconstructed Image')
plt.tight_layout();

## Stage 2

- 국소패턴 예측 결과를 바탕으로 전체 Sem Image에 대해 정상/불량 예측

In [None]:
path = 'dataset/test'
normal_files_t = glob(f'{path}/normal/*.png')
abnormal_files = glob(f'{path}/abnormal/*.png')

normal_sem_no_t = set([i.split('/')[-1].split('_')[0] for i in normal_files_t])
abnormal_sem_no_t = set([i.split('/')[-1].split('_')[0] for i in abnormal_files])

normal_sem_no_t = list(normal_sem_no_t)
abnormal_sem_no_t = list(abnormal_sem_no_t)

test_sem_no = result_table['sem_no'].unique()

true_label = []

for i in test_sem_no:
    if i in normal_sem_no_t:
        true_label.append(0)
    elif i in abnormal_sem_no_t:
        true_label.append(1)
        
final_prediction = pd.DataFrame({
    'test_sem_no': test_sem_no,
    'true_label': true_label,
})

sem_no_res = result_table.groupby('sem_no')

pred_list = []
for target in test_sem_no:
    sem_no_res_grp = sem_no_res.get_group(target)
    true_label = sem_no_res['label'].unique()[0]

    cnt = 0
    for subpattern_pred in sem_no_res_grp['prediction']:
        if subpattern_pred == 1:
            pred_list.append(1)
            break
        else:
            cnt += 1
    if cnt == len(sem_no_res_grp):
        pred_list.append(0)
final_prediction['prediction'] = pred_list

## Evaluation

### - F1 Score

In [None]:
f1_score(final_prediction['true_label'], final_prediction['prediction'])

### - Confusion Matrix

In [None]:
cm = confusion_matrix(final_prediction['true_label'], final_prediction['prediction'])
cm_plot = ConfusionMatrixDisplay(cm)
cm_plot.plot(cmap='Blues');

## Result Analysis

In [None]:
def plot_input(idx: int, img_path, n_row: int, n_col: int, title: Optional[str]=None) -> None:
    img_obj = Image.open(img_path)
    ax = fig.add_subplot(n_row, n_col, idx + 1)
    ax.set_axis_off()
    if title:
        ax.set(title=title)
    ax.imshow(img_obj)

### - Wrong Prediction

In [None]:
correct_cond = final_prediction['true_label'] != final_prediction['prediction']
wrong_pred = final_prediction[correct_cond]

In [None]:
ext = 'JPG'
n_sem_no = len(wrong_pred)
n_col = 4
n_row = int(np.ceil(n_sem_no / n_col))
fig = plt.figure(figsize=(15, 15))

for i in range(len(wrong_pred)):
    sem_no = wrong_pred.iloc[i, 0]
    true_label = wrong_pred.iloc[i, 1]
    pred = wrong_pred.iloc[i, 2]
    
    title = f'SEM No: {sem_no} | Class: {true_label} | Pred: {pred}'
    img_path = f'image/{sem_no}.{ext}'
    
    plot_input(idx=i, img_path=img_path, n_row=n_row, n_col=n_col, title=title)

fig.tight_layout()

In [None]:
target = result_table[result_table['sem_no'] == sem_no]

In [None]:
paths = list(target[target['label'] != target['prediction']]['image_path'])

In [None]:
n_pattern = len(paths)
n_col = 4
n_row = int(np.ceil(n_pattern / n_col))
fig = plt.figure(figsize=(15, 15))

for i, file in enumerate(paths):
    sem_no = file.split('/')[-1].split('.')[0]
    title = f'Image: {sem_no}'
    plot_input(idx=i, img_path=file, n_row=n_row, n_col=n_col, title=title)
    
fig.tight_layout()

### - By Sample

In [None]:
sample_filepath = (result_table
                   .query('sem_no == "00341"')
                   .query('prediction == 1')
                   .get('image_path'))

In [None]:
n_sample = len(sample_filepath)
n_col = 4
n_row = int(np.ceil(n_sample / n_col))

fig = plt.figure(figsize=(15, 15))

for i, path in enumerate(sample_filepath):
    plot_input(idx=i, img_path=path, n_row=n_row, n_col=n_col, title=path)
    
fig.tight_layout()