In [17]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/Colab\ Notebooks/

/content/drive/My Drive/Colab Notebooks
/content/drive/MyDrive/Colab Notebooks
Cloning into 'COPM-Project'...
remote: Enumerating objects: 58, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 58 (delta 17), reused 53 (delta 15), pack-reused 0[K
Unpacking objects: 100% (58/58), done.


In [None]:
!git clone https://github.com/egilltor17/COPM-Project.git

Cloning into 'COPM-Project'...
fatal: could not read Username for 'https://github.com': No such device or address


In [1]:
import os
from time import time

import numpy as np

import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

!pip install Visdom
from visdom import Visdom

from dataset.dataset import Dataset

from loss.Dice import DiceLoss
from loss.ELDice import ELDiceLoss
from loss.WBCE import WCELoss
from loss.Jaccard import JaccardLoss
from loss.SS import SSLoss
from loss.Tversky import TverskyLoss
from loss.Hybrid import HybridLoss
from loss.BCE import BCELoss

from net.ResUNet import net

import parameter as para

# 设置visdom
viz = Visdom(port=666)
step_list = [0]
win = viz.line(X=np.array([0]), Y=np.array([1.0]), opts=dict(title='loss'))

# 设置显卡相关
os.environ['CUDA_VISIBLE_DEVICES'] = para.gpu
cudnn.benchmark = para.cudnn_benchmark

# 定义网络
net = torch.nn.DataParallel(net).cuda()
net.train()

# 定义Dateset
train_ds = Dataset(os.path.join(para.training_set_path, 'ct'), os.path.join(para.training_set_path, 'seg'))

# 定义数据加载
train_dl = DataLoader(train_ds, para.batch_size, True, num_workers=para.num_workers, pin_memory=para.pin_memory)

# 挑选损失函数
loss_func_list = [DiceLoss(), ELDiceLoss(), WCELoss(), JaccardLoss(), SSLoss(), TverskyLoss(), HybridLoss(), BCELoss()]
loss_func = loss_func_list[5]

# 定义优化器
opt = torch.optim.Adam(net.parameters(), lr=para.learning_rate)

# 学习率衰减
lr_decay = torch.optim.lr_scheduler.MultiStepLR(opt, para.learning_rate_decay)

# 深度监督衰减系数
alpha = para.alpha

# 训练网络
start = time()
for epoch in range(para.Epoch):

    lr_decay.step()

    mean_loss = []

    for step, (ct, seg) in enumerate(train_dl):

        ct = ct.cuda()
        seg = seg.cuda()

        outputs = net(ct)

        loss1 = loss_func(outputs[0], seg)
        loss2 = loss_func(outputs[1], seg)
        loss3 = loss_func(outputs[2], seg)
        loss4 = loss_func(outputs[3], seg)

        loss = (loss1 + loss2 + loss3) * alpha + loss4

        mean_loss.append(loss4.item())

        opt.zero_grad()
        loss.backward()
        opt.step()

        if step % 5 is 0:
            
            step_list.append(step_list[-1] + 1)
            viz.line(X=np.array([step_list[-1]]), Y=np.array([loss4.item()]), win=win, update='append')
            
            print('epoch:{}, step:{}, loss1:{:.3f}, loss2:{:.3f}, loss3:{:.3f}, loss4:{:.3f}, time:{:.3f} min'
                  .format(epoch, step, loss1.item(), loss2.item(), loss3.item(), loss4.item(), (time() - start) / 60))

    mean_loss = sum(mean_loss) / len(mean_loss)

    # 保存模型
    if epoch % 50 is 0 and epoch is not 0:

        # 网络模型的命名方式为：epoch轮数+当前minibatch的loss+本轮epoch的平均loss
        torch.save(net.state_dict(), './module/net{}-{:.3f}-{:.3f}.pth'.format(epoch, loss, mean_loss))

    # 对深度监督系数进行衰减
    if epoch % 40 is 0 and epoch is not 0:
        alpha *= 0.8

# 深度监督的系数变化
# 1.000
# 0.800
# 0.640
# 0.512
# 0.410
# 0.328
# 0.262
# 0.210
# 0.168
# 0.134
# 0.107
# 0.086
# 0.069
# 0.055
# 0.044
# 0.035
# 0.028
# 0.023
# 0.018
# 0.014
# 0.012
# 0.009
# 0.007
# 0.006
# 0.005
# 0.004
# 0.003
# 0.002
# 0.002
# 0.002
# 0.001
# 0.001
# 0.001
# 0.001
# 0.001
# 0.000
# 0.000


Collecting Visdom
[?25l  Downloading https://files.pythonhosted.org/packages/c9/75/e078f5a2e1df7e0d3044749089fc2823e62d029cc027ed8ae5d71fafcbdc/visdom-0.1.8.9.tar.gz (676kB)
[K     |▌                               | 10kB 18.1MB/s eta 0:00:01[K     |█                               | 20kB 9.1MB/s eta 0:00:01[K     |█▌                              | 30kB 7.3MB/s eta 0:00:01[K     |██                              | 40kB 6.7MB/s eta 0:00:01[K     |██▍                             | 51kB 4.3MB/s eta 0:00:01[K     |███                             | 61kB 4.8MB/s eta 0:00:01[K     |███▍                            | 71kB 4.9MB/s eta 0:00:01[K     |███▉                            | 81kB 5.2MB/s eta 0:00:01[K     |████▍                           | 92kB 5.5MB/s eta 0:00:01[K     |████▉                           | 102kB 5.5MB/s eta 0:00:01[K     |█████▎                          | 112kB 5.5MB/s eta 0:00:01[K     |█████▉                          | 122kB 5.5MB/s eta 0:00:01[K 

ModuleNotFoundError: ignored