In [None]:
import sys; sys.path.append('..')
from mvtecad_test import *
from fastai.callbacks import SaveModelCallback
from PIL import ImageFilter

#fastai_progress_as_text()

VISUALIZE = True
PATH = Path('/mnt/dataset/mvtec_ad')

class DefectOnBlobImageList(AnomalyTwinImageList):
    BLOB_TH = 20
    WIDTH_MIN = 1
    WIDTH_MAX = 14
    LENGTH_MAX = 30
    COLOR = True

    @classmethod
    def set_params(cls, blob_th=20, width_min=1, width_max=14, length=225//5, color=True):
        cls.BLOB_TH = blob_th
        cls.WIDTH_MIN, cls.WIDTH_MAX = width_min, width_max
        cls.LENGTH_MAX, cls.COLOR = length, color

    def anomaly_twin(self, image):
        """Default anomaly twin maker."""
        np_img = np.array(image.filter(ImageFilter.SMOOTH)).astype(np.float32)
        scar_max = self.LENGTH_MAX
        half = self.SIZE // 2
        # Randomly choose point on object blob
        ys, xs = np.where(np.sum(np.abs(np.diff(np_img, axis=0)), axis=2) > self.BLOB_TH)
        x = random.choice(xs)
        ys_x = ys[np.where(xs == x)[0]]
        y = random.randint(ys_x.min(), ys_x.max())
        # Randomly choose other parameters
        dx, dy = random.randint(0, scar_max), random.randint(0, scar_max)
        x2, y2 = x + dx if x < half else x - dx, y + dy if y < half else y - dy
        c = (random.randint(0, 256), random.randint(0, 224), random.randint(0, 256))
        if not self.COLOR: c = (c[0], c[0], c[0])
        w = random.randint(self.WIDTH_MIN, self.WIDTH_MAX)
        ImageDraw.Draw(image).line((x, y, x2,y2), fill=c, width=w)
        return image

# 'capsule',
mvtecad = MVTecADTest(PATH, artificial_image_list_cls=DefectOnBlobImageList,
                      testcases = [ 'screw',], img_size=224)#, skip_data_creation=True)
DefectOnBlobImageList.set_params(blob_th=20, width_min=1, width_max=16, length=30, color=False)

In [None]:
mvtecad.set_test(0, 0)
data = mvtecad.databunch()
data.show_batch()

In [None]:
def learner_ArcFace(data):
    learn = cnn_learner(data, models.resnet34, metrics=accuracy)
    learn.model = XFaceNet(learn.model, data, ArcMarginProduct, m=0.5)
    learn.callback_fns.append(partial(LabelCatcher))
    learn.fit_one_cycle(10, max_lr=1e-2)
    learn.fit_one_cycle(10, max_lr=1e-4)
    learn.unfreeze()
    learn.fit_one_cycle(20, max_lr=slice(1e-6, 1e-5), callbacks=[SaveModelCallback(learn)])
    return learn

model_defs = {
    'ArcFace': learner_ArcFace,
}

for name, learner_fn in model_defs.items():
    for i, tc in enumerate(mvtecad.testcases):
        mvtecad.set_test(i, 0)
        mvtecad.test(name, learner_fn, vis_class=None)

In [None]:
paper_table2_compatible_result(mvtecad, reorder=False)