In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
import numpy as np
import torch
import torchvision.utils
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity
from torch import nn
import torch.nn.functional as F
from proj.show import imsshow
from proj.resblock import ResBlock

## 读取第一题的数据并分隔

In [None]:
img_recon = np.load('data/data_after_prob1.npz')["recon"]
img_fully = np.load('data/cine.npz')['dataset']
mask= np.load('data/data_after_prob1.npz')['mask']
print(img_fully.shape,img_fully.dtype)
print(img_recon.shape,img_recon.dtype)
imsshow(img_fully[0],num_col=5,cmap='gray', is_colorbar=True)
imsshow(img_recon[0],num_col=5,cmap='gray', is_colorbar=True)


In [None]:
#对图像进行归一化处理
def normalize(img):
    img = (img - img.min()) / (img.max() - img.min()+1e-8)
    return img
img_fully = normalize(img_fully)
img_recon = normalize(img_recon)
imsshow(img_fully[0],num_col=5,cmap='gray', is_colorbar=True)

In [None]:
#将200个数据分割成5:1:2的训练（125）、验证（25）、测试集（50）
train_label = img_fully[:125]
train_data = img_recon[:125]
#数据增强，现在shape是(125, 20, 192, 192)，将最后两个维度旋转90度，180度，270度（效果不好）
#将训练集中的数据旋转90度，180度，270度
#train_label = np.concatenate([train_label, np.rot90(train_label, 1, axes=(2, 3))])
#train_label = np.concatenate([train_label, np.rot90(train_label, 2, axes=(2, 3))])
#train_data = np.concatenate([train_data, np.rot90(train_data, 1, axes=(2, 3))])
#train_data = np.concatenate([train_data, np.rot90(train_data, 2, axes=(2, 3))])
print(train_label.shape,train_data.shape)
val_label = img_fully[125:150]
val_data = img_recon[125:150]
test_label = img_fully[150:]
test_data = img_recon[150:]

In [None]:
#把numpy数组转变成torch类型，构建loader
train_data = torch.from_numpy(train_data).float()
train_label = torch.from_numpy(train_label).float()
val_data = torch.from_numpy(val_data).float()
val_label = torch.from_numpy(val_label).float()
test_data = torch.from_numpy(test_data).float()
test_label = torch.from_numpy(test_label).float()
train = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_data, train_label), batch_size=5, shuffle=True)
val = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(val_data, val_label), batch_size=5, shuffle=True)
test = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_data, test_label), batch_size=5, shuffle=True)

In [None]:
#简单模型添加batchnorm(第四题用到的)和dropout
class ResBlock_Mini(nn.Module):
    def __init__(self):
        super(ResBlock_Mini, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1) 
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1) 
        self.dropout = nn.Dropout3d(0.4)
        self.conv7 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
        self.conv8 = nn.Conv3d(16, 1, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(16)
        self.bn2 = nn.BatchNorm3d(32)
        self.bn3 = nn.BatchNorm3d(16)
    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        #添加dropout
        #out = self.dropout(out)
        out = F.relu(self.bn3(self.conv7(out)))
        out = self.conv8(out)
        out += identity
        return out


In [None]:
#复杂模型添加batchnorm(第四题用到的)和dropout
class ResBlock(nn.Module):
    def __init__(self):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1) 
        self.pool = nn.MaxPool3d(2, 2) 
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        #dropout
        self.dropout = nn.Dropout3d(0.4)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1) 
        self.conv4 = nn.Conv3d(64, 128, kernel_size=3, padding=1) 
        self.conv5 = nn.Conv3d(128, 64, kernel_size=3, padding=1) 
        self.conv6 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.conv7 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
        self.conv8 = nn.Conv3d(16, 1, kernel_size=3, padding=1)
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.bn = nn.BatchNorm3d(16)
        self.bn2 = nn.BatchNorm3d(32)
        self.bn3 = nn.BatchNorm3d(64)
        self.bn4 = nn.BatchNorm3d(128)
        self.bn5 = nn.BatchNorm3d(64)
        self.bn6 = nn.BatchNorm3d(32)
        self.bn7 = nn.BatchNorm3d(16)
        self.bn8 = nn.BatchNorm3d(1)
    def forward(self, x):
        identity = x
        out = F.relu(self.bn(self.conv1(x)))
        out = self.pool(out)
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.relu(self.bn3(self.conv3(out)))
        #out = self.dropout(out)
        out = F.relu(self.bn4(self.conv4(out)))
        out = F.relu(self.bn5(self.conv5(out)))
        #out = self.dropout(out)
        out = F.relu(self.bn6(self.conv6(out)))
        out = F.relu(self.bn7(self.conv7(out)))
        out = self.up(out)
        out = self.conv8(out)
        out += identity
        return out

In [None]:
import torch.optim as optim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {device}")
model =ResBlock_Mini().to(device)
#model = ResBlock().to(device)
criterion = nn.MSELoss() 
#随机梯度下降，效果不好，拟合慢
#optimizer = optim.SGD(model.parameters(), lr=0.01)
optimizer = optim.Adam(model.parameters() , weight_decay=1e-8)

num_epochs = 500  # 设置训练的轮数

PSNR=[]
SSIM=[]
loss_train=[]
loss_val=[]

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}")
    running_loss = 0.0
    loss_tmp=0.0
    for i, data in enumerate(train, 0):
        data[0]=data[0].reshape(5,1,20,192,192)
        data[1]=data[1].reshape(5,1,20,192,192)
        inputs, labels =  data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        loss_tmp+=loss
    loss_train.append(loss_tmp/len(train))
    with torch.no_grad():
        val_loss = 0.0
        for i, data in enumerate(val, 0):
            data[0]=data[0].reshape(5,1,20,192,192)
            data[1]=data[1].reshape(5,1,20,192,192)
            inputs, labels =  data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            val_loss +=criterion(outputs, labels).item()+1e-10
        loss_val.append(val_loss/len(val))
    #在validation计算PSNR
    if (epoch+1) % 5 == 0:
        with torch.no_grad():
            val_loss = 0.0
            for i, data in enumerate(val, 0):
                data[0]=data[0].reshape(5,1,20,192,192)
                data[1]=data[1].reshape(5,1,20,192,192)
                inputs, labels =  data[0].to(device), data[1].to(device)
                outputs = model(inputs)
                val_loss += 20 * math.log10(1) - 10 * math.log10(criterion(outputs, labels).item()+1e-10)
            print(f"Validation PSNR: {val_loss / len(val)}")
            PSNR.append(val_loss / len(val))
    #在validation计算SSIM
        with torch.no_grad():
            val_ssim = 0.0
            for i, data in enumerate(val, 0):
                data[0]=data[0].reshape(5,1,20,192,192)
                data[1]=data[1].reshape(5,1,20,192,192)
                inputs, labels =  data[0].to(device), data[1].to(device)
                outputs = model(inputs)
                for j in range(5):
                    val_ssim += structural_similarity(outputs[j][0].cpu().numpy(), labels[j][0].cpu().numpy())
            print(f"Validation SSIM: {val_ssim / len(val) / 5}")
            SSIM.append(val_ssim / len(val) / 5)
print('Finished Training')

In [None]:
#保存模型和数据
name="resblock_32conv_noupdown_500epoch_with_batchnorm"
problem="problem3"
torch.save(model,f".\\data\\{problem}\\{name}.pth")
PSNR=torch.tensor(PSNR)
torch.save(PSNR,f".\\data\\{problem}\\{name}_PSNR.pth")
SSIM=torch.tensor(SSIM)
torch.save(SSIM,f".\\data\\{problem}\\{name}_SSIM.pth")
loss_val=torch.tensor(loss_val)
torch.save(loss_val,f".\\data\\{problem}\\{name}_loss_val.pth")
loss_train=torch.tensor(loss_train)
torch.save(loss_train,f".\\data\\{problem}\\{name}_loss_train.pth")

In [None]:
#显示训练集的第一个数据
inputs, labels = test_data[0], test_label[0]
inputs=inputs.reshape(1,1,20,192,192)
labels=labels.reshape(1,1,20,192,192)
outputs = model(inputs.to(device))
outputs = outputs.cpu().detach().numpy()
outputs = outputs.reshape(1,20,192,192)
print(outputs[0][0].shape)
imsshow (inputs[0].reshape(20,192,192),num_col=5,cmap='gray', is_colorbar=True,titles=["input"]*20)
imsshow(outputs[0],num_col=5,cmap='gray', is_colorbar=True,titles=["output"]*20)
imsshow(labels[0].reshape(20,192,192),num_col=5,cmap='gray', is_colorbar=True,titles=["label"]*20)