In [None]:
from fastai.basics import *
from fastai.vision.core import *
from fastai.vision.data import *
from fastai.vision.augment import *
from fastai.vision.models.unet import *
from fastai.vision.learner import *
from fastai.vision.models import *
from fastai.callback.wandb import *
import wandb 

In [None]:
wandb.init(project="mvtec")

Dataset: https://www.mvtec.com/company/research/datasets/mvtec-ad

In [None]:
data_url='https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz'

In [None]:
dest=Path('/home/molly/.fastai/data/mvtec')

In [None]:
#fastai doesn;t recognize .xz files
def file_extract(fname, dest=None):
    "Extract `fname` to `dest` using `tarfile` or `zipfile`."
    if dest is None: dest = Path(fname).parent
    fname = str(fname)
    if   fname.endswith('gz'):  tarfile.open(fname, 'r:gz').extractall(dest)
    elif fname.endswith('xz'):  tarfile.open(fname, 'r:xz').extractall(dest)
    elif fname.endswith('zip'): zipfile.ZipFile(fname     ).extractall(dest)
    else: raise Exception(f'Unrecognized archive: {fname}')

In [None]:
path=untar_data(data_url,dest=dest,extract_func=file_extract) #breaks with '-' dashes?
path=path.parent
path.absolute()

In [None]:
carpet_path=path/'carpet'/'test'
defect_folders=carpet_path.ls().map(lambda p:p.stem).filter(lambda p:p!='good')
image_files=get_image_files(carpet_path,folders=defect_folders)

In [None]:
Image.open(image_files[10])

In [None]:
def label_func(p): return p.parent.parent.parent/'ground_truth'/p.parent.stem/(p.stem+'_mask'+p.suffix)

Realized the mask values are two values [0,255]. Divided by 255, where `div_mask=255.`, to work with pytorch. 

In [None]:
@patch 
def encodes(self:SpaceTfm,x:TensorImage): 
    with torch.no_grad(): return self.space_fn(x,partial(compose_tfms, tfms=self.fs))

In [None]:
image_tfms=aug_transforms(mult=1.0,
    do_flip=True,
    flip_vert=True,
    max_rotate=10.0,
    size=(512,512),
    min_zoom=1.0,
    max_zoom=2.,
    max_lighting=0.2,
    max_warp=0.2,
    p_affine=0.75,
    p_lighting=0.75,
    xtra_tfms=[Saturation(max_lighting=0.1, p=0.75),
               Hue(max_hue=0.1, p=0.75)],
    mode='bilinear',
    pad_mode='reflection',
    align_corners=True,
    batch=False,
    min_scale=1.0,)

In [None]:
mvtec = DataBlock(blocks=(ImageBlock, MaskBlock(['good','bad'])),
                   get_items = partial(get_image_files,folders=defect_folders),
                   get_y = label_func,
                   splitter=RandomSplitter(),
                   batch_tfms=image_tfms+[IntToFloatTensor(div_mask=255.) ])

In [None]:
Path('/home/molly/.fastai/data/mvtec/carpet/ground_truth/').ls()

In [None]:
dls = mvtec.dataloaders(carpet_path,bs=8)

In [None]:
dls.show_batch()

In [None]:
learn = unet_learner(dls, resnet34,cbs=[WandbCallback]).to_fp16()
learn.freeze()
#learn.lr_find()

In [None]:
learn.fit_one_cycle(8,lr_max=0.0005)

In [None]:
mvtec = DataBlock(blocks=(ImageBlock, MaskBlock(['good','bad'])),
                   get_items = partial(get_image_files,folders=defect_folders),
                   get_y = label_func,
                   splitter=RandomSplitter(),
                   batch_tfms=aug_transforms(size=(512,512))+[IntToFloatTensor(div_mask=255.) ])
dls = mvtec.dataloaders(carpet_path,bs=8)
learn.dls=dls

In [None]:
#learn.lr_find()

In [None]:
learn.unfreeze()
learn.fit_one_cycle(8,lr_max=0.0001)

In [None]:
learn.show_results(max_n=2)