In [2]:
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import models
import json
from sklearn.model_selection import train_test_split

import os
from PIL import Image
import matplotlib.pyplot as plt

from utils.dataloader import *
from models import Create_nets

In [2]:
anomaly_categories = {
    'breakfast_box': ['logical_anomalies', 'structural_anomalies'],
    'juice_bottle': ['logical_anomalies', 'structural_anomalies'],
    'pushpins': ['logical_anomalies', 'structural_anomalies'],
    'screw_bag': ['logical_anomalies', 'structural_anomalies'],
    'splicing_connectors': ['logical_anomalies', 'structural_anomalies'],
}

In [None]:
def get_paths_mvtec_loco(contamination=0.0,category='breakfast_box',DATA_PATH='/home/bule/projects/datasets/mvtec_anomaly_detection',verbose=True,seed=223,valid_split=0.1):
    
    
    ## ADD args.specific data
    anomaly_categories = {
        'breakfast_box': ['logical_anomalies', 'structural_anomalies'],
        'juice_bottle': ['logical_anomalies', 'structural_anomalies'],
        'pushpins': ['logical_anomalies', 'structural_anomalies'],
        'screw_bag': ['logical_anomalies', 'structural_anomalies'],
        'splicing_connectors': ['logical_anomalies', 'structural_anomalies']} 
    
    NORMAL_PATH = os.path.join(DATA_PATH, f'{category}/train/good')
    VALIDATION_PATH = os.path.join(DATA_PATH, f'{category}/validation/good')
    ANOMALY_PATH = os.path.join(DATA_PATH , f'{category}/test')
    
    file_path = []
    for root, dirs, files in os.walk(ANOMALY_PATH):
        for file in files:
            file_path.append( os.path.join(root, file))
        
    anomaly_images_test=[item for item in file_path if "good" not in item]
    good_images_test=[item for item in file_path if "good" in item]
    
    normal_images=[os.path.join(NORMAL_PATH,item) for item in os.listdir(NORMAL_PATH)]
    validation_images=[os.path.join(VALIDATION_PATH,item) for item in os.listdir(VALIDATION_PATH)]

    
    n_samples = int((len(normal_images)+len(validation_images))*contamination)
    
    sampled_anomalies_for_train, remaining_anomalies_test = stratified_sample(anomaly_images_test, anomaly_categories[category], n_samples, seed)

    if valid_split > 0:
        sampled_anomalies_for_train, sampled_anomalies_for_val = train_test_split(sampled_anomalies_for_train, test_size=valid_split, random_state=seed)
    else:
        sampled_anomalies_for_val = []

    if verbose:
        print(f'category: {category}, normals train:  {len(normal_images)}, normal validiation:   {len(validation_images)}, anomalies test: {len(anomaly_images_test)}, normal test: {len(good_images_test)}')       
        print(f'anomalies test total:     {count_files_by_class(anomaly_images_test, anomaly_categories[category])}')
        print(f'anomalies test sampled:   {count_files_by_class(sampled_anomalies_for_train, anomaly_categories[category])}')
        print(f'anomalies test remaining: {count_files_by_class(remaining_anomalies_test, anomaly_categories[category])}')
        
    return normal_images, validation_images, sampled_anomalies_for_train, sampled_anomalies_for_val, good_images_test, remaining_anomalies_test

In [None]:
def get_paths_mvtec(args,contamination=0.0,category='bottle',DATA_PATH='/home/bule/projects/datasets/mvtec_anomaly_detection',verbose=True,seed=123,valid_split=0.0):
    
    
    ## ADD args.specific data

    
    anomaly_categories = {
    'bottle': ['broken_large', 'broken_small', 'contamination'],
    'cable': ['bent_wire', 'cable_swap', 'combined', 'cut_inner_insulation', 'cut_outer_insulation', 'missing_cable', 'missing_wire', 'poke_insulation'],
    'capsule': ['crack', 'faulty_imprint', 'poke', 'scratch','squeeze'],
    'carpet': ['color', 'cut', 'hole', 'metal_contamination', 'thread'],
    'grid': ['bent', 'broken', 'glue', 'metal_contamination', 'thread'],
    'hazelnut': ['crack', 'cut', 'hole', 'print'],
    'leather': ['color', 'cut', 'fold', 'glue', 'poke'],
    'metal_nut': ['bent', 'color', 'flip', 'scratch'],
    'pill': ['color', 'combined','contamination', 'crack', 'faulty_imprint', 'pill_type','scratch'],
    'screw': ['manipulated_front', 'scratch_head', 'scratch_neck','thread_side', 'thread_top'],
    'tile': ['crack', 'glue_strip', 'gray_stroke', 'oil','rough'],
    'toothbrush': ['defective'],
    'transistor': ['bent_lead', 'cut_lead', 'damaged_case', 'misplaced'],
    'wood': ['color', 'combined', 'hole', 'liquid', 'scratch'],
    'zipper': ['broken_teeth', 'combined','fabric_border', 'fabric_interior','split_teeth','rough', 'squeezed_teeth']}
    
    NORMAL_PATH = os.path.join(DATA_PATH, f'{category}/train/good')
    ANOMALY_PATH = os.path.join(DATA_PATH , f'{category}/test')
    
    
    print(NORMAL_PATH)
    
    normal_images=[os.path.join(NORMAL_PATH,item) for item in os.listdir(NORMAL_PATH)]
    file_path = []
    for root, dirs, files in os.walk(ANOMALY_PATH):
        for file in files:
            file_path.append( os.path.join(root, file))
        
    anomaly_images_test=[item for item in file_path if "good" not in item]
    good_images_test=[item for item in file_path if "good" in item]
    

    
    n_samples = int(len(normal_images)*contamination)
    
    sampled_anomalies_for_train, remaining_anomalies_test = stratified_sample(anomaly_images_test, anomaly_categories[category], n_samples, seed)

    if valid_split > 0:
        normal_images,validation_images = train_test_split(normal_images, test_size=valid_split, random_state=seed)
        sampled_anomalies_for_train, sampled_anomalies_for_val = train_test_split(sampled_anomalies_for_train, test_size=valid_split, random_state=seed)
    else:
        sampled_anomalies_for_val = []
        validation_images = []


    if verbose:
        print(f'category: {category}, normals train:  {len(normal_images)}, anomalies test: {len(anomaly_images_test)}, normal test: {len(good_images_test)}')       
        print(f'anomalies test total:     {count_files_by_class(anomaly_images_test, anomaly_categories[category])}')
        print(f'anomalies test sampled:   {count_files_by_class(sampled_anomalies_for_train, anomaly_categories[category])}')
        print(f'anomalies test remaining: {count_files_by_class(remaining_anomalies_test, anomaly_categories[category])}')
    
    return normal_images, validation_images, sampled_anomalies_for_train, sampled_anomalies_for_val, good_images_test, remaining_anomalies_test

In [3]:
class ImageDataset_mvtec(Dataset):
    def __init__(self, args, root, transforms_=None, mode='train', train_paths=None, test_paths=None):
        
        self.img_size = 280 * args.factor
        self.crop_size = 256 * args.factor
        self.args = args
        self.mode = mode
        if train_paths is None and test_paths is None:
            raise ValueError("either test or train paths must be provided depending on the mode")
        
        self.train_paths = train_paths
        self.test_paths = test_paths
        
        self.transform_train = transforms.Compose([ transforms.Resize((self.crop_size, self.crop_size), Image.BICUBIC),
                                                transforms.Pad(int(self.crop_size/10),fill=0,padding_mode='constant'),
                                                transforms.RandomRotation(10),
                                                transforms.RandomCrop((self.crop_size, self.crop_size)),
                                                transforms.ToTensor(),
                                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225 ]) ])
        if mode == 'train':
            self.files = train_paths
        elif mode == 'test':
            self.files = test_paths
            
    def _align_transform(self, img, mask):
        #resize to 224
        img = TF.resize(img, self.crop_size, Image.BICUBIC)
        mask = TF.resize(mask, self.crop_size, Image.NEAREST)
        #toTensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        #normalize
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])
        return img, mask
    
    def _unalign_transform(self, img, mask):
        #resize to 256
        img = TF.resize(img, self.img_size, Image.BICUBIC)
        mask = TF.resize(mask, self.img_size, Image.NEAREST)
        #random rotation
        angle = transforms.RandomRotation.get_params([-10, 10])
        img = TF.rotate(img, angle, fill=(0,))
        mask = TF.rotate(mask, angle, fill=(0,))
        #random crop
        i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(self.crop_size, self.crop_size))
        img = TF.crop(img, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)
        #toTensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        #normalize
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])
        return img, mask

    def __getitem__(self, index):
        filename = self.files[index]
        img = Image.open(filename)
        img = img.convert('RGB')
        
        
        
        # CHANGE here for MVTEC
        if self.mode == 'train':
            img = self.transform_train(img)
            return filename, img
        
        elif self.mode == 'test':
            transform_test = self._unalign_transform if self.args.unalign_test else self._align_transform
            img_size = (img.size[0], img.size[1])
            
            if 'good' in filename:
                ground_truth = Image.new('L',(img_size[0],img_size[1]),0)
                img, ground_truth = transform_test(img, ground_truth)
                return filename, img, ground_truth, 0
            else:
                ground_truth = Image.open(filename.replace("test", "ground_truth").replace(".png", "_mask.png"))
                img, ground_truth = transform_test(img, ground_truth)
                return filename, img, ground_truth, 1


    def __len__(self):
        return len(self.files)




def Get_dataloader(args):

    if args.mode == 'mvtec' or args.mode == 'mvtec_loco':
        ## get dataset paths for test and train 
        EXPERIMENT_LOG_PATH = os.path.join(args.results_dir,args.data_set ,
                                    f'contamination_{int(args.contamination_rate*100)}',
                                    f'{args.exp_name}-{args.data_category}',
                                    "experiment_paths.json")
        
        # Load if already exists
        if os.path.exists(EXPERIMENT_LOG_PATH):
            with open(EXPERIMENT_LOG_PATH, "r") as file:
                experiment_paths = json.load(file)
        else:    
            if args.mode == 'mvtec':    
                # TODO change args.specific data, only args as input
                normal_images, sampled_anomalies_for_train, good_images_test, remaining_anomalies_test = get_paths_mvtec(contamination=args.contamination_rate,
                                                                                                                            category=args.data_category,
                                                                                                                            DATA_PATH=args.data_root,
                                                                                                                            verbose=True)
                valid_paths=None
                train_paths = normal_images + sampled_anomalies_for_train
                
                # TODO change args.specific data, only args as input
            if args.mode == 'mvtec_loco':
                normal_images, validation_images, sampled_anomalies_for_train, sampled_anomalies_for_val, good_images_test, remaining_anomalies_test = get_paths_mvtec_loco(contamination=args.contamination_rate,
                                                                                                                            category=args.data_category,
                                                                                                                            DATA_PATH=args.data_root,
                                                                                                                            verbose=True)
                
            train_paths = normal_images + sampled_anomalies_for_train
            valid_paths = validation_images + sampled_anomalies_for_val
            test_paths = good_images_test + remaining_anomalies_test
            
            
            experiment_paths={'train':train_paths,'test':test_paths,'valid':valid_paths,'contamination_rate':args.contamination_rate,'seed':args.seed}

            with open(os.path.join(EXPERIMENT_LOG_PATH), "w") as file:
                json.dump(experiment_paths, file)
        
        
        DATA_PATH=os.path.join(args.data_root,args.data_category)
        train_dataloader = DataLoader(ImageDataset_mvtec(args,DATA_PATH,mode='train',train_paths = experiment_paths['train'],test_paths = None),
                                                        batch_size=args.batch_size,shuffle=True,num_workers=args.n_cpu,drop_last=False)
        if len(experiment_paths['valid']) > 0:
            valid_dataloader = DataLoader(ImageDataset_mvtec(args,DATA_PATH,mode='test',train_paths = None,test_paths = experiment_paths['valid']),
                                                batch_size=args.batch_size,shuffle=False,num_workers=1,drop_last=False)
        else:
            valid_dataloader = None
        

        test_dataloader = DataLoader(ImageDataset_mvtec(args,DATA_PATH,mode='test',train_paths = None,test_paths = experiment_paths['test']),
                                                        batch_size=args.batch_size,shuffle=False,num_workers=1,drop_last=False)
    

    if args.mode == 'beantec':
        pass
    if args.mode == 'visa':
        pass        
    if args.mode == 'utrad_mvtec':
        
        if args.contamination_rate != 0.0:
            raise ValueError("Contamination rate should be 0.0 for clean original implementation")
        DATA_PATH=os.path.join(args.data_root,args.data_category)
        
        train_dataloader = DataLoader(ImageDataset(args,DATA_PATH, mode='train'),
                                                    batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu, drop_last=False)
        test_dataloader = DataLoader(ImageDataset(args, DATA_PATH, mode='test'),
                            batch_size=args.batch_size, shuffle=False, num_workers=1, drop_last=False)


    return train_dataloader,valid_dataloader, test_dataloader

In [5]:
#args = TrainOptions().parse() # surpass kernelerror with this:
class TrainOptions:
    def __init__(self):
        self.exp_name = "Exp0-r18"
        self.epoch_start = 0
        self.epoch_num = 150
        self.factor = 1
        self.seed = 233
        self.num_row = 4
        self.activation = 'gelu'
        self.unalign_test = False
        self.data_root = '/home/bule/projects/datasets/mvtec_anomaly_detection/'
        self.dataset_name = "cable"
        self.batch_size = 20
        self.lr = 1e-4
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.image_result_dir = 'result_images'
        self.model_result_dir = 'saved_models'
        self.validation_image_dir = 'validation_images'
        self.contamination_rate = 0.1
        self.validation= 0.0
               
args = TrainOptions()
torch.manual_seed(args.seed)

train_dataloader, test_dataloader = Get_dataloader(args, mode='contaminated_mvtec')
dataiter = iter(train_dataloader)

try:
    images, _ = next(dataiter) 
    print(images)# Use next() function to get the next batch
    print(f'anomalies per batch: {count_files_by_class(images, anomaly_categories[args.dataset_name])}')

except ValueError:  # Adjust this based on the structure your dataloader returns
    print("Error: Adjust the unpacking based on your dataloader's return value")


/home/bule/projects/datasets/mvtec_anomaly_detection/cable/train/good
category: cable, normals train:  224, anomalies test: 92, normal test: 58
anomalies test total:     {'bent_wire': 13, 'cable_swap': 12, 'combined': 11, 'cut_inner_insulation': 14, 'cut_outer_insulation': 10, 'missing_cable': 12, 'missing_wire': 10, 'poke_insulation': 10}
anomalies test sampled:   {'bent_wire': 3, 'cable_swap': 2, 'combined': 2, 'cut_inner_insulation': 3, 'cut_outer_insulation': 2, 'missing_cable': 2, 'missing_wire': 2, 'poke_insulation': 2}
anomalies test remaining: {'bent_wire': 10, 'cable_swap': 10, 'combined': 9, 'cut_inner_insulation': 11, 'cut_outer_insulation': 8, 'missing_cable': 10, 'missing_wire': 8, 'poke_insulation': 8}
Train paths: 242
Test paths: 132
('/home/bule/projects/datasets/mvtec_anomaly_detection/cable/train/good/140.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/cable/train/good/043.png', '/home/bule/projects/datasets/mvtec_anomaly_detection/cable/test/poke_insulati