In [1]:
import json
import os
from collections import defaultdict

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
from accelerate.utils import set_seed
os.chdir("/workspace/DiffusionAD")
from data.dataset_beta_thresh import (
    DAGMTestDataset,
    DAGMTrainDataset,
    MPDDTestDataset,
    MPDDTrainDataset,
    MVTecTestDataset,
    MVTecTrainDataset,
    VisATestDataset,
    VisATrainDataset,
)
from models.DDPM import GaussianDiffusionModel, get_beta_schedule
from models.Recon_subnetwork import UNetModel
from models.Seg_subnetwork import SegmentationSubNetwork
from piq import psnr, ssim
from sklearn.metrics import roc_auc_score
from torch import optim
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def defaultdict_from_json(jsonDict):
    func = lambda: defaultdict(str)
    dd = func()
    dd.update(jsonDict)
    return dd


class BinaryFocalLoss(nn.Module):
    def __init__(self, alpha=0.5, gamma=4, logits=False, reduce=True):
        super(BinaryFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss



In [3]:
sub_class ="candle"
file = "args1.json"
# load the json args
with open(f"./args/{file}", "r") as f:
    args = json.load(f)
args["arg_num"] = file[4:-5]
args = defaultdict_from_json(args)
set_seed(args["seed"])
mvtec_classes = [
    "carpet",
    "grid",
    "leather",
    "tile",
    "wood",
    "bottle",
    "cable",
    "capsule",
    "hazelnut",
    "metal_nut",
    "pill",
    "screw",
    "toothbrush",
    "transistor",
    "zipper",
]

visa_classes = [
    "candle",
    "capsules",
    "cashew",
    "chewinggum",
    "fryum",
    "macaroni1",
    "macaroni2",
    "pcb1",
    "pcb2",
    "pcb3",
    "pcb4",
    "pipe_fryum",
]

mpdd_classes = ["bracket_black", "bracket_brown", "bracket_white", "connector", "metal_plate", "tubes"]
dagm_class = ["Class1", "Class2", "Class3", "Class4", "Class5", "Class6", "Class7", "Class8", "Class9", "Class10"]

current_classes = visa_classes

class_type = ""
#for sub_class in current_classes:
print("class", sub_class)
if sub_class in visa_classes:
    subclass_path = os.path.join(args["visa_root_path"], sub_class)
    print(subclass_path)
    training_dataset = VisATrainDataset(subclass_path, sub_class, img_size=args["img_size"], args=args)
    testing_dataset = VisATestDataset(
        subclass_path,
        sub_class,
        img_size=args["img_size"],
    )
    class_type = "VisA"
elif sub_class in mpdd_classes:
    subclass_path = os.path.join(args["mpdd_root_path"], sub_class)
    training_dataset = MPDDTrainDataset(subclass_path, sub_class, img_size=args["img_size"], args=args)
    testing_dataset = MPDDTestDataset(
        subclass_path,
        sub_class,
        img_size=args["img_size"],
    )
    class_type = "MPDD"
elif sub_class in mvtec_classes:
    subclass_path = os.path.join(args["mvtec_root_path"], sub_class)
    training_dataset = MVTecTrainDataset(subclass_path, sub_class, img_size=args["img_size"], args=args)
    testing_dataset = MVTecTestDataset(
        subclass_path,
        sub_class,
        img_size=args["img_size"],
    )
    class_type = "MVTec"
elif sub_class in dagm_class:
    subclass_path = os.path.join(args["dagm_root_path"], sub_class)
    training_dataset = DAGMTrainDataset(subclass_path, sub_class, img_size=args["img_size"], args=args)
    testing_dataset = DAGMTestDataset(
        subclass_path,
        sub_class,
        img_size=args["img_size"],
    )
    class_type = "DAGM"

print(file, args)

data_len = len(testing_dataset)
training_dataset_loader = DataLoader(
    training_dataset,
    batch_size=args["Batch_Size"],
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    drop_last=True,
)
test_loader = DataLoader(testing_dataset, batch_size=1, shuffle=False, num_workers=4)

# make arg specific directories
for i in [
    f'{args["output_path"]}/model/diff-params-ARGS={args["arg_num"]}/{sub_class}',
    f'{args["output_path"]}/diffusion-training-images/ARGS={args["arg_num"]}/{sub_class}',
    f'{args["output_path"]}/metrics/ARGS={args["arg_num"]}/{sub_class}',
]:
    try:
        os.makedirs(i)
    except OSError:
        pass

class candle
/workspace2/VisA_diffad/candle
args1.json defaultdict(<class 'str'>, {'seed': 42, 'img_size': [256, 256], 'Batch_Size': 4, 'EPOCHS': 3000, 'T': 1000, 'base_channels': 128, 'beta_schedule': 'linear', 'loss_type': 'l2', 'diffusion_lr': 0.0001, 'seg_lr': 1e-05, 'random_slice': True, 'weight_decay': 0.0, 'save_imgs': True, 'save_vids': False, 'dropout': 0, 'attention_resolutions': '32,16,8', 'num_heads': 2, 'num_head_channels': -1, 'noise_fn': 'gauss', 'channels': 3, 'mvtec_root_path': '/workspace2/mvtec_ad_diffad', 'visa_root_path': '/workspace2/VisA_diffad', 'dagm_root_path': '/workspace2/dagm', 'mpdd_root_path': '/workspace2/mpdd', 'anomaly_source_path': '/workspace2/dtd', 'noisier_t_range': 600, 'less_t_range': 300, 'condition_w': 1, 'eval_normal_t': 200, 'eval_noisier_t': 400, 'output_path': '/workspace3/diffusion_ad', 'arg_num': '1'})


In [4]:
unet_model = UNetModel(
        args["img_size"][0],
        args["base_channels"],
        channel_mults=args["channel_mults"],
        dropout=args["dropout"],
        n_heads=args["num_heads"],
        n_head_channels=args["num_head_channels"],
        in_channels=args["channels"],
    )
seg_model = SegmentationSubNetwork(in_channels=6, out_channels=1)

In [15]:
print(sum(p.numel() for p in unet_model.parameters() if p.requires_grad))

131654403


In [7]:
from torchsummary import summary

In [5]:
seg_model = seg_model.cuda()

In [14]:
summary(seg_model, [(6, 224, 224)], batch_size=16)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [16, 64, 224, 224]           3,520
       BatchNorm2d-2         [16, 64, 224, 224]             128
              ReLU-3         [16, 64, 224, 224]               0
            Conv2d-4         [16, 64, 224, 224]          36,928
       BatchNorm2d-5         [16, 64, 224, 224]             128
              ReLU-6         [16, 64, 224, 224]               0
         MaxPool2d-7         [16, 64, 112, 112]               0
            Conv2d-8        [16, 128, 112, 112]          73,856
       BatchNorm2d-9        [16, 128, 112, 112]             256
             ReLU-10        [16, 128, 112, 112]               0
           Conv2d-11        [16, 128, 112, 112]         147,584
      BatchNorm2d-12        [16, 128, 112, 112]             256
             ReLU-13        [16, 128, 112, 112]               0
        MaxPool2d-14          [16, 128,