In [None]:
import cv2
import os 
import torch
import json
import copy
import random
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
import torch.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split, Subset, ConcatDataset
import matplotlib.pyplot as plt
from collections import OrderedDict

torch.manual_seed(7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class dataset_defectDetection(Dataset):
    def __init__(self, path, transform=None):
        super(dataset_defectDetection, self).__init__()
        self.path = path
        self.component = ["data", "label"]
        self.transform = transform
        self.path_data = os.path.join(self.path, self.component[0])
        self.path_label = os.path.join(self.path, self.component[1])
        dataList, _, _ = os.walk(self.path_data)
        self.data = [os.path.join(self.path_data, data_name) for data_name in dataList[-1]]
        self.label = [os.path.join(self.path_label, label_name) for label_name in os.listdir(self.path_label)]
        
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, item):
        # 根据索引生成标签
        path_label = self.label[item]
        with open(path_label, "r", encoding="utf8") as fp:
            label = json.load(fp)
            label_mask = np.zeros((label["imageHeight"], label["imageWidth"]), np.uint8)
            for i in range(len(label["shapes"])):
                label_points = np.array(label["shapes"][i]["points"], np.int32)
                label_mask = cv2.drawContours(label_mask, [label_points], -1, 255, -1)
        
        # 根据标签查找图片(该数据集图片总共1104，带有标签的图片只有394)
        path_image = os.path.join(self.path_data, os.path.splitext(os.path.split(path_label)[-1])[0] + ".jpg")
        image = Image.open(path_image)
        
        # 数据处理
        if self.transform is not None:
            random.seed(7)
            image = self.transform(image) 
            random.seed(7)
            label = self.transform(Image.fromarray(np.uint8(label_mask)))
            label[label >= 0.5] = 1.
            label[label < 0.5] = 0.
        label = torch.cat([label, (1-label)], dim=0)
        
        # 返回图像和标签
        return image, label
    
'''
if __name__ == "__main__":

    path = r"C:\Users\风\Desktop\表面缺陷检测\BSData-main\BSData-main"
    input_size = (256, 512)
    batch_size = 2
    shuffle = True
    num_workers = 0
    pin_memory = True
    drop_last = True

    transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor()
    ])

    image_datasets = dataset_defectDetection(path, transform=transform)
    torch.manual_seed(7) # 设置随机种子以便dataset分割的结果可重复
    image_datasets_split = random_split(image_datasets, [len(image_datasets)//5, len(image_datasets)-len(image_datasets)//5])

    loader = {
        "train": DataLoader(image_datasets_split[1], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last),
        "eval": DataLoader(image_datasets_split[0], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last), 
    }
'''

In [None]:
class LBG_layer(nn.Module):
    def __init__(self, in_feature, out_feature):
        super(LBG_layer, self).__init__()
        
        self.in_feature = in_feature
        self.out_feature = out_feature
        
        self.lbg_conv = nn.Sequential(
            nn.Linear(self.in_feature, self.out_feature),
            nn.BatchNorm1d(self.out_feature),
            nn.ReLU()
#             nn.GELU()
        )
        
    def forward(self, x):
        x = self.lbg_conv(x)
        return x

In [None]:
class LBS_layer(nn.Module):
    def __init__(self, in_feature, out_feature):
        super(LBS_layer, self).__init__()
        
        self.in_feature = in_feature
        self.out_feature = out_feature
        
        self.lbs_conv = nn.Sequential(
            nn.Linear(self.in_feature, self.out_feature),
            nn.BatchNorm1d(self.out_feature),
            nn.Sigmoid()
#             nn.GELU()
        )
        
    def forward(self, x):
        x = self.lbs_conv(x)
        return x

In [None]:
class CBG_layer(nn.Module):
    def __init__(self, in_feature, out_feature, kernel_size, stride, padding=0):
        super(CBG_layer, self).__init__()
        
        self.in_feature = in_feature
        self.out_feature = out_feature
        self.kernel_size  = kernel_size
        self.stride = stride
        self.padding = padding
        
        self.cbg_conv = nn.Sequential(
            nn.Conv2d(self.in_feature, self.out_feature, self.kernel_size, self.stride, self.padding),
            nn.BatchNorm2d(self.out_feature),
            nn.ReLU()
#             nn.GELU()
        )
        
    def forward(self, x):
        x = self.cbg_conv(x)
        return x

In [None]:
class CBS_layer(nn.Module):
    def __init__(self, in_feature, out_feature, kernel_size, stride, padding=0):
        super(CBS_layer, self).__init__()
        
        self.in_feature = in_feature
        self.out_feature = out_feature
        self.kernel_size  = kernel_size
        self.stride = stride
        self.padding = padding
        
        self.cbs_conv = nn.Sequential(
            nn.Conv2d(self.in_feature, self.out_feature, self.kernel_size, self.stride, self.padding),
            nn.BatchNorm2d(self.out_feature),
            nn.Sigmoid()
#             nn.GELU()
        )
        
    def forward(self, x):
        x = self.cbs_conv(x)
        return x

In [None]:
class DBG_layer(nn.Module):
    def __init__(self, in_feature, out_feature, kernel_size, stride, padding=0, output_padding=0):
        super(DBG_layer, self).__init__()
        
        self.in_feature = in_feature
        self.out_feature = out_feature
        self.kernel_size = kernel_size 
        self.stride = stride
        self.padding = padding
        self.output_padding = output_padding
        
        self.dbg_conv = nn.Sequential(
            nn.ConvTranspose2d(self.in_feature, self.out_feature, self.kernel_size, self.stride, self.padding, self.output_padding),
            nn.BatchNorm2d(self.out_feature), 
            nn.ReLU()
#             nn.GELU()
        )
        
    def forward(self, x):
        x = self.dbg_conv(x)
        return x

In [None]:
class Scope_layer(nn.Module):
    def __init__(self, n, in_feature, out_feature, kernel_size, stride, padding):
        super(Scope_layer, self).__init__()
        
        self.n = n
        self.in_feature = in_feature
        self.out_feature = out_feature
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.cbg_conv = CBG_layer(self.in_feature, self.out_feature, self.kernel_size, self.stride, self.padding)
        
        self.scope_conv = nn.Sequential(*[self.cbg_conv for i in range(self.n)])
        
    def forward(self, x):
        for i, layer in enumerate(self.scope_conv):
            if i == 0:
                x = layer(x)
                x1 = x
            else:
                x1 = layer(x1)
                x = torch.cat([x, x1], dim=1)
        return x

In [None]:
class Res_layer(nn.Module):
    def __init__(self, in_feature, out_feature, kernel_size, stride, padding):
        super(Res_layer, self).__init__()
        
        self.in_feature = in_feature
        self.out_feature = out_feature
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        self.cbg_conv1 = CBG_layer(self.in_feature, int(self.in_feature/2), 1, 1)
        self.cbg_conv2 = CBG_layer(int(self.in_feature/2), int(self.in_feature/2), self.kernel_size,  1, self.padding)
        self.cbg_conv3 = CBG_layer(int(self.in_feature/2), self.out_feature-self.in_feature, 1, 1)
        
    def forward(self, x):
        x1 = self.cbg_conv1(x)
        x1 = self.cbg_conv2(x1)
        x1 = self.cbg_conv3(x1)
        x = torch.cat([x, x1], dim=1)
        return x

In [None]:
class SegNet(nn.Module):
    def __init__(self, input_size, device):
        super(SegNet, self).__init__()
        self.input_size = input_size
        self.device = device
        
        self.cbg_conv1 = CBG_layer(3, 6, 3, 1, 1)
        self.scope_conv1 = Scope_layer(4, 6, 6, 3, 1, 1)
        
        self.cbg_conv2 = CBG_layer(3, 30, 1, 1)
        self.pool1 = nn.AvgPool2d((input_size[0], input_size[1]), (1, 1), count_include_pad=False)
        self.lbg_conv1 = LBG_layer(30, 5)
        self.lbg_conv2 = LBG_layer(5, 30)
        
        self.cbg_conv3 = CBG_layer(30, 32, 5, 1, 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.res_conv1 = Res_layer(32, 64, 3, 1, 1)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.res_conv2 = Res_layer(64, 128, 3, 1, 1)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.res_conv3 = Res_layer(128, 256, 3, 1, 1)
        self.pool5 = nn.MaxPool2d(2, 2)
        
        self.cbg_conv4 = CBG_layer(256, 256, 5, 1, 2)
        
        self.d_conv1 = DBG_layer(256, 128, 3, 2, 1, 1)
        self.cbg_conv5 = CBG_layer(256, 256, 5, 1, 2)
        self.d_conv2 = DBG_layer(256, 64, 3, 2, 1, 1)
        self.cbg_conv6 = CBG_layer(128, 128, 5, 1, 2)
        self.d_conv3 = DBG_layer(128, 32, 3, 2, 1, 1)
        self.cbg_conv7 = CBG_layer(64, 64, 5, 1, 2)
        self.d_conv4 = DBG_layer(64, 16, 3, 2, 1, 1)
        
        self.cbg_conv8 = CBG_layer(46, 32, 3, 1, 1)
        self.cbg_conv9 = CBG_layer(32, 8, 3, 1, 1)

        
        self.pool6 = nn.AvgPool2d((int(input_size[0]/16), int(input_size[1]/16)), (1, 1), count_include_pad=False)
        self.lbg_conv3 = LBG_layer(256, 8)

        
        self.cbg_conv10 = CBG_layer(8, 8, 3, 1, 1)
        self.cbg_conv11 = CBG_layer(8, 2, 3, 1, 1)
 
    def forward(self, x):
        
        x1 = self.cbg_conv1(x)
        x2 = self.scope_conv1(x1)
        x1 = torch.cat([x1, x2], dim=1)
        
        x = self.cbg_conv2(x)
        x = self.pool1(x).view(x.shape[0], x.shape[1])
        x = self.lbg_conv1(x)
        x = self.lbg_conv2(x)
        x = x1 * x.view(x.shape[0], x.shape[1], 1, 1)
    
        # encode part
        x0 = x
        x1 = self.cbg_conv3(x0)
        x1 = self.pool2(x1)
        x2 = self.res_conv1(x1)
        x2 = self.pool3(x2)
        x3 = self.res_conv2(x2)
        x3 = self.pool4(x3)
        x4 = self.res_conv3(x3)
        x4 = self.pool5(x4)
        x4 = self.cbg_conv4(x4)

        
        # decode part
        x = self.d_conv1(x4)
        x = torch.cat([x, x3], dim=1)
        x = self.cbg_conv5(x)
        x = self.d_conv2(x)
        x = torch.cat([x, x2], dim=1)
        x = self.cbg_conv6(x)
        x = self.d_conv3(x)
        x = torch.cat([x, x1], dim=1)
        x = self.cbg_conv7(x)
        x = self.d_conv4(x)
        x = torch.cat([x, x0], dim=1)
        x = self.cbg_conv8(x)
        x = self.cbg_conv9(x)
        
        
        x4 = self.pool6(x4).view(x4.shape[0], x4.shape[1])
        x4 = self.lbg_conv3(x4)
        x = x * x4.view(x4.shape[0], x4.shape[1], 1, 1)
        x = self.cbg_conv10(x)
        x = self.cbg_conv11(x)
        x = nn.Softmax(dim=1)(x)

        return x

In [None]:
class evaluator():
    def __init__(self, outputs, labels):
        self.outputs = outputs
        self.labels = labels
        self.shape = self.outputs.shape
    
    def loss_fn(self):
        loss_value = nn.BCELoss()(self.outputs, self.labels)
        return loss_value
    
    def acc_fn(self):
        self.outputs = self.outputs.cpu()
        self.labels = self.labels.cpu()
        self.outputs[self.outputs >= 0.5] = 1.
        self.outputs[self.outputs < 0.5] = 0. 
        acc = (torch.sum(self.outputs == self.labels).item() / (self.shape[0]*self.shape[1]*self.shape[2]*self.shape[3]))
        acc_mIou = 0.
        acc_miou_p = 0.
        acc_miou_n = 0.

        for i in np.unique(self.labels.detach().numpy()):
            for j in range(self.shape[0]):
                acc_mIou += ((torch.sum((self.outputs[j, 0, :, :] == self.labels[j, 0, :, :]) & (self.labels[j, 0, :, :] ==i)).item() + 1) / (1 + torch.sum(self.outputs[j, 0, :, :] == i) + torch.sum(self.labels[j, 0, :, :] == i) - torch.sum((self.outputs[j, 0, :, :] == self.labels[j, 0, :, :]) & (self.labels[j, 0, :, :] == i))).item())
                acc_miou_n += ((torch.sum((self.outputs[j, 0, :, :] == self.labels[j, 0, :, :]) & (self.labels[j, 0, :, :] ==0)).item() + 1) / (1 + torch.sum(self.outputs[j, 0, :, :] == 0) + torch.sum(self.labels[j, 0, :, :] == 0) - torch.sum((self.outputs[j, 0, :, :] == self.labels[j, 0, :, :]) & (self.labels[j, 0, :, :] == 0))).item())
                acc_miou_p += ((torch.sum((self.outputs[j, 0, :, :] == self.labels[j, 0, :, :]) & (self.labels[j, 0, :, :] ==1)).item() + 1) / (1 + torch.sum(self.outputs[j, 0, :, :] == 1) + torch.sum(self.labels[j, 0, :, :] == 1) - torch.sum((self.outputs[j, 0, :, :] == self.labels[j, 0, :, :]) & (self.labels[j, 0, :, :] == 1))).item())
        acc_mIou /= (np.unique(self.labels.detach().numpy()).shape[0] * self.shape[0])
        acc_miou_p  /= (np.unique(self.labels.detach().numpy()).shape[0] * self.shape[0])
        acc_miou_n  /= (np.unique(self.labels.detach().numpy()).shape[0] * self.shape[0])
        return acc, acc_mIou, acc_miou_p, acc_miou_n

In [None]:
class model_train():
    def __init__(self, model, dataloader, evaluator, optimizer, num_epoch, state, device):
        self.model = model 
        self.dataloader = dataloader
        self.evaluator = evaluator
        self.optimizer = optimizer
        self.num_epoch = num_epoch
        self.state = state
        self.device = device
        self.running_loss = 0.
        self.running_acc = 0.
        self.running_acc_mIou = 0.
        self.running_acc_miou_p = 0.
        self.running_acc_miou_n = 0.
        self.best_acc = 0.
        self.acc_history_s = []
        self.acc_history = []
        self.acc_mIou_history_s = []
        self.acc_mIou_history = []
        
    def train(self):
        save_num = 1
        for epoch in range(num_epoch):
            for state in self.state:
                self.running_loss = 0.
                self.running_acc = 0.
                self.running_acc_mIou = 0.
                self.running_acc_miou_p = 0.
                self.running_acc_miou_n = 0.
                if state == "train":
                    self.model.train()
                else:
                    self.model.eval()
                for inputs, labels in self.dataloader[state]:
                    inputs, labels = inputs.to(device), labels.to(device)
                    with torch.autograd.set_grad_enabled(state=="train"):
                        outputs = self.model(inputs)
                        loss = self.evaluator(outputs, labels).loss_fn()
                        acc, acc_mIou, acc_miou_p, acc_miou_n = self.evaluator(outputs, labels).acc_fn()
                    if state == "train":
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()
                    self.running_loss += loss.item() * inputs.size(0)
                    self.running_acc += acc * inputs.size(0)
                    self.running_acc_mIou += acc_mIou * inputs.size(0)
                    self.running_acc_miou_p += acc_miou_p * inputs.size(0)
                    self.running_acc_miou_n += acc_miou_n * inputs.size(0)
                    self.acc_history_s.append(self.running_acc)
                    self.acc_mIou_history_s.append(self.running_acc_mIou)
                epoch_loss = self.running_loss / len(self.dataloader[state].dataset)
                epoch_acc = self.running_acc / len(self.dataloader[state].dataset)
                epoch_acc_mIou = self.running_acc_mIou / len(self.dataloader[state].dataset)
                epoch_acc_miou_p = self.running_acc_miou_p / len(self.dataloader[state].dataset)
                epoch_acc_miou_n = self.running_acc_miou_n / len(self.dataloader[state].dataset) 
                print("Epoch {}/{} __ Phase {} loss: {:.4f}, acc: {:.4f}, acc_mIou: {:.4f}, acc_miou_p: {:.4f}, acc_miou_n: {:.4f}". format(epoch+1, num_epoch, state, epoch_loss, epoch_acc, epoch_acc_mIou, epoch_acc_miou_p, epoch_acc_miou_n))
                
                if state == "eval":
                    print("{}".format("_"*45))
                    self.acc_history.append(epoch_acc)
                    self.acc_mIou_history.append(epoch_acc_mIou)
                    if epoch_acc_miou_p > self.best_acc:
                        self.best_acc = epoch_acc_miou_p
                        best_model_wts = copy.deepcopy(self.model.state_dict())
                    save_num += 1
        self.model.load_state_dict(best_model_wts)
        return self.model, (self.acc_history, self.acc_mIou_history, self.acc_history_s, self.acc_mIou_history_s)

In [None]:
path = r"C:\Users\风\Desktop\表面缺陷检测\BSData-main\BSData-main" # 文件位置
state = ["train", "eval"]
input_size = (256, 512) # 用于训练的文件大小
batch_size = 2 
shuffle = True
num_workers = 0
pin_memory = True
drop_last = True

lr = 1e-3
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
weight_decay = 0.
amsgrad = False
num_epoch = 50 # epoch数

transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor()
])

image_datasets = dataset_defectDetection(path, transform=transform)
torch.manual_seed(7)
image_datasets_split = random_split(image_datasets, [len(image_datasets)//5, len(image_datasets)-len(image_datasets)//5])

loader = {
    "train": DataLoader(image_datasets_split[1], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last),
    "eval": DataLoader(image_datasets_split[0], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last), 
}

model = SegNet(input_size, device).to(device)
optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=lr, betas=[beta1, beta2], eps=epsilon, weight_decay=weight_decay, amsgrad=amsgrad)
model = model_train(model, loader, evaluator, optimizer, num_epoch, state, device)
model, Acc = model.train()

plt.title("Validation Accuracy vs Number of Training Epochs")
plt.xlabel("Training Epochs")
plt.ylabel("Validation Accuracy")
plt.plot(range(1, len(Acc[0])+1), Acc[0])
plt.plot(range(1, len(Acc[1])+1), Acc[1])
plt.ylim((0, 1))
plt.xticks(np.arange(1, len(Acc[0])+1, 5))
plt.legend()
plt.show()