In [2]:
#加载一些基础的库
import torch
import os
import numpy as np
import torchvision
from tqdm import tqdm #一个实现进度条的库
import random

In [6]:
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import cv2
import albumentations as A

totensor =transforms.Compose({
    #转化为Tensor
    transforms.ToTensor()
})

#数据增强
transform = A.Compose(
    [
        A.HorizontalFlip(p=0.5),#水平翻转
        A.VerticalFlip(p=0.5),#垂直翻转
        A.OneOf(#随机选择一个OneOf
            [
                A.RandomGamma(p=1),#随机伽马变换
                A.RandomBrightnessContrast(p=1),#随机亮度
                A.Blur(p=1),#模糊
                A.OpticalDistortion(p=1),#光学畸变
            ],
            p=0.5,
        ),
        A.OneOf(
            [
                A.ElasticTransform(p=1),#弹性变换
                A.GridDistortion(p=1),#网格失真
                A.MotionBlur(p=1),#运动模糊
                A.HueSaturationValue(p=1),#色调，饱和度值随机变化
            ],
            p=0.5,
        ),
    ]
)

#首先继承Dataset写一个对于数据进行读入和处理的方式
class MyDataset(Dataset):
    def __init__(self,path):
        self.mode=('train' if 'mask' in os.listdir(path) else 'test')#表示训练模式
        self.path=path#图片路径
        dirlist=os.listdir(path+'image/')#图片的名称
        self.name=[n for n in dirlist if n[-3:]=='png'] #只读取图片
        
    def __len__(self):
        return len(self.name)
    
    def __getitem__(self,index):#获取数据的处理方式
        name=self.name[index]
        #读取原始图片和标签
        if self.mode=='train':#训练模式
            ori_img=cv2.imread(self.path+'image/'+name)#原始图片
            lb_img=cv2.imread(self.path+'mask/'+name)#标签图片
            ori_img=cv2.cvtColor(ori_img,cv2.COLOR_BGR2RGB)#转为RGB三通道图
            lb_img=cv2.cvtColor(lb_img,cv2.COLOR_BGR2GRAY)#掩膜转为灰度图
            transformed = transform(image=ori_img,mask=lb_img)
            return totensor(transformed['image']),totensor(transformed['mask'])
        
        if self.mode=='test':#测试模式
            ori_img=cv2.imread(self.path+'image/'+name)#原始图片
            ori_img=cv2.cvtColor(ori_img,cv2.COLOR_BGR2RGB)#转为RGB三通道图
            return totensor(ori_img)

#加载数据集
train_path=r'D:\暑期考核\train/'
traindata=MyDataset(train_path)
test_path=''

In [7]:
#查看图片读取效果
import matplotlib.pyplot as plt
o_img,l_img=traindata[np.random.randint(0,2000)]
plt.subplot(1,2,1)
plt.imshow(o_img.permute(1,2, 0))
plt.subplot(1,2,2)
plt.imshow(l_img.permute(1,2, 0))
print("原始图片张量的形状:",o_img.shape)
print("标签图片张量的形状:",l_img.shape)#([1, 320, 640]) 其中 1 表示分类类别，我们为2分类任务,类别表示为01

error: OpenCV(4.10.0) D:\a\opencv-python\opencv-python\opencv\modules\imgproc\src\color.cpp:196: error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'


In [4]:
#配置模型超参数
#模型保存的路径
model_path='models/'
if not os.path.exists(model_path):
    os.makedirs(model_path)
#推荐使用gpu进行训练
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
#学习率
lr=1e-3
#学习率衰减
weight_decay=1e-4
#批大小
bs=8
#训练轮次
epochs=100




In [5]:
import torchvision
import torch.nn as nn
import segmentation_models_pytorch as smp
'''
加载经典用于医学图像分割的UNet,encoder_name为模型的backbone
encoder_weigths可选imagenet或者None代表是否加载预训练参数
in_channel为输入图像的通道数
classes为分类数目
'''
model = smp.Unet(
        encoder_name="resnet50",  
        encoder_weights='imagenet',
        in_channels=3,
        classes=1,
        #activation="sigmoid",
    )
#打印模型信息
#print(model) 




In [6]:
#训练前准备
from torch.utils.data import DataLoader
#加载模型到gpu或cpu
model.to(device)
#使用Binary CrossEntropy作为损失函数，主要处理二分类问题
# BCEloss=nn.BCELoss()
#加载优化器,使用Adam,主要是炼的快(๑ت๑)
optim=torch.optim.Adam(model.parameters(),lr=lr, weight_decay=weight_decay)
#学习率调整机制
scheduler = torch.optim.lr_scheduler.StepLR(optim,step_size=10,gamma=0.5)
#使用traindata创建dataloader对象
trainloader=DataLoader(traindata,batch_size=bs, shuffle=True, num_workers=0)
#根据赛题评测选用dice_loss，这个是开源代码
def dice_loss(logits, target):
    smooth = 1.
    prob  = torch.sigmoid(logits)
    batch = prob.size(0)
    prob   = prob.view(batch,1,-1)
    target = target.view(batch,1,-1)
    intersection = torch.sum(prob*target, dim=2)
    denominator  = torch.sum(prob, dim=2) + torch.sum(target, dim=2)
    dice = (2*intersection + smooth) / (denominator + smooth)
    dice = torch.mean(dice)
    dice_loss = 1. - dice
    return dice_loss

In [7]:
#开始炼丹 没有做验证集，各位可以以自己需要去添加
loss_last=99999
best_model_name='x'
#记录loss变化
for epoch in range(1,epochs+1):
    for step,(inputs,labels) in tqdm(enumerate(trainloader),desc=f"Epoch {epoch}/{epochs}",
                                       ascii=True, total=len(trainloader)):
        #原始图片和标签
        inputs, labels = inputs.to(device), labels.to(device)
        out = model(inputs)
        loss = dice_loss(out, labels)
        # 后向
        optim.zero_grad()
        #梯度反向传播
        loss.backward()
        optim.step()
    scheduler.step()
    #损失小于上一轮则添加
    if loss<loss_last:
        loss_last=loss
        torch.save(model.state_dict(),model_path+'model_epoch{}_loss{}.pth'.format(epoch,loss))
        best_model_name=model_path+'model_epoch{}_loss{}.pth'.format(epoch,loss)
    print(f"\nEpoch: {epoch}/{epochs},DiceLoss:{loss}")

Epoch 1/100: 100%|##########| 250/250 [01:28<00:00,  2.83it/s]



Epoch: 1/100,DiceLoss:0.10695207118988037


Epoch 2/100: 100%|##########| 250/250 [01:24<00:00,  2.95it/s]



Epoch: 2/100,DiceLoss:0.13612818717956543


Epoch 3/100: 100%|##########| 250/250 [01:25<00:00,  2.94it/s]



Epoch: 3/100,DiceLoss:0.09359920024871826


Epoch 4/100: 100%|##########| 250/250 [01:23<00:00,  2.99it/s]



Epoch: 4/100,DiceLoss:0.08488637208938599


Epoch 5/100: 100%|##########| 250/250 [01:24<00:00,  2.96it/s]



Epoch: 5/100,DiceLoss:0.07217812538146973


Epoch 6/100: 100%|##########| 250/250 [01:25<00:00,  2.93it/s]



Epoch: 6/100,DiceLoss:0.12712466716766357


Epoch 7/100: 100%|##########| 250/250 [01:24<00:00,  2.96it/s]



Epoch: 7/100,DiceLoss:0.08231562376022339


Epoch 8/100: 100%|##########| 250/250 [01:26<00:00,  2.91it/s]



Epoch: 8/100,DiceLoss:0.09928834438323975


Epoch 9/100: 100%|##########| 250/250 [01:25<00:00,  2.92it/s]



Epoch: 9/100,DiceLoss:0.08739686012268066


Epoch 10/100: 100%|##########| 250/250 [01:24<00:00,  2.95it/s]



Epoch: 10/100,DiceLoss:0.09424811601638794


Epoch 11/100: 100%|##########| 250/250 [01:21<00:00,  3.05it/s]



Epoch: 11/100,DiceLoss:0.11943531036376953


Epoch 12/100: 100%|##########| 250/250 [01:21<00:00,  3.08it/s]



Epoch: 12/100,DiceLoss:0.08157360553741455


Epoch 13/100: 100%|##########| 250/250 [01:20<00:00,  3.10it/s]



Epoch: 13/100,DiceLoss:0.07318770885467529


Epoch 14/100: 100%|##########| 250/250 [01:20<00:00,  3.09it/s]



Epoch: 14/100,DiceLoss:0.08388876914978027


Epoch 15/100: 100%|##########| 250/250 [01:20<00:00,  3.09it/s]



Epoch: 15/100,DiceLoss:0.0740659236907959


Epoch 16/100: 100%|##########| 250/250 [01:28<00:00,  2.83it/s]



Epoch: 16/100,DiceLoss:0.09631085395812988


Epoch 17/100: 100%|##########| 250/250 [01:21<00:00,  3.08it/s]



Epoch: 17/100,DiceLoss:0.07004344463348389


Epoch 18/100: 100%|##########| 250/250 [01:21<00:00,  3.05it/s]



Epoch: 18/100,DiceLoss:0.0829918384552002


Epoch 19/100: 100%|##########| 250/250 [01:21<00:00,  3.08it/s]



Epoch: 19/100,DiceLoss:0.09192550182342529


Epoch 20/100: 100%|##########| 250/250 [01:24<00:00,  2.95it/s]



Epoch: 20/100,DiceLoss:0.08303797245025635


Epoch 21/100: 100%|##########| 250/250 [01:25<00:00,  2.94it/s]



Epoch: 21/100,DiceLoss:0.09928947687149048


Epoch 22/100: 100%|##########| 250/250 [01:25<00:00,  2.94it/s]



Epoch: 22/100,DiceLoss:0.08781594038009644


Epoch 23/100: 100%|##########| 250/250 [01:25<00:00,  2.92it/s]



Epoch: 23/100,DiceLoss:0.08189541101455688


Epoch 24/100: 100%|##########| 250/250 [01:25<00:00,  2.93it/s]



Epoch: 24/100,DiceLoss:0.08340150117874146


Epoch 25/100: 100%|##########| 250/250 [01:25<00:00,  2.92it/s]



Epoch: 25/100,DiceLoss:0.0700230598449707


Epoch 26/100: 100%|##########| 250/250 [01:25<00:00,  2.91it/s]



Epoch: 26/100,DiceLoss:0.0701669454574585


Epoch 27/100: 100%|##########| 250/250 [01:26<00:00,  2.90it/s]



Epoch: 27/100,DiceLoss:0.08356636762619019


Epoch 28/100: 100%|##########| 250/250 [01:25<00:00,  2.91it/s]



Epoch: 28/100,DiceLoss:0.08394801616668701


Epoch 29/100: 100%|##########| 250/250 [01:27<00:00,  2.85it/s]



Epoch: 29/100,DiceLoss:0.09484469890594482


Epoch 30/100: 100%|##########| 250/250 [01:27<00:00,  2.87it/s]



Epoch: 30/100,DiceLoss:0.08372056484222412


Epoch 31/100: 100%|##########| 250/250 [01:25<00:00,  2.91it/s]



Epoch: 31/100,DiceLoss:0.08470964431762695


Epoch 32/100: 100%|##########| 250/250 [01:25<00:00,  2.92it/s]



Epoch: 32/100,DiceLoss:0.06603515148162842


Epoch 33/100: 100%|##########| 250/250 [01:25<00:00,  2.93it/s]



Epoch: 33/100,DiceLoss:0.08943235874176025


Epoch 34/100: 100%|##########| 250/250 [01:25<00:00,  2.92it/s]



Epoch: 34/100,DiceLoss:0.07831847667694092


Epoch 35/100:  62%|######2   | 156/250 [00:54<00:32,  2.87it/s]


KeyboardInterrupt: 

In [1]:
#加载最优模型
model.load_state_dict(torch.load('models\model_epoch32_loss0.06603515148162842.pth'))
#加载测试集
test_path='test/'
testdata=MyDataset(test_path)
#测试模型的预测效果
x=np.random.randint(0,500)
inputs=testdata[x].to(device)
with torch.no_grad():
    # 模型预测
    t = model(inputs.view(1,3,320,640))
plt.subplot(1,2,1)
plt.imshow(testdata[x].permute(1,2,0))
#对预测的图片采取一定的阈值进行分类
threshold=0.5
t= torch.where(t >=threshold, torch.tensor(255,dtype=torch.float).to(device), t)
t= torch.where(t < threshold, torch.tensor(0,dtype=torch.float).to(device), t)
t=t.cpu().view(1,320,640)
plt.subplot(1,2,2)
plt.imshow(t.permute(1,2,0))

NameError: name 'model' is not defined

In [9]:
from torchvision.utils import save_image
from PIL import Image

img_save_path='infers/'
if not os.path.exists(img_save_path):
    os.makedirs(img_save_path)
for i,inputs in tqdm(enumerate(testdata)):
    #原始图片和标签
    inputs=inputs.reshape(1,3,320,640).to(device)
    # 输出生成的图像
    out = model(inputs.view(1,3,320,640)) # 模型预测
    #TTA
    inputs1 = inputs.flip(dims=[2]).to(device)
    inputs2 = inputs.flip(dims=[3]).to(device)
    inputs3 = inputs.flip(dims=[2,3]).to(device)
    
    out1 = model(inputs1).flip(dims=[2])
    out2 = model(inputs2).flip(dims=[3])
    out3 = model(inputs3).flip(dims=[2,3])

    out = (out+out1+out2+out3) / 4

    #对输出的图像进行后处理
    threshold=0.5
    out= torch.where(out >=threshold, torch.tensor(255,dtype=torch.float).to(device),out)
    out= torch.where(out < threshold, torch.tensor(0,dtype=torch.float).to(device),out)
    #保存图像
    out= out.detach().cpu().numpy().reshape(1,320,640)
    #注意保存为1位图提交
    img = Image.fromarray(out[0].astype(np.uint8))
    img = img.convert('1')
    img.save(img_save_path + testdata.name[i])

500it [00:13, 36.58it/s]


In [None]:
#对保存的图像进行打包
import zipfile

def zip_files(file_paths, output_path):
    with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for file in file_paths:
            zipf.write(file)
            
#打包图片
file_paths = [img_save_path+i for i in os.listdir(img_save_path) if i[-3:]=='png']
output_path = 'infer.zip'
zip_files(file_paths, output_path)