In [1]:
import math
import torch
import torch.nn as nn

class Token_performer(nn.Module):
    def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1):
        super().__init__()
        self.emb = in_dim * head_cnt 
        self.kqv = nn.Linear(dim, 3 * self.emb)
        self.dp = nn.Dropout(dp1)
        self.proj = nn.Linear(self.emb, self.emb)
        self.head_cnt = head_cnt
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(self.emb)
        self.epsilon = 1e-8 

        self.mlp = nn.Sequential(
            nn.Linear(self.emb, 1 * self.emb),
            nn.GELU(),
            nn.Linear(1 * self.emb, self.emb),
            nn.Dropout(dp2),
        )

        self.m = int(self.emb * kernel_ratio)
        self.w = torch.randn(self.m, self.emb)
        self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m),
                              requires_grad=False)

    def prm_exp(self, x):

        xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2
        wtx = torch.einsum('bti,mi->btm', x.float(), self.w)

        return torch.exp(wtx - xd) / math.sqrt(self.m)

    def single_attn(self, x):
        k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)
        kp, qp = self.prm_exp(k), self.prm_exp(q)  
        D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) 
        kptv = torch.einsum('bin,bim->bnm', v.float(), kp) 
        y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) 
        # skip connection
        y = v + self.dp(self.proj(y))

        return y

    def forward(self, x):
        x = self.single_attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [2]:
import torch.nn as nn
from timm.models.layers import DropPath

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, in_dim = None, qkv_bias=False, qk_scale=None,
                 attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.in_dim = in_dim
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(in_dim, in_dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim)
        x = self.proj(x)
        x = self.proj_drop(x)

        x = v.squeeze(1) + x 

        return x


class Token_transformer(nn.Module):
    def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None,
                 drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(in_dim)
        self.mlp = Mlp(in_features=in_dim, hidden_features=int(in_dim*mlp_ratio),
                       out_features=in_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = self.attn(self.norm1(x))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

In [3]:
import torch
import torch.nn as nn
import numpy as np
from timm.models.layers import DropPath

def intra_att_f(querys,keys,values): 
    len = querys.shape[-1]
    att = torch.zeros(querys.shape)
    for q in range(querys.shape[0]):
        tmp_q = torch.zeros(len, len)
        for k in keys:
            tmp = querys[q].reshape(len,1) @ k.reshape(1,len) * (torch.ones(len, len) - torch.eye(len, len))
            tmp_q += tmp

        att[q] = torch.sum(tmp_q, dim=0).reshape((1,-1))

    return values*att


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None,
                 act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
                 proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class intra_att(nn.Module): 
    def __init__(self,dim=64):
        super().__init__()
        self.qkv = nn.Linear(dim, dim * 3)

    def forward(self,x):
        N, C = x.shape
        qkv = self.qkv(x)
        print(qkv.shape)
        querys, keys, values = qkv[0], qkv[1], qkv[2]

        len = querys.shape[-1]
        att = torch.zeros(querys.shape)
        for q in range(querys.shape[0]):
            tmp_q = torch.zeros(len, len)
            for k in keys:
                tmp = querys[q].reshape(len,1) @ k.reshape(1,len) * (torch.ones(len, len) - torch.eye(len, len))
                tmp_q += tmp
        att[q] = torch.sum(tmp_q, dim=0).reshape((1,-1))

        return values*att


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0.,
                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
                       drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


def get_sinusoid_encoding(n_position, d_hid):
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid)
                for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from timm.models.helpers import load_pretrained
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
import numpy as np


class MultiHeadDense(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(MultiHeadDense, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(in_ch, out_ch))
    
    def forward(self, x):
        x = F.linear(x, self.weight)
        return x


class T2T_module(nn.Module):

    def __init__(self, img_size=64, tokens_type='performer', in_chans=1, embed_dim=256, token_dim=64, kernel=32, stride=32):
        super().__init__()

        if tokens_type == 'transformer':
            print('adopt transformer encoder for tokens-to-token')
            self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(2, 2))
            self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(1, 1),dilation=(2,2))
            self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(1, 1))

            self.attention1 = Token_transformer(dim=in_chans*7*7, in_dim=token_dim,
                                                num_heads=1, mlp_ratio=1.0)
            self.attention2 = Token_transformer(dim=token_dim*3*3, in_dim=token_dim,
                                                num_heads=1, mlp_ratio=1.0)
            self.project = nn.Linear(token_dim * 3 * 3, embed_dim)

        elif tokens_type == 'performer':
            self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(2, 2))
            self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(1, 1),dilation=(2,2))
            self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(1, 1))

            self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim,
                                              kernel_ratio=0.5)
            self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim,
                                              kernel_ratio=0.5)
            self.project = nn.Linear(token_dim * 3 * 3, embed_dim)

        self.num_patches = 529 
        
    def forward(self, x):

        x = self.soft_split0(x)
        

        x = self.attention1(x.transpose(1, 2))
        res_11 = x
        B, new_HW, C = x.shape
        x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
        x = torch.roll(x, shifts=(2, 2), dims=(2, 3)) 
        x = self.soft_split1(x)
        

        x = self.attention2(x.transpose(1, 2))
        res_22 = x
        B, new_HW, C = x.shape
        x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
        x = torch.roll(x, shifts=(2, 2), dims=(2, 3))
        x = self.soft_split2(x)

        x = self.project(x.transpose(1, 2))
        return x,res_11,res_22


class Token_back_Image(nn.Module):
    def __init__(self, img_size=64, tokens_type='performer', in_chans=1, embed_dim=256, token_dim=64, kernel=32, stride=32):
        super().__init__()

        if tokens_type == 'transformer':
            print('adopt transformer encoder for tokens-to-token')
            self.soft_split0 = nn.Fold((64,64),kernel_size=(7, 7), stride=(2, 2))
            self.soft_split1 = nn.Fold((29,29),kernel_size=(3, 3), stride=(1, 1),dilation=(2,2))
            self.soft_split2 = nn.Fold((25,25),kernel_size=(3, 3), stride=(1, 1))

            self.attention1 = Token_transformer(dim=token_dim, in_dim=in_chans*7*7, num_heads=1, mlp_ratio=1.0)
            self.attention2 = Token_transformer(dim=token_dim, in_dim=token_dim*3*3, num_heads=1, mlp_ratio=1.0)
            self.project = nn.Linear(embed_dim,token_dim * 3 * 3)
        elif tokens_type == 'performer':
            self.soft_split0 = nn.Fold((64,64),kernel_size=(7, 7), stride=(2, 2))
            self.soft_split1 = nn.Fold((29,29),kernel_size=(3, 3), stride=(1, 1),dilation=(2,2))
            self.soft_split2 = nn.Fold((25,25),kernel_size=(3, 3), stride=(1, 1))

            self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5)
            self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5)
            self.project = nn.Linear(embed_dim,token_dim * 3 * 3)

        self.num_patches = (img_size // (1 * 2 * 2)) * (img_size // (1 * 2 * 2))

    def forward(self, x, res_11,res_22):    
        x = self.project(x).transpose(1, 2) 

        x = self.soft_split2(x)
        x = torch.roll(x, shifts=(-2, -2), dims=(-1, -2))
        x = rearrange(x,'b c h w -> b c (h w)').transpose(1,2)
        x = x + res_22
        x = self.attention2(x).transpose(1, 2)
        
        x = self.soft_split1(x)
        x = torch.roll(x, shifts=(-2, -2), dims=(-1, -2))
        x = rearrange(x,'b c h w -> b c (h w)').transpose(1,2)
        x = x + res_11
        x = self.attention1(x).transpose(1, 2)
        
        x = self.soft_split0(x) 

        return x


class MARformer(nn.Module):
    def __init__(self, img_size=512, tokens_type='convolution', in_chans=1, num_classes=1000,
                 embed_dim=768, depth=12,
                 num_heads=12, kernel=32, stride=32, mlp_ratio=4., qkv_bias=False,
                 qk_scale=None, drop_rate=0.1, attn_drop_rate=0.1,
                 drop_path_rate=0.1, norm_layer=nn.LayerNorm, token_dim=1024):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim 

        self.tokens_to_token = T2T_module(
                img_size=img_size, tokens_type=tokens_type, in_chans=in_chans,
            embed_dim=embed_dim, token_dim=token_dim,kernel=kernel, stride=stride)
        num_patches = self.tokens_to_token.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches,
                                      d_hid=embed_dim), requires_grad=False)
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
                norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # CTformer decoder
        self.dconv1 = Token_back_Image(img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, 
                                       embed_dim=embed_dim, token_dim=token_dim, kernel=kernel, stride=stride)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x):
        res1 = x
        x, res_11, res_22 = self.tokens_to_token(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        i = 0
        for blk in self.blocks:
            i += 1
            x = blk(x)

        x = self.norm(x)
        out = res1 - self.dconv1(x,res_11,res_22)
        return out

In [5]:
import time
import os
import torch
from d2l import torch as d2l
from utils.plot_util import Animator
from tqdm import tqdm


def train(net, train_iter, val_iter, num_epochs, lr, device, loss, lr_period, lr_decay, pretrainModel=None):

    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    if pretrainModel is None:
        net.apply(init_weights)
    else:
        net.load_state_dict(torch.load(pretrainModel))
        
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_period, lr_decay)
    animator = Animator(xlabel='epoch', legend=['train loss', 'val loss', 'val ssim', 'val psnr'], figsize=(4.0, 3.0))
    num_batches = len(train_iter)
    
    ssim_max, psnr_max = 0.0, 0.0
    
    for epoch in range(num_epochs):

        metric = d2l.Accumulator(2)
        net.train()
        train_l = 0.0
        val_l = 0.0
        start_time = time.time()
        
        with tqdm(train_iter,
                      desc="train epoch {}/{} ".format(epoch + 1, num_epochs),
                      postfix={'loss': train_l}) as tbar:
            for i, (X, y) in enumerate(tbar):
                optimizer.zero_grad()
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                
                l = loss(y_hat, y) * 100 + 1e-4
                l.backward()
                
                optimizer.step()
                
                with torch.no_grad():
                    metric.add(l * X.shape[0], X.shape[0])

                train_l = metric[0] / metric[1]
                if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                    animator.add(epoch + (i + 1) / num_batches, (train_l, None, None, None))
            
                tbar.set_postfix({'loss': train_l})
                tbar.update()
            tbar.close()

        scheduler.step()
        
        val_l, ssim, psnr = validate(net, val_iter, device, loss, evaluate)

        animator.add(epoch + 1, (None, val_l, ssim * 10, psnr))

        print("epoch [{}/{}], train loss: {:.3f}, val loss: {:.3f}, val ssim: {:.4f}, val psnr: {:.2f}, time of this epoch: {:.1f}s"
                .format(epoch + 1, num_epochs, train_l, val_l, ssim, psnr, time.time() - start_time))

In [6]:
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

def evaluate(imgs1, imgs2):
    imgs1 = imgs1.numpy().astype(np.float64)
    imgs2 = imgs2.numpy().astype(np.float64)
    print(img1.shape)

    ssim = 0.0
    psnr = 0.0
    for img1,img2 in zip(imgs1,imgs2):
        ssim += compare_ssim(img1, img2, channel_axis=0)
        psnr += compare_psnr(imgs1, imgs2, data_range=1.0)
    
    return ssim, psnr


def validate(net, val_iter, device, loss ,evaluate):
    metric = d2l.Accumulator(4)
    net.eval()

    ssim_avg = 0.0
    psnr_avg = 0.0
    
    with torch.no_grad():
        for i, (X, y) in enumerate(val_iter):
            arrs = split_arr(x, 64).to(device)
            arrs[0:64] = net(arrs[0:64])
            arrs[64:2 * 64] = net(arrs[64:2 * 64])
            arrs[2 * 64:3 * 64] = net(arrs[2 * 64:3 * 64])
            arrs[3 * 64:4 * 64] = net(arrs[3 * 64:4 * 64])
            y_hat = agg_arr(arrs, 256).to('cpu')
            
            l = loss(y_hat, y) * 100 + 1e-4
            
            ssim,psnr = evaluate(y_hat.cpu(), y.cpu())

            metric.add(l * y.shape[0], ssim, psnr, y.shape[0])
            
            val_l = metric[0] / metric[3]
            ssim_avg = metric[1] / metric[3]
            psnr_avg = metric[2] / metric[3]

    return val_l, ssim_avg, psnr_avg    


def test(net, test_iter, device, model=None):

    if model is not None:
        net.load_state_dict(torch.load(model))
        
    net.to(device)
    net.eval()
    
    print('testing on', device)
    
    metric = d2l.Accumulator(3)
    
    ssim_avg = 0.0
    psnr_avg = 0.0
    
    start_time = time.time()

    with torch.no_grad():
        for X, y in test_iter:
            
            arrs = split_arr(x, 64).to(device)
            arrs[0:64] = net(arrs[0:64])
            arrs[64:2 * 64] = net(arrs[64:2 * 64])
            arrs[2 * 64:3 * 64] = net(arrs[2 * 64:3 * 64])
            arrs[3 * 64:4 * 64] = net(arrs[3 * 64:4 * 64])
            y_hat = agg_arr(arrs, 256).to('cpu')
            
            
            ssim,psnr = evaluate(y_hat.cpu(), y.cpu())

            metric.add(ssim, psnr, y.shape[0])

            ssim_avg = metric[0] / metric[2]
            psnr_avg = metric[1] / metric[2]


    print("test ssim: {:.4f}, test psnr: {:.2f}, time of this test: {:.1f}s"
                .format(ssim_avg, psnr_avg, time.time() - start_time))

In [7]:
import torch
def data_augmentation(image, mode):
    out = image
    if mode == 0:
        out = out
    elif mode == 1:
        out = np.flipud(out)
    elif mode == 2:
        out = np.rot90(out)
    elif mode == 3:
        out = np.rot90(out)
        out = np.flipud(out)
    elif mode == 4:
        out = np.rot90(out, k=2)
    elif mode == 5:
        out = np.rot90(out, k=2)
        out = np.flipud(out)
    elif mode == 6:
        out = np.rot90(out, k=3)
    elif mode == 7:
        out = np.rot90(out, k=3)
        out = np.flipud(out)

    return out


def get_patch(full_input_img, full_target_img, patch_n, patch_size):
    assert full_input_img.shape == full_target_img.shape
    patch_input_imgs = []
    patch_target_imgs = []
    h, w = full_input_img.shape
    new_h, new_w = patch_size, patch_size
    if patch_size == h:
        return full_input_img, full_target_img

    for _ in range(patch_n // 2):
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)
        patch_input_img = full_input_img[top:top + new_h, left:left + new_w]
        patch_target_img = full_target_img[top:top + new_h, left:left + new_w]

        patch_input_imgs.append(patch_input_img)
        patch_target_imgs.append(patch_target_img)

        
        tmp = np.random.randint(1, 8)
        patch_input_img = data_augmentation(patch_input_img, tmp)
        patch_target_img = data_augmentation(patch_target_img, tmp)

        patch_input_imgs.append(patch_input_img)
        patch_target_imgs.append(patch_target_img)

    return np.array(patch_input_imgs), np.array(patch_target_imgs)


def split_arr(arr, patch_size, stride=32):
    pad = (16, 16, 16, 16)
    arr = nn.functional.pad(arr, pad, "constant", 0)
    _, _, h, w = arr.shape
    num = h // stride - 1
    arrs = torch.zeros(num * num, 1, patch_size, patch_size)

    for i in range(num):
        for j in range(num):
            arrs[i * num + j, 0] = arr[0, 0, i * stride:i * stride + patch_size,
                                   j * stride:j * stride + patch_size]
    return arrs


def agg_arr(arrs, size, stride=32):
    arr = torch.zeros(size, size)
    n, _, h, w = arrs.shape
    num = size // stride
    for i in range(num):
        for j in range(num):
            arr[i * stride:(i + 1) * stride, j * stride:(j + 1) * stride] = arrs[i * num + j, :, 16:48, 16:48]
    # return arr
    return arr.unsqueeze(0).unsqueeze(1)

In [8]:
from torch.utils.data import Dataset, DataLoader
from utils.file_util import read_dir
import os.path as path
import json
import scipy.io as sio
import numpy as np
from d2l import torch as d2l
from random import choice


class DeepLesionDataset(Dataset):
    def __init__(self, mode, dataset_dir, partial_holdout=0.2, hu_offset=32768):
        self.mode = mode
        self.partial_holdout = partial_holdout
        self.hu_offset = hu_offset
        self.train_list = []
        self.validate_list = []
        self.test_list = []

        self.pre_process(dataset_dir)

    def pre_process(self, dataset_dir):
        data_dict = {}
        cache_name = ''
        dest_dir = ''
        if 'train' == self.mode or 'validate' == self.mode:
            cache_name = 'train_cache.json'
            dest_dir = dataset_dir + '/train'
        elif 'test' == self.mode:
            cache_name = 'test_cache.json'
            dest_dir = dataset_dir + '/test'

        if type(dataset_dir) is str and path.isdir(dataset_dir):
            cache_file = path.join(dataset_dir, cache_name)
            if path.isfile(cache_file):
                with open(cache_file) as f:
                    data_dict = json.load(f)
            else:
                gt_files = read_dir(dest_dir, predicate=lambda x: x == "gt.mat", recursive=True)

                for gt_file in gt_files:
                    metal_dir = path.split(gt_file)[0]
                    metal_files = sorted(read_dir(metal_dir, predicate=lambda x: x.endswith("mat") and x != "gt.mat"))
                    data_dict[gt_file] = [f for f in metal_files]

                with open(cache_file, 'w') as f:
                    json.dump(data_dict, f)

        data_dict = sorted(data_dict.items())

        if 'train' == self.mode or 'validate' == self.mode:
            if self.partial_holdout:
                train_size = int(len(data_dict) * (1 - self.partial_holdout))
                self.train_list = data_dict[:train_size]
                self.validate_list = data_dict[train_size:]
            else:
                self.train_list = data_dict
                self.validate_list = []
        elif 'test' == self.mode:
            self.test_list = data_dict

    def convert2coefficient(self, image, MIUWATER=0.192):
        image = np.array(image, dtype=np.float32)
        image = image - self.hu_offset
        image[image < -1000] = -1000
        image = image / 1000 * MIUWATER + MIUWATER
        return image

    def load_data(self, data_file):
        with_art = sio.loadmat(data_file[0])['image']
        gt = sio.loadmat(data_file[1])['image']
        gt = self.convert2coefficient(gt).T
        return with_art, gt

    def normalize(self, data, min=0.0, max=0.5):
        data = np.clip(data, min, max)
        data = (data - min) / (max - min)
        data = data * 2.0 - 1.0
        return data

    def denormalize(self, data, min=0.0, max=0.5):
        data = data * 0.5 + 0.5
        data = data * (max - min) + min
        return data

    def __len__(self):
        if 'train' == self.mode:
            return len(self.train_list)
        elif 'validate' == self.mode:
            return len(self.validate_list)
        elif 'test' == self.mode:
            return len(self.test_list)

    def __getitem__(self, idx):
        if 'train' == self.mode:
            gt_file, art_files = self.train_list[idx]

        elif 'validate' == self.mode:
            gt_file, art_files = self.validate_list[idx]

        elif 'test' == self.mode:
            gt_file, art_files = self.test_list[idx]
        
        data_file = choice(art_files), gt_file
        art_img, gt = self.load_data(data_file)
        art_img, gt = self.normalize(art_img), self.normalize(gt)
        
        
        if 'train' == self.mode:
            art_img, gt = get_patch(art_img, gt, patch_n=4, patch_size=64)
            
        art_img, gt = torch.tensor(art_img), torch.tensor(gt)

        return art_img, gt

    
    def get_dataset(self, batch_size=8, shuffle=True, num_workers=10):
        return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)