In [6]:
import numpy as np 
import gc, os, pickle, cv2
from scipy.ndimage import gaussian_filter
from sampling.kcenter_greedy import KCenterGreedy
from sklearn.random_projection import SparseRandomProjection

import pytorch_lightning as pl

import torch
from torchvision import transforms

In [7]:
from torch import nn

class BackBoneModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.features = []
        self.pretrained_model = None

    def hook_t(self, module, input, output):
        self.features.append(output)

    def forward(self, x):
        self.features = []
        _ = self.pretrained_model(x)
        return self.features

    def eval(self):
        self.pretrained_model.eval()


class WideResNet50(BackBoneModel):
    def __init__(self, args):
        super().__init__(args)
        self.pretrained_model = torch.hub.load('pytorch/vision:v0.8.2',
                                               'wide_resnet50_2', pretrained=True)

        for param in self.pretrained_model.parameters():
            param.requires_grad = False

        self.pretrained_model.layer2[-1].register_forward_hook(self.hook_t)
        self.pretrained_model.layer3[-1].register_forward_hook(self.hook_t)

In [8]:
import math
# import faiss

import torch
import numpy as np
from sklearn.metrics import (roc_curve, 
                             roc_auc_score, 
                             recall_score,
                             precision_score, 
                             confusion_matrix, 
                             f1_score, accuracy_score)
from torch.nn import functional as F


def min_max_norm(image, min_=None, max_=None):
    a_min = image.min() if min_ is None else min_
    a_max = image.max() if max_ is None else max_
    return (image - a_min) / (a_max - a_min)


def embedding_concat(x, y):
    B, C1, H1, W1 = x.size()
    _, C2, H2, W2 = y.size()
    s = int(H1 / H2)
    x = F.unfold(x, kernel_size=s, dilation=1, stride=s)
    x = x.view(B, C1, -1, H2, W2)
    z = torch.zeros(B, C1 + C2, x.size(2), H2, W2)
    for i in range(x.size(2)):
        z[:, :, i, :, :] = torch.cat((x[:, :, i, :, :], y), 1)
    z = z.view(B, -1, H2 * W2)
    z = F.fold(z, kernel_size=s, output_size=(H1, W1), stride=s)
    return z


def reshape_embedding(embedding):
    embedding_list = []
    for k in range(embedding.shape[0]):
        for i in range(embedding.shape[2]):
            for j in range(embedding.shape[3]):
                embedding_list.append(embedding[k, :, i, j])
    return embedding_list


def features_to_embedding(features, kernel_size=3, stride=1, padding=1):
    embeddings = []
    for feature in features:
        m = torch.nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding)
        embeddings.append(m(feature))

    embeddings_num = len(embeddings)
    for index in range(embeddings_num):
        if index == 0:
            embedding = embeddings[index]
        else:
            b1, c1, h1, w1 = embedding.shape
            b2, c2, h2, w2 = embeddings[index].shape
            if h1 % h2 != 0:
                diff_h = abs(h1 - (h2 * 2))
                embedding = torch.nn.ZeroPad2d((diff_h, 0, diff_h, 0))(embedding)
            embedding = embedding_concat(embedding, embeddings[index]).cuda()
    reshaped_embedding = reshape_embedding(np.array(embedding.cpu()))
    return reshaped_embedding


def distance_matrix(x, y=None, p=2):  # pairwise distance of vectors
    y = x if type(y) == type(None) else y

    n = x.size(0)
    m = y.size(0)
    d = x.size(1)

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)
    dist = torch.pow(x - y, p).sum(dim=2)
    return dist


def tensor_from_numpy(data):
    """
    numpyからPyTorchのTensorを作成する
    :param data:
    :return:
    """
    d = torch.from_numpy(data)
    if torch.cuda.is_available():
        d = d.cuda()
    return d


def get_knn_distance_faiss(coreset, test_data, k, batch_size=512):
    """
    coresetとtest_dataの最短距離を上位k個求める
    :param coreset:
    :param test_data:
    :param k:
    :return:
    """
    cpu_index = faiss.IndexFlatL2(coreset.shape[1])
    index = faiss.index_cpu_to_all_gpus(cpu_index)
    index.add(coreset)
    distance, index_info = index.search(test_data, k)
    return distance


def get_knn_distance(coreset, test_data, k, batch_size=512):
    """
    coresetとtest_dataの最短距離を上位k個求める
    :param coreset:
    :param test_data:
    :param k:
    :return:
    """
    p = 2
    coreset_tsr = tensor_from_numpy(coreset)
    test_data_tsr = tensor_from_numpy(test_data)

    max_test_data_num = test_data_tsr.shape[0]
    max_loop_num = int(np.ceil(max_test_data_num / batch_size))

    dist = None
    for index in range(max_loop_num):
        start = index * batch_size
        end = (index + 1) * batch_size
        dist_m = distance_matrix(test_data_tsr[start:end], coreset_tsr, p=2)
        dist_batch = dist_m ** (1 / p)
        if dist is None:
            dist = dist_batch
        else:
            dist = torch.cat([dist, dist_batch], dim=0)

    dist_knn = dist.topk(k, largest=False)
    dist_knn = dist_knn[0].cpu().detach().numpy()
    return dist_knn


def denormalization(x):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    x = (((x.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8)
    return x


def calc_roc_best_score(fpr, tpr, thresholds):
    """
    fpr, tprの相乗平均(幾何平均)を最大にする閾値とindexを返す
    :param fpr:
    :param tpr:
    :param thresholds:
    :return:
    """
    gmeans = np.sqrt(tpr * (1 - fpr))
    ix = np.argmax(gmeans)
    roc_threshold = thresholds[ix]
    return roc_threshold, ix

def to_labels(pos_probs, threshold):
    return (pos_probs >= threshold).astype('int')

def specificity_score(y_true, y_pred):
    """
    特異度を返す
    :param y_true:
    :param y_pred:
    :return:
    """
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).flatten()
    return tn / (tn + fp)


def calc_metrics(img_scores, gt_list):
    """
    閾値を0.1刻みで変化させ、recall, specificity, precisionを求める
    :param img_scores: 0〜１に正規化されている必要がある
    :param gt_list: 0〜１に正規化されている必要がある
    :return: thresholds, recall, specificity, precision
    """
    thresholds = np.arange(0, 1.01, 0.01)
    recall = np.array([recall_score(gt_list, to_labels(img_scores, th_i)) for th_i in thresholds])
    specificity = np.array([specificity_score(gt_list, to_labels(img_scores, th_i)) for th_i in thresholds])
    precision = np.array([precision_score(gt_list, to_labels(img_scores, th_i)) for th_i in thresholds])
    return thresholds, recall, specificity, precision


def calc_f1_score_precision_recall(recall, precision, thresholds):
    """
    recallとprecisionの調和平均を求め、最大の値とindex, thresholdを返す
    :param recall:
    :param precision:
    :param thresholds:
    :return:
    """
    a = 2 * precision * recall
    b = precision + recall
    f1_precision_recall = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
    idx = np.argmax(f1_precision_recall)
    pr_threshold = thresholds[idx]
    pr_best_f1score = f1_precision_recall[idx]
    return pr_threshold, pr_best_f1score, idx


def calc_f1_score_specificity_recall(recall, specificity, thresholds):
    c = 2 * specificity * recall
    d = specificity + recall
    f1_specificity_recall = np.divide(c, d, out=np.zeros_like(c), where=d != 0)
    idx = np.argmax(f1_specificity_recall)
    sr_threshold = thresholds[idx]
    sr_best_f1score = f1_specificity_recall[idx]
    return sr_threshold, sr_best_f1score, idx


def get_best_threshold(scores, labels, thr):
    best_thr = -1
    best_f1 = -1
    for t in thr:
        pred_label = np.copy(labels)
        good_index = np.where(scores < t)
        pred_label[good_index] = 0
        bad_index = np.where(t <= scores)
        pred_label[bad_index] = 1

        f1 = f1_score(labels, pred_label)
        # cm = confusion_matrix(labels, pred_label)
        # acc = accuracy_score(labels, pred_label)
        # pre = precision_score(labels, pred_label)
        # rec = recall_score(labels, pred_label)

        if best_f1 < f1:
            best_f1 = f1
            best_thr = t

    return best_f1, best_thr


if __name__ == '__main__':
    pass

In [13]:
# noinspection PyAttributeOutsideInit
class PatchCore(pl.LightningModule):
    def __init__(self, args, category):
        super(PatchCore, self).__init__()
        self.args = args
        self.category = category
        
        # https://pytorch-lightning.readthedocs.io/en/latest/common/hyperparameters.html#save-hyperparameters
#         self.save_hyperparameters(self.args)
        self.init_results_list()

        self.inv_normalize = transforms.Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
            std=[1/0.229, 1/0.224, 1/0.255])

        self.backbone_model = globals()[self.args.backbone](self.args)
        # self.knn = None

        # dummy
        self.criterion = torch.nn.MSELoss(reduction='sum')

    def init_results_list(self):
        self.gt_pixel = []
        self.score_pixel = []
        self.best_thr_pixel = []
        self.gt_image = []
        self.score_image = []
        self.img_path_list = []

    def configure_optimizers(self):
        return None

    def forward(self, x_t):
        features = self.backbone_model(x_t)
        return features

    def on_train_start(self):
        self.backbone_model.eval()
        self.embedding_list = []

    def training_step(self, batch, batch_idx):
        x, _, _, file_name, _ = batch
        features = self(x)
        # https://gitlab.com/chowagiken/toyota_kinuura/anomaly_detection/patch_core/-/blob/main/model/tensor_util.py#L40
        embedding = features_to_embedding(features, 
                                          self.args.kernel_size, 
                                          self.args.stride, self.args.padding)
        self.embedding_list.extend(embedding)

    def training_epoch_end(self, outputs):
        if (self.args.num_epochs - 1) == self.current_epoch:
            self.embedding_coreset = self.create_coreset(self.embedding_list)
            gc.collect()

    def on_test_start(self):
        self.init_results_list()
        self.sample_path = os.path.join(self.logger.log_dir, 'sample')
        os.makedirs(self.sample_path, exist_ok=True)

        coreset_file_path = os.path.join(self.args.model_path, f'{self.category}.pickle')
        self.load_model_data = pickle.load(open(coreset_file_path, 'rb'))
        self.embedding_coreset = self.load_model_data['model_data']
        # self.embedding_coreset = pickle.load(open(coreset_file_path, 'rb'))
        # self.knn = FaissKNeighbors(self.embedding_coreset, k=self.args.n_neighbors)
        # self.score_patches_array = None

    def test_step(self, batch, batch_idx):
        x, gt, label, file_name, x_type = batch

        features = self(x)
        embedding = features_to_embedding(features)
        embedding_test = np.array(embedding)

        # score_patches, index_info = self.knn.search(embedding_test)
        
        # score_patches = get_knn_distance_faiss(
        #     self.embedding_coreset,
        #     embedding_test,
        #     k=self.args.n_neighbors,
        #     batch_size=self.args.calc_distance_batch_size)

        score_patches = get_knn_distance(
            self.embedding_coreset,
            embedding_test,
            k=self.args.n_neighbors,
            batch_size=self.args.calc_distance_batch_size)

        # Image-level score
        nearest_distance = score_patches[np.argmax(score_patches[:, 0])]
        w = (1 - (np.max(np.exp(nearest_distance)) / np.sum(np.exp(nearest_distance))))
        
        if math.isnan(w):
            # distanceが大きすぎる場合の暫定処理
            w = (1 - (np.max(np.exp(nearest_distance * (1 / 16))) / np.sum(np.exp(nearest_distance * (1 / 16)))))
        
        score = w * max(score_patches[:, 0])

        # pixel level anomaly map
        reshape_size = np.sqrt(score_patches[:, 0].shape)[0].astype(np.int)
        anomaly_map = score_patches[:, 0].reshape((reshape_size, reshape_size))
        anomaly_map_resized = cv2.resize(anomaly_map, (self.args.input_size, self.args.input_size))
        anomaly_map_resized_blur = gaussian_filter(anomaly_map_resized, sigma=4)
        anomaly_map_norm = min_max_norm(anomaly_map_resized_blur)

        # set results
        gt_np = gt.cpu().numpy()[0, 0].astype(int)
        self.gt_pixel.append(gt_np)
        # self.gt_pixel.append(gt_np.ravel())
        # self.gt_pixel.extend(gt_np.ravel())

        self.score_pixel.append(anomaly_map_resized_blur)
        # self.score_pixel.append(anomaly_map_resized_blur)
        # self.score_pixel.append(anomaly_map_resized_blur.ravel())
        # self.score_pixel.extend(anomaly_map_resized_blur.ravel())

        if 1 < np.unique(gt_np.ravel()).size:
            # 正常、異常のピクセルがある場合のみ
            fpr, tpr, thr = roc_curve(gt_np.ravel(), anomaly_map_resized_blur.ravel())
            best_thr, best_idx = calc_roc_best_score(fpr, tpr, thr)
            self.best_thr_pixel.append(best_thr)

        self.gt_image.append(label.cpu().numpy()[0])
        self.score_image.append(score)
        self.img_path_list.extend(file_name)

        features.clear()
        del features
        embedding.clear()
        del embedding
        del embedding_test

    def test_epoch_end(self, outputs):
        score_image = np.array(self.score_image).ravel()
        norm_score = min_max_norm(score_image)
        self.score_image_norm = norm_score

        if self.embedding_coreset is not None:
            del self.embedding_coreset
        gc.collect()

    def create_coreset(self, embedding_list):
        total_embeddings = np.array(embedding_list)

        # Random projection
        # 'auto' => Johnson-Lindenstrauss lemma
        try:
            randomprojector = SparseRandomProjection(n_components='auto', eps=0.9)
            randomprojector.fit(total_embeddings)
        except Exception as ex:
            randomprojector = None

        # Coreset Subsampling
        selector = KCenterGreedy(X=total_embeddings, y=0, seed=0)

        selected_idx = selector.select_batch_torch(
            model=randomprojector,
            already_selected=[],
            N=int(total_embeddings.shape[0] * self.args.coreset_sampling_ratio))

        # selected_idx = selector.select_batch(
        #     model=randomprojector,
        #     already_selected=[],
        #     N=int(total_embeddings.shape[0] * self.args.coreset_sampling_ratio))

        embedding_coreset = total_embeddings[selected_idx]

        embedding_list.clear()
        del embedding_list
        del total_embeddings
        return embedding_coreset

    def save_coreset(self, save_file_path):
        try:
            with open(save_file_path, 'wb') as f:
                pickle.dump(
                    {
                        'args': self.args,
                        'model_data': self.embedding_coreset
                    }
                    , f)
                # pickle.dump(
                #     {
                #         'model': self.args.model,
                #         'backbone': self.args.backbone,
                #         'backbone_layers': self.args.backbone_layers,
                #         'model_data': self.embedding_coreset}
                #     , f)
                # pickle.dump(self.embedding_coreset, f)
        except Exception as ex:
            raise AnodetException(f'学習済みモデルの保存に失敗しました。\n {save_file_path}を確認してください', ex)

    def test_step_other(self, batch, batch_idx):
        x, gt, label, file_name, x_type = batch

        features = self(x)
        embedding = features_to_embedding(features)
        embedding_test = np.array(embedding)

        # score_patches, index_info = self.knn.search(embedding_test)

        # score_patches = get_knn_distance_faiss(
        #     self.embedding_coreset,
        #     embedding_test,
        #     k=self.args.n_neighbors,
        #     batch_size=self.args.calc_distance_batch_size)

        score_patches = get_knn_distance(
            self.embedding_coreset,
            embedding_test,
            k=self.args.n_neighbors,
            batch_size=self.args.calc_distance_batch_size)

        if self.score_patches_array is None:
            self.score_patches_array = score_patches[np.newaxis]
        else:
            self.score_patches_array = np.concatenate([self.score_patches_array, score_patches[np.newaxis]], axis=0)

        # set results
        gt_np = gt.cpu().numpy()[0, 0].astype(int)
        self.gt_pixel.append(gt_np)
        # self.gt_pixel.append(gt_np.ravel())
        # self.gt_pixel.extend(gt_np.ravel())

        self.gt_image.append(label.cpu().numpy()[0])
        self.img_path_list.extend(file_name)

        features.clear()
        del features
        embedding.clear()
        del embedding

    def test_epoch_other(self, outputs):
        self.score_patches_array = min_max_norm(self.score_patches_array)
        max_count = self.score_patches_array.shape[0]
        for index in range(max_count):
            score_patches = self.score_patches_array[index]
            gt = self.gt_pixel[index]
            self.eval_for_one_image(score_patches, gt)

        if self.embedding_coreset is not None:
            del self.embedding_coreset
        if self.score_patches_array is not None:
            del self.score_patches_array
        gc.collect()

    def eval_for_one_image(self, score_patches, gt):
        # score_patches, index_info = self.knn.search(embedding_test)

        # score_patches = get_knn_distance_faiss(
        #     self.embedding_coreset,
        #     embedding_test,
        #     k=self.args.n_neighbors,
        #     batch_size=self.args.calc_distance_batch_size)

        # score_patches = get_knn_distance(
        #     self.embedding_coreset,
        #     embedding_test,
        #     k=self.args.n_neighbors,
        #     batch_size=self.args.calc_distance_batch_size)

        # Image-level score
        nearest_distance = score_patches[np.argmax(score_patches[:, 0])]
        w = (1 - (np.max(np.exp(nearest_distance)) / np.sum(np.exp(nearest_distance))))
        if math.isnan(w):
            # distanceが大きすぎる場合の暫定処理
            w = (1 - (np.max(np.exp(nearest_distance * (1 / 16))) / np.sum(np.exp(nearest_distance * (1 / 16)))))
        score = w * max(score_patches[:, 0])

        # pixel level anomaly map
        reshape_size = np.sqrt(score_patches[:, 0].shape)[0].astype(np.int)
        anomaly_map = score_patches[:, 0].reshape((reshape_size, reshape_size))
        anomaly_map_resized = cv2.resize(anomaly_map, (self.args.input_size, self.args.input_size))
        anomaly_map_resized_blur = gaussian_filter(anomaly_map_resized, sigma=4)
        anomaly_map_norm = min_max_norm(anomaly_map_resized_blur)

        # set results
        # gt_np = gt.cpu().numpy()[0, 0].astype(int)
        # self.gt_pixel.append(gt_np)
        # # self.gt_pixel.append(gt_np.ravel())
        # # self.gt_pixel.extend(gt_np.ravel())

        self.score_pixel.append(anomaly_map_resized_blur)
        # self.score_pixel.append(anomaly_map_resized_blur)
        # self.score_pixel.append(anomaly_map_resized_blur.ravel())
        # self.score_pixel.extend(anomaly_map_resized_blur.ravel())

        if 1 < np.unique(gt.ravel()).size:
            # 正常、異常のピクセルがある場合のみ
            fpr, tpr, thr = roc_curve(gt.ravel(), anomaly_map_resized_blur.ravel())
            best_thr, best_idx = calc_roc_best_score(fpr, tpr, thr)
            self.best_thr_pixel.append(best_thr)

        # self.gt_image.append(label.cpu().numpy()[0])
        self.score_image.append(score)
        # self.img_path_list.extend(file_name)


In [14]:
import argparse
from distutils.util import strtobool

def get_args():
    parser = argparse.ArgumentParser(description='ANOMALYDETECTION')
    parser.add_argument('--phase', choices=['train', 'test'], default='train')
    parser.add_argument('--backbone', choices=['WideResNet50',
                                               'WideResNet101',
                                               'EfficientNetB5', 
                                               'EfficientNetB7'],
                        default='WideResNet50')
    parser.add_argument('--dataset_path', default='../dataset', type=str)
    parser.add_argument('--project_root_path', default='./results', type=str)
    parser.add_argument('--categories', default='', nargs='*', type=str)
    parser.add_argument('--num_epochs', default=1, type=int)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--load_size', default=256, type=int)
    parser.add_argument('--input_size', default=224, type=int)
    parser.add_argument('--coreset_sampling_ratio', default=0.01, type=float)
    parser.add_argument('--kernel_size', default=3, type=int, help='average poolingのkernel size')
    parser.add_argument('--stride', default=1, type=int, help='average poolingのstride')
    parser.add_argument('--padding', default=1, type=int, help='average poolingのpadding')
    parser.add_argument('--coreset_save_root', type=str, default='embeddings')
    parser.add_argument('--save_anomaly_map', default=True)
    parser.add_argument('--n_neighbors', type=int, default=9)
    parser.add_argument('--calc_distance_batch_size', type=int, default=512)
    parser.add_argument('--seed', type=int, default=8)
    parser.add_argument('--gpus', type=int, default=-1)
    parser.add_argument("--num_workers", type=int, default=0,
                        help="データを読み込む際に使用するスレッド数を指定")
    parser.add_argument("--is_plot_pixel_graph", type=strtobool, default=True, help="")
    parser.add_argument("--is_force_train", type=strtobool, default=False, help="")
    parser.add_argument("--random_crop", type=strtobool, default=False, help="")
    parser.add_argument('--iter', type=int, default=1)
    parser.add_argument('--comment', type=str, default='comment')

    args = parser.parse_args()
    return args

In [15]:
# args = get_args()

category = 'wcvt'
class Args:
    data = './data/penn'
    backbone = 'WideResNet50'
    num_epochs = 1
    batch_size = 16
    input_size = 224
    load_size=256
    is_plot_pixel_graph = 1
    is_force_train = 0
    kernel_size = 3
    stride = 1
    padding = 1
    coreset_sampling_ratio=0.01
    random_crop=0
    iter=1
    n_neighbors=9
    calc_distance_batch_size=512
    categories = category
    
args = Args()

In [16]:
model = PatchCore(args, category)

Downloading: "https://github.com/pytorch/vision/archive/v0.8.2.zip" to C:\Users\innat/.cache\torch\hub\v0.8.2.zip


In [17]:
model

PatchCore(
  (inv_normalize): Normalize(mean=[-2.1179039301310043, -2.0357142857142856, -1.5921568627450982], std=[4.366812227074235, 4.464285714285714, 3.9215686274509802])
  (backbone_model): WideResNet50(
    (pretrained_model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): C

In [23]:
model(torch.rand((1, 3, 224, 224)))[0].shape

torch.Size([1, 512, 28, 28])

In [24]:
model(torch.rand((1, 3, 224, 224)))[1].shape

torch.Size([1, 1024, 14, 14])