In [1]:
# 训练unet模型
# 1.搭建unet模型
# 2.自定义loss 函数
# 3.开始训练

In [2]:
# 仍然是加载数据

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import glob
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class SegmentDataset(Dataset):

    def __init__(self,where='train',seq=None):
        # 获取数据
        self.img_list = glob.glob('processed/{}/*/img_*'.format(where))
        self.mask_list = glob.glob('processed/{}/*/img_*')
        # 数据增强pipeline
        self.seq = seq

    def __len__(self):
        # 返回数据大小
        return len(self.img_list)
    

    def __getitem__(self, idx):
        # 获取具体每一个数据
        
        # 获取图片
        img_file = self.img_list[idx]
        mask_file = img_file.replace('img','label')
        img = np.load(img_file)
        # 获取mask
        mask = np.load(mask_file)
        
        # 如果需要数据增强
        if self.seq:
            segmap = SegmentationMapsOnImage(mask, shape=mask.shape)
            img,mask = seq(image=img, segmentation_maps=segmap)
            # 直接获取数组内容
            mask =  mask.get_arr()
        
        # 灰度图扩张维度成张量
        return np.expand_dims(img,0) , np.expand_dims(mask,0)
        

In [5]:
# 数据增强处理流程
seq = iaa.Sequential([
    iaa.Affine(scale=(0.8, 1.2), # 缩放
               rotate=(-45, 45)),  # 旋转
    iaa.ElasticTransformation()  # 变换
                ])

In [6]:
# 使用dataloader加载
batch_size = 12
num_workers = 0

train_dataset = SegmentDataset('train',seq)
test_dataset = SegmentDataset('test',None)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)


In [7]:
# 构建模型

In [8]:
# 定义两次卷积操作
class ConvBlock(torch.nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        
        self.step = torch.nn.Sequential(
            # 第一次卷积
            torch.nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1),
            # ReLU
            torch.nn.ReLU(),
            # 第二次卷积
            torch.nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1),
            # ReLU
            torch.nn.ReLU()
        )
    
    def forward(self,x):
        
        return self.step(x)
    

In [9]:
class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 定义左侧编码器的操作
        self.layer1 = ConvBlock(1,64)
        self.layer2 = ConvBlock(64,128)
        self.layer3 = ConvBlock(128,256)
        self.layer4 = ConvBlock(256,512)
        
        # 定义右侧解码器的操作
        self.layer5 = ConvBlock(256+512,256)
        self.layer6 = ConvBlock(128+256,128)
        self.layer7 = ConvBlock(64+128,64)
        
        #最后一个卷积
        self.layer8  = torch.nn.Conv2d(in_channels=64,out_channels=1,kernel_size=1,padding=0,stride=1)
        
        # 定一些其他操作
        # 池化
        self.maxpool = torch.nn.MaxPool2d(kernel_size=2)
        #上采样
        self.upsample = torch.nn.Upsample(scale_factor=2,mode='bilinear')
        # sigmoid
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self,x):
        # 对输入数据进行处理
        
        # 定义下采样部分
        
        # input:1X256x256, output: 64x256x256
        x1 = self.layer1(x)
        # input:64x256x256, output: 64 x 128 x 128
        x1_p = self.maxpool(x1)
        
        # input:  64 x 128 x 128 , output: 128 x 128 x 128
        x2 = self.layer2(x1_p)
        # input:128 x 128 x 128 , output: 128 x 64 x 64
        x2_p = self.maxpool(x2)
        
        # input: 128 x 64 x 64, output: 256 x 64 x 64
        x3 = self.layer3(x2_p)
        #input:256 x 64 x 64, output: 256 x 32 x 32
        x3_p = self.maxpool(x3)
        
        #input: 256 x 32 x 32, output: 512 x 32 x 32
        x4 = self.layer4(x3_p)
        
        
        
        # 定义上采样
        # input: 512 x 32 x 32，output: 512 x 64 x 64
        x5 = self.upsample(x4)
        # 拼接,output: 768x 64 x 64
        x5 = torch.cat([x5,x3],dim=1)
        # input: 768x 64 x 64,output: 256 x 64 x 64
        x5 = self.layer5(x5)
        
        # input: 256 x 64 x 64,output: 256 x 128 x 128
        x6  = self.upsample(x5)
        # 拼接,output: 384 x 128 x 128
        x6 = torch.cat([x6,x2],dim=1)
        # input: 384 x 128 x 128, output: 128 x 128 x 128
        x6 = self.layer6(x6)
        
        
        # input:128 x 128 x 128, output: 128 x 256 x 256
        x7 = self.upsample(x6)
        # 拼接, output: 192 x 256 x256
        x7 = torch.cat([x7,x1],dim=1)
        # input: 192 x 256 x256, output: 64 x 256 x 256
        x7 = self.layer7(x7)
        
        # 最后一次卷积,input: 64 x 256 x 256, output: 1 x 256 x 256
        x8 = self.layer8(x7)
        
        #sigmoid
        # x9= self.sigmoid(x8)
        
        
        
        return x8
        
        
        

In [10]:
# 测试模型

# 模型架构可视化
from torchsummary import summary

In [11]:
# device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [12]:
model = UNet().to(device)

In [13]:
summary(model,(1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
        DoubleConv-5         [-1, 64, 256, 256]               0
         MaxPool2d-6         [-1, 64, 128, 128]               0
            Conv2d-7        [-1, 128, 128, 128]          73,856
              ReLU-8        [-1, 128, 128, 128]               0
            Conv2d-9        [-1, 128, 128, 128]         147,584
             ReLU-10        [-1, 128, 128, 128]               0
       DoubleConv-11        [-1, 128, 128, 128]               0
        MaxPool2d-12          [-1, 128, 64, 64]               0
           Conv2d-13          [-1, 256, 64, 64]         295,168
             ReLU-14          [-1, 256,



In [14]:
# 模拟输入

In [15]:
random_input = torch.randn(1, 1, 256, 256).to(device)
output = model(random_input)

In [16]:
output.shape

torch.Size([1, 1, 256, 256])

In [17]:
# 准备训练

In [18]:
# 定义损失

In [19]:
loss_fn = torch.nn.BCEWithLogitsLoss()

In [20]:
# 定义优化器

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [22]:
# 动态减少LR
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, 'min')

In [23]:
import time

In [24]:
# 计算测试集的loss
def check_test_loss(loader,model):
    loss = 0
    # 不记录梯度
    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            # 图片
            x = x.to(device,dtype=torch.float32)
            # 标签
            y = y.to(device,dtype=torch.float32)
            # 预测值
            y_pred = model(x)
            #计算损失
            loss_batch = loss_fn(y_pred, y)
            
            loss += loss_batch
    return loss / len(loader)

In [25]:
# 使用tensorboard记录参数
from torch.utils.tensorboard import SummaryWriter

In [26]:
# 记录变量
writer = SummaryWriter(log_dir='./log')

In [27]:
# 训练

In [28]:
# 训练100个epoch
EPOCH_NUM = 200
# 记录最好的测试acc
best_test_loss = 100

for epoch in range(EPOCH_NUM):
    # 获取批次图像
    start_time = time.time()
    loss = 0
    for i, (x, y) in enumerate(train_loader):
        # ！！！每次update前清空梯度
        model.zero_grad()
        # 获取数据
        # 图片
        x = x.to(device,dtype=torch.float32)
        # 标签
        y = y.to(device,dtype=torch.float32)
        # 预测值
        y_pred = model(x)
        #计算损失
        loss_batch = loss_fn(y_pred, y)
        
        # 计算梯度
        loss_batch.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # 记录每个batch的train loss
        loss_batch = loss_batch.detach().cpu()
        # 打印
        print(loss_batch.item())
        loss += loss_batch
        
    # 每个epoch的loss
    loss = loss / len(train_loader)
    # 如果降低LR：如果loss连续10个epoch不再下降，就减少LR
    scheduler.step(loss)
    
    # 计算测试集的loss
    test_loss = check_test_loss(test_loader,model)
    
    # tensorboard 记录 Loss/train
    writer.add_scalar('Loss/train', loss, epoch)
    # tensorboard 记录 Loss/test
    writer.add_scalar('Loss/test', test_loss, epoch)
    
     # 记录最好的测试loss，并保存模型
    if best_test_loss > test_loss:
        best_test_loss = test_loss
        # 保存模型
        torch.save(model.state_dict(), './save_model/unet_best.pt')
        print('第{}个EPOCH达到最低的测试loss:{}'.format(epoch,best_test_loss))
    
    # 打印信息
    print('第{}个epoch执行时间：{}s，train loss为：{}，test loss为：{}'.format(
        epoch,
        time.time()-start_time,
        loss,
        test_loss
    ) )
    # 保存最新模型
    torch.save(model.state_dict(), './save_model/unet_latest.pt')

0.6476278305053711
0.6448570489883423
0.6427306532859802
0.6413723826408386
0.6385133266448975
0.6362111568450928
0.6347560882568359
0.6316084861755371
0.6278462409973145


KeyboardInterrupt: 