In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
import numpy as np
import torch
import torchvision.utils
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity
from torch import nn
import torch.nn.functional as F
from proj.show import imsshow
from proj.resblock import ResBlock

In [None]:
img_recon = np.load('../data_after_prob1.npz')["recon"]
img_fully = np.load('../cine.npz')['dataset']
mask= np.load('../data_after_prob1.npz')['mask']
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mask=torch.from_numpy(mask)
k_masked = np.load('../data_after_prob1.npz')['k_masked'].astype(np.complex64)

In [None]:
#对图像进行归一化处理
def normalize(img):
    img = (img - img.min()) / (img.max() - img.min()+1e-8)
    return img
img_fully = normalize(img_fully)
img_recon = normalize(img_recon)
imsshow(img_fully[0],num_col=5,cmap='gray', is_colorbar=True)

In [None]:
#将200个数据分割成5:1:2的训练（125）、验证（25）、测试集（50）
train_label = img_fully[:125]
train_data = img_recon[:125]
val_label = img_fully[125:150]
val_data = img_recon[125:150]
test_label = img_fully[150:]
test_data = img_recon[150:]
train_k_masked = k_masked[:125]
val_k_masked = k_masked[125:150]
test_k_masked = k_masked[150:]


In [None]:
#把numpy数组转变成torch类型，构建loader
train_data = torch.from_numpy(train_data).float()
train_label = torch.from_numpy(train_label).float()
val_data = torch.from_numpy(val_data).float()
val_label = torch.from_numpy(val_label).float()
test_data = torch.from_numpy(test_data).float()
test_label = torch.from_numpy(test_label).float()
train_k_masked = torch.from_numpy(train_k_masked)
val_k_masked = torch.from_numpy(val_k_masked)
test_k_masked = torch.from_numpy(test_k_masked)
train = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_data, train_label, train_k_masked), batch_size=5, shuffle=True)
val = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(val_data, val_label, val_k_masked), batch_size=5, shuffle=True)
test = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_data, test_label, test_k_masked), batch_size=5, shuffle=True)

In [None]:
class ResBlock_Mini(nn.Module):
    def __init__(self):
        super(ResBlock_Mini, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1) 
        self.pool = nn.MaxPool3d(2, 2)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1) 
        self.conv7 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
        self.conv8 = nn.Conv3d(16, 1, kernel_size=3, padding=1)
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)

    def forward(self, x):
        identity = x
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv7(out))
        out = self.conv8(out)
        out += identity
        return out


In [None]:
#傅里叶变换(tensor)
def fft2c(img):
    return torch.fft.fftn(img, dim=(-2, -1))
#傅里叶逆变换
def ifft2c(img):
    return torch.fft.ifftn(img, dim=(-2, -1))

In [None]:
class DataConsistencyLayer(nn.Module):
    def __init__(self, mask,lambda_value=0.05):
        super(DataConsistencyLayer, self).__init__()
        self.mask = mask
        # λ可以是固定的也可以是可训练的参数
        #将其定义为可训练的参数最后会学习到0......
        #self.lambda_value = nn.Parameter(torch.tensor(lambda_value))
        self.lambda_value = lambda_value
    def print_lambda(self):
        print(self.lambda_value)
    def forward(self, cnn_output,original_kspace):
        cnn_output_in_kspace = fft2c(cnn_output)
        mask = self.mask.repeat(5,1,1,1,1)
        cnn_output_in_kspace[mask==1]=(cnn_output_in_kspace[mask==1]+original_kspace[mask==1]*self.lambda_value)/ \
                                                (1+self.lambda_value)
        consistent_output = ifft2c(cnn_output_in_kspace)
        consistent_output = consistent_output.abs()
        return consistent_output

In [None]:
class CascadingModel_with_multi_layer(nn.Module):
    def __init__(self, nn1,dc):
        super(CascadingModel_with_multi_layer, self).__init__()
        #根据迭代层数，增加网络和数据一致性层
        self.denoising_network1 =  nn1
        #self.denoising_network2 =  nn2
        #self.denoising_network3 =  nn3

        self.data_consistency_layer1 =  dc
        #self.data_consistency_layer2 =  dc
        #self.data_consistency_layer3 =  dc
        
    
    def forward(self, x, k_masked):
        denoised_image1 = self.denoising_network1(x)
        consistent_output1 = self.data_consistency_layer1(denoised_image1,k_masked)
        #denoised_image2 = self.denoising_network2(consistent_output1)
        #consistent_output2 = self.data_consistency_layer2(denoised_image2,k_masked)
        #denoised_image3 = self.denoising_network3(consistent_output2)
        #consistent_output3 = self.data_consistency_layer3(denoised_image3,k_masked)
        #return consistent_output3
        #return consistent_output2
        return consistent_output1

In [None]:
mask=mask.to(device)
denoising_net1 = ResBlock_Mini().to(device)
#denoising_net2 = ResBlock_Mini().to(device)
#denoising_net3 = ResBlock_Mini().to(device)
data_consistency_layer = DataConsistencyLayer(mask).to(device)
model = CascadingModel_with_multi_layer(denoising_net1,data_consistency_layer).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
num_epochs = 100
PSNR=[]
SSIM=[]
loss_train=[]
loss_val=[]
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}")
    loss_tmp=0.0
    for i, data in enumerate(train, 0):
        data[0]=data[0].reshape(5,1,20,192,192)
        data[1]=data[1].reshape(5,1,20,192,192)
        data[2]=data[2].reshape(5,1,20,192,192)
        inputs, labels, kmask=  data[0].to(device), data[1].to(device), data[2].to(device)
        optimizer.zero_grad()
        outputs = model(inputs, kmask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        loss_tmp+=loss
    loss_train.append(loss_tmp/len(train))
    with torch.no_grad():
        val_loss = 0.0
        for i, data in enumerate(val, 0):
            data[0]=data[0].reshape(5,1,20,192,192)
            data[1]=data[1].reshape(5,1,20,192,192)
            data[2]=data[2].reshape(5,1,20,192,192)
            inputs, labels ,kmask   =  data[0].to(device), data[1].to(device), data[2].to(device)
            outputs = model(inputs, kmask)
            val_loss +=criterion(outputs, labels).item()+1e-10
        loss_val.append(val_loss/len(val))
    if (epoch+1) % 5 == 0:
        with torch.no_grad():
            val_loss = 0.0
            val_ssim = 0.0
            for i, data in enumerate(val, 0):
                data[0]=data[0].reshape(5,1,20,192,192)
                data[1]=data[1].reshape(5,1,20,192,192)
                data[2]=data[2].reshape(5,1,20,192,192)
                inputs, labels, kmask=  data[0].to(device), data[1].to(device), data[2].to(device)
                outputs = model(inputs, kmask)
                val_loss += 20 * math.log10(1) - 10 * math.log10(criterion(outputs, labels).item()+1e-5)
                for j in range(5):
                    val_ssim += structural_similarity(outputs[j][0].cpu().numpy(), labels[j][0].cpu().numpy())
            print(f"Validation PSNR: {val_loss / len(val)}")
            print(f"Validation SSIM: {val_ssim / len(val) / 5}")
            PSNR.append(val_loss / len(val))
            SSIM.append(val_ssim / len(val) / 5)

In [None]:
name="bonus_resblock_32conv_noupdown_100epoch_withcascade_0times_new"
problem="problem5"
torch.save(model,f".\\data\\{problem}\\{name}.pth")
PSNR=torch.tensor(PSNR)
torch.save(PSNR,f".\\data\\{problem}\\{name}_PSNR.pth")
SSIM=torch.tensor(SSIM)
torch.save(SSIM,f".\\data\\{problem}\\{name}_SSIM.pth")
loss_val=torch.tensor(loss_val)
torch.save(loss_val,f".\\data\\{problem}\\{name}_loss_val.pth")
loss_train=torch.tensor(loss_train)
torch.save(loss_train,f".\\data\\{problem}\\{name}_loss_train.pth")

In [None]:
inputs, labels, kmask=  data[0].to(device), data[1].to(device), data[2].to(device)
outputs = model(inputs, kmask)
inp=inputs.cpu().detach().numpy()
out=outputs.cpu().detach().numpy()
lab=labels.cpu().detach().numpy()
imsshow (inp[0].reshape(20,192,192),num_col=5,cmap='gray', is_colorbar=True,titles=["input"]*20)
imsshow( out[0].reshape(20,192,192),num_col=5,cmap='gray', is_colorbar=True,titles=["output"]*20)
imsshow(lab[0].reshape(20,192,192),num_col=5,cmap='gray', is_colorbar=True,titles=["label"]*20)

In [None]:
#输出测试集的PSNR和SSIM的平均值和标准差
model=torch.load("resblock_32conv_noupdown_100epoch.pth")
with torch.no_grad():
    #print("lambda:",model.data_consistency_layer.print_lambda())
    test_loss = []
    test_ssim = []
    for i, data in enumerate(test, 0):
        data[0]=data[0].reshape(5,1,20,192,192)
        data[1]=data[1].reshape(5,1,20,192,192)
        inputs, labels=  data[0].to(device), data[1].to(device)
        outputs = model(inputs)
        #20 * math.log10(max_intensity) - 10 * np.log10(compute_mse(reconstructed_im, target_im) + eps)
        test_loss .append( 20 * math.log10(1) - 10 * math.log10(criterion(outputs, labels).item()+1e-5))
        for j in range(5):
            test_ssim .append( structural_similarity(outputs[j][0].cpu().numpy(), labels[j][0].cpu().numpy()))
    print(f"Mean Test PSNR after Reconstruction: {sum(test_loss) / len(test_loss)}")
    print(f"Std Test PSNR after Reconstruction: {np.std(test_loss)}")
    print("------------------------------------------------------------")
    print(f"Mean Test SSIM after Reconstruction: {sum(test_ssim) / len(test_ssim)}")
    print(f"Std Test SSIM after Reconstruction: {np.std(test_ssim)}")
    #print(f"Validation PSNR: {test_loss / len(test)}")
    #print(f"Validation SSIM: {test_ssim / len(test) / 5}")

In [None]:
#显示训练集的第一个数据
inputs, labels = test_data[0], test_label[0]
inputs=inputs.reshape(1,1,20,192,192)
labels=labels.reshape(1,1,20,192,192)
outputs = model(inputs.to(device))
outputs = outputs.cpu().detach().numpy()
outputs = outputs.reshape(1,20,192,192)
#input
imsshow (inputs[0].reshape(20,192,192),num_col=5,cmap='gray', is_colorbar=True,titles=["input"]*20)
imsshow(outputs[0],num_col=5,cmap='gray', is_colorbar=True,titles=["output"]*20)
imsshow(labels[0].reshape(20,192,192),num_col=5,cmap='gray', is_colorbar=True,titles=["label"]*20)