## Code for "RSTAR4D: Rotational Streak Artifact Reduction in 4D CBCT Using Separable 4D Convolutions", published in IEEE TRPMS

Authors:Ziheng Deng, Jun Zhao, SJTU

The RSTAR4D-Net is a 4D CNN with separable 4D Convolutions. To effectively train the model with limited 4D data, we propose the Tetris training strategy. 

Here is the Tetris training stage 1. We first train the model on 2D+T data. The input and output data size is 1\*2\*10\*1\*256\*256 (batch, channel, phase(temporal), slice (z-axis, SI), width (y-axis, LR), height(x-axis AP)), please check the example_data. In this stage, the z-axis (SI) convolution is freezed.

In [None]:
import time
import math
import torch
import numpy as np
from torch import nn, optim, autograd
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models, transforms
from torch.optim.optimizer import Optimizer, required
from torch.autograd import Variable
from torch.nn import Parameter
from prefetch_generator import BackgroundGenerator
import os
from PIL import Image
import h5py
import matplotlib.pyplot as plt
import torchvision
from math import exp
from pytorch_msssim import ms_ssim,ssim
from torch.cuda.amp import GradScaler, autocast

In [None]:
BATCH_SIZE = 1
EPOCHS = 50
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark=True

In [None]:
class simudata(Dataset):
    def __init__(self,simu_dir,transform=None):
        self.simu_dir = simu_dir
        self.transform = transform
        self.simu = os.listdir(self.simu_dir)
        self.simu.sort()  
    def __len__(self):
        return len(self.simu)
    
    def __getitem__(self,index):
        simu_index = self.simu[index]
        simu_path = os.path.join(self.simu_dir,simu_index)
        
        with h5py.File(simu_path,'r') as f:
            data = f.get('cbct') 
            cbct=torch.tensor(np.array(data) ,dtype=float)
            f.close()
            
        with h5py.File(simu_path,'r') as f:
            data = f.get('ct') 
            ct=torch.tensor(np.array(data) ,dtype=float)
            f.close()
        
        with h5py.File(simu_path,'r') as f:
            data = f.get('prior') 
            prior=torch.tensor(np.array(data) ,dtype=float)
            f.close()
                
        cbct=cbct.unsqueeze(1).float().unsqueeze(0)/1000
        ct=ct.unsqueeze(1).float().unsqueeze(0)/1000
        prior=prior.unsqueeze(0).float().unsqueeze(0)/1000
        
        return cbct,ct,prior

In [None]:
class evaldata(Dataset):
    def __init__(self,simu_dir,transform=None):
        self.simu_dir = simu_dir
        self.transform = transform
        self.simu = os.listdir(self.simu_dir)
        self.simu.sort()  
    def __len__(self):
        return len(self.simu)
    
    def __getitem__(self,index):
        simu_index = self.simu[index]
        simu_path = os.path.join(self.simu_dir,simu_index)
            
        with h5py.File(simu_path,'r') as f:
            data = f.get('cbct') 
            cbct4=torch.tensor(np.array(data) ,dtype=float)
            f.close()
   
        with h5py.File(simu_path,'r') as f:
            data = f.get('ct') 
            ct=torch.tensor(np.array(data) ,dtype=float)
            f.close()
        
        with h5py.File(simu_path,'r') as f:
            data = f.get('prior') 
            prior=torch.tensor(np.array(data) ,dtype=float)
            f.close()
                
        cbct4=cbct4.float().unsqueeze(1).unsqueeze(0)/1000
        ct=ct.float().unsqueeze(1).unsqueeze(0)/1000
        prior=prior.float().unsqueeze(0).unsqueeze(0)/1000
        
        return cbct4,ct,prior

In [None]:
simudataset = simudata('example_data/train_data_stage1')
evaldataset = evaldata('example_data/eval_data_stage1')

In [None]:
class DataLoaderX(DataLoader):

    def __iter__(self):
        return BackgroundGenerator(super().__iter__())

In [None]:
simuloader = DataLoaderX(simudataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)
evalloader = DataLoaderX(evaldataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)

In [None]:
class convwithactivation(nn.Module):
    # Spatial downsampling layer
    def __init__(self,in_ch,out_ch,kernel_size=[3,3,3],padding=[1,1,1],stride=[1,1,1],padding_mode='zeros'):
        super(convwithactivation,self).__init__()
        # XY downsamapling
        self.conv1=nn.Conv3d(in_ch,out_ch,[1,kernel_size[1],kernel_size[2]],[1,stride[1],stride[2]],[0,padding[1],padding[2]],padding_mode=padding_mode)
        # Z downsampling
        self.conv2=nn.Conv3d(out_ch,out_ch,[1,kernel_size[0],1],[1,stride[0],1],[0,padding[0],0],padding_mode='replicate')
        self.lrelu=nn.LeakyReLU(0.2, inplace=True)
        self.stride=stride
        self.out_ch=out_ch
        
    def forward(self,x):
        [B,C,T,Z,Y,X]=x.shape
        x = x.view(B,C,T*Z,Y,X)
        x = self.conv1(x)
        #x = x.view(B,self.out_ch,T,Z,-1)
        #x = self.conv2(x)
        x = self.lrelu(x)
        x = x.view(B,self.out_ch,T,Z//self.stride[0],Y//self.stride[1],X//self.stride[2])
        return x

In [None]:
class convwithactivation4DRSTAR(nn.Module):
    # separable 4D convolution
    def __init__(self,in_ch,out_ch,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1],ifrelu=1):
        super(convwithactivation4DRSTAR,self).__init__()
        # XY conv，TZ*Y*X
        self.conv1=nn.Conv3d(int(1*in_ch),int(1*out_ch),[1,kernel_size[2],kernel_size[3]],[1,stride[2],stride[3]],[0,padding[2],padding[3]])
        # Z conv，T*Z*YX
        self.conv2=nn.Conv3d(int(1*in_ch),int(1*out_ch),[1,kernel_size[1],1],[1,stride[1],1],[0,padding[1],0],padding_mode='replicate')
        # T conv，T*Z*YX
        self.conv3=nn.Conv3d(int(1*in_ch),int(1*out_ch),[kernel_size[0],1,1],[stride[0],1,1],[padding[0],0,0],padding_mode='circular')

        # XY conv，TZ*Y*X
        self.conv1_2=nn.Conv3d(int(1*out_ch),int(1*out_ch),[1,kernel_size[2],kernel_size[3]],[1,stride[2],stride[3]],[0,padding[2],padding[3]])
        # Z conv，T*Z*YX
        self.conv2_2=nn.Conv3d(int(1*out_ch),int(1*out_ch),[1,kernel_size[1],1],[1,stride[1],1],[0,padding[1],0],padding_mode='replicate')
        # T conv，T*Z*YX
        self.conv3_2=nn.Conv3d(int(1*out_ch),int(1*out_ch),[kernel_size[0],1,1],[stride[0],1,1],[padding[0],0,0],padding_mode='circular')

        self.lrelu1=nn.LeakyReLU(0.2, inplace=True)
        self.in_ch=in_ch
        self.out_ch=out_ch
        self.ifrelu=ifrelu
        
    def forward(self,x):
        [B,C,T,Z,Y,X]=x.shape
        x = self.conv1(x.view(B,C,T*Z,Y,X)).view(B,self.out_ch,T,Z,-1)+self.conv3(x.view(B,C,T,Z,-1))
        x = self.conv1_2(x.view(B,self.out_ch,T*Z,Y,X)).view(B,self.out_ch,T,Z,-1)+self.conv3_2(x)

        #x = self.conv1(x.view(B,C,T*Z,Y,X)).view(B,self.out_ch,T,Z,-1)+self.conv2(x.view(B,C,T,Z,-1))+self.conv3(x.view(B,C,T,Z,-1))
        #x = self.conv1_2(x.view(B,self.out_ch,T*Z,Y,X)).view(B,self.out_ch,T,Z,-1)+self.conv2_2(x)+self.conv3_2(x)

        if(self.ifrelu==1):
            x = self.lrelu1(x)
        x = x.view(B,self.out_ch,T,Z,Y,X)
        return x

In [None]:
class upconvwithactivation(nn.Module):
    # Spatial upsampling
    def __init__(self,in_ch,out_ch,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1],scale_factor=[1,2,2,2],ifrelu=1):
        super(upconvwithactivation,self).__init__()
        self.scale_factor=scale_factor
        self.conv=convwithactivation4DRSTAR(in_ch,out_ch,kernel_size,padding,stride,ifrelu)
    def forward(self,x):
        [B,C,T,Z,Y,X]=x.shape
        x = x.view(B,C*T,Z,Y,X)
        x = F.interpolate(x,scale_factor=self.scale_factor[1:],mode='nearest')
        x = x.view(B,C,T,Z*self.scale_factor[1],Y*self.scale_factor[2],X*self.scale_factor[3])
        x = self.conv(x)
        return x

In [None]:
class myUNet(nn.Module):
    def __init__(self):
        super(myUNet,self).__init__()
        cnum=16
        self.conv1=convwithactivation(2,cnum,kernel_size=[3,3,3],padding=[1,1,1],stride=[1,1,1],padding_mode='zeros')
        self.conv1_2=convwithactivation4DRSTAR(cnum,cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        self.conv1_3=convwithactivation4DRSTAR(cnum,cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])

        self.conv2=convwithactivation(cnum,2*cnum,kernel_size=[3,3,3],padding=[1,1,1],stride=[1,2,2],padding_mode='zeros')
        self.conv2_2=convwithactivation4DRSTAR(2*cnum,2*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        self.conv2_3=convwithactivation4DRSTAR(2*cnum,2*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        
        self.conv3=convwithactivation(2*cnum,4*cnum,kernel_size=[3,3,3],padding=[1,1,1],stride=[1,2,2],padding_mode='zeros')
        self.conv3_2=convwithactivation4DRSTAR(4*cnum,4*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        self.conv3_3=convwithactivation4DRSTAR(4*cnum,4*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        
        self.conv4=convwithactivation(4*cnum,8*cnum,kernel_size=[3,3,3],padding=[1,1,1],stride=[1,2,2],padding_mode='zeros')
        self.conv4_2=convwithactivation4DRSTAR(8*cnum,8*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        self.conv4_3=convwithactivation4DRSTAR(8*cnum,8*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        
        self.conv5=upconvwithactivation(8*cnum,4*cnum,kernel_size=[3,3,3,3],stride=[1,1,1,1],padding=[1,1,1,1],scale_factor=[1,1,2,2])
        self.conv5_2=convwithactivation4DRSTAR(8*cnum,4*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        self.conv5_3=convwithactivation4DRSTAR(4*cnum,4*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        
        self.conv6=upconvwithactivation(4*cnum,2*cnum,kernel_size=[3,3,3,3],stride=[1,1,1,1],padding=[1,1,1,1],scale_factor=[1,1,2,2])
        self.conv6_2=convwithactivation4DRSTAR(4*cnum,2*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        self.conv6_3=convwithactivation4DRSTAR(2*cnum,2*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        
        self.conv7=upconvwithactivation(2*cnum,1*cnum,kernel_size=[3,3,3,3],stride=[1,1,1,1],padding=[1,1,1,1],scale_factor=[1,1,2,2])
        self.conv7_2=convwithactivation4DRSTAR(2*cnum,1*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        self.conv7_3=convwithactivation4DRSTAR(1*cnum,1*cnum,kernel_size=[3,3,3,3],padding=[1,1,1,1],stride=[1,1,1,1])
        
        ##output
        self.conv8=upconvwithactivation(1*cnum,1,kernel_size=[3,3,3,3],stride=[1,1,1,1],padding=[1,1,1,1],scale_factor=[1,1,1,1],ifrelu=0)
        #self.conv8=convwithactivation4DRSTAR(1*cnum,1,kernel_size=[5,3,3,3],padding=[2,1,1,1],stride=[1,1,1,1],ifrelu=0)
        

    def forward(self,x):
        x=self.conv1(x)
        x=self.conv1_3(self.conv1_2(x))+x
        x1=self.conv2(x)
        x1=self.conv2_3(self.conv2_2(x1))+x1
        x2=self.conv3(x1)
        x2=self.conv3_3(self.conv3_2(x2))+x2
        x3=self.conv4(x2)
        x3=self.conv4_3(self.conv4_2(x3))+x3
        
        x3=self.conv5(x3)
        x3=self.conv5_3(self.conv5_2(torch.cat([x3,x2],dim=1)))+x3
        x3=self.conv6(x3)
        x3=self.conv6_3(self.conv6_2(torch.cat([x3,x1],dim=1)))+x3
        x3=self.conv7(x3)
        x3=self.conv7_3(self.conv7_2(torch.cat([x3,x],dim=1)))+x3
        x3=self.conv8(x3)
        
        return x3

In [None]:
model=myUNet().to(DEVICE)
params = list(model.parameters())
loss_fn = torch.nn.L1Loss()
scaler=GradScaler()

In [None]:
def train1(model,device,trainloader,optimizer,epoch):
    start=time.time()
    loss_sim_sum=0
    loss_ssim_sum=0
    model.train()
    for batch_idx,(cbct,ct,prior) in enumerate(trainloader):
        ct=ct.to(device)
        cbct=cbct.to(device)
        prior=prior.to(device)
        with autocast():
            cbct_refine=model(torch.cat([cbct,prior.unsqueeze(2).expand_as(cbct)],dim=1))
            loss_sim=loss_fn(cbct_refine.squeeze(1),ct.squeeze(1))
        loss_ssim=ssim(cbct_refine.float()*1000,ct*1000)
        with autocast():
            loss_G=loss_sim+(1-loss_ssim)*0.1
            
        optimizer.zero_grad()
        scaler.scale(loss_G).backward()
        scaler.step(optimizer)
        scaler.update()
        #loss_G.backward()
        #optimizer.step()
        
        loss_sim_sum += loss_sim.cpu().item()
        loss_ssim_sum += loss_ssim.cpu().item()
        
        if(batch_idx+1)%1==0:
            print('Train Epoch: %d, loss_sim %.4f, loss_ssim %.4f, time %.1f sec' % (epoch,loss_sim_sum,loss_ssim_sum,time.time()-start))
            loss_sim_sum=0
            loss_ssim_sum=0

In [None]:
def eval1(model,device,trainloader,epoch):
    start=time.time()
    model.eval()
    with torch.no_grad():
        loss_sim_sum=0
        loss_ssim_sum=0
        for batch_idx,(cbct,ct,prior) in enumerate(trainloader):
            cbct=cbct.to(device)
            ct=ct.to(device)
            prior=prior.to(device)
            cbct_refine=model(torch.cat([cbct,prior.unsqueeze(2).expand_as(cbct)],dim=1))
        
            loss_sim=loss_fn(cbct_refine,ct)
            loss_ssim=ssim(cbct_refine.squeeze(1)*1000,ct.squeeze(1)*1000)
            loss_sim_sum += loss_sim.cpu().item()
            loss_ssim_sum += loss_ssim.cpu().item()
        print('eval Epoch: %d, loss_sim %.4f, loss_ssim %.4f' % (epoch,loss_sim_sum,loss_ssim_sum/(batch_idx+1)))

In [None]:
lrate=0.00015
optimizer=torch.optim.Adam(params,lr=lrate,betas=(0.9, 0.999))
for epoch in range(1,61):
    train1(model,DEVICE,simuloader,optimizer,epoch)
    #torch.save(model.state_dict(),'checkpoints/checkpoints_Tetris1/model_%s.pth'%(epoch))
    if (epoch==20 or epoch==30 or epoch==40):
        lrate=lrate*0.5
        for params in optimizer.param_groups:             
            params['lr'] = lrate 
    eval1(model,DEVICE,evalloader,epoch)
    torch.cuda.empty_cache()