In [1]:
%matplotlib inline
import numpy as np
import cv2
import os
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import pywt
import re
import pydicom
import SimpleITK as sitk
import astra
from ipywidgets import interact

In [3]:
def find_paths_with_id(root_directory, list):    
    for dirpath, dirnames, filenames in os.walk(root_directory):
        for name in dirnames + filenames:
            if 'npy' in name:
                # Construct the full path and add to list
                full_path = os.path.join(dirpath, name)
                img = np.load(full_path)
                img[img<0] = 0
                list.append(img)

In [4]:
dic_noi = '/home/haoran/task1/project/scatter'
dic_clean = '/home/haoran/task1/project/clean'
projections_noise = []
projections = []
find_paths_with_id(dic_clean, projections)
find_paths_with_id(dic_noi, projections_noise)

In [5]:
len(projections_noise)

64

In [6]:
np.max(projections)

4.7473383

In [9]:
from skimage.metrics import structural_similarity
from skimage.metrics import mean_squared_error
import numpy as np

def compute_metrics(original, processed):
    slices = original.shape[0]
    
    mse_values = []
    ssim_values = []
    
    # Determine the data range for floating-point images
    data_range = original.max() - original.min()  # This assumes both images have the same scale
    
    for i in range(slices-1):
        slice_original = original[ i,:, :]
        slice_processed = processed[ i,:, :]
        
        # Compute MSE for the current slice
        mse_slice = mean_squared_error(slice_original, slice_processed)
        mse_values.append(mse_slice)
        
        # Compute SSIM for the current slice, including the data_range
        ssim_slice = structural_similarity(slice_original, slice_processed, data_range=data_range)
        ssim_values.append(ssim_slice)
   
# Calculate NMSE
    numerator = np.sum((original - processed) ** 2)
    denominator = np.sum(original ** 2)

    nmse = numerator / denominator


    # Average MSE and SSIM over all slices
    ssim = np.std(ssim_values)
    mse_avg = np.mean(mse_values)
    ssim_avg = np.mean(ssim_values)
    
    # Compute PSNR for the entire 3D dataset, using the average MSE
    max_pixel_value = data_range  # Adjusted to use the calculated data range
    psnr_avg = 20 * np.log10(max_pixel_value / np.sqrt(mse_avg))
    
    return nmse, 100*ssim_avg, psnr_avg

In [11]:
def wavelet(data):
    depth, height, width = data.shape
 
    final_data = np.zeros((depth, height , width))
    
    for z in range(depth):
        coeff = pywt.wavedec2(data[z], 'haar', level = 4)
        hh0 = coeff[4][2]
        hl0= coeff[4][1]
        lh0 = coeff[4][0]
        hh1 = coeff[3][2]
        hl1 = coeff[3][1]
        lh1 = coeff[3][0]
        hh2 = coeff[2][2]
        hl2 = coeff[2][1]
        lh2 = coeff[2][0]
        hh3 = coeff[1][2]
        hl3 = coeff[1][1]
        lh3 = coeff[1][0]
        ll = coeff[0]
        # print(hl0)
        # print(ll)
        final_data[z, 0:32, 0:32] = ll
        final_data[z, 32:64, 32:64] = hh3  
        final_data[z, 32:64, 0:32] = hl3
        final_data[z, 0:32, 32:64] = lh3  

        final_data[z, 64:128, 64:128] = hh2
        final_data[z, 64:128, 0:64] = hl2
        final_data[z, 0:64, 64:128] = lh2  

        final_data[z, 128:256, 128:256] = hh1  
        final_data[z, 128:256, 0:128] = hl1
        final_data[z, 0:128, 128:256] = lh1  

        final_data[z, 256:512, 256:512] = hh0  
        final_data[z, 256:512, 0:256] = hl0
        final_data[z, 0:256, 256:512] = lh0
    
    return final_data

In [12]:
def re_wavelet(final_data):
    depth, height, width = final_data.shape
    
    data = np.zeros((depth, height, width))
  
    for z in range(depth):
        ll = final_data[z, 0:32, 0:32] 
        hh3 =  final_data[z, 32:64, 32:64]  
        hl3 = final_data[z, 32:64, 0:32]  
        lh3 = final_data[z, 0:32, 32:64]   

        hh2 = final_data[z, 64:128, 64:128] 
        hl2 = final_data[z, 64:128, 0:64]  
        lh2 = final_data[z, 0:64, 64:128]    

        hh1 = final_data[z, 128:256, 128:256]  
        hl1 = final_data[z, 128:256, 0:128] 
        lh1 = final_data[z, 0:128, 128:256]  

        hh0 = final_data[z, 256:512, 256:512] 
        hl0 = final_data[z, 256:512, 0:256] 
        lh0 = final_data[z, 0:256, 256:512] 

        coeff = [ll, (lh3, hl3, hh3), (lh2, hl2, hh2), (lh1, hl1, hh1), (lh0, hl0, hh0)]
        data[z] = pywt.waverec2(coeff, 'haar')

        
        
    return data

In [None]:
ww_target = []


for projection in projections:
    ww = wavelet(projection)
    ww = (ww+12)/(70+12)
    ww_target.append(ww)

ww_input = []
for projection in projections_noise:
    ww = wavelet(projection)
    ww = (ww+9)/58
    ww_input.append(ww)

In [None]:
import torch
from torch import nn

In [None]:
np.max(ww_input)

1.0013726333092

In [None]:
np.min(ww_input)

0.013522246788287985

In [None]:
ww_target_array = np.array(ww_target)
print(np.max(ww_target_array), np.min(ww_target_array))
ww_target_array = np.expand_dims(ww_target_array,1)
# ww_target_array = np.expand_dims(ww_target_array,0)
ww_target_ten = torch.Tensor(ww_target_array)
# ww_target_ten = torch.unsqueeze(ww_target_ten, 0)

ww_input_array = np.array(ww_input)
print(np.max(ww_input_array), np.min(ww_input_array))
ww_input_array = np.expand_dims(ww_input_array,1)
# ww_input_array = np.expand_dims(ww_input_array,0)
ww_input_ten = torch.Tensor(ww_input_array)
# ww_input_ten = torch.unsqueeze(ww_input_ten, 0)

0.9900551772699123 0.003583884820705507
1.0013726333092 0.013522246788287985


In [None]:
from torch.utils.data import DataLoader, TensorDataset, random_split

dataset2 = TensorDataset(ww_input_ten, ww_target_ten)
train, test = random_split(dataset2, (50,14))

train_loader = DataLoader(train, batch_size=1, shuffle=False)
val_loader=DataLoader(test,batch_size=1,shuffle=False)
test_loader = DataLoader(test, batch_size=1, shuffle= False)

In [None]:
device=torch.device("cuda")

In [14]:
class DownsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownsampleBlock, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=2, stride=2)
        self.actv = nn.PReLU(out_channels)

    def forward(self, x):
        return self.actv(self.conv(x))


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, cat_channels, out_channels):
        super(UpsampleBlock, self).__init__()

        self.conv = nn.Conv3d(in_channels + cat_channels, out_channels, 3, padding=1)
        self.conv_t = nn.ConvTranspose3d(in_channels, in_channels, 2, stride=2)
        self.actv = nn.PReLU(out_channels)
        self.actv_t = nn.PReLU(in_channels)

    def forward(self, x):
        upsample, concat = x
        upsample = self.actv_t(self.conv_t(upsample))
        return self.actv(self.conv(torch.cat([concat, upsample], 1)))


class InputBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InputBlock, self).__init__()
        self.conv_1 = nn.Conv3d(in_channels, out_channels, 3, padding=1)
        self.conv_2 = nn.Conv3d(out_channels, out_channels, 3, padding=1)

        self.actv_1 = nn.PReLU(out_channels)
        self.actv_2 = nn.PReLU(out_channels)

    def forward(self, x):
        x = self.actv_1(self.conv_1(x))
        return self.actv_2(self.conv_2(x))


class OutputBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutputBlock, self).__init__()
        self.conv_1 = nn.Conv3d(in_channels, in_channels, 3, padding=1)
        self.conv_2 = nn.Conv3d(in_channels, out_channels, 3, padding=1)

        self.actv_1 = nn.PReLU(in_channels)
        self.actv_2 = nn.PReLU(out_channels)

    def forward(self, x):
        x = self.actv_1(self.conv_1(x))
        return self.actv_2(self.conv_2(x))


class DenoisingBlock(nn.Module):
    def __init__(self, in_channels, inner_channels, out_channels):
        super(DenoisingBlock, self).__init__()
        self.conv_0 = nn.Conv3d(in_channels, inner_channels, 3, padding=1)
        self.conv_1 = nn.Conv3d(in_channels + inner_channels, inner_channels, 3, padding=1)
        self.conv_2 = nn.Conv3d(in_channels + 2 * inner_channels, inner_channels, 3, padding=1)
        self.conv_3 = nn.Conv3d(in_channels + 3 * inner_channels, out_channels, 3, padding=1)

        self.actv_0 = nn.PReLU(inner_channels)
        self.actv_1 = nn.PReLU(inner_channels)
        self.actv_2 = nn.PReLU(inner_channels)
        self.actv_3 = nn.PReLU(out_channels)

    def forward(self, x):
        out_0 = self.actv_0(self.conv_0(x))

        out_0 = torch.cat([x, out_0], 1)
        out_1 = self.actv_1(self.conv_1(out_0))

        out_1 = torch.cat([out_0, out_1], 1)
        out_2 = self.actv_2(self.conv_2(out_1))

        out_2 = torch.cat([out_1, out_2], 1)
        out_3 = self.actv_3(self.conv_3(out_2))

        return out_3 + x


class RDUNet(nn.Module):
    r"""
    Residual-Dense U-net for image denoising.
    """
    def __init__(self, **kwargs):
        super().__init__()

        channels = 1
        filters_0 = 8
        filters_1 = 2 * filters_0
        filters_2 = 4 * filters_0
        filters_3 = 8 * filters_0

        # Encoder:
        # Level 0:
        self.input_block = InputBlock(channels, filters_0)
        self.block_0_0 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)
        self.block_0_1 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)
        self.down_0 = DownsampleBlock(filters_0, filters_1)

        # Level 1:
        self.block_1_0 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)
        self.block_1_1 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)
        self.down_1 = DownsampleBlock(filters_1, filters_2)

        # Level 2:
        self.block_2_0 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)
        self.block_2_1 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)
        self.down_2 = DownsampleBlock(filters_2, filters_3)

        # Level 3 (Bottleneck)
        self.block_3_0 = DenoisingBlock(filters_3, filters_3 // 2, filters_3)
        self.block_3_1 = DenoisingBlock(filters_3, filters_3 // 2, filters_3)

        # Decoder
        # Level 2:
        self.up_2 = UpsampleBlock(filters_3, filters_2, filters_2)
        self.block_2_2 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)
        self.block_2_3 = DenoisingBlock(filters_2, filters_2 // 2, filters_2)

        # Level 1:
        self.up_1 = UpsampleBlock(filters_2, filters_1, filters_1)
        self.block_1_2 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)
        self.block_1_3 = DenoisingBlock(filters_1, filters_1 // 2, filters_1)

        # Level 0:
        self.up_0 = UpsampleBlock(filters_1, filters_0, filters_0)
        self.block_0_2 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)
        self.block_0_3 = DenoisingBlock(filters_0, filters_0 // 2, filters_0)

        self.output_block = OutputBlock(filters_0, channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        out_0 = self.input_block(inputs)    # Level 0
        out_0 = self.block_0_0(out_0)
        out_0 = self.block_0_1(out_0)

        out_1 = self.down_0(out_0)          # Level 1
        out_1 = self.block_1_0(out_1)
        out_1 = self.block_1_1(out_1)

        out_2 = self.down_1(out_1)          # Level 2
        out_2 = self.block_2_0(out_2)
        out_2 = self.block_2_1(out_2)

        out_3 = self.down_2(out_2)          # Level 3 (Bottleneck)
        out_3 = self.block_3_0(out_3)
        out_3 = self.block_3_1(out_3)

        out_4 = self.up_2([out_3, out_2])   # Level 2
        out_4 = self.block_2_2(out_4)
        out_4 = self.block_2_3(out_4)

        out_5 = self.up_1([out_2, out_1])   # Level 1
        out_5 = self.block_1_2(out_5)
        out_5 = self.block_1_3(out_5)

        out_6 = self.up_0([out_5, out_0])   # Level 0
        out_6 = self.block_0_2(out_6)
        out_6 = self.block_0_3(out_6)

        return self.output_block(out_6) + inputs

In [None]:
def train1(model1, train_loader, optimizer):
    loss_fn = nn.MSELoss()

    model1.train()
    train_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        x = model1(x)
        
        x1 = x[:,:,:, 0:32, 0:32]
        x2 = x[:,:,:, 32:64, 32:64]  
        x3 = x[:,:,:, 32:64, 0:32]  
        x4 = x[:,:,:, 0:32, 32:64] 

        x5 = x[:,:,:, 64:128, 64:128]   
        x6 = x[:,:,:, 64:128, 0:64]
        x7 = x[:,:,:, 0:64, 64:128] 
        x8 = x[:,:,:, 128:256, 128:256]

        x9 = x[:,:,:, 128:256, 0:128]
        x10 = x[:,:,:, 0:128, 128:256]
        x11 = x[:,:,:, 256:512, 256:512]
        x12 = x[:,:,:, 256:512, 0:256]

        x13 = x[:,:,:, 0:256, 256:512]

        y1 = y[:,:,:, 0:32, 0:32]
        y2 = y[:,:,:, 32:64, 32:64]  
        y3 = y[:,:,:, 32:64, 0:32]  
        y4 = y[:,:,:, 0:32, 32:64] 

        y5 = y[:,:,:, 64:128, 64:128]   
        y6 = y[:,:,:, 64:128, 0:64]
        y7 = y[:,:,:, 0:64, 64:128] 
        y8 = y[:,:,:, 128:256, 128:256]

        y9 = y[:,:,:, 128:256, 0:128]
        y10 = y[:,:,:, 0:128, 128:256]
        y11 = y[:,:,:, 256:512, 256:512]
        y12 = y[:,:,:, 256:512, 0:256]

        y13 = y[:,:,:, 0:256, 256:512]
        
        loss = loss_fn(x1, y1) +loss_fn(x2, y2) +loss_fn(x3, y3)+ loss_fn(x4, y4)+ loss_fn(x5, y5) + loss_fn(x6, y6) +loss_fn(x7, y7)+loss_fn(x8, y8)+ loss_fn(x9, y9) +loss_fn(x10, y10) +loss_fn(x11, y11)+ loss_fn(x12, y12)+ loss_fn(x13, y13)  
        # loss = 0.003*loss_fn(x1, y1) +0.032*loss_fn(x2, y2) +0.701*loss_fn(x3, y3)+ 0.018*loss_fn(x4, y4)+ 0.021*loss_fn(x5, y5) + 0.080*loss_fn(x6, y6) +0.02*loss_fn(x7, y7)+0.02*loss_fn(x8, y8)+ 0.025*loss_fn(x9, y9) +0.02*loss_fn(x10, y10) +0.02*loss_fn(x11, y11)+ 0.02*loss_fn(x12, y12)+ 0.02*loss_fn(x13, y13) 

        loss.backward()  # Accumulate gradients
        train_loss += loss.item()

        optimizer.step()  # Perform optimization step


    train_loss /= len(train_loader.dataset)

    return train_loss


In [None]:
def validate(model, val_loader):
    model.eval() 
    val_loss = 0
    loss_fn = nn.MSELoss()
    
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            x = model(x)
            x1 = x[:,:,:, 0:32, 0:32]
            x2 = x[:,:,:, 32:64, 32:64]  
            x3 = x[:,:,:, 32:64, 0:32]  
            x4 = x[:,:,:, 0:32, 32:64] 

            x5 = x[:,:,:, 64:128, 64:128]   
            x6 = x[:,:,:, 64:128, 0:64]
            x7 = x[:,:,:, 0:64, 64:128] 
            x8 = x[:,:,:, 128:256, 128:256]

            x9 = x[:,:,:, 128:256, 0:128]
            x10 = x[:,:,:, 0:128, 128:256]
            x11 = x[:,:,:, 256:512, 256:512]
            x12 = x[:,:,:, 256:512, 0:256]

            x13 = x[:,:,:, 0:256, 256:512]

            y1 = y[:,:,:, 0:32, 0:32]
            y2 = y[:,:,:, 32:64, 32:64]  
            y3 = y[:,:,:, 32:64, 0:32]  
            y4 = y[:,:,:, 0:32, 32:64] 

            y5 = y[:,:,:, 64:128, 64:128]   
            y6 = y[:,:,:, 64:128, 0:64]
            y7 = y[:,:,:, 0:64, 64:128] 
            y8 = y[:,:,:, 128:256, 128:256]

            y9 = y[:,:,:, 128:256, 0:128]
            y10 = y[:,:,:, 0:128, 128:256]
            y11 = y[:,:,:, 256:512, 256:512]
            y12 = y[:,:,:, 256:512, 0:256]

            y13 = y[:,:,:, 0:256, 256:512]
                
            # loss = 0.003*loss_fn(x1, y1) +0.032*loss_fn(x2, y2) +0.701*loss_fn(x3, y3)+ 0.018*loss_fn(x4, y4)+ 0.021*loss_fn(x5, y5) + 0.080*loss_fn(x6, y6) +0.02*loss_fn(x7, y7)+0.02*loss_fn(x8, y8)+ 0.025*loss_fn(x9, y9) +0.02*loss_fn(x10, y10) +0.02*loss_fn(x11, y11)+ 0.02*loss_fn(x12, y12)+ 0.02*loss_fn(x13, y13) 
        
   
            loss = loss_fn(x1, y1) +loss_fn(x2, y2) +loss_fn(x3, y3)+ loss_fn(x4, y4)+ loss_fn(x5, y5) + loss_fn(x6, y6) +loss_fn(x7, y7)+loss_fn(x8, y8)+ loss_fn(x9, y9) +loss_fn(x10, y10) +loss_fn(x11, y11)+ loss_fn(x12, y12)+ loss_fn(x13, y13) 

            val_loss += loss.item()
    val_loss /= len(val_loader.dataset)

    return val_loss

In [None]:
import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0



In [15]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

import os
epoch_save = 10

val_loss_set=[]
train_loss_set=[]

myUNet1 = RDUNet().to(device)
early_stopping = EarlyStopping(patience=20, verbose=True)
optimizer = optim.Adam( myUNet1.parameters(), lr=0.0001)
scheduler = CosineAnnealingLR(optimizer, T_max=300, eta_min=0)
for epoch in range(300):
    train_loss = train1(myUNet1,train_loader, optimizer)
    
    val_loss = validate(myUNet1, val_loader)
    scheduler.step()

    print('epoch', str(epoch),':train_loss:',train_loss, 'val_loss:',val_loss)
    # print("train_loss: %.6f, val_loss: %.6f", % (train_loss, val_loss))

    if (epoch+1) % epoch_save == 0:
        save_path=os.path.join("/home/haoran/task1/save_model", "model_ww" + str(epoch+1).zfill(4) + ".pth")
        torch.save(myUNet1.state_dict(), save_path)
        

    early_stopping(val_loss = val_loss, model=myUNet1.eval)

   
    save_path=os.path.join("/home/haoran/task1/save_model", "model_ww.pth")
    torch.save(myUNet1.state_dict(), save_path)
    
    # val_loss_set.append(val_loss)
    train_loss_set.append(train_loss)


NameError: name 'device' is not defined

In [16]:
myUNet1 = RDUNet()
myUNet1.load_state_dict(torch.load("*"))
# need to retrain mse 210 sino 170  me 240

<All keys matched successfully>

In [None]:
wl = []
wl_tar = []
recon = []
recon_target = []
myUNet1.cpu()
for (x,y),proj in tqdm(zip(test_loader,projections)):
    # print(x.shape)
    output = myUNet1(x)
    output = np.reshape((output.detach().numpy()),(32,512, 512))
    
    wl.append(output)
    output = (70+12)*output - 12
    img = re_wavelet(output)
    recon.append(img)
    
    target = np.reshape((y.detach().numpy()),(32,512,512))

    wl_tar.append(target)
    target = (70+12)*target - 12
    target = re_wavelet(target)
    recon_target.append(target)

14it [01:51,  7.99s/it]


In [None]:
def plot_image(index):
    plt.imshow(recon_target[0][index,:,:], cmap='gray')
    plt.title(f"Image {index}")
    plt.colorbar()
    plt.show()
interact(plot_image, index=(0, recon_target[0].shape[0]-1))

interactive(children=(IntSlider(value=15, description='index', max=31), Output()), _dom_classes=('widget-inter…

<function __main__.plot_image(index)>

In [None]:
def plot_image(index):
    plt.imshow(recon[0][index,:,:], cmap='gray')
    plt.title(f"Image {index}")
    plt.colorbar()
    plt.show()
interact(plot_image, index=(0, recon_target[0].shape[0]-1))

interactive(children=(IntSlider(value=15, description='index', max=31), Output()), _dom_classes=('widget-inter…

<function __main__.plot_image(index)>

In [None]:
def plot_image(index):
    plt.imshow(wl[0][index,:,:], cmap='gray')
    plt.title(f"Image {index}")
    plt.colorbar()
    plt.show()
interact(plot_image, index=(0, wl[0].shape[0]-1))

interactive(children=(IntSlider(value=15, description='index', max=31), Output()), _dom_classes=('widget-inter…

<function __main__.plot_image(index)>

In [None]:
def plot_image(index):
    plt.imshow(wl_tar[0][index,:,:], cmap='gray')
    plt.title(f"Image {index}")
    plt.colorbar()
    plt.show()
interact(plot_image, index=(0, wl[0].shape[0]-1))

interactive(children=(IntSlider(value=15, description='index', max=31), Output()), _dom_classes=('widget-inter…

<function __main__.plot_image(index)>

In [None]:
temp = recon[0]-recon_target[0]

In [None]:
def plot_image(index):
    plt.imshow(temp[index,:,:], cmap='gray')
    plt.title(f"Image {index}")
    plt.colorbar()
    plt.show()
interact(plot_image, index=(0, recon_target[0].shape[0]-1))

interactive(children=(IntSlider(value=15, description='index', max=31), Output()), _dom_classes=('widget-inter…

<function __main__.plot_image(index)>

In [None]:
def plot_image(index):
    plt.imshow(wl[0][index,:,:], cmap='gray')
    plt.title(f"Image {index}")
    plt.colorbar()
    plt.show()
interact(plot_image, index=(0, recon_target[0].shape[0]-1))

interactive(children=(IntSlider(value=15, description='index', max=31), Output()), _dom_classes=('widget-inter…

<function __main__.plot_image(index)>

In [None]:




mse_ = []
ssim_ = []
psnr_ = []
for a, b in zip(wl, wl_tar):
    mse, ssim, psnr = compute_metrics(a, b)
    mse_.append(mse)
    ssim_.append(ssim)
    psnr_.append(psnr)
print(f"MSE: {np.mean(mse_),np.std(mse_)}, SSIM: {np.mean(ssim_),np.std(ssim_)}, PSNR: {np.mean(psnr_),np.std(psnr_)}")


MSE: (5.5455216e-06, 1.8713496e-06), SSIM: (99.99376027011985, 0.001208096570093948), PSNR: (67.89076216380359, 0.8510168867694164)


In [None]:
mse_ = []
ssim_ = []
psnr_ = []
for a, b in zip(recon_target, recon):

    mse, ssim, psnr = compute_metrics(a, b)
    mse_.append(mse)
    ssim_.append(ssim)
    psnr_.append(psnr)
print(f"MSE: {np.mean(mse_),np.std(mse_)}, SSIM: {np.mean(ssim_),np.std(ssim_)}, PSNR: {np.mean(psnr_),np.std(psnr_)}")


MSE: (0.0005297821080679277, 0.00011584852215238191), SSIM: (97.42182720618158, 0.3869787891127376), PSNR: (43.210290443086215, 0.8817996182168156)


In [None]:
def plot_image(index):
    plt.imshow(recon_target[0][index,:,:], cmap='gray')
    plt.title(f"Image {index}")
    plt.colorbar()
    plt.show()
interact(plot_image, index=(0, recon_target[0].shape[0]-1))

interactive(children=(IntSlider(value=15, description='index', max=31), Output()), _dom_classes=('widget-inter…

<function __main__.plot_image(index)>

In [None]:
def plot_image(index):
    plt.imshow(recon[0][index,:,:], cmap='gray')
    plt.title(f"Image {index}")
    plt.colorbar()
    plt.show()
interact(plot_image, index=(0, recon_target[10].shape[0]-1))

interactive(children=(IntSlider(value=15, description='index', max=31), Output()), _dom_classes=('widget-inter…

<function __main__.plot_image(index)>