In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
from PIL import Image
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import random
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import wandb
import logging
import math

In [None]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, image_folder, transform=None):
        self.df = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df['img_path'].iloc[idx].split("/")[-1]
        img_path = os.path.join(self.image_folder, img_name)
        image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)

        return image

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
    ]
)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
memory_bank_path =  "./memory_bank/resnet18"  # memory bank를 저장할 경로

device

## PatchCore
- N: batch size
- N': memory bank size
- query: query size
- D: target_dim
- |A|: patch collection size
- |P|: patch_size

In [None]:
model = models.resnet18(weights = models.ResNet18_Weights.IMAGENET1K_V1)

model.eval()
test = nn.Sequential(*list(model.children())[:-3])
model = model.to(device)

# Patchcore Implementation - https://arxiv.org/abs/2106.08265
class PatchCore(nn.Module):
    def __init__(
        self,
        backbone,
        per_memory_bank_size=4,
        memory_bank_path="./memory_bank",
        device=device,
        target_dim=384 * 3,  # |D|
        patch_size=3,  # |P|
        d=3,  # nearest
    ):
        """
        Args:
            backbone: torch.nn.Module
            per_memory_bank_size: int
            memory_bank_path: str
            device: torch.device
            target_dim: int
            patch_size: int
            d: int
        """

        super().__init__()
        self.backbone = model
        self.backbone.eval()

        self.layer2_output, self.layer3_output = None, None
        self.register_hook_for_layer2()  # register hook for layer2
        self.register_hook_for_layer3()  # register hook for layer3

        self.memory_bank = None
        self.per_memory_bank_size = per_memory_bank_size
        self.memory_bank_path = memory_bank_path
        self.device = device
        self.target_dim = target_dim
        self.average_pool = nn.AdaptiveAvgPool1d(self.target_dim)
        self.patch_size = patch_size
        self.d = d

    def register_hook_for_layer2(self):
        self.backbone.layer2.register_forward_hook(self._register_hook_for_layer2)

    def _register_hook_for_layer2(self, module, input, output):  # (B, 128, 28, 28)
        layer2_output = output
        self.layer2_output = layer2_output  # (B, 128, 28, 28)
        # self.layer2_output = None

    def register_hook_for_layer3(self):
        self.backbone.layer3.register_forward_hook(self._register_hook_for_layer3)

    def _register_hook_for_layer3(self, module, input, output):  # (B, 256, 14, 14)
        layer3_output = output
        layer3_output = nn.functional.interpolate(
            layer3_output, scale_factor=2, mode="bilinear"
        )
        self.layer3_output = layer3_output  # (B, 256, 28, 28)

    def save_memory_bank(self, train_batch, file_name):
        """
        training batch를 받아서 memory bank에 저장한다.
        """
        self.backbone(train_batch)
        path = os.path.join(self.memory_bank_path, file_name)
        new_memory = self.patch_collection(self.patch_size)
        new_memory = new_memory.reshape(-1, new_memory.shape[-1])
        torch.save(new_memory, path)

    def loader_memory_bank(self):
        """
        memory bank를 이터레이터로 불러온다.
        """
        file_list = os.listdir(self.memory_bank_path)
        file_list = [file for file in file_list if file.endswith(".pth")]
        file_list.sort()
        for file in file_list:
            path = os.path.join(self.memory_bank_path, file)
            new_memory = torch.load(path, map_location=self.device)

            yield new_memory

    def forward(self, x):
        self.backbone(x)
        query = self.patch_collection(self.patch_size)  # (N, |A|, D)
        l2 = torch.Tensor([]).to(device)
        for memory in self.loader_memory_bank():
            l2 = torch.cat((l2, self.cal_l2(query, memory)), dim=1)
            del memory

        s_ = self.get_anomaly_score(l2)
        s = self.update_anomaly_score(s_, l2, self.d)
        return s

    def get_anomaly_score(self, l2):
        """
        Args:
            l2: (query, N', |A|)
        """
        min_l2 = l2.min(dim=1).values  # (query, |A|)
        max_min_l2 = min_l2.max(dim=1).values  # (query, )
        return max_min_l2

    def update_anomaly_score(self, s_, l2: torch.Tensor, d):
        """
        Args:
            l2: (query, N', |A|)
            d = nearest
        """
        m_train = l2.min(dim=1)  # (query, |A|)
        m_test = m_train.values.max(dim=1)  # (query)
        m_for_test = l2[:, :, m_test.indices]  # (query, N')
        m_train_nearest = m_for_test.topk(k=d, dim=1).values  # (query, d)

        update_weight = 1 - (
            torch.exp(m_test.values) / torch.exp(m_train_nearest).sum(dim=1)
        )
        s = s_ * update_weight
        return s

    def cal_l2(self, query, memory_bank):
        """
        return: (query_size, memory_bank_size, |A|)
        """
        memory_bank = memory_bank.unsqueeze(1)  # (N', 1, D)
        N, D = memory_bank.shape[0], memory_bank.shape[2]
        memory_bank = memory_bank.expand(N, query.shape[1], D)  # (N', |A|, D)
        l2 = []
        for q in query:
            q = q.unsqueeze(0)  # (1, |A|, D)
            diff = memory_bank - q  # (N', |A|, D)
            l2_ = diff.square().sum(dim=2)  # (N', |A|)
            l2_ = l2_.sqrt()
            l2.append(l2_)
            del q
        l2 = torch.stack(l2, dim=0)  # (query, N', |A|)
        return l2

    def feature(self, h, w):
        """
        return: (N, C)
        """
        H, W = self.layer2_output.shape[2], self.layer2_output.shape[3]
        if not (0 <= h < H and 0 <= w < W):
            return torch.tensor([]).to(self.device)
        layer2 = self.layer2_output[:, :, h, w]  # (B, C)

        # TODO: 아래 코드
        if self.layer3_output is not None:
            layer3 = self.layer3_output[:, :, h, w]  # (B, C')
        else:
            layer3 = torch.tensor([]).to(self.device)

        feature = torch.cat((layer2, layer3), dim=1)
        return feature

    def neighborhood_features(self, h, w, patch_size):
        """
        return: (N, |P|, C) -> path_size x patch_size
        """
        features = []
        for i in range(math.floor(-patch_size / 2), math.floor(patch_size / 2)):
            for j in range(math.floor(-patch_size / 2), math.floor(patch_size / 2)):
                feature = self.feature(h + i, w + j)
                if feature.shape[0] == 0:
                    continue
                features.append(feature)
        features = torch.stack(features, dim=1)
        return features

    def patch(self, h, w, patch_size):
        """
        return: (N, D)
        """
        features = self.neighborhood_features(h, w, patch_size)
        features = features.permute(0, 2, 1)  # (N, C, |P|)
        features = features.reshape(features.shape[0], -1)  # (N, C X |P|)
        features = self.average_pool(features)  # (N, target_dim) = (N, D)
        return features

    def patch_collection(self, patch_size):
        """
        return: (N, |A|, D)
        """
        H, W = self.layer2_output.shape[2], self.layer2_output.shape[3]

        patch_collection = []
        # average pooling으로 나눌 수 있게 중심만 고려한다.
        for h in range(math.ceil(patch_size / 2), H - math.ceil(patch_size / 2)):
            for w in range(math.ceil(patch_size / 2), W - math.ceil(patch_size / 2)):
                patch = self.patch(h, w, patch_size)
                patch_collection.append(patch)
        patch_collection = torch.stack(patch_collection, dim=1)
        return patch_collection

In [None]:
params = {
    "patch_size": 5,
    "per_memory_bank_size": 1,
    "d": 2,
    "target_dim": 1920,
}

In [None]:
# Finding Threshold of PatchCore
k = 13

# Threshold
data = pd.read_csv("./train.csv")
sum = 0.0
m = 213 // k


def save2(batch_size=1, config=dict()):  # Threshold 구하는 데 필요한 save 함수
    # memory_bank_path 폴더 초기화
    if os.path.exists(memory_bank_path):
        for file in os.listdir(memory_bank_path):
            os.remove(os.path.join(memory_bank_path, file))
    else:
        os.makedirs(memory_bank_path)

    patch_core = PatchCore(model, memory_bank_path=memory_bank_path, **config)
    train_data = CustomDataset(csv_file="./train2.csv", transform=transform)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False)

    for idx, x in enumerate(tqdm(train_loader)):
        x = x[0].to(device)
        patch_core.save_memory_bank(x, f"memory_bank_{batch_size*idx}.pth")


for i in range(m):
    j = i * k
    validation_indices = range(j, j + 13)  # 13개씩 슬라이싱
    validation_idx = data.loc[validation_indices]
    train_idx = data.drop(validation_indices)

    train_idx.to_csv("./train2.csv", index=False)
    validation_idx.to_csv("./validation.csv", index=False)

    save2(batch_size=1, config=params)

    patch_core = PatchCore(model, memory_bank_path=memory_bank_path, **params)
    test_data = CustomDataset(csv_file="./validation.csv", transform=transform)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
    anomaly_score = torch.tensor([], device=device)
    with torch.no_grad():
        for idx, x in tqdm(enumerate(test_loader), desc="query"):
            x = x[0].to(device)
            l2 = patch_core.forward(x)

            print(f"l2: {l2}")

            anomaly_score = torch.cat([anomaly_score, l2], dim=0)
            anomaly_score_sorted_idx = anomaly_score.sort(descending=True).indices

    anomaly_score = anomaly_score.cpu()

    max_value, max_index = torch.max(anomaly_score, 0)
    print(f"{i+1}th slice max value : {max_value.item():.4f} \n")
    sum += max_value.item()

In [None]:
Threshold = sum /m
Threshold

In [None]:
def save(batch_size=1, config=dict()):  # Inference에 필요한 save 함수
    # memory_bank_path 폴더 초기화
    if os.path.exists(memory_bank_path):
        for file in os.listdir(memory_bank_path):
            os.remove(os.path.join(memory_bank_path, file))
    else:
        os.makedirs(memory_bank_path)

    patch_core = PatchCore(model, memory_bank_path=memory_bank_path, **config)
    train_data = CustomDataset(csv_file="./train.csv", transform=transform)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False)

    for idx, x in enumerate(tqdm(train_loader)):
        x = x[0].to(device)
        patch_core.save_memory_bank(x, f"memory_bank_{batch_size*idx}.pth")

In [None]:
save(batch_size=1, config=params)

patch_core = PatchCore(model, memory_bank_path=memory_bank_path, **params)

test_data = CustomDataset(csv_file="./test.csv", transform=transform)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)
anomaly_score = torch.tensor([], device=device)
with torch.no_grad():
    for idx, x in tqdm(enumerate(test_loader), desc="query"):
        x = x[0].to(device)
        l2 = patch_core.forward(x)

        # for i in range(l2.shape[0]):
        #     wandb.log({"l2": l2[i].item()})
        print(f"l2: {l2}")

        anomaly_score = torch.cat([anomaly_score, l2], dim=0)
        anomaly_score_sorted_idx = anomaly_score.sort(descending=True).indices

anomaly_score = anomaly_score.cpu()
plt.plot(anomaly_score)
plt.show()

In [None]:
# Threshold 보다 크면 이상치로 판별
data = anomaly_score
data = [idx for idx, d in enumerate(data) if d > Threshold]
outliers_idx1 = data
print(outliers_idx1)