<a href="https://colab.research.google.com/github/hoa92ng/Homework/blob/main/embedding_homework.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""PatchCore and PatchCore detection methods."""
import logging
import os
import pickle

import numpy as np
import torch
import torch.nn.functional as F
import tqdm

import patchcore
import patchcore.backbones
import patchcore.common
import patchcore.sampler
from patchcore.siamese import SiameseNetwork

LOGGER = logging.getLogger(__name__)


class PatchCore(torch.nn.Module):
    def __init__(self, device):
        """PatchCore anomaly detection class."""
        super(PatchCore, self).__init__()
        self.device = device

    def load(
        self,
        backbone,
        layers_to_extract_from,
        device,
        input_shape,
        pretrain_embed_dimension,
        target_embed_dimension,
        patchsize=3,
        patchstride=1,
        anomaly_score_num_nn=1,
        featuresampler=patchcore.sampler.IdentitySampler(),
        nn_method=patchcore.common.FaissNN(False, 4),
        sub_model=True,
        **kwargs,
    ):
        self.backbone = backbone.to(device)
        self.layers_to_extract_from = layers_to_extract_from
        self.input_shape = input_shape

        self.device = device
        self.patch_maker = PatchMaker(patchsize, stride=patchstride)

        self.forward_modules = torch.nn.ModuleDict({})

        feature_aggregator = patchcore.common.NetworkFeatureAggregator(
            self.backbone, self.layers_to_extract_from, self.device
        )
        feature_dimensions = feature_aggregator.feature_dimensions(input_shape)
        self.forward_modules["feature_aggregator"] = feature_aggregator

        preprocessing = patchcore.common.Preprocessing(
            feature_dimensions, pretrain_embed_dimension
        )
        self.forward_modules["preprocessing"] = preprocessing

        self.target_embed_dimension = target_embed_dimension
        preadapt_aggregator = patchcore.common.Aggregator(
            target_dim=target_embed_dimension
        )

        _ = preadapt_aggregator.to(self.device)

        self.forward_modules["preadapt_aggregator"] = preadapt_aggregator

        self.anomaly_scorer = patchcore.common.NearestNeighbourScorer(
            n_nearest_neighbours=anomaly_score_num_nn, nn_method=nn_method
        )

        self.anomaly_segmentor = patchcore.common.RescaleSegmentor(
            device=self.device, target_size=input_shape[-2:]
        )

        self.featuresampler = featuresampler

        if sub_model:
            self.sub_model = SiameseNetwork(in_dimension=1536, out_dimension=target_embed_dimension).to(device=device)
            self.sub_model.load_state_dict(torch.load('./sub_model/siamese_wide_resnet50_2_final.pth'))

    def embed(self, data):
        if isinstance(data, torch.utils.data.DataLoader):
            features = []
            for image in data:
                if isinstance(image, dict):
                    image = image["image"]
                with torch.no_grad():
                    input_image = image.to(torch.float).to(self.device)
                    features.append(self._embed(input_image))
            return features
        return self._embed(data)

    def _embed(self, images, detach=True, provide_patch_shapes=False):
        """Returns feature embeddings for images."""

        def _detach(features):
            if detach:
                return [x.detach().cpu().numpy() for x in features]
            return features

        _ = self.forward_modules["feature_aggregator"].eval()
        with torch.no_grad():
            features = self.forward_modules["feature_aggregator"](images)

        features = [features[layer] for layer in self.layers_to_extract_from]

        features = [
            self.patch_maker.patchify(x, return_spatial_info=True) for x in features
        ]
        patch_shapes = [x[1] for x in features]
        features = [x[0] for x in features]
        ref_num_patches = patch_shapes[0]

        for i in range(1, len(features)):
            _features = features[i]
            patch_dims = patch_shapes[i]

            # TODO(pgehler): Add comments
            _features = _features.reshape(
                _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
            )
            _features = _features.permute(0, -3, -2, -1, 1, 2)
            perm_base_shape = _features.shape
            _features = _features.reshape(-1, *_features.shape[-2:])
            _features = F.interpolate(
                _features.unsqueeze(1),
                size=(ref_num_patches[0], ref_num_patches[1]),
                mode="bilinear",
                align_corners=False,
            )
            _features = _features.squeeze(1)
            _features = _features.reshape(
                *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
            )
            _features = _features.permute(0, -2, -1, 1, 2, 3)
            _features = _features.reshape(len(_features), -1, *_features.shape[-3:])
            features[i] = _features
        features = [x.reshape(-1, *x.shape[-3:]) for x in features]

        # As different feature backbones & patching provide differently
        # sized features, these are brought into the correct form here.
        features = self.forward_modules["preprocessing"](features)
        features = self.forward_modules["preadapt_aggregator"](features)

        if provide_patch_shapes:
            return _detach(features), patch_shapes
        return _detach(features)

    def fit(self, training_data):
        """PatchCore training.

        This function computes the embeddings of the training data and fills the
        memory bank of SPADE.
        """
        # self._fill_memory_bank(training_data)
        self._fill_memory_bank_with_sub_model(training_data)

    def get_feature(self, input_image):
        """Computes and sets the support features for SPADE."""

        _ = self.forward_modules.eval()
        with torch.no_grad():
            input_image = input_image.to(torch.float).to(self.device)
            _ = self.forward_modules["feature_aggregator"].eval()
            features = self.forward_modules["feature_aggregator"](input_image)
            features = [features[layer] for layer in self.layers_to_extract_from]
            return features

    def _fill_memory_bank(self, input_data):
        """Computes and sets the support features for SPADE."""
        _ = self.forward_modules.eval()

        def _image_to_features(input_image):
            with torch.no_grad():
                input_image = input_image.to(torch.float).to(self.device)
                return self._embed(input_image)

        features = []
        with tqdm.tqdm(
            input_data, desc="Computing support features...", position=1, leave=False
        ) as data_iterator:
            for image in data_iterator:
                if isinstance(image, dict):
                    image = image["image"]
                features.append(_image_to_features(image))

        features = np.concatenate(features, axis=0)
        features = self.featuresampler.run(features)

        self.anomaly_scorer.fit(detection_features=[features])

    def _fill_memory_bank_with_sub_model(self, input_data, detach=True, provide_patch_shapes=False):
        """Computes and sets the support features for SPADE."""

        def _detach(features):
            if detach:
                return [x.detach().cpu().numpy() for x in features]
            return features

        _ = self.forward_modules.eval()

        def _image_to_features(input_image):
            with torch.no_grad():
                input_image = input_image.to(torch.float).to(self.device)
                out_feature = self.get_feature(input_image=input_image)
                patch_shapes = [[x.shape[-2], x.shape[-1]] for x in out_feature]
                img0, img1 = out_feature[0].to(self.device), out_feature[1].to(self.device)
                out_feature = self.sub_model.forward_once(img0, img1)
                out_feature = out_feature.permute(0,2,3,1)
                out_feature = torch.reshape(out_feature, (-1, self.target_embed_dimension))

                if provide_patch_shapes:
                    return _detach(out_feature), patch_shapes
                return _detach(out_feature)

        def _image_to_features_raw(input_image):
            with torch.no_grad():
                input_image = input_image.to(torch.float).to(self.device)
                return self._embed(input_image)

        features = []
        with tqdm.tqdm(
            input_data, desc="Computing support features...", position=1, leave=False
        ) as data_iterator:
            for image in data_iterator:
                if isinstance(image, dict):
                    image = image["image"]
                # features.append( _image_to_features(image))
                features.append([(x - y) * 2 for x, y in zip(_image_to_features_raw(image), _image_to_features(image))])

        features = np.concatenate(features, axis=0)
        features = self.featuresampler.run(features)

        self.anomaly_scorer.fit(detection_features=[features])


    def _embed_with_sub_model(self, input_data, detach=True, provide_patch_shapes=False):
        def _detach(features):
            if detach:
                return [x.detach().cpu().numpy() for x in features]
            return features
        def _image_to_features(input_image):
            with torch.no_grad():
                input_image = input_image.to(torch.float).to(self.device)
                out_feature = self.get_feature(input_image=input_image)
                patch_shapes = [[x.shape[-2], x.shape[-1]] for x in out_feature]
                img0, img1 = out_feature[0].to(self.device), out_feature[1].to(self.device)
                out_feature = self.sub_model.forward_once(img0, img1)
                out_feature = out_feature.permute(0,2,3,1)
                out_feature = torch.reshape(out_feature, (-1, self.target_embed_dimension))

                if provide_patch_shapes:
                    return _detach(out_feature), patch_shapes
                return _detach(out_feature)

        def _image_to_features_raw(input_image):
            with torch.no_grad():
                input_image = input_image.to(torch.float).to(self.device)
                return self._embed(input_image)

        return _image_to_features(input_image=input_data)

    def predict(self, data):
        if isinstance(data, torch.utils.data.DataLoader):
            return self._predict_dataloader(data)
        return self._predict(data)

    def _predict_dataloader(self, dataloader):
        """This function provides anomaly scores/maps for full dataloaders."""
        _ = self.forward_modules.eval()

        scores = []
        masks = []
        labels_gt = []
        masks_gt = []
        with tqdm.tqdm(dataloader, desc="Inferring...", leave=False) as data_iterator:
            for image in data_iterator:
                if isinstance(image, dict):
                    labels_gt.extend(image["is_anomaly"].numpy().tolist())
                    masks_gt.extend(image["mask"].numpy().tolist())
                    image = image["image"]
                _scores, _masks = self._predict(image)
                for score, mask in zip(_scores, _masks):
                    scores.append(score)
                    masks.append(mask)
        return scores, masks, labels_gt, masks_gt

    def _predict(self, images):
        """Infer score and mask for a batch of images."""
        images = images.to(torch.float).to(self.device)
        _ = self.forward_modules.eval()

        batchsize = images.shape[0]
        with torch.no_grad():
            features_raw, patch_shapes_raw = self._embed(images, provide_patch_shapes=True)
            features, patch_shapes = self._embed_with_sub_model(images, provide_patch_shapes=True)
            features = [(x - y) * 2 for x, y in zip(features_raw, features)]
            features = np.asarray(features)

            patch_scores = image_scores = self.anomaly_scorer.predict([features])[0]
            image_scores = self.patch_maker.unpatch_scores(
                image_scores, batchsize=batchsize
            )
            image_scores = image_scores.reshape(*image_scores.shape[:2], -1)
            image_scores = self.patch_maker.score(image_scores)

            patch_scores = self.patch_maker.unpatch_scores(
                patch_scores, batchsize=batchsize
            )
            scales = patch_shapes[0]
            patch_scores = patch_scores.reshape(batchsize, scales[0], scales[1])

            masks = self.anomaly_segmentor.convert_to_segmentation(patch_scores)

        return [score for score in image_scores], [mask for mask in masks]

    @staticmethod
    def _params_file(filepath, prepend=""):
        return os.path.join(filepath, prepend + "patchcore_params.pkl")

    def save_to_path(self, save_path: str, prepend: str = "") -> None:
        LOGGER.info("Saving PatchCore data.")
        self.anomaly_scorer.save(
            save_path, save_features_separately=False, prepend=prepend
        )
        patchcore_params = {
            "backbone.name": self.backbone.name,
            "layers_to_extract_from": self.layers_to_extract_from,
            "input_shape": self.input_shape,
            "pretrain_embed_dimension": self.forward_modules[
                "preprocessing"
            ].output_dim,
            "target_embed_dimension": self.forward_modules[
                "preadapt_aggregator"
            ].target_dim,
            "patchsize": self.patch_maker.patchsize,
            "patchstride": self.patch_maker.stride,
            "anomaly_scorer_num_nn": self.anomaly_scorer.n_nearest_neighbours,
        }
        with open(self._params_file(save_path, prepend), "wb") as save_file:
            pickle.dump(patchcore_params, save_file, pickle.HIGHEST_PROTOCOL)

    def load_from_path(
        self,
        load_path: str,
        device: torch.device,
        nn_method: patchcore.common.FaissNN(False, 4),
        prepend: str = "",
    ) -> None:
        LOGGER.info("Loading and initializing PatchCore.")
        with open(self._params_file(load_path, prepend), "rb") as load_file:
            patchcore_params = pickle.load(load_file)
        patchcore_params["backbone"] = patchcore.backbones.load(
            patchcore_params["backbone.name"]
        )
        patchcore_params["backbone"].name = patchcore_params["backbone.name"]
        del patchcore_params["backbone.name"]
        self.load(**patchcore_params, device=device, nn_method=nn_method)

        self.anomaly_scorer.load(load_path, prepend)


# Image handling classes.
class PatchMaker:
    def __init__(self, patchsize, stride=None):
        self.patchsize = patchsize
        self.stride = stride

    def patchify(self, features, return_spatial_info=False):
        """Convert a tensor into a tensor of respective patches.
        Args:
            x: [torch.Tensor, bs x c x w x h]
        Returns:
            x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
            patchsize]
        """
        padding = int((self.patchsize - 1) / 2)
        unfolder = torch.nn.Unfold(
            kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1
        )
        unfolded_features = unfolder(features)
        number_of_total_patches = []
        for s in features.shape[-2:]:
            n_patches = (
                s + 2 * padding - 1 * (self.patchsize - 1) - 1
            ) / self.stride + 1
            number_of_total_patches.append(int(n_patches))
        unfolded_features = unfolded_features.reshape(
            *features.shape[:2], self.patchsize, self.patchsize, -1
        )
        unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)

        if return_spatial_info:
            return unfolded_features, number_of_total_patches
        return unfolded_features

    def unpatch_scores(self, x, batchsize):
        return x.reshape(batchsize, -1, *x.shape[1:])

    def score(self, x):
        was_numpy = False
        if isinstance(x, np.ndarray):
            was_numpy = True
            x = torch.from_numpy(x)
        while x.ndim > 1:
            x = torch.max(x, dim=-1).values
        if was_numpy:
            return x.numpy()
        return x


In [None]:
import numpy as np
import torch.utils.data
from torch import nn
from torchvision import models
import sys, os
# print(os.getcwd())
sys.path.append('./src')
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import patchcore
from patchcore import patchcore as patchcore_model

from siamese import SiameseNetwork

from torch.utils.data import TensorDataset, ConcatDataset, DataLoader

_DATASETS = {"mvtec": ["patchcore.datasets.mvtec", "MVTecDataset"]}
_sub_dataset = ('bottle',  'cable',  'capsule',  'carpet',  'grid', 'hazelnut', 'leather',  'metal_nut',  'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper')
_data_path = r'D:\Non_Documents\9.Document\Dataset\mvtech_anomaly_detection\mvtech_anomaly_detection'

def _standard_patchcore(image_dimension):
    patchcore_instance = patchcore_model.PatchCore(torch.device("cpu"))
    backbone = models.wide_resnet50_2(pretrained=False)
    backbone.name, backbone.seed = "wideresnet50", 0
    patchcore_instance.load(
        backbone=backbone,
        layers_to_extract_from=["layer2", "layer3"],
        device=torch.device("cpu"),
        input_shape=[3, image_dimension, image_dimension],
        pretrain_embed_dimension=1024,
        target_embed_dimension=1024,
        patchsize=3,
        patchstride=1,
        spade_nn=2,
    )
    return patchcore_instance

def _dummy_constant_dataloader(number_of_examples, shape_of_examples):
    features = _dummy_features(number_of_examples, shape_of_examples)
    return torch.utils.data.DataLoader(features, batch_size=1)


def _dummy_features(number_of_examples, shape_of_examples):
    return torch.Tensor(
        np.stack(number_of_examples * [np.ones(shape_of_examples)], axis=0)
    )

def dataset(
    name="mvtec",
    data_path=_data_path,
    subdatasets=_sub_dataset,
    train_val_split=1,
    batch_size=2,
    resize=256,
    imagesize=224,
    num_workers=8,
    augment=False,
):
    dataset_info = _DATASETS[name]
    dataset_library = __import__(dataset_info[0], fromlist=[dataset_info[1]])

    def get_dataloaders(seed=0):
        dataloaders = []
        for subdataset in subdatasets:
            train_dataset = dataset_library.__dict__[dataset_info[1]](
                data_path,
                classname=subdataset,
                resize=resize,
                train_val_split=train_val_split,
                imagesize=imagesize,
                split=dataset_library.DatasetSplit.TRAIN,
                seed=seed,
                augment=augment,
            )

            test_dataset = dataset_library.__dict__[dataset_info[1]](
                data_path,
                classname=subdataset,
                resize=resize,
                imagesize=imagesize,
                split=dataset_library.DatasetSplit.TEST,
                seed=seed,
            )

            train_dataloader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=num_workers,
                pin_memory=True,
            )

            test_dataloader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=num_workers,
                pin_memory=True,
            )

            train_dataloader.name = name
            if subdataset is not None:
                train_dataloader.name += "_" + subdataset

            if train_val_split < 1:
                val_dataset = dataset_library.__dict__[dataset_info[1]](
                    data_path,
                    classname=subdataset,
                    resize=resize,
                    train_val_split=train_val_split,
                    imagesize=imagesize,
                    split=dataset_library.DatasetSplit.VAL,
                    seed=seed,
                )

                val_dataloader = torch.utils.data.DataLoader(
                    val_dataset,
                    batch_size=batch_size,
                    shuffle=False,
                    num_workers=num_workers,
                    pin_memory=True,
                )
            else:
                val_dataloader = None
            dataloader_dict = {
                "training": train_dataloader,
                "validation": val_dataloader,
                "testing": test_dataloader,
            }

            dataloaders.append(dataloader_dict)
        return dataloaders

    return ("get_dataloaders", get_dataloaders)

def test_dummy_patchcore(image_path):
    image_dimension = 112
    model = _standard_patchcore(image_dimension)
    training_dataloader = _dummy_constant_dataloader(
        4, [3, image_dimension, image_dimension]
    )
    print(model.featuresampler)
    model.fit(training_dataloader)

    # test_features = torch.Tensor(2 * np.ones([2, 3, image_dimension, image_dimension]))
    # scores, masks = model.predict(test_features)

    # assert all([score > 0 for score in scores])
    # for mask in masks:
    #     assert np.all(mask.shape == (image_dimension, image_dimension))

def reset_params(model):
    for layer in model.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

if __name__ == "__main__":
    epoch_num = 3
    network = SiameseNetwork(in_dimension=1536).to(device='cuda')
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
    _, x = dataset(num_workers=1, train_val_split=1)
    _dataset = x(0)
    image_dimension = 224
    model = _standard_patchcore(image_dimension)

    dt_training = ConcatDataset([x['training'].dataset for x in _dataset])
    dtloader_training = DataLoader(dt_training, batch_size=2, shuffle=True)

    for epoch in range(epoch_num):
        print(f"=========== epoch {epoch} ===========")
        loss_item = 0
        network.train()
        for i, image in enumerate(dtloader_training):
            feature = model.get_feature(image["image"])
            img0, img1 = feature[0], feature[1]
            img0, img1 = img0.cuda(), img1.cuda()
            if (img0.shape[0] < 2 or  img1.shape[0] < 2 ): continue

            optimizer.zero_grad()
            output1, output2 = network(img0, img1)
            if image['classname'][0] == image['classname'][1]:
                loss = criterion(output1, output2)
            else:
                loss = 1 / criterion(output1, output2)
            loss.backward()
            optimizer.step()
            loss_item += loss.item()
            if i % 5 == 0:
                print(f"Epoch [{epoch+1}/{epoch_num}], Step [{i}/{len(dtloader_training)}], Loss: {loss.item():.4f}")

    torch.save(network.state_dict(), f'./sub_model/siamese_wide_resnet50_2_final.pth')
        # # valid
        # for id, data in enumerate(_dataset):
        #     if id == 0:
        #         for i, image in enumerate(data['validation']):
        #             if (image[0].shape[0] < 2): continue
        #             feature = model.get_feature(image["image"])
        #             img0, img1 = feature[0], feature[1]
        #             img0, img1 = img0.cuda(), img1.cuda()

        #             optimizer.zero_grad()
        #             output1, output2 = network(img0, img1)
        #             loss = criterion(output1, output2)
        #             loss.backward()
        #             optimizer.step()
        #             if i % 2 == 0:
        #                 print(f"Epoch [{epoch+1}/{2}], Step [{i}/{len(data['training'])}], Loss: {loss.item():.4f}")



In [None]:
import torch
from torch import nn

from torchsummary import summary

class SiameseNetwork(nn.Module):
    def __init__(self, in_dimension=512, out_dimension=1024):
        super(SiameseNetwork, self).__init__()

        self.out_dimension = out_dimension

        self.upsampling = nn.UpsamplingBilinear2d(scale_factor=2)

        self.cnn = nn.Conv2d(in_channels=in_dimension, out_channels=out_dimension, kernel_size=3, padding='same')
        self.batchnorm = nn.BatchNorm2d(out_dimension)
        self.activation = nn.ReLU()

        self.linear_sub_1 = nn.LazyLinear(out_features=256)
        self.linear_sub_2 = nn.LazyLinear(out_features=1024)
        self.global_pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward_once(self, x1, x2):
    # def forward(self, x):
        # Forward pass
        x = self.upsampling(x2)
        # x = torch.unsqueeze(x, dim=0)
        x = torch.concat((x1, x), dim=1)
        x = self.cnn(x)
        x = self.batchnorm(x)
        x = self.activation(x)

        global_pooling = self.global_pooling(x)

        attention_score = torch.transpose(global_pooling, -3, -1)
        attention_score = self.linear_sub_1(attention_score)
        attention_score = self.relu(attention_score)
        attention_score = self.linear_sub_2(attention_score)
        attention_score = self.sigmoid(attention_score)
        attention_score = torch.transpose(attention_score, -3, -1)

        output = torch.mul(x, attention_score)

        # output = torch.reshape(output, (-1, self.out_dimension))
        return output

    def forward(self, input1, input2):
        # forward pass of input 1
        output1 = self.forward_once(torch.unsqueeze(input1[0,:,:,:], dim=0), torch.unsqueeze(input2[0,:,:,:], dim=0))
        # forward pass of input 2
        output2 = self.forward_once(torch.unsqueeze(input1[1,:,:,:], dim=0), torch.unsqueeze(input2[1,:,:,:], dim=0))
        return output1, output2


# network = SiameseNetwork(1536, 1024).to(device='cuda')
# summary(network, ((1536, 14, 14), (1536, 14, 14)))
# tensor = torch.randn(4, 1536, 14, 14).to(device='cuda')
# print(network(tensor, tensor).shape)