In [None]:
import os
import cv2
import albumentations as A
import torchvision.transforms as T
from torch.utils.data import Dataset


totensor = T.Compose(
    [
        T.ToTensor(),
    ]
)
transform = A.Compose(
    [
        A.HorizontalFlip(p=0.5),#垂直翻转
        A.VerticalFlip(p=0.5),#水平翻转
        A.OneOf(
            [
                A.RandomGamma(p=1),#随机伽马变换
                A.RandomBrightnessContrast(p=1),#随机亮度
                A.Blur(p=1),#模糊
                A.OpticalDistortion(p=1),#光学畸变
            ],
            p=0.5,
        ),
        A.OneOf(
            [
                A.ElasticTransform(p=1),#弹性变换
                A.GridDistortion(p=1),#网格失真
                A.MotionBlur(p=1),#运动模糊
                A.HueSaturationValue(p=1),#色调，饱和度值随机变化
            ],
            p=0.5,
        ),
    ]
)


class MyDataset(Dataset):
    def __init__(self, path):
        self.mode = "train" if "mask" in os.listdir(path) else "test"  # 表示训练模式
        self.path = path  # 图片路径
        dirlist = os.listdir(path + "image/")  # 图片的名称
        self.name = [n for n in dirlist if n[-3:] == "png"]  # 只读取图片

    def __len__(self):
        return len(self.name)

    def __getitem__(self, index):  # 获取数据的处理方式
        name = self.name[index]
        # 读取原始图片和标签
        if self.mode == "train":  # 训练模式
            ori_img = cv2.imread(self.path + "image/" + name)  # 原始图片
            lb_img = cv2.imread(self.path + "mask/" + name)  # 标签图片
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)  # 转为RGB三通道图
            lb_img = cv2.cvtColor(lb_img, cv2.COLOR_BGR2GRAY)  # 掩膜转为灰度图
            transformed = transform(image=ori_img, mask=lb_img)
            return totensor(transformed["image"]), totensor(transformed["mask"])

        if self.mode == "test":  # 测试模式
            ori_img = cv2.imread(self.path + "image/" + name)  # 原始图片
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)  # 转为RGB三通道图
            return totensor(ori_img)


def get_data(
    mode,
    train_path="train/",
    test_path="test/",
):
    if mode == "train":
        return MyDataset(train_path)
    if mode == "test":
        return MyDataset(test_path)

In [None]:
import torch


# dice_loss
def dice_loss(prob, target):
    smooth = 1.0
    # prob = torch.sigmoid(logits)
    batch = prob.size(0)
    prob = prob.view(batch, 1, -1)
    target = target.view(batch, 1, -1)
    intersection = torch.sum(prob * target, dim=2)
    denominator = torch.sum(prob, dim=2) + torch.sum(target, dim=2)
    dice = (2 * intersection + smooth) / (denominator + smooth)
    dice = torch.mean(dice)
    dice_loss = 1.0 - dice
    return dice_loss


# bce_loss
def bce_loss():
    return torch.nn.BCELoss()


# bce_dice_loss
def bce_dice_loss(prob, target):
    bce = torch.nn.BCELoss()
    dice = dice_loss
    alpha = 0.2
    return alpha * bce(prob, target) + (1 - alpha) * dice(prob, target)


def get_loss(type):
    if type == "dice":
        return dice_loss
    elif type == "bce":
        return torch.nn.BCELoss()
    elif type == "bce_dice":
        return bce_dice_loss

In [None]:
import segmentation_models_pytorch as smp


def get_model(
    name,
    encoder_name="resnet50",  # efficientnet-b5,se_resnext50_32x4d
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation="sigmoid",
):
    if name == "deeplabv3p":
        return smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation,
        )
    elif name == "unetpp":
        return smp.UnetPlusPlus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation,
        )
    elif name == "unet":
        return smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation,
        )


In [None]:
import os
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import KFold

"""
from model import get_model
from data import get_data
from loss import get_loss
"""



@torch.no_grad()
def val(val_loader, model, device, loss_fun):
    model.eval()
    val_loss_total = 0
    for step, (inputs, labels) in enumerate(val_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        out = model(inputs)
        loss = loss_fun(out, labels)
        val_loss_total += loss.item()
    loss_val = val_loss_total / len(val_loader)
    return loss_val


def train(
    model_name,
    traindataset,
    valdataset,
    checkpoint_path,
    model_save_path,
    loss_name,
    epochs,
    lr=4e-3,
    weight_decay=0,
    step_size=20,
    gamma=0.5,
    batch_size=16,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(model_name).to(device)
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint)
    model.train()
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=step_size, gamma=gamma)
    trainloader = DataLoader(traindataset, batch_size=batch_size, shuffle=True)
    valdataset = DataLoader(valdataset, batch_size=batch_size, shuffle=True)
    loss_f = get_loss(loss_name)
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
    loss_last = [99999, 99999]
    best_model_name = ""
    for epoch in range(1, epochs + 1):
        train_loss_total = 0
        for step, (inputs, labels) in tqdm(enumerate(trainloader),desc=f"Epoch {epoch}/{epochs}",
                                       ascii=True, total=len(trainloader)):
            # 原始图片和标签
            inputs, labels = inputs.to(device), labels.to(device)
            out = model(inputs)
            loss = loss_f(out, labels)
            train_loss_total += loss.item()
            # 梯度清零
            optim.zero_grad()
            # 梯度反向传播
            loss.backward()
            optim.step()
        scheduler.step()
        loss_train = train_loss_total / len(trainloader)
        loss_val = val(valdataset, model, device, loss_f)
        # 损失小于上一轮则添加
        if loss_val < loss_last[0]:
            loss_last[0], loss_last[1] = loss_val, loss_train
            torch.save(
                model.state_dict(),
                model_save_path
                + "epoch{}_valloss{:.5f}_trainloss{:.5f}.pth".format(
                    epoch, loss_val, loss_train
                ),
            )
            best_model_name = (
                model_save_path
                + "epoch{}_valloss{:.5f}_trainloss{:.5f}.pth".format(
                    epoch, loss_val, loss_train
                )
            )
        print(
            f"Epoch: {epoch}/{epochs},train_Loss:{loss_train:.5f},val_loss:{loss_val:.5f},dice_loss:{loss}"
        )
    print(f"best model is:{best_model_name}")


def k_fold_train(
    fold_num,
    model_name,
    checkpoint_path,
    model_save_path,
    loss_name,
    epochs=100,
    lr=1e-3,
    weight_decay=0,
    step_size=20,
    gamma=0.5,
    batch_size=16,
):
    skf = KFold(n_splits=fold_num, shuffle=True)
    dataset = get_data("train")
    for fold_idx, (train_idx, valid_idx) in enumerate(skf.split(dataset)):
        train_dataset = Subset(dataset, train_idx)
        valid_dataset = Subset(dataset, valid_idx)
        if not os.path.exists(model_save_path):
            os.makedirs(model_save_path)
        temp_save_path = model_save_path + f"fold{fold_idx}/"
        if not os.path.exists(temp_save_path):
            os.makedirs(temp_save_path)
        print(f"training fold {fold_idx}......")
        print(f"checkpoint is saving to {temp_save_path}")
        train(
            model_name=model_name,
            traindataset=train_dataset,
            valdataset=valid_dataset,
            checkpoint_path=checkpoint_path,
            model_save_path=temp_save_path,
            loss_name=loss_name,
            epochs=epochs,
            lr=lr,
            weight_decay=weight_decay,
            step_size=step_size,
            gamma=gamma,
            batch_size=batch_size,
        )


k_fold_train(
    fold_num=5,
    checkpoint_path=None,
    model_name="deeplabv3p",
    model_save_path="5_fold_deeplabv3p_with_se_resnext_bcedice/",
    loss_name="bce_dice",
    epochs=100,
    lr=1e-4,
    weight_decay=1e-4,
    step_size=20,
    gamma=0.5,
    batch_size=8,
)#optim为AdamW

In [None]:
import os
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
import zipfile
"""
from model import get_model
from data import get_data
"""




img_save_path="infers/"

if not os.path.exists(img_save_path):
        os.makedirs(img_save_path)

def zip_files(file_paths, output_path):
    with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
        for file in file_paths:
            zipf.write(file)


@torch.no_grad()
def infer(cp_path, model_name,  threshold):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    folds = os.listdir(cp_path)
    models = []
    
    for fold in folds:
        model_checkpoints = os.listdir(os.path.join(cp_path, fold))
        for model_checkpoint in model_checkpoints:
            model = get_model(model_name).to(device)
            # print(os.path.join(path, fold, model_checkpoint))
            weight = torch.load(
                os.path.join(cp_path, fold, model_checkpoint), map_location=device
            )
            model.load_state_dict(weight)
            model.eval()
            models.append(model)
    testdata = get_data("test")
    for i, inputs in tqdm(enumerate(testdata)):
        inputs0 = inputs.reshape(1, 3, 320, 640).to(device)
        inputs1 = inputs0.flip(dims=[2]).to(device)
        inputs2 = inputs0.flip(dims=[3]).to(device)
        inputs3 = inputs0.flip(dims=[2, 3]).to(device)
        out = 0
        for model in models:
            out0 = model(inputs0)
            out1 = model(inputs1).flip(dims=[2])
            out2 = model(inputs2).flip(dims=[3])
            out3 = model(inputs3).flip(dims=[2, 3])
            out = out + out0 + out1 + out2 + out3
        out = out / len(models)
        threshold = threshold
        out = torch.where(
            out >= threshold, torch.tensor(255, dtype=torch.float).to(device), out
        )
        out = torch.where(
            out < threshold, torch.tensor(0, dtype=torch.float).to(device), out
        )
        out = out.detach().cpu().numpy().reshape(1, 320, 640)
        img = Image.fromarray(out[0].astype(np.uint8))
        img = img.convert("1")
        img.save(img_save_path + testdata.name[i])

        

infer(
    model_name="deeplabv3p",
    cp_path="cp_v2",
    threshold=0.5,
)

file_paths = [img_save_path + i for i in os.listdir(img_save_path) if i[-3:] == "png"]
zip_out_path = "infer.zip"
zip_files(file_paths, zip_out_path)

