In [1]:
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import models
import json
from sklearn.model_selection import train_test_split
import warnings
import os
from PIL import Image
import matplotlib.pyplot as plt

from utils.dataloader import *
from models import Create_nets

In [2]:
def get_paths_mvtec_loco(args,verbose=True):
    

    anomaly_categories=args.dataset_parameters['anomaly_categories']
    category=args.data_category
    validation=args.dataset_parameters['use_validation']
    validation_split=args.dataset_parameters['validation_split']
    DATA_PATH=args.data_root
    
    NORMAL_PATH = os.path.join(DATA_PATH, f'{category}/train/good')
    VALIDATION_PATH = os.path.join(DATA_PATH, f'{category}/validation/good')
    ANOMALY_PATH = os.path.join(DATA_PATH , f'{category}/test')
    
    file_path = []
    for root, dirs, files in os.walk(ANOMALY_PATH):
        for file in files:
            file_path.append( os.path.join(root, file))
        
    anomaly_images_test=[item for item in file_path if "good" not in item]
    good_images_test=[item for item in file_path if "good" in item]
    
    normal_images=[os.path.join(NORMAL_PATH,item) for item in os.listdir(NORMAL_PATH)]
    validation_images=[os.path.join(VALIDATION_PATH,item) for item in os.listdir(VALIDATION_PATH)]

    
    n_samples = int((len(normal_images)+len(validation_images))*args.contamination_rate)
    valid_train_ratio=float(len(validation_images)/(len(normal_images)+len(validation_images)))
    
    
    
    sampled_anomalies_for_train, remaining_anomalies_test = stratified_sample(anomaly_images_test,anomaly_categories[category], n_samples, args.seed)

    if validation:
        warnings.warn(f"Vaidation split is set to {validation_split}, but the dataset is already split into train and validation sets by publisher. Ignoring validation split ratio.")
        sampled_anomalies_for_train, sampled_anomalies_for_val = train_test_split(sampled_anomalies_for_train, test_size=valid_train_ratio, random_state=args.seed)
    else:
        sampled_anomalies_for_val = []

    if verbose:
        print(f'category: {category}, normals train:  {len(normal_images)}, anomalies test: {len(anomaly_images_test)}, normal test: {len(good_images_test)}')       
        print(f'anomalies test total:     {count_files_by_class(anomaly_images_test, anomaly_categories[category])}')
        print(f'anomalies test sampled:   {count_files_by_class(sampled_anomalies_for_train, anomaly_categories[category])}')
        print(f'anomalies test remaining: {count_files_by_class(remaining_anomalies_test, anomaly_categories[category])}')
        
    return normal_images, validation_images, sampled_anomalies_for_train, sampled_anomalies_for_val, good_images_test, remaining_anomalies_test

In [3]:
def get_paths_mvtec(args,verbose=True,):
    
    anomaly_categories=args.dataset_parameters['anomaly_categories']
    category=args.data_category
    validation=args.dataset_parameters['use_validation']
    validation_split=args.dataset_parameters['validation_split']
    DATA_PATH=args.data_root
    
    
    NORMAL_PATH = os.path.join(DATA_PATH, f'{category}/train/good')
    ANOMALY_PATH = os.path.join(DATA_PATH , f'{category}/test')
    
        
    normal_images=[os.path.join(NORMAL_PATH,item) for item in os.listdir(NORMAL_PATH)]
    file_path = []
    for root, _, files in os.walk(ANOMALY_PATH):
        for file in files:
            file_path.append( os.path.join(root, file))
        
    anomaly_images_test=[item for item in file_path if "good" not in item]
    good_images_test=[item for item in file_path if "good" in item]
    

    
    n_samples = int(len(normal_images)*args.contamination_rate)
    
    sampled_anomalies_for_train, remaining_anomalies_test = stratified_sample(anomaly_images_test, anomaly_categories[category], n_samples, args.seed)

    if validation_split > 0:
        if validation!= True:
            raise ValueError('validation_split is set to > 0 but use_validation is set to False')
        normal_images,validation_images = train_test_split(normal_images, test_size=validation_split, random_state=args.seed)
        sampled_anomalies_for_train, sampled_anomalies_for_val = train_test_split(sampled_anomalies_for_train, test_size=validation_split, random_state=args.seed)
    else:
        sampled_anomalies_for_val = []
        validation_images = []


    if verbose:
        print(f'category: {category}, normals train:  {len(normal_images)}, anomalies test: {len(anomaly_images_test)}, normal test: {len(good_images_test)}')       
        print(f'anomalies test total:     {count_files_by_class(anomaly_images_test, anomaly_categories[category])}')
        print(f'anomalies test sampled:   {count_files_by_class(sampled_anomalies_for_train, anomaly_categories[category])}')
        print(f'anomalies test remaining: {count_files_by_class(remaining_anomalies_test, anomaly_categories[category])}')
    
    return normal_images, validation_images, sampled_anomalies_for_train, sampled_anomalies_for_val, good_images_test, remaining_anomalies_test

In [4]:
def get_paths_beantec(args,verbose=True,):
    
    anomaly_categories=args.dataset_parameters['anomaly_categories']
    category=args.data_category
    validation=args.dataset_parameters['use_validation']
    validation_split=args.dataset_parameters['validation_split']
    DATA_PATH=args.data_root
    
    NORMAL_PATH = os.path.join(DATA_PATH, f'{category}/train/ok')
    ANOMALY_PATH = os.path.join(DATA_PATH , f'{category}/test')
    
        
    normal_images=[os.path.join(NORMAL_PATH,item) for item in os.listdir(NORMAL_PATH)]
    file_path = []
    for root, _, files in os.walk(ANOMALY_PATH):
        for file in files:
            file_path.append( os.path.join(root, file))
        
    anomaly_images_test=[item for item in file_path if "ko" not in os.path.split(item)[0].split(os.sep)]
    good_images_test=[item for item in file_path if "ko" in os.path.split(item)[0].split(os.sep)]
    

    
    n_samples = int(len(normal_images)*args.contamination_rate)
    sampled_anomalies_for_train, remaining_anomalies_test = stratified_sample(anomaly_images_test, anomaly_categories[category], n_samples, args.seed)

    if validation_split > 0:
        if validation!= True:
            raise ValueError('validation_split is set to > 0 but use_validation is set to False')
        normal_images,validation_images = train_test_split(normal_images, test_size=validation_split, random_state=args.seed)
        sampled_anomalies_for_train, sampled_anomalies_for_val = train_test_split(sampled_anomalies_for_train, test_size=validation_split, random_state=args.seed)
    else:
        sampled_anomalies_for_val = []
        validation_images = []

    if verbose:
        print(f'category: {category}, normals train:  {len(normal_images)}, anomalies test: {len(anomaly_images_test)}, normal test: {len(good_images_test)}')       
        print(f'anomalies test total:     {count_files_by_class(anomaly_images_test, anomaly_categories[category])}')
        print(f'anomalies test sampled:   {count_files_by_class(sampled_anomalies_for_train, anomaly_categories[category])}')
        print(f'anomalies test remaining: {count_files_by_class(remaining_anomalies_test, anomaly_categories[category])}')
    
    return normal_images, validation_images, sampled_anomalies_for_train, sampled_anomalies_for_val, good_images_test, remaining_anomalies_test

In [5]:
class ImageDataset_mvtec(Dataset):
    def __init__(self, args, root, transforms_=None, mode='train', train_paths=None, test_paths=None):
        
        self.img_size = 280 * args.factor
        self.crop_size = 256 * args.factor
        self.args = args
        self.mode = mode
        if train_paths is None and test_paths is None:
            raise ValueError("either test or train paths must be provided depending on the mode")
        
        self.train_paths = train_paths
        self.test_paths = test_paths
        
        self.transform_train = transforms.Compose([ transforms.Resize((self.crop_size, self.crop_size), Image.BICUBIC),
                                                transforms.Pad(int(self.crop_size/10),fill=0,padding_mode='constant'),
                                                transforms.RandomRotation(10),
                                                transforms.RandomCrop((self.crop_size, self.crop_size)),
                                                transforms.ToTensor(),
                                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225 ]) ])
        
        self.resize_transform_loco = transforms.Resize((self.crop_size, self.crop_size), Image.BICUBIC)

        if mode == 'train':
            self.files = train_paths
        elif mode == 'test':
            self.files = test_paths
            
    def _align_transform(self, img, mask):
        #resize to 224
        img = TF.resize(img, self.crop_size, Image.BICUBIC)
        mask = TF.resize(mask, self.crop_size, Image.NEAREST)
        #toTensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        #normalize
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])
        return img, mask
    
    def _unalign_transform(self, img, mask):
        #resize to 256
        img = TF.resize(img, self.img_size, Image.BICUBIC)
        mask = TF.resize(mask, self.img_size, Image.NEAREST)
        #random rotation
        angle = transforms.RandomRotation.get_params([-10, 10])
        img = TF.rotate(img, angle, fill=(0,))
        mask = TF.rotate(mask, angle, fill=(0,))
        #random crop
        i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(self.crop_size, self.crop_size))
        img = TF.crop(img, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)
        #toTensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        #normalize
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])
        return img, mask

    def __getitem__(self, index):
        filename = self.files[index]
        img = Image.open(filename)
        img = img.convert('RGB')
        

        if self.args.mode=='mvtec_loco':
            img = self.resize_transform_loco(img)       

        
        if self.mode == 'train':
            img = self.transform_train(img)            
            return filename, img
        
        elif self.mode == 'test':
            transform_test = self._unalign_transform if self.args.unalign_test else self._align_transform
            img_size = (img.size[0], img.size[1])
            
            
            if 'good' in filename:    
                ground_truth =Image.new('L',(img_size[0],img_size[1]),0)
                img, ground_truth = transform_test(img, ground_truth)
                
                return filename, img, ground_truth, 0
            else:   
                # different ground truth schema for mvtec_loco
                if self.args.mode=='mvtec_loco':
                    ground_truth = Image.open(filename.replace("test", "ground_truth").replace(".png", "/000.png"))                        
                if self.args.mode=='mvtec':
                    ground_truth = Image.open(filename.replace("test", "ground_truth").replace(".png", "_mask.png"))
                
                ground_truth = self.resize_transform_loco(ground_truth)  
                img, ground_truth = transform_test(img, ground_truth)
                
                return filename, img, ground_truth, 1
            
    def __len__(self):
        return len(self.files)


In [6]:
class ImageDataset_beantec(Dataset):
    def __init__(self, args, root, transforms_=None, mode='train', train_paths=None, test_paths=None):
        
        self.img_size = 280 * args.factor
        self.crop_size = 256 * args.factor
        self.args = args
        self.mode = mode
        if train_paths is None and test_paths is None:
            raise ValueError("either test or train paths must be provided depending on the mode")
        
        self.train_paths = train_paths
        self.test_paths = test_paths
        
        self.transform_train = transforms.Compose([ transforms.Resize((self.crop_size, self.crop_size), Image.BICUBIC),
                                                transforms.Pad(int(self.crop_size/10),fill=0,padding_mode='constant'),
                                                transforms.RandomRotation(10),
                                                transforms.RandomCrop((self.crop_size, self.crop_size)),
                                                transforms.ToTensor(),
                                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225 ]) ])
        
        self.resize_transform_bean = transforms.Resize((self.crop_size, self.crop_size), Image.BICUBIC)

        if mode == 'train':
            self.files = train_paths
        elif mode == 'test':
            self.files = test_paths
            
    def _align_transform(self, img, mask):
        #resize to 224
        img = TF.resize(img, self.crop_size, Image.BICUBIC)
        mask = TF.resize(mask, self.crop_size, Image.NEAREST)
        #toTensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        #normalize
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])
        return img, mask
    
    def _unalign_transform(self, img, mask):
        #resize to 256
        img = TF.resize(img, self.img_size, Image.BICUBIC)
        mask = TF.resize(mask, self.img_size, Image.NEAREST)
        #random rotation
        angle = transforms.RandomRotation.get_params([-10, 10])
        img = TF.rotate(img, angle, fill=(0,))
        mask = TF.rotate(mask, angle, fill=(0,))
        #random crop
        i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(self.crop_size, self.crop_size))
        img = TF.crop(img, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)
        #toTensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        #normalize
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])
        return img, mask

    def __getitem__(self, index):
        filename = self.files[index]
        img = Image.open(filename)
        img = img.convert('RGB')
        img = self.resize_transform_bean(img)

        
        if self.mode == 'train':
            img = self.transform_train(img)            
            return filename, img
        
        elif self.mode == 'test':
            transform_test = self._unalign_transform if self.args.unalign_test else self._align_transform
            img_size = (img.size[0], img.size[1])
            
            
            if 'ok' in os.path.split(filename)[0].split(os.sep):    
                ground_truth =Image.new('L',(img_size[0],img_size[1]),0)
                img, ground_truth = transform_test(img, ground_truth)
                
                return filename, img, ground_truth, 0
            else:   
 
                ground_truth = Image.open(filename.replace("test", "ground_truth").replace(".bmp", ".png"))
                
                ground_truth = self.resize_transform_bean(ground_truth)  
                
                img, ground_truth = transform_test(img, ground_truth)
                
                return filename, img, ground_truth, 1
            
    def __len__(self):
        return len(self.files)

In [7]:

def Get_dataloader(args):

    if args.mode == 'mvtec' or args.mode == 'mvtec_loco':
        
        EXPERIMENT_LOG_PATH = os.path.join(args.results_dir,args.data_set ,
                                    f'contamination_{int(args.contamination_rate*100)}',
                                    f'{args.exp_name}-{args.data_category}')
        if os.path.exists(os.path.join(EXPERIMENT_LOG_PATH,"experiment_paths.json")) and not args.development:
            with open(os.path.join(EXPERIMENT_LOG_PATH,"experiment_paths.json"), "r") as file:
                experiment_paths = json.load(file)     
        else:    
            if args.mode == 'mvtec':    
                normal_images, validation_images, sampled_anomalies_for_train, sampled_anomalies_for_val, good_images_test, remaining_anomalies_test = get_paths_mvtec(args,verbose=True)
            if args.mode == 'mvtec_loco':
                normal_images, validation_images, sampled_anomalies_for_train, sampled_anomalies_for_val, good_images_test, remaining_anomalies_test = get_paths_mvtec_loco(args,verbose=True)
                
            train_paths = normal_images + sampled_anomalies_for_train
            valid_paths = validation_images + sampled_anomalies_for_val
            test_paths = good_images_test + remaining_anomalies_test
            experiment_paths={'train':train_paths,'test':test_paths,'valid':valid_paths,'contamination_rate':args.contamination_rate,'seed':args.seed}
            
            if not args.development:
                os.makedirs(EXPERIMENT_LOG_PATH, exist_ok=True)
                with open(os.path.join(EXPERIMENT_LOG_PATH,"experiment_paths.json"), "w") as file:
                    json.dump(experiment_paths, file)
        
        DATA_PATH=os.path.join(args.data_root,args.data_category)
        train_dataloader = DataLoader(ImageDataset_mvtec(args,DATA_PATH,mode='train',train_paths = experiment_paths['train'],test_paths = None),
                                                        batch_size=args.batch_size,shuffle=True,num_workers=args.n_cpu,drop_last=False)
        if len(experiment_paths['valid']) > 0:
            valid_dataloader = DataLoader(ImageDataset_mvtec(args,DATA_PATH,mode='test',train_paths = None,test_paths = experiment_paths['valid']),
                                                batch_size=args.batch_size,shuffle=False,num_workers=1,drop_last=False)
        else:
            valid_dataloader = None
        test_dataloader = DataLoader(ImageDataset_mvtec(args,DATA_PATH,mode='test',train_paths = None,test_paths = experiment_paths['test']),
                                                        batch_size=args.batch_size,shuffle=False,num_workers=1,drop_last=False)
    

    if args.mode == 'beantec':
        
        EXPERIMENT_LOG_PATH = os.path.join(args.results_dir,args.data_set ,
                                    f'contamination_{int(args.contamination_rate*100)}',
                                    f'{args.exp_name}-{args.data_category}')
        
        if os.path.exists(os.path.join(EXPERIMENT_LOG_PATH,"experiment_paths.json")) and not args.development:
            with open(os.path.join(EXPERIMENT_LOG_PATH,"experiment_paths.json"), "r") as file:
                experiment_paths = json.load(file)     
                
        else:    
            normal_images, validation_images, sampled_anomalies_for_train, sampled_anomalies_for_val, good_images_test, remaining_anomalies_test = get_paths_beantec(args,verbose=True)    
            train_paths = normal_images + sampled_anomalies_for_train
            valid_paths = validation_images + sampled_anomalies_for_val
            test_paths = good_images_test + remaining_anomalies_test
            experiment_paths={'train':train_paths,'test':test_paths,'valid':valid_paths,'contamination_rate':args.contamination_rate,'seed':args.seed}
            
            if not args.development:
                os.makedirs(EXPERIMENT_LOG_PATH, exist_ok=True)
                with open(os.path.join(EXPERIMENT_LOG_PATH,"experiment_paths.json"), "w") as file:
                    json.dump(experiment_paths, file)
        
        DATA_PATH=os.path.join(args.data_root,args.data_category)
        train_dataloader = DataLoader(ImageDataset_beantec(args,DATA_PATH,mode='train',train_paths = experiment_paths['train'],test_paths = None),
                                                        batch_size=args.batch_size,shuffle=True,num_workers=args.n_cpu,drop_last=False)
        if len(experiment_paths['valid']) > 0:
            valid_dataloader = DataLoader(ImageDataset_beantec(args,DATA_PATH,mode='test',train_paths = None,test_paths = experiment_paths['valid']),
                                                batch_size=args.batch_size,shuffle=False,num_workers=1,drop_last=False)
        else:
            valid_dataloader = None
        test_dataloader = DataLoader(ImageDataset_beantec(args,DATA_PATH,mode='test',train_paths = None,test_paths = experiment_paths['test']),
                                                        batch_size=args.batch_size,shuffle=False,num_workers=1,drop_last=False)
        
        
    
    if args.mode == 'visa':
        pass        
    if args.mode == 'utrad_mvtec':
        
        if args.contamination_rate != 0.0:
            raise ValueError("Contamination rate should be 0.0 for clean original implementation")
        DATA_PATH=os.path.join(args.data_root,args.data_category)
        
        train_dataloader = DataLoader(ImageDataset(args,DATA_PATH, mode='train'),
                                                    batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu, drop_last=False)
        test_dataloader = DataLoader(ImageDataset(args, DATA_PATH, mode='test'),
                            batch_size=args.batch_size, shuffle=False, num_workers=1, drop_last=False)


    return train_dataloader,valid_dataloader, test_dataloader

MVTEC


In [8]:
anomaly_categories={
    "bottle": ["broken_large", "broken_small", "contamination"],
    "cable": ["bent_wire", "cable_swap", "combined", "cut_inner_insulation", "cut_outer_insulation", "missing_cable", "missing_wire", "poke_insulation"],
    "capsule": ["crack", "faulty_imprint", "poke", "scratch","squeeze"],
    "carpet": ["color", "cut", "hole", "metal_contamination", "thread"],
    "grid": ["bent", "broken", "glue", "metal_contamination", "thread"],
    "hazelnut": ["crack", "cut", "hole", "print"],
    "leather": ["color", "cut", "fold", "glue", "poke"],
    "metal_nut": ["bent", "color", "flip", "scratch"],
    "pill": ["color", "combined","contamination", "crack", "faulty_imprint", "pill_type","scratch"],
    "screw": ["manipulated_front", "scratch_head", "scratch_neck","thread_side", "thread_top"],
    "tile": ["crack", "glue_strip", "gray_stroke", "oil","rough"],
    "toothbrush": ["defective"],
    "transistor": ["bent_lead", "cut_lead", "damaged_case", "misplaced"],
    "wood": ["color", "combined", "hole", "liquid", "scratch"],
    "zipper": ["broken_teeth", "combined","fabric_border", "fabric_interior","split_teeth","rough", "squeezed_teeth"]}


#args = TrainOptions().parse() # surpass kernelerror with this:
class TrainOptions:
    def __init__(self):
        self.exp_name = "Exp_18_02_24"
        self.epoch_start = 0
        self.epoch_num = 150
        self.factor = 1
        self.seed = 233
        self.num_row = 4
        self.activation = 'gelu'
        self.unalign_test = False
        self.data_root = '/home/bule/projects/datasets/mvtec_anomaly_detection/'
        self.data_category = "bottle"
        self.batch_size = 20
        self.lr = 1e-4
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.image_result_dir = 'result_images'
        self.model_result_dir = 'saved_models'
        self.validation_image_dir = 'validation_images'
        self.contamination_rate = 0.1
        self.validation= 0.0
        self.data_set = 'mvtec'
        self.mode = 'mvtec'
        self.results_dir = 'results'
        self.development = True
        
        
args = TrainOptions()
torch.manual_seed(args.seed)


try:
    with open(os.path.join('configurations',f'{args.data_set}.json' ), 'r') as file:
        dataset_parameters = json.load(file)
    setattr(args, 'dataset_parameters', dataset_parameters)
except FileNotFoundError:
    print(f"Configuration file for {args.data_set} not found. Proceeding with default parameters.")
    setattr(args, 'dataset_parameters', {})


train_dataloader, valid_dataloader , test_dataloader = Get_dataloader(args)
dataiter = iter(train_dataloader)

try:
    images, _ = next(dataiter) 
    print(images)# Use next() function to get the next batch
    print(f'anomalies per batch: {count_files_by_class(images, anomaly_categories[args.data_category])}')

except ValueError:  # Adjust this based on the structure your dataloader returns
    print("Error: Adjust the unpacking based on your dataloader's return value")


category: bottle, normals train:  209, anomalies test: 63, normal test: 20
anomalies test total:     {'broken_large': 20, 'broken_small': 22, 'contamination': 21}
anomalies test sampled:   {'broken_large': 6, 'broken_small': 6, 'contamination': 6}
anomalies test remaining: {'broken_large': 14, 'broken_small': 16, 'contamination': 15}


('/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/023.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/178.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/134.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/111.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/098.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/100.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/081.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/017.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/137.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/004.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/077.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/bottle/train/good/061.png', '/home/bule/pro

MVTEC LOCO

In [9]:
# loco
anomaly_categories = {
    'breakfast_box': ['logical_anomalies', 'structural_anomalies'],
    'juice_bottle': ['logical_anomalies', 'structural_anomalies'],
    'pushpins': ['logical_anomalies', 'structural_anomalies'],
    'screw_bag': ['logical_anomalies', 'structural_anomalies'],
    'splicing_connectors': ['logical_anomalies', 'structural_anomalies'],
}

#args = TrainOptions().parse() # surpass kernelerror with this:
class TrainOptions:
    def __init__(self):
        self.exp_name = "Exp_18_02_24"
        self.epoch_start = 0
        self.epoch_num = 150
        self.factor = 1
        self.seed = 233
        self.num_row = 4
        self.activation = 'gelu'
        self.unalign_test = False
        self.data_root = '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/'
        self.data_category = "breakfast_box"
        self.batch_size = 20
        self.lr = 1e-4
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.image_result_dir = 'result_images'
        self.model_result_dir = 'saved_models'
        self.validation_image_dir = 'validation_images'
        self.contamination_rate = 0.2
        self.validation= 0.0
        self.data_set = 'mvtec_loco'
        self.mode = 'mvtec_loco'
        self.results_dir = 'results'
        self.development = True
               
args = TrainOptions()
torch.manual_seed(args.seed)


try:
    with open(os.path.join('configurations',f'{args.data_set}.json' ), 'r') as file:
        dataset_parameters = json.load(file)
    setattr(args, 'dataset_parameters', dataset_parameters)
except FileNotFoundError:
    print(f"Configuration file for {args.data_set} not found. Proceeding with default parameters.")
    setattr(args, 'dataset_parameters', {})


train_dataloader, valid_dataloader , test_dataloader = Get_dataloader(args)
dataiter = iter(train_dataloader)

try:
    images, _ = next(dataiter) 
    print(images)# Use next() function to get the next batch
    print(f'anomalies per batch: {count_files_by_class(images, anomaly_categories[args.data_category])}')

except ValueError:  # Adjust this based on the structure your dataloader returns
    print("Error: Adjust the unpacking based on your dataloader's return value")


category: breakfast_box, normals train:  351, anomalies test: 173, normal test: 102
anomalies test total:     {'logical_anomalies': 83, 'structural_anomalies': 90}
anomalies test sampled:   {'logical_anomalies': 39, 'structural_anomalies': 42}
anomalies test remaining: {'logical_anomalies': 44, 'structural_anomalies': 48}


('/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/train/good/023.png', '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/train/good/179.png', '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/test/logical_anomalies/014.png', '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/train/good/140.png', '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/train/good/163.png', '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/train/good/065.png', '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/train/good/277.png', '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/train/good/067.png', '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/test/structural_anomalies/043.png', '/home/bule/projects/datasets/mvtec_loco_anomaly_detection/breakfast_box/train/good/256.png', '/home/bule/projects/datasets/mv

BEANTEC

In [11]:
# beantec
anomaly_categories = {
    "01": ["ko"],
    "02": ["ko"],
    "03": ["ko"]}



#args = TrainOptions().parse() # surpass kernelerror with this:
class TrainOptions:
    def __init__(self):
        self.exp_name = "Exp_18_02_24"
        self.epoch_start = 0
        self.epoch_num = 150
        self.factor = 1
        self.seed = 233
        self.num_row = 4
        self.activation = 'gelu'
        self.unalign_test = False
        self.data_root = '/home/bule/projects/datasets/BTech_Dataset_transformed/'
        self.data_category = "02"
        self.batch_size = 20
        self.lr = 1e-4
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.image_result_dir = 'result_images'
        self.model_result_dir = 'saved_models'
        self.validation_image_dir = 'validation_images'
        self.contamination_rate = 0.1
        self.validation= 0.0
        self.data_set = 'beantec'
        self.mode = 'beantec'
        self.results_dir = 'results'
        self.development = True
               
args = TrainOptions()
torch.manual_seed(args.seed)


try:
    with open(os.path.join('configurations',f'{args.data_set}.json' ), 'r') as file:
        dataset_parameters = json.load(file)
    setattr(args, 'dataset_parameters', dataset_parameters)
except FileNotFoundError:
    print(f"Configuration file for {args.data_set} not found. Proceeding with default parameters.")
    setattr(args, 'dataset_parameters', {})


train_dataloader, valid_dataloader , test_dataloader = Get_dataloader(args)
dataiter = iter(train_dataloader)

try:
    images, _ = next(dataiter) 
    print(images)# Use next() function to get the next batch
    print(f'anomalies per batch: {count_files_by_class(images, anomaly_categories[args.data_category])}')

except ValueError:  # Adjust this based on the structure your dataloader returns
    print("Error: Adjust the unpacking based on your dataloader's return value")


#BUG does not get the samples , probably because of class "01"  or "ko" and "ok"

category: 02, normals train:  399, anomalies test: 30, normal test: 200
anomalies test total:     {'ko': 0}
anomalies test sampled:   {'ko': 0}
anomalies test remaining: {'ko': 0}
('/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0275.png', '/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0228.png', '/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0101.png', '/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0246.png', '/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0357.png', '/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0117.png', '/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0058.png', '/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0036.png', '/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0334.png', '/home/bule/projects/datasets/BTech_Dataset_transformed/02/train/ok/0059.png', '/home/bule/projects/datasets