In [1]:
# 训练数据
# 1.导入unet模型
# 2.自定义dice loss函数、
# 3.开始训练

In [1]:
import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
import matplotlib.pyplot as plt
import glob
# 数据增强相关包
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class SegmentDataset(Dataset):
    def __init__(self,where='train',seq=None):
        # 获取numpy文件数据
        # 图片列表
        self.img_list = glob.glob('processed/{}/*/img_*'.format(where))
        # self.mask_list = glob.glob('processed/{}/*/label_*'.format(where))
        # 数据增强的处理流程
        self.seq = seq
        
    def __len__(self):
        # 获取数据集大小
        return len(self.img_list)
    
    def __getitem__(self,idx):
        # 获取具体某一个数据
        
        # 获取图片文件名
        img_file = self.img_list[idx]
        # 获取标注文件名
        mask_file = img_file.replace('img','label')
        
        # 加载数据
        img = np.load(img_file)
        mask = np.load(mask_file)
        
        # 数据增强处理
        if self.seq:
            segmap = SegmentationMapsOnImage(mask,shape=mask.shape)
            img,mask = self.seq(image=img, segmentation_maps=segmap)
            # 获取数组内容
            mask = mask.get_arr()

            
        # 扩张维度变成张量
        return np.expand_dims(img,0),np.expand_dims(mask,0)

In [4]:
# 数据增强的处理流程
seq = iaa.Sequential([
    iaa.Affine(
        scale=(0.8,1.2), # 缩放
        rotate=(-45,45) # 旋转
    ),
    iaa.ElasticTransformation() # 弹性形变
])

In [5]:
# 使用dataloader加载数据
batch_size = 12
num_workers = 0

train_dataset = SegmentDataset('train',seq)
test_dataset = SegmentDataset('test',None)

# dataloader
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,num_workers=num_workers,shuffle=False)

In [6]:
len(train_loader)

161

In [7]:
# 导入unet模型

In [8]:
# 定义两次卷积操作
class ConvBlock(torch.nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        
        self.step = torch.nn.Sequential(
            # 第一次卷积
            torch.nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1),
            # ReLU
            torch.nn.ReLU(),
            # 第二次卷积
            torch.nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1),
            # ReLU
            torch.nn.ReLU()
        )
    
    def forward(self,x):
        
        return self.step(x)
    

In [9]:
class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 定义左侧编码器的操作
        self.layer1 = ConvBlock(1,64)
        self.layer2 = ConvBlock(64,128)
        self.layer3 = ConvBlock(128,256)
        self.layer4 = ConvBlock(256,512)
        
        # 定义右侧解码器的操作
        self.layer5 = ConvBlock(256+512,256)
        self.layer6 = ConvBlock(128+256,128)
        self.layer7 = ConvBlock(64+128,64)
        
        #最后一个卷积
        self.layer8  = torch.nn.Conv2d(in_channels=64,out_channels=1,kernel_size=1,padding=0,stride=1)
        
        # 定一些其他操作
        # 池化
        self.maxpool = torch.nn.MaxPool2d(kernel_size=2)
        #上采样
        self.upsample = torch.nn.Upsample(scale_factor=2,mode='bilinear')
        # sigmoid
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self,x):
        # 对输入数据进行处理
        
        # 定义下采样部分
        
        # input:1X256x256, output: 64x256x256
        x1 = self.layer1(x)
        # input:64x256x256, output: 64 x 128 x 128
        x1_p = self.maxpool(x1)
        
        # input:  64 x 128 x 128 , output: 128 x 128 x 128
        x2 = self.layer2(x1_p)
        # input:128 x 128 x 128 , output: 128 x 64 x 64
        x2_p = self.maxpool(x2)
        
        # input: 128 x 64 x 64, output: 256 x 64 x 64
        x3 = self.layer3(x2_p)
        #input:256 x 64 x 64, output: 256 x 32 x 32
        x3_p = self.maxpool(x3)
        
        #input: 256 x 32 x 32, output: 512 x 32 x 32
        x4 = self.layer4(x3_p)
        
        
        
        # 定义上采样
        # input: 512 x 32 x 32，output: 512 x 64 x 64
        x5 = self.upsample(x4)
        # 拼接,output: 768x 64 x 64
        x5 = torch.cat([x5,x3],dim=1)
        # input: 768x 64 x 64,output: 256 x 64 x 64
        x5 = self.layer5(x5)
        
        # input: 256 x 64 x 64,output: 256 x 128 x 128
        x6  = self.upsample(x5)
        # 拼接,output: 384 x 128 x 128
        x6 = torch.cat([x6,x2],dim=1)
        # input: 384 x 128 x 128, output: 128 x 128 x 128
        x6 = self.layer6(x6)
        
        
        # input:128 x 128 x 128, output: 128 x 256 x 256
        x7 = self.upsample(x6)
        # 拼接, output: 192 x 256 x256
        x7 = torch.cat([x7,x1],dim=1)
        # input: 192 x 256 x256, output: 64 x 256 x 256
        x7 = self.layer7(x7)
        
        # 最后一次卷积,input: 64 x 256 x 256, output: 1 x 256 x 256
        x8 = self.layer8(x7)
        
        #sigmoid
        # x9= self.sigmoid(x8)
        
        
        
        return x8
        
        
        

In [10]:
# 测试模型

In [11]:
from torchsummary import summary

In [12]:
# device 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [13]:
device

device(type='cuda', index=0)

In [14]:
# 模型实例化
model = UNet().to(device)

In [15]:
summary(model,(1,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
         ConvBlock-5         [-1, 64, 256, 256]               0
         MaxPool2d-6         [-1, 64, 128, 128]               0
            Conv2d-7        [-1, 128, 128, 128]          73,856
              ReLU-8        [-1, 128, 128, 128]               0
            Conv2d-9        [-1, 128, 128, 128]         147,584
             ReLU-10        [-1, 128, 128, 128]               0
        ConvBlock-11        [-1, 128, 128, 128]               0
        MaxPool2d-12          [-1, 128, 64, 64]               0
           Conv2d-13          [-1, 256, 64, 64]         295,168
             ReLU-14          [-1, 256,



In [16]:
# 模拟输入输出

In [17]:
random_input = torch.randn((1,1,256,256)).to(device)

In [18]:
output = model(random_input)

In [19]:
output

tensor([[[[0.4702, 0.4719, 0.4711,  ..., 0.4712, 0.4723, 0.4720],
          [0.4734, 0.4704, 0.4706,  ..., 0.4719, 0.4727, 0.4727],
          [0.4712, 0.4747, 0.4716,  ..., 0.4729, 0.4725, 0.4727],
          ...,
          [0.4717, 0.4717, 0.4710,  ..., 0.4724, 0.4726, 0.4716],
          [0.4708, 0.4731, 0.4731,  ..., 0.4715, 0.4735, 0.4722],
          [0.4720, 0.4708, 0.4728,  ..., 0.4712, 0.4742, 0.4733]]]],
       device='cuda:0', grad_fn=<SigmoidBackward>)

In [20]:
output.shape

torch.Size([1, 1, 256, 256])

In [21]:
# 准备训练

In [22]:
#定义 loss 函数

In [24]:
loss_fn = torch.nn.BCEWithLogitsLoss()

In [25]:
#定义优化器

In [26]:
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)

In [27]:
# 动态减少学习率
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [28]:
scheduler = ReduceLROnPlateau(optimizer,'min')

In [29]:
# 使用tensorboard可视化
from torch.utils.tensorboard import SummaryWriter

In [30]:
writer = SummaryWriter(log_dir='./log')

In [31]:
import time

In [32]:
len(test_loader)

29

In [33]:
# 计算测试集的loss
def check_test_loss(loader,model):
    # 记录loss
    loss = 0
    # 不记录梯度
    with torch.no_grad():
        # 遍历测试数据
        for i,(x,y) in enumerate(loader):
            # 获取图像
            x = x.to(device,dtype=torch.float32)
            # 获取标注
            y = y.to(device,dtype=torch.float32)

            # 获取预测值
            y_pred = model(x)

            # 计算loss
            loss_batch = loss_fn(y_pred,y)
            
            # 累加
            loss += loss_batch
    
    loss = loss/len(loader)
    return loss

In [34]:
# 开始训练
EPOCH_NUM = 200
# 记录最小的测试loss
best_test_loss = 100

for epoch in range(EPOCH_NUM):
    # 获取每一批次图像信息
    # 计算整批数据的loss
    loss = 0
    # 记录一个epoch运行的时间
    start_time = time.time()
    for i,(x,y) in enumerate(train_loader):
        # 每次update清空梯度
        model.zero_grad()
        # 获取图像
        x = x.to(device,dtype=torch.float32)
        # 获取标注
        y = y.to(device,dtype=torch.float32)
        
        # 获取预测值
        y_pred = model(x)
        
        # 计算loss
        loss_batch = loss_fn(y_pred,y)
        
        # 计算梯度
        loss_batch.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # 获取每个batch的训练loss
        loss_batch = loss_batch.detach().cpu()
        # print
        print(loss_batch)
        loss+=loss_batch
        
    # 计算loss
    loss = loss / len(train_loader)
    # 降低LR：如果loss连续10个epoch不再下降，则降低LR
    scheduler.step(loss)

    # 计算测试集loss
    test_loss = check_test_loss(test_loader,model)
    
    # 记录到tensorboard可视化
    writer.add_scalar('LOSS/train',loss,epoch)
    writer.add_scalar('LOSS/test',test_loss,epoch)
    
    # 保存最佳模型
    if best_test_loss > test_loss:
        # 赋值
        best_test_loss = test_loss
        # 保存模型
        torch.save(model.state_dict(),'saved_model/unet_course_best.pt')
        # 输出信息
        print('第{}个EPOCH达到最低的测试LOSS'.format(epoch))
        
    
    print('第{}个epoch执行时间{}s,train loss为{},test loss 为{}'.format(
        epoch,
        time.time() - start_time,
        loss,
        test_loss
    ))

tensor(0.9859)
tensor(0.9841)
tensor(0.9859)
tensor(0.9926)
tensor(0.9858)
tensor(0.9898)
tensor(0.9887)
tensor(0.9870)
tensor(0.9915)
tensor(0.9911)
tensor(0.9925)
tensor(0.9797)
tensor(0.9925)
tensor(0.9906)
tensor(0.9862)
tensor(0.9817)
tensor(0.9799)
tensor(0.9791)
tensor(0.9833)
tensor(0.9936)
tensor(0.9817)
tensor(0.9885)
tensor(0.9875)
tensor(0.9790)
tensor(0.9951)
tensor(0.9841)
tensor(0.9628)
tensor(0.9142)
tensor(0.7962)
tensor(0.7979)
tensor(0.8303)
tensor(0.7448)
tensor(0.7398)
tensor(0.9175)
tensor(0.7898)
tensor(0.9971)
tensor(0.6858)
tensor(0.7681)
tensor(0.7243)
tensor(0.6306)
tensor(0.6135)
tensor(0.6614)
tensor(0.7772)
tensor(0.5373)
tensor(0.6541)
tensor(0.4955)
tensor(0.6770)
tensor(0.7131)
tensor(0.6002)
tensor(0.5722)
tensor(0.5872)
tensor(0.5619)
tensor(0.5302)
tensor(0.6186)
tensor(0.5945)
tensor(0.6427)
tensor(0.5634)
tensor(0.6602)
tensor(0.5104)
tensor(0.6638)
tensor(0.6217)
tensor(0.6387)
tensor(0.4631)
tensor(0.5970)
tensor(0.6890)
tensor(0.6843)
tensor(0.5

KeyboardInterrupt: 