# 繰り返し処理による異常検知

---
## 目的

Auto Encoder (AE) を用いた繰り返し処理による異常検知の仕組みについて理解する．

## 準備

### Google Colaboratoryの設定確認・変更
本チュートリアルではPyTorchを利用してニューラルネットワークの実装を確認，学習および評価を行います．
**GPUを用いて処理を行うために，上部のメニューバーの「ランタイム」→「ランタイムのタイプを変更」からハードウェアアクセラレータをGPUにしてください．**

## モジュールのインポート
はじめに必要なモジュールをインポートする．

### GPUの確認
GPUを使用した計算が可能かどうかを確認します．

`GPU availability: True`と表示されれば，GPUを使用した計算をChainerで行うことが可能です．
Falseとなっている場合は，上記の「Google Colaboratoryの設定確認・変更」に記載している手順にしたがって，設定を変更した後に，モジュールのインポートから始めてください．

In [None]:
# モジュールのインポート
import os
from time import time
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
import torchvision.transforms as transforms

# GPUの確認
use_cuda = torch.cuda.is_available()
print('Use CUDA:', use_cuda)

## データセット

**MVTec-AD**

演習に使用するデータセット


In [None]:
!wget http://www.mprg.cs.chubu.ac.jp/~hirakawa/share/tutorial_data/anomaly_detection_data.zip
!unzip anomaly_detection_data.zip

In [None]:
class MVTecAD(torch.utils.data.Dataset):
    def __init__(self, image_dir, transform):
        self.transform = transform
        self.image_dir = image_dir
        
    def __len__(self):
        return len(os.listdir(self.image_dir))
    
    def __getitem__(self, i):
        filename = '{:0>3}.png'.format(i)
        image = Image.open(os.path.join(self.image_dir, filename))
        if self.transform:
            image = self.transform(image)
        return image, torch.zeros(1)

## ネットワークモデル


In [None]:
class VAE(nn.Module):
    def __init__(self, z_dim=100, input_c=1):
        super(VAE, self).__init__()

        self.z_dim = z_dim
        
        # Encoder
        self.encoder = nn.Sequential(
                nn.Conv2d(input_c, 32, kernel_size=4, stride=2, padding=1), # 128 -> 64
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1),      # 64 -> 32 
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),      # 32 -> 32
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),      # 32 -> 16
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),      # 16 -> 16
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),     # 16 -> 8
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),     # 8 -> 8
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),      # 8 -> 8
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv2d(32, z_dim, kernel_size=8, stride=1)               # 8 -> 1
            )
        
        self.mu_fc = nn.Linear(z_dim, z_dim)
        self.logvar_fc = nn.Linear(z_dim, z_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=8, mode='nearest'),
            nn.Conv2d(z_dim, 32, kernel_size=3, stride=1, padding=1),  # 1 -> 8
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),     # 8 -> 8
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),    # 8 -> 8
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),    # 8 -> 16
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),     # 16 -> 16
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),     # 16 -> 32
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),     # 32 -> 32
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),     # 32 -> 64
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(32, input_c, kernel_size=3, stride=1, padding=1),  # 64 -> 128
            nn.Sigmoid())

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
        
    def reparameterize(self, mu, logvar):
        #std = logvar.mul(0.5).exp_()
        #eps = std.new(std.size()).normal_()
        #return eps.mul(std).add_(mu)
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        h      = self.encoder(x)
        
        h = torch.flatten(h, start_dim=1)
        mu     = self.mu_fc(h)                    # 平均ベクトル
        logvar = self.logvar_fc(h)                # 分散共分散行列の対数
        z      = self.reparameterize(mu, logvar)  # 潜在変数

        x_hat  = self.decoder(z.view(z.size(0), -1, 1, 1))
        self.mu     = mu.squeeze()
        self.logvar = logvar.squeeze()
        return x_hat

## 学習

In [None]:
epochs = 1000
batch_size = 10

# 誤差関数
def loss_function(recon_x, x, mu, logvar):
    recon = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + kld, recon, kld

# データセットの設定
transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.RandomAffine(degrees=[-60, 60], translate=(0.1, 0.1), scale=(0.5, 1.5)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.ToTensor()
            ])
train_data = MVTecAD(image_dir="./mvtec_anomaly_detection/capsule/train/good", transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=40)

# ネットワークモデル・最適化手法の設定
model = VAE(z_dim=100, input_c=3)
if use_cuda:
    model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.5, 0.999), weight_decay=1e-5)

model.train()
for epoch in range(1, epochs + 1):
    for idx, (inputs, _) in enumerate(train_loader):
        if use_cuda:
            inputs = inputs.cuda()
        output = model(inputs)
        loss, _, _ = loss_function(output, inputs, model.mu, model.logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 100 == 0 and epoch % 10 == 0:
            print('%d epoch [%d/%d] | loss: %.4f |' % (epoch, idx, len(train_loader), loss.item()))
            
    if epoch % 100 == 0:
        torch.save(model.state_dict(), "snapshot-mvtec/capsule_vae/anomaly_det_model_%04d.pt" % epoch)

### テスト（学習データで）

In [None]:
model.load_state_dict(torch.load("snapshot-mvtec/capsule_vae/anomaly_det_model_1000.pt"))

model.eval()
with torch.no_grad():
    for ind, (inputs, _) in enumerate(train_loader):
        if use_cuda:
            inputs = inputs.cuda()
            
        reconstructed = model(inputs).detach()
        b = inputs.data.cpu().numpy()[0].transpose(1,2,0)
        a = reconstructed.data.cpu().numpy()[0].transpose(1,2,0)
        
        diff = np.abs(a - b)
        
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        plt.imshow(b, cmap='gray', vmin = 0, vmax = 1, interpolation='none')
        plt.subplot(1, 3, 2)
        plt.imshow(a, cmap='gray', vmin = 0, vmax = 1, interpolation='none')
        plt.subplot(1, 3, 3)
        plt.imshow(diff, interpolation='none')
        plt.show()
        
        if ind == 1:
            break

## テスト

In [None]:
transform_test = transforms.Compose([transforms.Resize((128,128)), transforms.ToTensor()])
test_bad_data = MVTecAD(image_dir="./mvtec_anomaly_detection/capsule/test/good", transform=transform_test)
test_bad_data = MVTecAD(image_dir="./mvtec_anomaly_detection/capsule/test/crack", transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_bad_data, batch_size=1, shuffle=False)

model.eval()
with torch.no_grad():
    for ind, (inputs, _) in enumerate(test_loader):
        if use_cuda:
            inputs = inputs.cuda()
            
        reconstructed = model(inputs).detach()
        b = inputs.data.cpu().numpy()[0].transpose(1,2,0)
        
        a = reconstructed.data.cpu().numpy()[0].transpose(1,2,0)
        
        diff = np.abs(a - b)

        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        plt.imshow(b, cmap='gray', vmin = 0, vmax = 1, interpolation='none')
        plt.subplot(1, 3, 2)
        plt.imshow(a, cmap='gray', vmin = 0, vmax = 1, interpolation='none')
        plt.subplot(1, 3, 3)
        plt.imshow(diff, interpolation='none')
        plt.show()
        
        if ind == 2:
            break

## 繰り返しによる異常検知

In [None]:
max_iter = 99
alpha = 0.5
lam = 0.05
decay_rate = 0.1
minimum = 1e12
th = 0.5

transform_test = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])
test_bad_data = MVTecAD(image_dir="./mvtec_anomaly_detection/capsule/test/crack", transform=transform_test)
test_bad_loader = torch.utils.data.DataLoader(test_bad_data, batch_size=1, shuffle=False)

model.eval()

loss_func = nn.MSELoss(reduction='sum')
if use_cuda:
    loss_func = loss_func.cuda()

for index, (x_org, _) in enumerate(test_bad_loader):
    # 2つ目のサンプルを例として実行するため，1つ目を飛ばして実行
    if index == 0:
        continue
    
    img = x_org[0].data.numpy()
    
    x_t_images = []
    grad_images = []
    reconstructed_images = []
    
    if use_cuda:
        x_org = x_org.cuda()

    x_org.requires_grad_(True)
    rec_x = model(x_org).detach()

    loss = loss_func(x_org, rec_x)
    loss.backward()
    grads = x_org.grad.data
    x_t = x_org - alpha*grads*(x_org - rec_x)**2
    
    grad_images.append(grads[0].cpu().numpy().transpose(1,2,0))
    reconstructed_images.append(rec_x[0].data.cpu().numpy().transpose(1,2,0))
    x_t_images.append(x_t[0].data.cpu().numpy().transpose(1,2,0))

    losses = torch.zeros(max_iter)

    for i in range(max_iter):
        x_t = Variable(x_t.clamp(min=0, max=1), requires_grad=True)
        rec_x = model(x_t).detach()
        rec_loss = loss_func(x_t, rec_x)
        losses[i] = rec_loss.item()

        if minimum <= rec_loss and use_decay_lr is True:
            minimum = min(minimum, rec_loss)
        if rec_loss <= th:
            break

        l1 = torch.abs(x_t - x_org).sum()
        loss = rec_loss + lam*l1
        loss.backward()
        grads = x_t.grad.data

        mask = (x_t - rec_x)**2
        energy = grads * mask

        x_t = x_t - alpha*energy
        
        grad_images.append(grads[0].cpu().numpy().transpose(1,2,0))
        reconstructed_images.append(rec_x[0].data.cpu().numpy().transpose(1,2,0))
        x_t_images.append(x_t[0].data.cpu().numpy().transpose(1,2,0))

    plt.imshow(img.transpose(1, 2, 0), vmin = 0, vmax = 1, interpolation='none')
    plt.title("orig image")
    plt.show()

    plt.figure(figsize=(30, 30))
    for i, img in enumerate(x_t_images):
        plt.subplot(10, 10, i+1)
        plt.imshow(img, vmin = 0, vmax = 1, interpolation='none')
    plt.show()
    
    break

## 課題

1. 反復回数を増加させたときにどのような再構成画像が生成されるか確認しましょう．

## 参考文献

[1] David Dehaene, Oriel Frigo, Sébastien Combrexelle, Pierre Eline, "Iterative energy-based projection on a normal data manifold for anomaly localization," in ICLR, 2020.