In [1]:
# %pip install wandb

In [2]:
import os
import sys
sys.path.append('../scripts')

import pandas as pd
import numpy as np

from PIL import Image

import torch
import torch.autograd as autograd
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms, utils
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import wandb

# Датасет, гиперпараметры, модели

In [26]:
from dataset import ArchNegatives, ArchPositives

from model_conv import Generator, Discriminator
from model_encoder import Encoder

In [27]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [28]:
params = {
    'batch_size': 64,
    'channels': 3,
    'img_size': 64,
    'latent_dim': 100,
    'lr': 0.0002,
    'b1': 0.5,
    'b2': 0.999,
    'n_epochs': 200,
    'n_critic': 5,
    'sample_interval': 400,
}

In [29]:
transform_crop_64 = transforms.Compose([
    transforms.CenterCrop((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [30]:
train_dataset = ArchNegatives(transform=transform_crop_64)
train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)

anomalies_dataset = ArchPositives(transform=transform_crop_64)
anomalies_loader = DataLoader(anomalies_dataset, batch_size=params['batch_size'], shuffle=False)

In [31]:
print(f'train: {len(train_dataset)} images, {len(train_loader)} batches')
print(f'anomalies: {len(anomalies_dataset)} images, {len(anomalies_loader)} batches')

train: 139945 images, 2187 batches
anomalies: 163 images, 3 batches


In [32]:
generator = Generator(params)
discriminator = Discriminator(params)
encoder = Encoder(params)

In [33]:
# map_location=torch.device('cpu')

generator.load_state_dict(torch.load("../train_generator_discriminator/results_conv/generator"))
discriminator.load_state_dict(torch.load("../train_generator_discriminator/results_conv/discriminator"))
encoder.load_state_dict(torch.load("../train_izif_encoder/results_encoder/encoder"))

<All keys matched successfully>

In [34]:
generator.to(device).eval();
discriminator.to(device).eval();
encoder.to(device).eval();

# Модель сравнения изображений

In [138]:
def compute_anomaly_scores(imgs, kappa=1.0):
    
    # imgs: batch_size x num_channels x img_size x img_size
    
    mse = nn.MSELoss(reduction='none')
    
    real_imgs = imgs.to(device)
    real_zs = encoder(real_imgs)
    fake_imgs = generator(real_zs)
    fake_zs = encoder(fake_imgs)

    real_features = discriminator.forward_features(real_imgs)
    fake_features = discriminator.forward_features(fake_imgs)
    
    img_distances = torch.mean(mse(fake_imgs, real_imgs), dim=(1, 2, 3))
    loss_features = torch.mean(mse(fake_features, real_features), dim=(1))
    anomaly_scores = img_distances + kappa * loss_features
    z_distances = torch.mean(mse(fake_zs, real_zs), dim=(1))
    
    return img_distances, loss_features, anomaly_scores, z_distances
    

# Замеряем скорость инференса

In [141]:
%%time
for i, (imgs, labels) in enumerate(anomalies_loader):
    img_distances, loss_features, anomaly_scores, z_distances = compute_anomaly_scores(imgs)

CPU times: user 437 ms, sys: 13.9 ms, total: 451 ms
Wall time: 449 ms


In [None]:
%%time
for i, (imgs, labels) in enumerate(train_loader):
    img_distances, loss_features, anomaly_scores, z_distances = compute_anomaly_scores(imgs)

CPU times: user 6min 45s, sys: 20.1 s, total: 7min 5s
Wall time: 18min 15s
