# Convolutional autoencoder for image anomaly detection

This is about using an autoencoder to find anomalies in images. The dataset used is the MVTEC bottles dataset and the loss function is [SSIM](https://en.wikipedia.org/wiki/Structural_similarity_index_measure) This approach was prompted by [The MVTec Anomaly Detection Dataset: A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection](https://link.springer.com/content/pdf/10.1007/s11263-020-01400-4.pdf)

- Anomalous images: the [MVTEC](https://www.mvtec.com/company/research/datasets/mvtec-ad) bottles dataset.
- Distance measure: Piqa [SSIM](https://piqa.readthedocs.io/en/stable/api/piqa.ssim.html)

Steps:
1. augment the data
2. build the autoencoder
3. train using SSIM loss
4. test

## SSIM loss
We use piqa to provide the SSIM loss.

In [None]:
!pip install piqa
import piqa

In [None]:

import numpy as np 
import pandas as pd 
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import cv2
import random
import matplotlib.pyplot as plt
from PIL import Image
import shutil
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import math
import piqa
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score


In [None]:
# Directories
BASEDIR = 
INPUT_TRAIN_DIR = os.path.join(BASEDIR, 'mvtec-bottles/bottle/train/good/')
TEST_DIR = os.path.join(BASEDIR, 'mvtec-bottles/bottle/test')
AUGMENTED_DIR = os.path.join(BASEDIR, 'augmented/')
os.makedirs(AUGMENTED_DIR, exist_ok=True)
OUT_DIR = os.path.join(BASEDIR, 'outputs')
os.makedirs(OUT_DIR, exist_ok=True)
TB_DIR = os.path.join(BASEDIR, 'tb_runs')
RESULT_DIR = os.path.join(BASEDIR, 'results')
os.makedirs(RESULT_DIR, exist_ok=True)

CPU = torch.device('cpu')
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(DEVICE)


RANDOM_SEED = 13
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# augmenting and resizing
NUM_AUGMENTATIONS = 20
NEW_IMAGE_SIZE = 256

# training parameters
BATCH_SIZE = 4
EPOCHS = 7
LR = 0.001
DROPOUT = 0.2
LOG_INTERVAL = 100
VAL_LOG_INTERVAL = 20 



In [4]:
# Utilities

# diaplay a bunch of images in a grid

def tile_images(imgs, rows, cols, is_paths=False):
  """
  Tiles a list of images into a grid with specified rows and columns.

  Args:
      imgs: A list of torch images or of paths
      rows: Number of rows in the grid.
      cols: Number of columns in the grid.

  Returns:
      A matplotlib figure object containing the tiled image grid.
  """
  if is_paths:
    imgs = [Image.open(f) for f in imgs]
    width, height = imgs[0].size()  # Assuming all images have same size
  else:
    imgs = [x.permute(1, 2, 0) for x in imgs]
    width, height, _ = imgs[0].size()

  # Create a new figure with a white background
  fig, axs = plt.subplots(rows, cols, figsize=(cols * width / 100, rows * height / 100), 
                          facecolor='white')

  # Iterate over images and add them to subplots
  i = 0
  for r in range(rows):
    for c in range(cols):
      if i < len(imgs):
        axs[r, c].imshow(imgs[i])
        axs[r, c].axis('off')  # Hide axes for cleaner visualization
      i += 1

  # Adjust layout to prevent overlapping labels (optional)
  fig.tight_layout()
  plt.show(fig)
  return fig

# count trainable parameters in a model
def count_parameters(model, print_all_parms=False):
    total_params = 0
    if print_all_parms: print('\nTrainable Parameters')
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        if print_all_parms: print(name, '\t', params)
        total_params += params
    print(f"Total Trainable Params: {total_params}")
    return total_params


In [None]:
def view_train(n=6, rows=3, cols=2):
    assert rows * cols == n
    allFiles = os.listdir(INPUT_TRAIN_DIR)
    to_view = [os.path.join(INPUT_TRAIN_DIR, random.choice(allFiles)) for i in range(n)]
    tile_images(to_view, rows, cols, True)
    
#view_train()

## Data augmentation

We have 200 images which are too few, so we do some data augmentation. Bergman applies rotation, mirrorring and translation. THe bottle images seem centered and all of the same size. So we start with rotation and mirrorring. 

Bergman also scales the images to 256x256, so we do that too.


In [None]:
# Define the augmentation functions (modify as needed)
def rotate(image, angle):
    """Rotates an image by a random angle."""
    rows, cols, _ = image.shape
    center = (cols // 2, rows // 2)
    rot_mat = cv2.getRotationMatrix2D(center, angle, 1.0)
    return cv2.warpAffine(image, rot_mat, (cols, rows),
                            borderMode=cv2.BORDER_CONSTANT,
                            borderValue=(255,255,255))

def flip(image, mode):
    """Flips an image horizontally or vertically."""
    return cv2.flip(image, mode)

# Define augmentation parameters (adjust as needed)
rotation_range = 180  # Range of degrees for random rotation
flip_probability = 0.5  # Probability of horizontal or vertical flip

def augment_images(n_augment=NUM_AUGMENTATIONS, n_files=-1):
    cnt = 0
    output_dir = AUGMENTED_DIR
    for filename in os.listdir(INPUT_TRAIN_DIR):
        if filename.endswith(".jpg") or filename.endswith(".png"):
      # Read the image
            if n_files > 0 and cnt > n_files: break
            if cnt % 20 == 0:
                print('augmented ', cnt)
            shutil.copy(os.path.join(INPUT_TRAIN_DIR, filename),
                        os.path.join(output_dir, filename))
            for i in range(n_augment):
                image = cv2.imread(os.path.join(INPUT_TRAIN_DIR, filename))
                augmented_image = rotate(image.copy(), random.uniform(-rotation_range, rotation_range))
                if random.random() < flip_probability:
                    augmented_image = flip(augmented_image, random.randint(0, 1))
                augmented_image = cv2.resize(augmented_image, (NEW_IMAGE_SIZE, NEW_IMAGE_SIZE), 
                                             interpolation=cv2.INTER_AREA)
                # Save the augmented image with a modified filename
                cv2.imwrite(os.path.join(output_dir, f"aug_{i}_{filename}"), augmented_image)
                cnt += 1
    return cnt


In [None]:
def test_augment():
    augment_images(5, 1)
    allFiles = os.listdir(AUGMENTED_DIR)
    to_view = [os.path.join(AUGMENTED_DIR, f) for f in allFiles]
    tile_images(to_view, 3, 2, True)
    
test_augment()

In [None]:
augment_images()

## Dataset

Assume the augmented data is split into train/val beforehand, and that the labels are available for the test case.

In [5]:
class BottleDataset(Dataset):
    """
    Given.a list of filenames and possibly a lsot of labels, generate the data, Transforms:
    - scale to [0,1]
    - resize to 256x256
    """
    def __init__(self, fnames, labels=None):
        super().__init__()
        self.fnames = fnames
        print('dataset len fnames ', len(fnames))
        self.labels = labels
        
    def __len__(self):
        return len(self.fnames)
    
    def __getitem__(self, idx):
        image = torchvision.io.read_image(self.fnames[idx])
        image = torchvision.transforms.functional.resize(image, (256, 256))
        image = image/255
        if image.size()[0] != 3:
            print('ERROR in ', idx, ': ', self.fnames[idx])
        if self.labels is None:
            return image, None
        else:
            return image, self.labels[idx]

def get_dataloaders(file_dir,   # all files are in this dir
                    props=[1],  # proportions to split the data in
                    batch_size=BATCH_SIZE,  # dataloader batch size
                    num_samples=-1):
    labels = None
    file_list = os.listdir(file_dir)
    file_list = [x for x in file_list if x.endswith('.png')]
    print(f"{file_dir} num pngs: {len(file_list)}")
    random.shuffle(file_list)
    if num_samples > 0:
        file_list = file_list[:num_samples]
    print(f"len file_list: {len(file_list)}")
    dls = []
    lb, ub = 0, 0
    for i, prop in enumerate(props):
        if i == len(props) - 1:
            ub = len(file_list)
        else:
            ub = lb + math.floor(prop * len(file_list))
        if labels is None:
            ds = BottleDataset([os.path.join(file_dir, x) for x in file_list[lb:ub]],
                               [-1] * (ub-lb))
        else:
            ds = BottleDataset([os.path.join(file_dir, x) for x in file_list[lb:ub]], 
                               labels[lb:ub])
        dls.append(DataLoader(ds, batch_size=batch_size, shuffle=True))
        lb = ub
    #for pdl in zip(props, dls):
    #    print(f"Prop: {pdl[0]}, num batches: {len(pdl[1])}")
    return dls

def get_test_dataloader(test_dir, batch_size=BATCH_SIZE, num_samples=10):
    # test dir has subdirs: broken_large, broken_small, contaminations, and good
    dir_label = [['broken_large', 1], ['broken_small', 1], ['contamination', 1],
                 ['good', 0]]
    file_label = []
    for x in dir_label:
        xp = os.path.join(test_dir, x[0])
        files = [os.path.join(xp, f) for f in os.listdir(xp) if f.endswith('.png')]
        file_label.extend(list(zip(files, [x[1]] * len(files))))
    random.shuffle(file_label)
    print('TEST SIZE: ', len(file_label))
    if num_samples > 0:
        file_label = file_label[:num_samples]
    files = [x[0] for x in file_label]
    labels = [x[1] for x in file_label]
    t_ds = BottleDataset(files, labels)
    t_dl = DataLoader(t_ds, batch_size=batch_size)
    return t_dl
        
        

In [None]:
# test

dataset = BottleDataset(['003.png', '054.png', '143.png', '043.png'], INPUT_TRAIN_DIR, None)
print(len(dataset))
images = []
for i in range(4):
    ximg = dataset[i][0]
    print(ximg.size())
    images.append(ximg)
_ = tile_images(images, 2, 2, False)
dataset = None

## Convolutional autoencoder

Bergman does not provvide details about the autoencoder except that SSIM is used for the loss function.



In [6]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.dropout = nn.Dropout(p=DROPOUT)
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # Output between 0 and 1 for image reconstruction
        )

    def forward(self, x):
        encoded = self.encoder(x)
        encoded = self.dropout(encoded)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
# test autoencoder

def test_autoencoder():
    imgs = []
    imgs.append(torchvision.io.read_image(os.path.join(INPUT_TRAIN_DIR, '001.png')))
    imgs.append(torchvision.io.read_image(os.path.join(INPUT_TRAIN_DIR, '002.png')))
    imgs.append(torchvision.io.read_image(os.path.join(INPUT_TRAIN_DIR, '003.png')))
    imgs.append(torchvision.io.read_image(os.path.join(INPUT_TRAIN_DIR, '004.png')))
    for i in range(len(imgs)):
        imgs[i] = torchvision.transforms.functional.resize(imgs[1], (256, 256))/255
    p = torch.stack(imgs)
    print('input ', p.size())
    ae = Autoencoder()
    ssim_loss = SSIMLoss()
    ae.eval()
    with torch.no_grad():
        preds = ae(p)
        print('preds ', preds.size())
        loss = ssim_loss(preds, p)
        print('loss: ', loss.item())
        
# test_autoencoder()

In [7]:
def train(train_file_dir = AUGMENTED_DIR, learning_rate=LR, epochs=EPOCHS, num_samples=-1):
    # TODO load a chekpoint
    start_time = datetime.now()
    trained_model_name = os.path.join(OUT_DIR, f"{start_time.strftime('%b%d_%H%M')}_model.pt")
    checkpoint_name = os.path.join(OUT_DIR, f"{start_time.strftime('%b%d_%H%M')}_checkpoint.pt")
    tb_dir = os.path.join(TB_DIR, f"{start_time.strftime('%b%d_%H%M')}")
    os.makedirs(tb_dir, exist_ok=True)
    tb_writer = SummaryWriter(tb_dir)
    train_val = get_dataloaders(train_file_dir, [0.85, 0.15], BATCH_SIZE, num_samples)
    train_dl = train_val[0]
    val_dl = train_val[1]
    
    model = Autoencoder().to(DEVICE)
    count_parameters(model, True)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    ssim_loss = piqa.SSIM().to(DEVICE)

    int_loss, total_loss, best_epoch, best_loss = 0, 0, -1, 1e9
    checkpoint = None
    
    for epoch in range(epochs):
        int_loss, total_loss = 0, 0
        model.train()
        for b, data in enumerate(train_dl):
            optimizer.zero_grad()
            reconstructed = model(data[0].to(DEVICE))
            loss = 1- ssim_loss(reconstructed, data[0].to(DEVICE))
            loss.backward()
            optimizer.step()
            int_loss += loss.item()
            total_loss += loss.item()
            
            if b % LOG_INTERVAL == 0 and b > 0:
                avg_loss = int_loss/LOG_INTERVAL
                tb_writer.add_scalar('train_Loss', avg_loss, epoch * len(train_dl) + b)
                print('train_Loss ', epoch * BATCH_SIZE + b, ': ', avg_loss)
                int_loss = 0
        val_int_loss, val_total_loss = 0, 0
        avg_train_loss = total_loss/len(train_dl)
        val_int_loss, val_total_loss, avg_val_loss = 0, 0, 0
        model.eval()
        print('Validating with batches: ', len(val_dl))
        for b, data in enumerate(val_dl):
            with torch.no_grad():
                dataSize = data[0].size()
                lastData = data[0][dataSize[0] - 1, :, :, :]
                preds = model(data[0].to(DEVICE))
                loss = 1 - ssim_loss(preds, data[0].to(DEVICE))
                val_int_loss += loss.item()
                val_total_loss += loss.item()
            if b % VAL_LOG_INTERVAL == 0 and b > 0:
                avg_val_loss = int_loss/LOG_INTERVAL
                tb_writer.add_scalar('val_loss', avg_val_loss, epoch * len(val_dl) + b)
                print(f"val_loss {epoch * len(val_dl) + b}: {avg_val_loss}")
            data_pred = []
            if b == len(val_dl) - 1:
                for i in range(len(data[0])):
                    data_pred += [data[0][i].to(CPU), preds[i].to(CPU)]
                tile_images(data_pred, len(data[0]), 2, False)
        avg_val_loss = val_total_loss/len(val_dl)
        tb_writer.add_scalars("Epoch train val loss", {'train_loee': avg_train_loss,
                                                      'val_loss': avg_val_loss}, epoch)        
        print(f"Epoch {epoch} train/val loss: {avg_train_loss} / {avg_val_loss}")
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            best_epoch = epoch
            checkpoint = {'model_state_dict': model.state_dict(),
                         'epoch': epoch,
                         'optimizer_state_dict': optimizer.state_dict(),
                         'loss': avg_val_loss}
            print(f"Saving checkpoint at epoch {epoch}")
            torch.save(checkpoint, checkpoint_name)
    torch.save(checkpoint['model_state_dict'], trained_model_name)
    print('saved model')
        

In [None]:
train(epochs=EPOCHS, num_samples=-1)

In [None]:
def quick_run_AE(n=3):
    allFiles = os.listdir(INPUT_TRAIN_DIR)
    test_paths = [os.path.join(INPUT_TRAIN_DIR, random.choice(allFiles)) for i in range(n)]
    ximages = []
    ae = Autoencoder()
    with torch.no_grad():
        for f in test_paths:
            image = torchvision.io.read_image(f)
            image = torchvision.transforms.functional.resize(image, (256, 256))
            image = image/255
            print(image.size())
            ximages.append(image)
            enc = ae.encoder(image)
            print('enc: ', enc.size())
            #ximages.append(enc)
            dec = ae.decoder(enc)
            print('dec ', dec.size())
            ximages.append(dec * 255)
    print(type(ximages[0]))
    tile_images(ximages, n, 2, False)

#quick_run_AE()

In [10]:
# test
# the test dir has 3 bad subdirs: broken_large, broken_small and contamination
# and one good subdir: good
#
# generate the ROC curve and compute the AUC
TEST_NUM_TO_DISPLAY = 2

def run_test(model, dataloader):
    scores = []
    labels = []
    start_time = datetime.now().strftime('%b%d_%H%M')
    sl_fname = os.path.join(RESULT_DIR, f"{start_time}_pred_true.csv")
    roc_fname = os.path.join(RESULT_DIR, f"{start_time}_roc.csv")
    #ssim = piqa.ssim.SSIM(reduction='none').to(DEVICE)
    mseloss = torch.nn.MSELoss(reduction='none')
    for i, batch in enumerate(dataloader):
        print('Batch ', i)
        with torch.no_grad():
            reconstructed = model(batch[0].to(DEVICE))
        #sim = ssim(reconstructed, batch[0].to(DEVICE))
        #diff = 1 - sim
        diff = mseloss(reconstructed, batch[0].to(DEVICE)).mean(dim=(1, 2, 3))
        scores.extend(list(diff.detach().cpu().numpy()))
        labels.extend(list(batch[1].cpu().numpy()))
        idxs = list(range(len(batch[0])))
        random.shuffle(idxs)
        display_images = []
        for idx in idxs[:TEST_NUM_TO_DISPLAY]:
            display_images.extend([batch[0][idx], reconstructed[idx].detach().to(CPU)])
        tile_images(display_images, len(display_images)//2, 2, False)
    auc = roc_auc_score(labels, scores)
    fpr, tpr, thresholds = roc_curve(labels, scores)
    print('FPR ', fpr)
    print('TPR ', tpr)
    print('Thresholds ', thresholds)
    print('AUC ', auc)
    sl_df = pd.DataFrame({'labels': labels, 'scores': scores})
    roc_df = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds})
    sl_df.to_csv(sl_fname, header=True, index=False)
    roc_df.to_csv(roc_fname, header=True, index=False)
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % auc)
    plt.show()
    

def dp_test(model_fname, test_dir):
    model = Autoencoder()
    model.load_state_dict(torch.load(model_fname, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    t_dl = get_test_dataloader(test_dir, batch_size=BATCH_SIZE, num_samples=-1)
    run_test(model, t_dl)


In [None]:
dp_test(saved_model, TEST_DIR)