In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [27]:
noddi= torch.from_numpy(np.load('../data/noddi.npy'))

In [3]:
# show img
def show_img(restore_x1, name1, restore_x2, name2, font_size= 16):
    if isinstance(restore_x1, torch.Tensor) or isinstance(restore_x2, torch.Tensor):
        restore_x1, restore_x2= restore_x1.numpy(), restore_x2.numpy()
    fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(25.2, 20))
    a, b, c = restore_x1[0].shape
    # add image
    axs[0, 0].imshow(restore_x1[0, :, :, c // 2].T[::-1], cmap='gray')
    axs[0, 1].imshow(restore_x1[1, :, :, c // 2].T[::-1], cmap='gray')
    axs[0, 2].imshow(restore_x1[2, :, :, c // 2].T[::-1], cmap='gray')
    axs[1, 0].imshow(restore_x2[0, :, :, c // 2].T[::-1], cmap='gray')
    axs[1, 1].imshow(restore_x2[1, :, :, c // 2].T[::-1], cmap='gray')
    axs[1, 2].imshow(restore_x2[2, :, :, c // 2].T[::-1], cmap='gray')
    # drop axis
    for ax in axs.flat:
        ax.axis('off')
    # add titles
    col_titles = ['Axial plane1', 'Axial plane2', 'Axial plane3']
    row_titles = [name1, name2]
    font_dict = {'family': 'Times New Roman', 'size': font_size}
    for ax, col in zip(axs[0], col_titles):
        ax.set_title(col, fontdict=font_dict, fontweight='bold')
    for i, row in enumerate(row_titles):
        fig.text(0.1, 0.7- i* 0.4, row, ha= 'center', va= 'center', fontsize= font_size, fontfamily= 'Times New Roman', rotation= 90, fontweight='bold')
    # adjust gaps
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

In [4]:
# apply gaussian blur on image
def gaussian_blur_3d(img, kernel_size= 5, sigma= 1.0):
    channels, depth, height, width= img.shape
    # 3d gaussian kernel
    grid= torch.stack(torch.meshgrid([torch.arange(size, dtype= torch.float32) for size in [kernel_size]* 3]), dim= -1)
    grid= grid- (kernel_size- 1)/ 2.0
    kernel= torch.exp(-(grid** 2).sum(-1)/ (2* sigma** 2))
    kernel= kernel/ kernel.sum()
    kernel= kernel.view(1, 1, *kernel.shape)  # [1, 1, k, k, k]
    kernel= kernel.repeat(channels, 1, 1, 1, 1)  # [C, 1, k, k, k]
    # conv
    img= img.unsqueeze(0)  # [1, C, D, H, W]
    blurred_img= F.conv3d(img, kernel, padding= kernel_size//2, groups= channels)
    return blurred_img.squeeze(0)  # [C, D, H, W]

In [8]:
noddi_blur= gaussian_blur_3d(noddi, 3, 1e0)
%matplotlib qt
show_img(noddi, 'NODDI', noddi_blur, 'After blur')

In [9]:
# apply sharp on image
def sharpen_3d(img, alpha= 0.1, ratio= 0.4):
    channels, depth, height, width= img.shape
    # define laplacian kernel
    laplacian_kernel= torch.tensor([[[[0,  0,  0],
                                       [0, -1,  0],
                                       [0,  0,  0]],
                                      [[0, -1,  0],
                                       [-1, 6, -1],
                                       [0, -1,  0]],
                                      [[0,  0,  0],
                                       [0, -1,  0],
                                       [0,  0,  0]]]], dtype= torch.float32)
    # adjust kernel, [1, 1, 3, 3, 3]
    laplacian_kernel= laplacian_kernel.view(1, 1, 3, 3, 3)
    laplacian_kernel= laplacian_kernel.repeat(channels, 1, 1, 1, 1)
    # add batch
    img= img.unsqueeze(0) # [1, C, D, H, W]
    # 3dconv
    sharpened_img= F.conv3d(img, alpha* laplacian_kernel, padding= 1, groups= channels)
    # mix by interpolate
    return ratio* sharpened_img.squeeze(0)+ (1- ratio)* img.squeeze(0)  # [C, D, H, W]

In [10]:
noddi_sharp= sharpen_3d(noddi, 1, 1e-2)
noddi_sharp[noddi_sharp<= 1e-2]= 0
show_img(noddi, 'NODDI', noddi_sharp, 'After sharp')

In [11]:
# pseudo jpeg compression 
# quality \in [0, 100)
def compression_3d(img, quality= 99.9):
    quantization_factor= (100- quality)/ 100.0
    quantized_img= torch.round(img/ (quantization_factor+ 1e-10))* quantization_factor
    return torch.clamp(quantized_img, 0, 1)

In [12]:
noddi_cprs= compression_3d(noddi, 80)
show_img(noddi, 'NODDI', noddi_cprs, 'After compress')

In [13]:
# downsample and upsample
def dusample_3d(x):
    assert(x.shape[-1]% 2== 0 and x.shape[-2]% 2== 0 and x.shape[-3]% 2== 0, 'Dim should be even.')
    x_ds= F.interpolate(x.unsqueeze(0), scale_factor= (0.5, 0.5, 0.5))
    return F.interpolate(x_ds, scale_factor= (2, 2, 2)).squeeze(0)

  assert(x.shape[-1]% 2== 0 and x.shape[-2]% 2== 0 and x.shape[-3]% 2== 0, 'Dim should be even.')


In [14]:
noddi_du= dusample_3d(noddi)
show_img(noddi, 'NODDI', noddi_du, 'After downsample and upsample')

In [16]:
def erosion_3d(img, kernel_size= 3):
    eroded_img= -F.max_pool3d(-img.unsqueeze(0), kernel_size= kernel_size, stride= 1, padding= kernel_size//2)
    return eroded_img.squeeze(0)

In [19]:
noddi_eros= erosion_3d(noddi, 2)
show_img(noddi, 'NODDI', noddi_eros, 'After erosion')

In [20]:
def add_gaussian_noise(img, mean= 0.0, std= 0.1):
    noise= torch.randn_like(img)* std+ mean
    noisy_img= img+ noise
    return torch.clamp(noisy_img, 0, 1)

In [22]:
noddi_gauss_noise= add_gaussian_noise(noddi)
noddi_gauss_noise[noddi<= 1e-2]= 0
show_img(noddi, 'NODDI', noddi_gauss_noise, 'After add gaussian noise')

In [23]:
def add_salt_and_pepper_noise(img, salt_prob= 0.01, pepper_prob= 0.01):
    salt_mask= torch.rand_like(img)< salt_prob
    pepper_mask= torch.rand_like(img)< pepper_prob
    img[salt_mask]= 1.0
    img[pepper_mask]= 0.0
    return img

In [29]:
noddi_sp_noise= add_salt_and_pepper_noise(noddi.clone())
noddi_sp_noise[noddi<= 1e-2]= 0
show_img(noddi, 'NODDI', noddi_sp_noise, 'After add salt and pepper noise')

#### Meshgrid

##### 关于index的区别

In [95]:
torch.stack(torch.meshgrid([torch.tensor([0, 1]), torch.tensor([1, 2, 3])], indexing= 'ij'), dim= -1)

tensor([[[0, 1],
         [0, 2],
         [0, 3]],

        [[1, 1],
         [1, 2],
         [1, 3]]])

In [97]:
torch.stack(torch.meshgrid([torch.tensor([0, 1]), torch.tensor([1, 2, 3])], indexing= 'xy'), dim= -1)

tensor([[[0, 1],
         [1, 1]],

        [[0, 2],
         [1, 2]],

        [[0, 3],
         [1, 3]]])

结论：torch.meshgrid本质都是卡氏积，只不过表达形式不同而已

##### 应用(张量选择元素)

In [86]:
mat= torch.arange(0, 33, 1).view(11, 3)

我要选择元素，行号为[0, 4, 8]，列号为[0, 2]

In [87]:
mat[[0, 4, 8], :][:, [0, 2]]

tensor([[ 0,  2],
        [12, 14],
        [24, 26]])

In [88]:
mat[torch.meshgrid(torch.tensor([0, 4, 8]), torch.tensor([0, 2]))]

tensor([[ 0,  2],
        [12, 14],
        [24, 26]])