In [None]:
import os
import sys

import pandas as pd
import numpy as np

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms, utils
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

## Данные и модель

In [None]:
from dataset import ArchDataset, GT_NEGATIVE_PATHS, GT_POSITIVE_PATHS, GT_NEGATIVE_COORDS, GT_POSITIVE_COORDS

In [None]:
from model import VQVAE

# Параметры модели и обучения

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
results_dir = 'v5'
checkpoint_path = f'{results_dir}/checkpoint_14.pth'

In [None]:
params = {
    'batch_size': 512,
    'img_size': 64,
    'channels': 3,
    'embedding_dim': 64,
    'num_embeddings': 128,
    'beta': 0.25,
    'n_epochs': 15,  # 10
} 

# Датасет

In [None]:
transform_crop_64 = transforms.Compose([
    transforms.CenterCrop((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
train_dataset = ArchDataset(
    img_paths=GT_NEGATIVE_PATHS, 
    coords=GT_NEGATIVE_COORDS, 
    anomalies=False, 
    transform=transform_crop_64
)
train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)

anomalies_dataset = ArchDataset(
    img_paths=GT_POSITIVE_PATHS, 
    coords=GT_POSITIVE_COORDS, 
    anomalies=False, 
    transform=transform_crop_64
)
anomalies_loader = DataLoader(anomalies_dataset, batch_size=params['batch_size'], shuffle=False)

In [None]:
print(f'train: {len(train_dataset)} images, {len(train_loader)} batches')
print(f'anomalies: {len(anomalies_dataset)} images, {len(anomalies_loader)} batches')

In [None]:
ncols = 6

fig, axs = plt.subplots(nrows=2, ncols=ncols, figsize=(16, 8))
for i in range(ncols):
    axs[0, i].axis('off')
    axs[1, i].axis('off')
    axs[0, i].imshow((anomalies_dataset[i][0].permute(1, 2, 0) + 1) / 2)
    axs[1, i].imshow((train_dataset[i][0].permute(1, 2, 0) + 1) / 2)
fig.tight_layout()
fig.show()

# Создание модели

In [None]:
model = VQVAE(
    in_channels=params['channels'],
    img_size=params['img_size'],
    
    embedding_dim=params['embedding_dim'],
    num_embeddings=params['num_embeddings'],
    beta=params['beta']
)
model.to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))

# Метрики

In [None]:
import lpips
loss_fn_alex = lpips.LPIPS(net='alex')

from torchvision.models import vit_b_16, ViT_B_16_Weights


weights = ViT_B_16_Weights.IMAGENET1K_V1
vit_model = vit_b_16(weights=weights)
vit_model.to(device)

vit_model.heads = torch.nn.Identity()
vit_model.eval();

from sklearn.metrics.pairwise import cosine_similarity

In [None]:
def load_image(names):
    imgs = []
    for name in names:
        img = transform_crop_64_grayscale(Image.open(name).convert('RGB'))
        imgs.append(img)
    return imgs

def align_image(img):
    return (img - img.min()) / (img.max() - img.min())

def draw_image(imgs):
    fig, axes = plt.subplots(ncols=len(imgs) , figsize=(len(imgs) * 2, 2))

    if len(imgs) == 1:
        img = imgs[0]
        cmap = 'gray' if img.shape[0] == 1 else None
        axes.imshow(align_image(img).permute(1, 2, 0), cmap=cmap)
    else:
        for i, img in enumerate(imgs):
            cmap = 'gray' if img.shape[0] == 1 else None
            axes[i].imshow(align_image(img).permute(1, 2, 0), cmap=cmap)
    plt.show()

In [None]:
def vit_cosine_similarity(img1, img2):

    assert len(img1.shape) == len(img2.shape) == 3,\
        f'expected images of shape C x W x H, got {img1.shape=} {img2.shape=}'

    assert img1.shape[0] == img2.shape[0],\
        f'different number of channels: {img1.shape[0]=} {img2.shape[0]=}'
    

    if img1.shape[0] == 1:  # grayscale
        img1 = img1.repeat(3, 1, 1)
        img2 = img2.repeat(3, 1, 1)

    resize = transforms.Resize(size=224, antialias=True)

    img1 = resize(img1.unsqueeze(0))
    img2 = resize(img2.unsqueeze(0))

    features1 = vit_model(img1).cpu().detach().numpy()
    features2 = vit_model(img2).cpu().detach().numpy()

    sim = cosine_similarity(features1, features2)
    return sim[0, 0].item()


In [None]:
def test(real_img, return_img=False):
    
    real_img_cuda = real_img.to(device)
    reconstructed_img_cuda = model(real_img_cuda[None, :])[0][0]

    sim = vit_cosine_similarity(real_img_cuda, reconstructed_img_cuda)
    
    if return_img:
        return sim, reconstructed_img
    else:
        return sim
    

In [None]:
anomaly_dataset_scores = []
for img, idx, coords, anomaly_flag in tqdm(anomalies_dataset, total=len(anomalies_dataset)):
    sim = test(img)
    anomaly_dataset_scores.append(sim)

In [None]:
train_dataset_scores = []
for img, idx, coords, anomaly_flag in tqdm(train_dataset, total=len(train_dataset)):
    sim = test(img)
    train_dataset_scores.append(sim)

In [None]:
plt.hist(anomaly_dataset_scores, bins=20, color='lightcoral', alpha=0.7)
plt.hist(train_dataset_scores, bins=20, color='royalblue', alpha=0.7);

In [None]:
anomaly_dataset_scores = np.array(anomaly_dataset_scores)
train_dataset_scores = np.array(train_dataset_scores)

In [None]:
anomalies_sorted_idxs = np.argsort(anomaly_dataset_scores)
nrows, ncols = 2, 6

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize = (3 * ncols, 3 * nrows))

fig.suptitle('Anomalies dataset: most suspisious')

for j in range(ncols):
    
    for i in range(nrows):
        axs[i, j].axis('off')
        axs[i, j].grid('off')
        
    idx = anomalies_sorted_idxs[j]
    img = anomalies_dataset[idx][0]
    sim, reconstructed_img = test(img, return_img=True)

    axs[0, j].imshow((img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    axs[1, j].imshow((reconstructed_img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    
    axs[0, j].title.set_text(f'sim: {sim:.4f}')
    
fig.tight_layout()
fig.show()

In [None]:
anomalies_sorted_idxs = np.argsort(-anomaly_dataset_scores)
nrows, ncols = 2, 6

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize = (3 * ncols, 3 * nrows))

fig.suptitle('Anomalies dataset: least suspisious')

for j in range(ncols):
    
    for i in range(nrows):
        axs[i, j].axis('off')
        axs[i, j].grid('off')
        
    idx = anomalies_sorted_idxs[-(j + 1)]
    img = anomalies_dataset[idx][0]
    sim, reconstructed_img = test(img, return_img=True)

    axs[0, j].imshow((img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    axs[1, j].imshow((reconstructed_img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    
    axs[0, j].title.set_text(f'sim: {sim:.4f}')
    
fig.tight_layout()
fig.show()

In [None]:
train_sorted_idxs = np.argsort(train_dataset_scores)
nrows, ncols = 2, 6

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize = (3 * ncols, 3 * nrows))

fig.suptitle('Train dataset: most suspisious')

for j in range(ncols):
    
    for i in range(nrows):
        axs[i, j].axis('off')
        axs[i, j].grid('off')
        
    idx = train_sorted_idxs[j]
    img = train_dataset[idx][0]
    sim, reconstructed_img = test(img, return_img=True)

    axs[0, j].imshow((img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    axs[1, j].imshow((reconstructed_img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    
    axs[0, j].title.set_text(f'sim: {sim:.4f}')
    
fig.tight_layout()
fig.show()

In [None]:
train_sorted_idxs = np.argsort(train_dataset_scores)
nrows, ncols = 2, 6

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize = (3 * ncols, 3 * nrows))

fig.suptitle('Train dataset: least suspisious')

for j in range(ncols):
    
    for i in range(nrows):
        axs[i, j].axis('off')
        axs[i, j].grid('off')
        
    idx = train_sorted_idxs[-(j+1)]
    img = train_dataset[idx][0]
    sim, reconstructed_img = test(img, return_img=True)

    axs[0, j].imshow((img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    axs[1, j].imshow((reconstructed_img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    
    axs[0, j].title.set_text(f'sim: {sim:.4f}')
    
fig.tight_layout()
fig.show()

In [None]:
train_sorted_idxs = np.argsort(train_dataset_scores)
nrows, ncols = 2, 100

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize = (3 * ncols, 3 * nrows))

fig.suptitle('Train dataset: most suspisious')

for j in range(ncols):
    
    for i in range(nrows):
        axs[i, j].axis('off')
        axs[i, j].grid('off')
        
    idx = train_sorted_idxs[j]
    img = train_dataset[idx][0]
    sim, reconstructed_img = test(img, return_img=True)

    axs[0, j].imshow((img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    axs[1, j].imshow((reconstructed_img.permute(1, 2, 0).detach().numpy() + 1) / 2)
    
    axs[0, j].title.set_text(f'sim: {sim:.4f}')
    
fig.tight_layout()
fig.show()