In [1]:
import numpy as np

In [2]:
import torch
import matplotlib
import torch.nn as nn
import random
import cv2

In [3]:
def positive_pair(data):
    [batch_size,q,h,w]=data.shape
    ret=np.array(data)
    NOISE_R=0.7
    FLIP_R=0.6
    r1=random.random()
    #add random noise
    if r1<NOISE_R:
        ret=ret+np.random.randn(batch_size,q,h,w)*12
    #flip image randomly on all dimensions
    r2=random.random()
    if r2<FLIP_R:
        ret=ret[:,::-1,:,:]
    r3=random.random()
    if r3<FLIP_R:
        ret=ret[:,:,::-1,:]
    if random.random()<0.6 or (r1>NOISE_R and r2>FLIP_R and r1>FLIP_R):
        ret=ret[:,:,:,::-1]
    return torch.as_tensor(ret.copy())

In [4]:
# Not used
def data_padding(data):
    [batch_size,nx,ny,nz]=data.shape
    maxr=max(nx,ny,nz)
    data2=np.zeros([batch_size,maxr,maxr,maxr])
    for i in range(batch_size):
        for x in range(nx):
            for y in range(ny):
                for z in range(nz):
                    data2[i,x+(maxr-nx)//2,y+(maxr-ny)//2,z+(maxr-nz)//2]=data[i,x,y,z]
    return data2.copy()
    

In [5]:
# Different shape on different slices
class cnn_multi_dim(nn.Module):
    def __init__(self,dim=0,output_dim=10):
        super(cnn_multi_dim,self).__init__()
        if dim==1:
            self.conv=nn.Sequential(
                nn.Conv2d(1,8,5,1,0),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(8,32,3,1,0),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(32,32,3,1,0),
                nn.ReLU(),
                nn.MaxPool2d(2),
            )
            self.out=nn.Linear(2592,output_dim)
        else:
            self.conv=nn.Sequential(
                nn.Conv2d(1,8,5,1,0),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(8,32,3,1,0),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(32,32,3,1,0),
                nn.ReLU(),
                nn.MaxPool2d(2),
            )
            self.out=nn.Linear(3168,output_dim)    
        self.output_dim=output_dim
            
    def forward(self,x):
            y=x.view([x.shape[1]*x.shape[0],1,x.shape[2],x.shape[3]])
            y=self.conv(y)
            y=y.view(y.size(0),-1)
            y=self.out(y)
            y=y.view(1,y.shape[0],y.shape[1])
            y=nn.AvgPool2d(kernel_size=[x.shape[1],1],stride=[x.shape[1],1])(y)
            y=y.view(y.shape[1],y.shape[2])
            return y

In [6]:
def train(loader,ep,lrate,alpha):
    cnn=[cnn_multi_dim(0),cnn_multi_dim(1),cnn_multi_dim(2)]
    optimizer=torch.optim.Adam([{"params":cnn[0].parameters()},{"params":cnn[1].parameters()},{"params":cnn[2].parameters()}],lrate)
    for epoch in range(ep):
        losses=[]
        for _,d in enumerate(loader):
            if (_>0):
                positive=positive_pair(d)
                # Generate slices
                tran_d=[d,d.permute(0,2,1,3),d.permute(0,3,1,2)]
                tran_neg=[negative,negative.permute(0,2,1,3),negative.permute(0,3,1,2)]
                tran_pos=[positive,positive.permute(0,2,1,3),positive.permute(0,3,1,2)]
                pred_d=[_,_,_]
                pred_pos=[_,_,_]
                pred_neg=[_,_,_]
                loss=0
                for dim in range(3):
                    pred_d[dim]=cnn[dim](tran_d[dim].float())
                    pred_pos[dim]=cnn[dim](tran_pos[dim].float())
                    pred_neg[dim]=cnn[dim](tran_neg[dim].float())
                    d1=pred_d[dim]-pred_pos[dim]
                    d2=pred_d[dim]-pred_neg[dim]
                    d1=torch.norm(d1)
                    d2=torch.norm(d2)
                    loss=loss+nn.ReLU()(d1-d2+alpha)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #print(epoch,_,loss.item())
                losses.append(loss.item())
            negative=d
        print('epoch:',epoch,'loss:',np.array(losses).mean())
        state={0:cnn[0].state_dict(),1:cnn[1].state_dict(),2:cnn[0].state_dict()}
        torch.save(state,'checkpointat{}.pth'.format(epoch))
    return cnn

In [25]:
def output(nets,images):
    ret=np.zeros((images.shape[0],3,nets[0].output_dim))
    loader=torch.utils.data.DataLoader(dataset=images,batch_size=images.shape[0],shuffle=False)
    with torch.no_grad():
        for _,d in enumerate(loader):
            # Generate slices
            tran_d=[d,d.permute(0,2,1,3),d.permute(0,3,1,2)]
            pred_d=[_,_,_]
            for dim in range(3):
                nets[dim].eval()
                pred_d[dim]=nets[dim](tran_d[dim].float())
                ret[:,dim,:]=pred_d[dim].numpy()
    return ret
    

In [8]:
#182 218 182
data=np.random.rand(12,91,109,91)*200
batch_size=3
dataloader=torch.utils.data.DataLoader(dataset=data,batch_size=batch_size,shuffle=True)
train(dataloader,2,1e-3,10)


epoch: 0 loss: 29.903202692667644
epoch: 1 loss: 28.640853881835938


[cnn_multi_dim(
   (conv): Sequential(
     (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))
     (1): ReLU()
     (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (3): Conv2d(8, 32, kernel_size=(3, 3), stride=(1, 1))
     (4): ReLU()
     (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
     (7): ReLU()
     (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   )
   (out): Linear(in_features=3168, out_features=10, bias=True)
 ),
 cnn_multi_dim(
   (conv): Sequential(
     (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))
     (1): ReLU()
     (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (3): Conv2d(8, 32, kernel_size=(3, 3), stride=(1, 1))
     (4): ReLU()
     (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (6): Conv2d(32, 32, kernel_size=(3, 3), s

In [9]:
import nibabel as nib
#img=nib.load('MPRFlirt_0.nii.gz')
#imgdata=np.array(img.get_fdata())
#print(imgdata.shape)

In [10]:
def MeanPool3d(img,kernel):
    s=img.shape
    ret=np.zeros((s[0]//kernel,s[1]//kernel,s[2]//kernel))
    for i in range(0,s[0],kernel):
        for j in range(0,s[1],kernel):
            for k in range(0,s[2],kernel):
                r=img[i:i+kernel,j:j+kernel,k:k+kernel]
                avg=r.mean()
                ret[i//kernel,j//kernel,k//kernel]=avg
    return ret            
                

In [11]:
ndata=32
data=[]
for i in range(ndata):
    img=nib.load('MPRFlirt_{}.nii.gz'.format(i))
    imgdata=np.array(img.get_fdata())
    imgdata=MeanPool3d(imgdata,2)
    data.append(imgdata)
    print('.',end='')
data=np.array(data)
print(data.shape)


................................(32, 91, 109, 91)


In [12]:
batch_size=2
dataloader=torch.utils.data.DataLoader(dataset=data,batch_size=batch_size,shuffle=True)

In [42]:
cnns=train(dataloader,10,1e-5,40)


epoch: 0 loss: 108.20485763549804
epoch: 1 loss: 80.13615417480469
epoch: 2 loss: 67.57412618001302
epoch: 3 loss: 48.27934010823568
epoch: 4 loss: 37.01170260111491
epoch: 5 loss: 23.869907506306966
epoch: 6 loss: 18.490877405802408
epoch: 7 loss: 15.882590993245442
epoch: 8 loss: 12.98216183980306
epoch: 9 loss: 7.426420084635416


In [47]:
features=output(cnns,data[0:4])

In [54]:
print(features-np.mean(features,axis=0))

[[[-18.04083443 -30.10380173   9.25126839 -24.64769363  19.53997231
   -23.32155609  12.49834633 -22.53586769 -19.83687019 -27.53133774]
  [ 14.11794281   9.75600624 -13.1545105   18.08741379  17.60120773
    18.08508301  17.818573    12.37252617 -12.44902802  13.12752151]
  [  9.53529358   7.29224586   6.42493439  -7.90307045  -6.19275856
    -7.53355789  -4.94872665   6.91421318   6.90393257   8.67295074]]

 [[-13.49768257 -24.45940781  10.14556313 -19.16000748  14.18871498
   -19.68483734   8.31364632 -17.77193642 -14.22083473 -21.8798027 ]
  [ 20.84435272  13.15213203 -21.20957947  22.71351242  23.44721603
    22.60429382  22.89402771  15.6972332  -22.35839081  20.63333511]
  [ 11.5473938   11.02936745  14.00681305 -10.71984291 -11.41077614
   -13.0578804   -9.56788254  11.11262321  10.13861275  11.70465851]]

 [[ 18.55711174  30.03129959 -11.36250877  25.46561813 -20.0010891
    24.33745575 -12.10825157  23.25816154  18.68953514  27.5301857 ]
  [-16.34326935 -10.33164024  15.41595