In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from torch import nn
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
import torch
from skimage.color import rgb2gray
import torch.nn.functional as F

In [None]:
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
path="./cine_seg.npz" # input the path of cine_seg.npz in your environment
data=np.load(path,allow_pickle=True)

In [None]:
def imgshow(im, cmap=None, rgb_axis=None, dpi=100, figsize=(6.4, 4.8)):
    if isinstance(im, torch.Tensor):
        im = im.to('cpu').detach().cpu().numpy()
    if rgb_axis is not None:
        im = np.moveaxis(im, rgb_axis, -1)
        im = rgb2gray(im)

    plt.figure(dpi=dpi, figsize=figsize)
    norm_obj = Normalize(vmin=im.min(), vmax=im.max())
    plt.imshow(im, norm=norm_obj, cmap=cmap)
    plt.colorbar()
    plt.show()
    plt.close('all')

In [None]:
#把list转换成array
data_new=[]
for x in data.files:
    data_new.append(data[x])
data_new=np.array(data_new)

In [None]:
#计算data_new中的label有多少全是0并且删除，data_new的shape为(1798,2,256,256),label是data_new[:,1]
count=0
print(data_new.shape)
for i in range(data_new.shape[0]):
    if np.sum(data_new[i,1])==0:
        count+=1
print(count)
data_new=data_new[data_new[:,1].sum(axis=(1,2))!=0]
print(data_new.shape)

In [None]:
#计算data_new中的label的类别数量小于等于3个的并且删除，data_new的shape为(1798,2,256,256),label是data_new[:,1]
count=0
print(data_new.shape)
for i in range(data_new.shape[0]):
    if len(np.unique(data_new[i,1]))<=3:
        count+=1
print(count)
data_new=data_new[np.array([len(np.unique(x))>3 for x in data_new[:,1]])]
print(data_new.shape)

In [None]:
total_samples = len(data_new)
num_train = int(total_samples * 4 / 7)
num_val = int(total_samples * 1 / 7)
num_test = total_samples - num_train - num_val
print(total_samples,num_train, num_val, num_test)
#为了让训练时每个batch size一致，取整
train_input = data_new[:850][:,0]
val_input = data_new[850:1100][:,0]
test_input = data_new[1100:][:,0]
train_output = data_new[:850][:,1]
val_output = data_new[850:1100][:,1]
test_output = data_new[1100:][:,1]

In [None]:
#将0,85,170,255转换为0,1,2,3
train_output[train_output == 85] = 1
train_output[train_output == 170] = 2
train_output[train_output == 255] = 3
val_output[val_output == 85] = 1
val_output[val_output == 170] = 2
val_output[val_output == 255] = 3
test_output[test_output == 85] = 1
test_output[test_output == 170] = 2
test_output[test_output == 255] = 3

In [None]:
#加入数据增强
is_aug = False
if is_aug:
    #旋转90度
    def rotate90(img):
        return np.rot90(img, k=1, axes=(0, 1))
    #翻转
    def flip(img):
        return np.flip(img, axis=1)
    #应用到训练集，验证集和测试集不需要增强
    train_input_aug = np.concatenate([train_input, np.array([rotate90(x) for x in train_input])])
    train_output_aug = np.concatenate([train_output, train_output])
    train_input_aug = np.concatenate([train_input_aug, np.array([flip(x) for x in train_input])])
    train_output_aug = np.concatenate([train_output_aug, train_output])
    train_input = train_input_aug
    train_output = train_output_aug
print(train_input.shape, train_output.shape)

In [None]:
#把numpy数组转变成torch类型，构建loader
train_data = torch.from_numpy(train_input).float()
train_label = torch.from_numpy(train_output).float()
val_data = torch.from_numpy(val_input).float()
val_label = torch.from_numpy(val_output).float()
test_data = torch.from_numpy(test_input) .float()
test_label = torch.from_numpy(test_output) .float()
BZ=10
train = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_data, train_label), batch_size=BZ, shuffle=True)
val = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(val_data, val_label), batch_size=BZ, shuffle=True)
test = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_data, test_label), batch_size=BZ, shuffle=True)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        """
        x: [B, C_in, H, W]
        out: [B, C_out, H, W]
        """
        out = self.layers(x)
        return out

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):        
        return self.maxpool_conv(x)
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()        
        
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
        

    def forward(self, x1, x2):        
        
        
        x1 = self.up(x1)
        H1, W1 = x1.shape[2:]
        H2, W2 = x2.shape[2:]
        
        x1 = F.pad(x1, [
            (W2-W1) // 2, # left
            (W2-W1) // 2, # right
            (H2-H1) // 2, # top
            (H2-H1) // 2  # bottom
            ])

        x = torch.cat([x2, x1], dim=1)
        out = self.conv(x)
        
        return out
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, C_base=64):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.in_conv = DoubleConv(n_channels, C_base)

        self.down1 = Down(C_base, 2 * C_base)
        self.down2 = Down(2 * C_base, 4 * C_base)
        self.down3 = Down(4 * C_base, 8 * C_base)
        self.down4 = Down(8 * C_base, 16 * C_base)
        self.up1 = Up(16 * C_base, 8 * C_base)
        self.up2 = Up(8 * C_base, 4 * C_base)
        self.up3 = Up(4 * C_base, 2 * C_base)
        self.up4 = Up(2 * C_base, C_base)
        self.out_projection = nn.Conv2d(C_base, n_classes, kernel_size=1)
        

    def forward(self, x):
        """
        :param x: [B, n_channels, H, W]
        :return [B, n_classes, H, W]
        """ 
        x1 = self.in_conv(x)
        
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        pred = self.out_projection(x)        
        
        return pred

#test
net = UNet(n_channels=1, n_classes=4)
x = torch.randn(5, 1, 256, 256)
out = net(x)
assert (5, 4, 256, 256) == out.shape

In [None]:
def pixel_wise_cross_entropy_loss_weighted(logits, labels, class_weights):
    '''
    Custom weighted cross entropy loss for pixel-wise class weighting using class index labels
    '''
    n_class = len(class_weights)
    class_weights = torch.tensor(class_weights, dtype=torch.float32, device=logits.device)
    return nn.CrossEntropyLoss(weight=class_weights)(logits, labels)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=1, n_classes=4) .to(device)
criterion = pixel_wise_cross_entropy_loss_weighted
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs=200
train_losses = []
val_losses = []
train_cur_loss = 0
val_cur_loss = 0
for epoch in range(num_epochs):
    for data in train:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        inputs=inputs.reshape(BZ,1,256,256)
        outputs = model(inputs)
        loss = criterion(outputs, labels.long(),[0.1,0.3,0.3,0.3])
        loss.backward()
        optimizer.step()
        train_cur_loss += loss.item()
    train_losses.append(train_cur_loss / len(train))
    train_cur_loss = 0
    for data in val:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs=inputs.reshape(BZ,1,256,256)
        outputs = model(inputs)
        loss = criterion(outputs, labels.long(),[0.1,0.3,0.3,0.3])
        val_cur_loss += loss.item()
    val_losses.append(val_cur_loss / len(val))
    val_cur_loss = 0
    print (f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_losses[-1]}, Val Loss: {val_losses[-1]}')
    

In [None]:
#将模型和两个loss保存
#name="(c)model200epslr001"
name="test"
torch.save(model.state_dict(), ".\\savedata\\"+name+".pth")
np.savez(".\\savedata\\"+name+".npz",train_losses=train_losses,val_losses=val_losses)

In [None]:
import matplotlib.pyplot as plt

# 绘制训练损失和验证损失
length_of_data_loss = len(train_losses)
x_ticks_loss = [i for i in range(1, length_of_data_loss + 1)]
fig, ax1 = plt.subplots(figsize=(20, 10))

ax1.plot(x_ticks_loss, train_losses, label='Train Loss', marker='o')
ax1.plot(x_ticks_loss, val_losses, label='Validation Loss', marker='o')

# 添加竖线
for i in range(1, length_of_data_loss + 2, 1):
    ax1.axvline(x=i-1, color='gray', linestyle='--', linewidth=0.5)

# 添加文本注释
# for i in range(length_of_data_loss):
#     if (i) % 1 == 0 or i == length_of_data_loss - 1 or i == 0:
#         ax1.text(x_ticks_loss[i], train_losses[i], f'{train_losses[i]:.6f}', ha='center', va='bottom', fontsize=12)
#         ax1.text(x_ticks_loss[i], val_losses[i], f'{val_losses[i]:.6f}', ha='center', va='bottom', fontsize=12)

ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()

plt.tight_layout()
plt.show()


In [None]:
#展示某张图片的预测结果
# input=val_data[101].to(device)
# label=val_label[101].to(device)
# output=model(input.reshape(1,1,256,256))
# _, predicted_indices = torch.max(output, 1)  # 使用torch.max获取类别索引，第一个维度是类别维

# # 将类别索引映射到灰度值
# # 创建一个映射数组
# mapping = torch.tensor([0, 85, 170, 255], dtype=torch.uint8).to(predicted_indices.device)
# # 使用映射数组将类别索引转换为相应的灰度值
# final_image = mapping[predicted_indices]
# imgshow(final_image[0])
# imgshow(label)

In [None]:
#展示某张图片的原始图片
# print(val_label[101].shape)
# print(np.unique(val_label[101]))
# #输出每个类别的像素数量
# for i in range(4):
#     a=val_label[101]==i
#     print(sum(sum(a)))
# imgshow(val_data[101])
# imgshow(val_label[101])


In [None]:
#计算验证集上的Dice系数
def dice_coefficient(pred, true, label):
    index={85:1,170:2,255:3}
    # 获取指定标签的二值图像
    pred_binary = (pred == label)
    true_binary = (true == index[label])
    #显示二值图像
    
    # 计算交集
    intersection = np.sum(pred_binary & true_binary)
    # 计算每个图像的目标区域像素数量
    pred_sum = np.sum(pred_binary)
    true_sum = np.sum(true_binary)
    
    # 计算Dice系数
    if pred_sum + true_sum == 0:
        return 1.0  # 如果预测和真实图像中都没有这个标签，认为Dice系数为1
    dice = (2. * intersection) / (pred_sum + true_sum)
    
    return dice
test_data=val_data.to(device)
test_label=val_label.to(device)
model = model.to(device)
total_rv = []
total_myo = []
total_lv = []
cnt_rv=cnt_myo=cnt_lv=0
min_rv=min_myo=min_lv=1
index_min_rv=index_min_myo=index_min_lv=0
for i in range(len(val_data)):
    input=val_data[i].to(device)
    label=val_label[i].to(device)
    output=model(input.reshape(1,1,256,256))
    _, predicted_indices = torch.max(output, 1)  # 使用torch.max获取类别索引，第一个维度是类别维

    # 将类别索引映射到灰度值
    # 创建一个映射数组
    mapping = torch.tensor([0, 85, 170, 255], dtype=torch.uint8).to(predicted_indices.device)
    # 使用映射数组将类别索引转换为相应的灰度值
    final_image = mapping[predicted_indices]
    pred=final_image.reshape(256,256).cpu().numpy()
    true=label.cpu().numpy()
    dice_rv = dice_coefficient(pred, true, 85)
    dice_myo = dice_coefficient(pred, true, 170)
    dice_lv = dice_coefficient(pred, true, 255)
    if dice_rv<min_rv:
        min_rv=dice_rv
        index_min_rv=i
    if dice_myo<min_myo:
        min_myo=dice_myo
        index_min_myo=i
    if dice_lv<min_lv:
        min_lv=dice_lv
        index_min_lv=i
    if dice_rv!=0:
        total_rv .append(dice_rv)
        cnt_rv+=1
    if dice_myo!=0:
        total_myo.append(dice_myo)
        cnt_myo+=1
    if dice_lv!=0:
        total_lv.append(dice_lv)
        cnt_lv+=1
total_rv = np.array(total_rv)
total_myo = np.array(total_myo)
total_lv = np.array(total_lv)
#计算平均Dice系数
mean_rv = np.mean(total_rv)
mean_myo = np.mean(total_myo)
mean_lv = np.mean(total_lv)
#计算标准差
std_rv = np.std(total_rv)
std_myo = np.std(total_myo)
std_lv = np.std(total_lv)
print(f'RV Dice系数平均值: {mean_rv:.4f}, 标准差: {std_rv:.4f}')
print(f'MYO Dice系数平均值: {mean_myo:.4f}, 标准差: {std_myo:.4f}')
print(f'LV Dice系数平均值: {mean_lv:.4f}, 标准差: {std_lv:.4f}')
# print(min_rv,index_min_rv)
# print(min_myo,index_min_myo)
# print(min_lv,index_min_lv)