In [None]:
!rm -rf diffusionAD

In [None]:
!git clone https://github.com/nguyenduchuyiu/diffusionAD.git

In [None]:
%cd diffusionAD
!git checkout c9c39c9e4e96a65b06ad1e0c4509aeb0bfc4bc4f

In [None]:
%%writefile args/args1.json

{
  "img_size": [512,512],
  "Batch_Size": 8,
  "EPOCHS": 3000,
  "T": 1000,
  "base_channels": 128,
  "beta_schedule": "linear",
  "loss_type": "l2",
  "diffusion_lr": 1e-4,
  "seg_lr": 1e-5,
  "random_slice": true,
  "weight_decay": 0.0,
  "save_imgs":true,
  "save_vids":false, 
  "dropout":0,
  "attention_resolutions":"32,16,8",
  "num_heads":4,
  "num_head_channels":-1,
  "noise_fn":"gauss",
  "channels":3,
  "data_name":"RealIAD",
  "data_root_path":"/kaggle/input/pcb-dataset",
  "anomaly_source_path":"/kaggle/input/pcb-dataset/dtd",
  "noisier_t_range":600,
  "less_t_range":300,
  "condition_w":1,
  "eval_normal_t":200,
  "eval_noisier_t":400,
  "output_path":"outputs",
  "gradient_accumulation_steps": 2,
  "use_mixed_precision": true,
  "channel_mults": [1, 1, 1, 2, 2, 4, 4],
  "loss_weight": "uniform",
  "loss-type": "l2",
  "resume_training": true,
  "use_gradient_checkpointing": true,
  "use_bfloat16": false
}

In [None]:
!python3 -u src/train.py

In [None]:
import torch, gc

gc.collect() 
torch.cuda.empty_cache()  
torch.cuda.ipc_collect()  
from eval import testing, load_parameters, defaultdict_from_json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check for multiple GPUs
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")
if num_gpus > 1:
    print(f"Using {num_gpus} GPUs for evaluation")
elif num_gpus == 1:
    print("Using single GPU for evaluation")
else:
    print("Using CPU for evaluation")
file = "args1.json"
# load the json args
with open(f'./args/{file}', 'r') as f:
    args = json.load(f)
args['arg_num'] = file[4:-5]
args = defaultdict_from_json(args)
real_iad_classes = os.listdir(os.path.join(args["data_root_path"], args['data_name']))

current_classes = real_iad_classes
checkpoint_types = ['best', 'last']

for sub_class in current_classes:
    for checkpoint_type in checkpoint_types:
        try:
            args, output = load_parameters(device, sub_class, checkpoint_type)
        except FileNotFoundError:
            print(f"Checkpoint {checkpoint_type} not found for class {sub_class}, skipping.")
            continue

        print(f"args{args['arg_num']}")
        print("class", sub_class)
        
        in_channels = args["channels"]

        unet_model = UNetModel(args['img_size'][0], args['base_channels'], channel_mults=args['channel_mults'], dropout=args[
                    "dropout"], n_heads=args["num_heads"], n_head_channels=args["num_head_channels"],
                in_channels=in_channels
                ).to(device)

        seg_model = SegmentationSubNetwork(in_channels=6, out_channels=1).to(device)

        # Load model states
        unet_model.load_state_dict(output["unet_model_state_dict"])
        unet_model.to(device)
        
        seg_model.load_state_dict(output["seg_model_state_dict"])
        seg_model.to(device)
        
        # Enable multi-GPU for evaluation if available
        if num_gpus > 1:
            print(f"Wrapping models with DataParallel for {num_gpus} GPUs")
            unet_model = torch.nn.DataParallel(unet_model)
            seg_model = torch.nn.DataParallel(seg_model)
        
        unet_model.eval()
        seg_model.eval()

        print("EPOCH:", output['n_epoch'])

        testing_dataset = RealIADTestDataset(
            args["data_root_path"], sub_class, img_size=args["img_size"]
        )
        class_type = args['data_name']
                
        data_len = len(testing_dataset) 
        test_loader = DataLoader(testing_dataset, batch_size=1, shuffle=False, num_workers=4)

        # make arg specific directories
        for i in [f'{args["output_path"]}/metrics/ARGS={args["arg_num"]}/{sub_class}']:
            try:
                os.makedirs(i)
            except OSError:
                pass

        testing(test_loader, args, unet_model, seg_model, data_len, sub_class, class_type, checkpoint_type, device)


In [None]:
import sys
!{sys.executable} -m pip install matplotlib