In [1]:
import pandas as pd
import torch
import torch.utils.data

from torch.autograd import Variable
import torch.nn as nn

# Dataset

In [2]:
class CTVolumesDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file):
        self.dataframe = pd.read_csv(csv_file)
        #self.root_dir = root_dir

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        #access the dataframe at the row idx, at columns 0 and 1 respectively
        lo_res_dir = self.dataframe.iloc[idx, 0]
        hi_res_dir = self.dataframe.iloc[idx, 1]
        
        #load the patches
        lo_res_patch = torch.load(lo_res_dir)
        hi_res_patch = torch.load(hi_res_dir)

        #tuple with the pair
        pair = (lo_res_patch, hi_res_patch)
        return pair


# Autoencoder

In [15]:

class Autoencoder(nn.Module):
    def __init__(self, cube_len=64):
        
        super(Autoencoder, self).__init__()
        
        #cube side
        self.cube_len = cube_len
        #hidden code size
        self.code_len = cube_len * 8
        
        #Contracting path:
        
        self.enc_1 = nn.Sequential(
            nn.Conv3d(1, self.cube_len, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm3d(self.cube_len),
            nn.ReLU()
        )
        
        
        self.enc_2 = nn.Sequential(
            nn.Conv3d(self.cube_len, self.cube_len * 2, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm3d(self.cube_len * 2),
            nn.ReLU()        
        )
        
        self.enc_3 = nn.Sequential(
            nn.Conv3d(self.cube_len * 2, self.cube_len * 4, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm3d(self.cube_len * 4),
            nn.ReLU()        
        ) 
        
        self.enc_4 = nn.Sequential(
            nn.Conv3d(self.cube_len * 4, self.code_len, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm3d(self.code_len),
            nn.ReLU()        
        ) 
        
        self.enc_5 = nn.Sequential(
            nn.Conv3d(self.code_len, self.code_len, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm3d(self.code_len),
            nn.ReLU()        
        )  
        
        self.enc_6 = nn.Sequential(
            nn.Conv3d(self.code_len, self.code_len, kernel_size = 4, stride = 2, padding = 1),
            #nn.BatchNorm3d(self.code_len),
            nn.ReLU()        
        )
        
        #Expansive path
        
        self.dec_1 = torch.nn.Sequential(
            nn.ConvTranspose3d(self.code_len, self.code_len, kernel_size=4, stride=2, padding = 1),
            nn.BatchNorm3d(self.code_len),
            nn.ReLU()
        )
        
    
        self.dec_2 = torch.nn.Sequential(
            nn.ConvTranspose3d(self.code_len, self.code_len, kernel_size=4, stride=2, padding = 1),
            nn.BatchNorm3d(self.code_len),
            nn.ReLU()
        )        
        
        self.dec_3 = torch.nn.Sequential(
            nn.ConvTranspose3d(self.code_len , (self.cube_len * 4), kernel_size=4, stride=2, padding = 1),
            nn.BatchNorm3d((self.cube_len * 4) ),
            nn.ReLU()
        )        
        
        self.dec_4 = torch.nn.Sequential(
            nn.ConvTranspose3d((self.cube_len * 4), (self.cube_len * 2), kernel_size=4, stride=2, padding = 1),
            nn.BatchNorm3d((self.cube_len * 2) ),
            nn.ReLU()
        )

        self.dec_5 = torch.nn.Sequential(
            nn.ConvTranspose3d((self.cube_len * 2) , self.cube_len , kernel_size=4, stride=2, padding = 1),
            nn.BatchNorm3d(self.cube_len ),
            nn.ReLU()
        )
        
        self.dec_6 = torch.nn.Sequential(
            nn.ConvTranspose3d(self.cube_len , 1, kernel_size=4, stride=2, padding = 1),
            nn.BatchNorm3d(1),
            nn.ReLU()
        )
        
    
    
    def forward(self, x):
        #downconvolutions
        out = self.enc_1(x)
        out = self.enc_2(out)
        out = self.enc_3(out)                
        out = self.enc_4(out)    
        out = self.enc_5(out)
        #code
        out = self.enc_6(out)
        
        #transposed convolutions
        out = self.dec_1(out)     
        out = self.dec_2(out)
        out = self.dec_3(out)
        out = self.dec_4(out)
        out = self.dec_5(out)
        out = self.dec_6(out)
        
        return out 
        


# Train loop

In [16]:
num_epochs = 1
batch_size = 1
learning_rate = 0.001

#instantiating the model
model = Autoencoder()
model.double()

#loss function
criterion = nn.MSELoss()
#optimizer algorithm
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

#dataset, dataloader
dataset = CTVolumesDataset(r'C:\Users\Juanig\Desktop\Desktop_\ct images\para probar el procesamiento de los volumenes\100_FBPPhil_500FBP.csv')
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

### To have idea about the shape of our data

In [30]:
#elements in our dataset
print(len(dataset))
#type of each element
print( type( dataset[0] ) )
#shape of the first element of the tuple (.shape is called because it is a numpy array)
print( dataset[0][0].shape )
#length of the dataloader(the same as the dataset because here batch_size = 1)
print(len(dataloader))

1171
<class 'tuple'>
(64, 64, 64)
1171


In [17]:
#set the model to train
model.train()
total_step = len(dataloader)

#train
for epoch in range(num_epochs):
    for i, (lo_res, hi_res) in enumerate(dataloader):
        #lo_res.size() = (batch_size,64,64,64) 
        #add an extra dimension:
        lo_res = lo_res.unsqueeze(1)
        #now lo_res.size() = (batch_size,1,64,64,64) 
        
        #forward pass       
        outputs = model(lo_res)
        loss = criterion(outputs, hi_res.unsqueeze(1))
        
        #backward & optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        """ 
        if (i+1) % 1 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
        """        
        
#torch.save(model.state_dict(), 'toy_3d_autoencoder')  

...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([1, 256, 8, 8, 8])
torch.Size([1, 512, 4, 4, 4])
torch.Size([1, 512, 2, 2, 2])
torch.Size([1, 64, 32, 32, 32])
...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([1, 256, 8, 8, 8])
torch.Size([1, 512, 4, 4, 4])
torch.Size([1, 512, 2, 2, 2])
torch.Size([1, 64, 32, 32, 32])
...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([1, 256, 8, 8, 8])
torch.Size([1, 512, 4, 4, 4])
torch.Size([1, 512, 2, 2, 2])
torch.Size([1, 64, 32, 32, 32])
...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([1, 256, 8, 8, 8])
torch.Size([1, 512, 4, 4, 4])
torch.Size([1, 512, 2, 2, 2])
torch.Size([1, 64, 32, 32, 32])
...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([

...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([1, 256, 8, 8, 8])
torch.Size([1, 512, 4, 4, 4])
torch.Size([1, 512, 2, 2, 2])
torch.Size([1, 64, 32, 32, 32])
...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([1, 256, 8, 8, 8])
torch.Size([1, 512, 4, 4, 4])
torch.Size([1, 512, 2, 2, 2])
torch.Size([1, 64, 32, 32, 32])
...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([1, 256, 8, 8, 8])
torch.Size([1, 512, 4, 4, 4])
torch.Size([1, 512, 2, 2, 2])
torch.Size([1, 64, 32, 32, 32])
...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([1, 256, 8, 8, 8])
torch.Size([1, 512, 4, 4, 4])
torch.Size([1, 512, 2, 2, 2])
torch.Size([1, 64, 32, 32, 32])
...
torch.Size([1, 1, 64, 64, 64])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 128, 16, 16, 16])
torch.Size([

KeyboardInterrupt: 