In [None]:
from lightning import Trainer
import torch
from data_loader import MVTecDRAEMTestDataset, MVTecDRAEMTrainDataset
from torch.utils.data import DataLoader
import os
from models.Draem import DraemModel
import gc
torch.cuda.empty_cache()
gc.collect()


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def train_on_device(obj_names, args):

    if not os.path.exists(args.checkpoint_path):
        os.makedirs(args.checkpoint_path)

    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)

    for obj_name in obj_names:
        print(f'############## object_name: {args.data_path}/{obj_name}')
        dataset = MVTecDRAEMTrainDataset(args.data_path + obj_name + "/train/good/", args.anomaly_source_path, resize_shape=[256, 256], ignore_black_region=True)

        dataloader = DataLoader(dataset, batch_size=args.bs,
                                shuffle=True, num_workers=1)
        if args.load_from_checkpoint:
             trainer = Trainer( accelerator="cuda",devices=1,
             max_epochs=0)
             model = DraemModel(load_check=True, load_check_model=args.model_1_cpt, load_check_model_seg=args.model_2_cpt)
        else:
            model = DraemModel(lr= args.lr, epochs=args.epochs)
            trainer = Trainer( accelerator="cuda",
            devices=1,
            max_epochs=args.epochs)
            trainer.fit(model, dataloader)

        
        return trainer, model
    


import argparse

###########RUNNING ON SPLIT IMAGES
# Define your arguments and their default values
default_args = {
    'obj_id': 0,
    'bs': 8,
    'lr': 0.0001,
    'epochs': 700,
    'gpu_id': 0,
    'data_path': '/home/data/',
    'anomaly_source_path': '/home/datasets/dtd/images/',
    'checkpoint_path': './checkpoints/',
    'log_path': './logs/',
    'visualize': False,
    'load_from_checkpoint': False,
    'model_1_cpt':'',
    'model_2_cpt':''
}

# Create an ArgumentParser to get the argument names
parser = argparse.ArgumentParser()
for arg_name, arg_value in default_args.items():
    arg_type = type(arg_value)
    parser.add_argument(f'--{arg_name}', action='store', type=arg_type, default=arg_value)

# Parse the arguments using the default values
args = parser.parse_args([])


obj_batch = [   ['cars'],
                ['capsule'],
                 ['bottle'],
                 ['carpet'],
                 ['leather'],
                 ['pill'],
                 ['transistor'],
                 ['tile'],
                 ['cable'],
                 ['zipper'],
                 ['toothbrush'],
                 ['metal_nut'],
                 ['hazelnut'],
                 ['screw'],
                 ['grid'],
                 ['wood']
                 ]


picked_classes = obj_batch[int(args.obj_id)]

with torch.cuda.device(args.gpu_id):
        trainer, model = train_on_device(['cars_all'], args)


In [None]:
from lightning import Trainer
import torch
from data_loader import MVTecDRAEMTestDataset, MVTecDRAEMTrainDataset
from torch.utils.data import DataLoader
from torch import optim
from tensorboard_visualizer import TensorboardVisualizer
from pytorch_lightning.loggers import TensorBoardLogger
from models.Draem import DraemModel

def test(obj_name, trainer, model, dim=256):     
        dataset = MVTecDRAEMTestDataset('/home/data/' + obj_name + "/test/", resize_shape=[dim, dim])
        
        dataloader = DataLoader(dataset, batch_size=1,
                                shuffle=False, num_workers=0)

        trainer.test(model, dataloader)


        




obj_list = ['cars_all',
                    'cars_broken',
                    'cars_broken_wind',
                    'cars_dent',
                    'cars_scratch',
                    'cars_scratch_dent',
                    'cars_wind'
                     ]
obj_list = ['cars']
obj_list = [
    "broken_part_dent",
    "broken_part_scratch",
    "broken_part_scratch_dent",
    "broken_part_scratch_wind_screen",
    "broken_part_wind_screen",
    "broken_part_wind_screen_dent",
    "dent",
    "misc",
    "scratch",
    "scratch_dent",
    "scratch_wind_screen",
    "scratch_wind_screen_dent",
    "wind_screen",
    "wind_screen_dent"
]

#obj_list = ['cars']
obj_list = ['broken_part']

dims = [256]

for obj in obj_list:
        for dim in dims:
                print(obj, '##################')
                logger = TensorBoardLogger("logs", name="focused_lpips")
                trainer = Trainer( accelerator="cuda",
                        devices=1,
                        max_epochs=0,
                        logger=logger)
                model = DraemModel(load_check=True, load_check_models=['checkpoints/cars_focused_lpips_VAEAttention_700.pckl', 'checkpoints/cars_focused_lpips_VAEAttention_700_seg.pckl'], USE_MODEL='VAEAttention')
                test(obj, trainer, model, dim=dim)
                torch.cuda.empty_cache()