In [55]:
from utils.dataloader_utils import *

from torch.utils.data import Dataset


import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
from models import Create_nets
#from options import TrainOptions
from torchvision import models
import os
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms

#args = TrainOptions().parse() # surpass kernelerror with this:

class TrainOptions:
    def __init__(self):
        self.exp_name = "Exp0-r18"
        self.epoch_start = 0
        self.epoch_num = 150
        self.factor = 1
        self.seed = 233
        self.num_row = 4
        self.activation = 'gelu'
        self.unalign_test = False
        self.data_root = '/home/bule/projects/datasets/mvtec_anomaly_detection/'
        self.dataset_name = "cable"
        self.batch_size = 2
        self.lr = 1e-4
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.image_result_dir = 'result_images'
        self.model_result_dir = 'saved_models'
        self.validation_image_dir = 'validation_images'
        self.contamination_rate = 0.1
        
        
        
# Example of how to use this class
args = TrainOptions()

torch.manual_seed(args.seed)

<torch._C.Generator at 0x7fe617bff430>

In [54]:
anomaly_categories = {
    'bottle': ['broken_large', 'broken_small', 'contamination'],
    'cable': ['bent_wire', 'cable_swap', 'combined', 'cut_inner_insulation', 'cut_outer_insulation', 'missing_cable', 'missing_wire', 'poke_insulation'],
    'capsule': ['crack', 'faulty_imprint', 'poke', 'scratch','squeeze'],
    'carpet': ['color', 'cut', 'hole', 'metal_contamination', 'thread'],
    'grid': ['bent', 'broken', 'glue', 'metal_contamination', 'thread'],
    'hazelnut': ['crack', 'cut', 'hole', 'print'],
    'leather': ['color', 'cut', 'fold', 'glue', 'poke'],
    'metal_nut': ['bent', 'color', 'flip', 'scratch'],
    'pill': ['color', 'combined','contamination', 'crack', 'faulty_imprint', 'pill_type','scratch'],
    'screw': ['manipulated_front', 'scratch_head', 'scratch_neck','thread_side', 'thread_top'],
    'tile': ['crack', 'glue_strip', 'gray_stroke', 'oil','rough'],
    'toothbrush': ['defective'],
    'transistor': ['bent_lead', 'cut_lead', 'damaged_case', 'misplaced'],
    'wood': ['color', 'combined', 'hole', 'liquid', 'scratch'],
    'zipper': ['broken_teeth', 'combined','fabric_border', 'fabric_interior','split_teeth','rough', 'squeezed_teeth']
}

In [61]:
def Get_dataloader(args,mode='original_mvtec'):

    # TODO test the dataloader for contaminated mvtec

    if mode == 'contaminated_mvtec':
        ## get dataset paths for test and train 
        
        normal_images, sampled_anomalies_for_train, good_images_test, remaining_anomalies_test = get_paths_mvtec(contamination=args.contamination_rate,
                                                                                                                 category=args.dataset_name,
                                                                                                                 DATA_PATH=args.data_root,
                                                                                                                 verbose=True)
        
        train_paths = normal_images + sampled_anomalies_for_train
        test_paths = good_images_test + remaining_anomalies_test
        
        print(f"Train paths: {len(train_paths)}")
        print(f"Test paths: {len(test_paths)}")
        
        train_dataloader = DataLoader(ImageDataset_mvtec(args, "%s/%s" % (args.data_root,args.dataset_name), mode='train', train_paths = train_paths, test_paths = test_paths),
                        batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu, drop_last=False)

        test_dataloader = DataLoader(ImageDataset_mvtec(args, "%s/%s" % (args.data_root,args.dataset_name), mode='test', train_paths = train_paths, test_paths = test_paths),
                            batch_size=args.batch_size, shuffle=False, num_workers=1, drop_last=False)


    # if mode == 'original_mvtec':
    #     train_dataloader = DataLoader(ImageDataset(args, "%s/%s" % (args.data_root,args.dataset_name), mode='train'),
    #                     batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu, drop_last=False)

    #     test_dataloader = DataLoader(ImageDataset(args, "%s/%s" % (args.data_root,args.dataset_name), mode='test'),
    #                         batch_size=args.batch_size, shuffle=False, num_workers=1, drop_last=False)


    return train_dataloader, test_dataloader



class ImageDataset_mvtec(Dataset):
    def __init__(self, args, root, transforms_=None, mode='train', train_paths=None, test_paths=None):
        
        self.img_size = 280 * args.factor
        self.crop_size = 256 * args.factor
        self.args = args
        self.mode = mode
        if train_paths is None or test_paths is None:
            raise ValueError("train_paths and test_paths must be provided")
        
        self.train_paths = train_paths
        self.test_paths = test_paths
        
        self.transform_train = transforms.Compose([ transforms.Resize((self.crop_size, self.crop_size), Image.BICUBIC),
                                                transforms.Pad(int(self.crop_size/10),fill=0,padding_mode='constant'),
                                                transforms.RandomRotation(10),
                                                transforms.RandomCrop((self.crop_size, self.crop_size)),
                                                transforms.ToTensor(),
                                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                    std=[0.229, 0.224, 0.225 ]) ])
        if mode == 'train':
            self.files = train_paths
        elif mode == 'test':
            self.files = test_paths
            
    def _align_transform(self, img, mask):
        #resize to 224
        img = TF.resize(img, self.crop_size, Image.BICUBIC)
        mask = TF.resize(mask, self.crop_size, Image.NEAREST)
        #toTensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        #normalize
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])
        return img, mask
    
    def _unalign_transform(self, img, mask):
        #resize to 256
        img = TF.resize(img, self.img_size, Image.BICUBIC)
        mask = TF.resize(mask, self.img_size, Image.NEAREST)
        #random rotation
        angle = transforms.RandomRotation.get_params([-10, 10])
        img = TF.rotate(img, angle, fill=(0,))
        mask = TF.rotate(mask, angle, fill=(0,))
        #random crop
        i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(self.crop_size, self.crop_size))
        img = TF.crop(img, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)
        #toTensor
        img = TF.to_tensor(img)
        mask = TF.to_tensor(mask)
        #normalize
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])
        return img, mask

    def __getitem__(self, index):
        filename = self.files[index]
        img = Image.open(filename)
        img = img.convert('RGB')
        
        if self.mode == 'train':
            img = self.transform_train(img)
            return filename, img
        
        elif self.mode == 'test':
            transform_test = self._unalign_transform if self.args.unalign_test else self._align_transform
            img_size = (img.size[0], img.size[1])
            
            if 'good' in filename:
                ground_truth = Image.new('L',(img_size[0],img_size[1]),0)
                img, ground_truth = transform_test(img, ground_truth)
                return filename, img, ground_truth, 0
            else:
                ground_truth = Image.open(filename.replace("test", "ground_truth").replace(".png", "_mask.png"))
                img, ground_truth = transform_test(img, ground_truth)
                return filename, img, ground_truth, 1


In [62]:
## BUG: The dataloader is not working for the contaminated mvtec dataset
Get_dataloader(args,mode='contaminated_mvtec')

category: cable, normals train:  224, anomalies test: 92, normal test: 58
anomalies test total: {'bent_wire': 13, 'cable_swap': 12, 'combined': 11, 'cut_inner_insulation': 14, 'cut_outer_insulation': 10, 'missing_cable': 12, 'missing_wire': 10, 'poke_insulation': 10}
anomalies test sampled: {'bent_wire': 3, 'cable_swap': 2, 'combined': 2, 'cut_inner_insulation': 3, 'cut_outer_insulation': 2, 'missing_cable': 2, 'missing_wire': 2, 'poke_insulation': 2}
anomalies test remaining: {'bent_wire': 10, 'cable_swap': 10, 'combined': 9, 'cut_inner_insulation': 11, 'cut_outer_insulation': 8, 'missing_cable': 10, 'missing_wire': 8, 'poke_insulation': 8}
Train paths: 242
Test paths: 132


TypeError: object of type 'ImageDataset_mvtec' has no len()