In [1]:
import numpy as np
import png
import pydicom
from sklearn.preprocessing import normalize
import warnings
warnings.filterwarnings('ignore')
from os import listdir
from os.path import isfile, join
import os
import torch
import torch.nn as nn
from torchvision.transforms import Compose
import transform_classes

from roi import RoiLearn
from roi_dataset import RoiDataset
from preprocessor import Preprocessor

import preprocess_img

In [4]:
import torch.nn.functional as F

class Autoencoder(nn.Module):
    def __init__(self, n_inp, n_hidden):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Linear(n_inp, n_hidden)
        self.decoder = nn.Linear(n_hidden, n_inp)
        self.n_inp = n_inp
        self.n_hidden = n_hidden

    def forward(self, x):
        encoded = F.sigmoid(self.encoder(x))
        decoded = F.sigmoid(self.decoder(encoded))
        return encoded, decoded

def kl_divergence(p, q):
    '''
    args:
        2 tensors `p` and `q`
    returns:
        kl divergence between the softmax of `p` and `q`
    '''
    p = F.softmax(p)
    q = F.softmax(q)

    s1 = torch.sum(p * torch.log(p / q))
    s2 = torch.sum((1 - p) * torch.log((1 - p) / (1 - q)))
    return s1 + s2

In [73]:
import numpy as np
import png
import pydicom
from sklearn.preprocessing import normalize
import torch.nn.functional as F
from os import listdir
from os.path import isfile, join
import os
import torch
import torch.nn as nn


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class RoiLearn:
    def __init__(self):
        torch.manual_seed(12)
        self.conv1 = nn.Conv2d(1,100, (11,11))
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.avgpool = nn.AvgPool2d(6)
        self.flatten = Flatten()
        self.full = nn.Linear(8100,1024)
        #self.encoder = nn.Linear(121,100)
        #self.decoder = nn.Linear(100,121)
          
    # Autoencoder architecture
    def build_ae(self):
        self.autoencoder = Autoencoder(121,100)
        self.autoencoder = self.autoencoder.double()
        
    # Autoencoder W2 and b2 to the original model conv1 layer features and biases.
    # From the parameters list - index 0 is the weights
    #                          - index 1 is the biases
    def ae_weights2model_feature_set(self):
        
        w2 = list(self.encoder.parameters())
        
        b2 = w2[1].detach().numpy()
        # weights shape here (100,121)
        w2 = np.expand_dims(w2[0].detach().numpy().reshape((100,11,11)), axis = 1)
        # weights shape (100,1,11,11)
        
        conv1_features = list(self.conv1.parameters())
        conv1_features[0] = torch.nn.Parameter(torch.from_numpy(w2))
        conv1_features[1] = torch.nn.Parameter(torch.from_numpy(b2))
        conv1_features[0].requires_grad=False
        conv1_features[1].requires_grad=False
        
    # Rho - sparsity penalty pj = p    
    def learn_ae(self, dataset_loader, optimizer,criterion, ep = 1, lr = 0.01, BETA = 3, RHO = 0.1):
        rho = torch.FloatTensor([[RHO for _ in range(self.autoencoder.n_hidden)] for _ in range(dataset_loader.batch_size)]).double()
        crit2 = nn.KLDivLoss(size_average=False)
        for epoch in range(ep):
            for i_batch, sample_batched in enumerate(dataset_loader):
                #print(i_batch, sample_batched['image'].size(),sample_batched['mask'].size())
                #print(sample_batched['image'].shape)
                encoded, decoded = self.autoencoder(sample_batched['image'])
                first_loss = criterion(sample_batched['image'],decoded)
                sparsity_loss = crit2(rho, encoded)
                #MSE_loss.view(1, -1).sum(1)
                #MSE_loss = MSE_loss.view(1, -1).sum(1) / dataset_loader.batch_size
                
                #y_pred = self.autoencoder(sample_batched['image'])
                #rho_hat = torch.sum(encoded, dim=0, keepdim=True) / dataset_loader.batch_size
                #sparsity_penalty = BETA * F.kl_div( rho,rho_hat)
                print(sparsity_loss)
                loss = first_loss + BETA*sparsity_loss
                
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            print('epoch: ', epoch,' loss: ', loss.item())
        
                
    
    def build_model(self):
        self.model = nn.Sequential(self.conv1,
                            self.avgpool,
                            self.softmax,
                            self.flatten,
                            self.full,
                            self.softmax
                            )
        self.model = self.model.double()
    
    def propagate_from_dataLoader(self,dl):
        for i_batch, sample_batched in enumerate(dl):
            print(self.model(sample_batched[0]))
        
    def propagate(self):
        return self.model(self.x)
    
    def save_image( self,npdata, outfilename ) :
        img = Image.fromarray( np.asarray( np.clip(npdata,0,255), dtype="uint8"), "L" )
        img.save( outfilename )


In [47]:
# if we don't have the .csv file
preprocess_img.write_all_rectangle2file('O:\\ProgrammingSoftwares\\anaconda_projects\\heart_contour\\sa_all_1\\')

Dicom files were read in!
Con files were read in!
Dicom files were read in!
Con files were read in!
Dicom files were read in!
Con files were read in!
Dicom files were read in!
Con files were read in!


In [74]:
csv_file = 'O:/ProgrammingSoftwares/anaconda_projects/heart_contour/sa_all_1/rectangle.csv'

compose3 = Compose([transform_classes.GetRandomPatch(),transform_classes.StandardScale2(),transform_classes.ToTensor()])

ds2 = RoiDataset(csv_file, compose3)
roi = RoiLearn()
roi.build_ae()

crit = torch.nn.MSELoss(size_average = True)
opt = torch.optim.Adam(roi.autoencoder.parameters(),  weight_decay = 0.0001 )

# Random 1000 sample
weighted_rnd_sample = torch.utils.data.WeightedRandomSampler([float(1/len(ds2)) for i in range(len(ds2))], 1000, replacement=True)
dataset_loader = torch.utils.data.DataLoader(ds2,batch_size=8, num_workers=0, sampler=weighted_rnd_sample)

roi.learn_ae(dataset_loader, optimizer = opt, criterion = crit,  ep = 10)

#roi.ae_weights2model_feature_set()

tensor(-302.4361, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-304.0619, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-306.4438, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-303.4931, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-304.0738, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-304.5660, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-306.3905, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-307.1995, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-306.6145, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-305.4734, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-308.3309, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-306.5888, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-305.7754, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-309.7013, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-309.6545, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-307.4005, dtype=torch.float64, g

tensor(-319.4951, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-317.3405, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-319.3658, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-318.5754, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-318.8458, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-317.5928, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-319.5735, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-317.6623, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-318.7681, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-319.1903, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-317.8335, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-319.4497, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-318.9635, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-318.9003, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-318.1489, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-319.3140, dtype=torch.float64, g

tensor(-322.3799, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.1801, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.3388, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-320.5626, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.1229, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.3973, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.3450, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.8967, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.5705, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.7734, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.3542, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.3346, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.3847, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.5423, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.2743, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.6146, dtype=torch.float64, g

tensor(-323.7402, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.5641, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.3521, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9638, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.3325, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9303, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.5418, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.4448, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.0804, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.5278, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.4025, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.8947, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.0913, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.8563, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9159, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.3135, dtype=torch.float64, g

tensor(-323.7924, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.8524, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.6646, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.0712, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.3159, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.5614, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1272, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9973, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.2473, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.4817, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.3790, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9736, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.3961, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1972, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.3902, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.7392, dtype=torch.float64, g

tensor(-324.6710, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.4102, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.7993, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.4669, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.8408, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.5704, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.0560, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1256, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.7932, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9793, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1625, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.2591, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1595, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.2651, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.2561, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.3787, dtype=torch.float64, g

tensor(-324.0128, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.4125, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1939, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.5194, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1864, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.5220, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.3491, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1895, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.4621, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.0490, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1803, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.3043, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.4400, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.4309, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.6031, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9738, dtype=torch.float64, g

tensor(-324.1542, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.4339, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.4628, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.2316, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9059, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.2870, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1464, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.5489, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.5251, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9390, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.3981, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9729, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.2075, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.6702, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.0357, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.6286, dtype=torch.float64, g

tensor(-322.6796, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.0707, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.7968, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.7327, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.7121, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9271, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.2302, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.8626, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.0400, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.7477, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.1461, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.6014, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.7184, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.4173, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.0196, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.3615, dtype=torch.float64, g

tensor(-322.9410, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.1495, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.1767, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.7685, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.8958, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.1154, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.0424, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.7334, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.9113, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.0759, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.0905, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.2297, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-324.0322, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-321.4253, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-322.8807, dtype=torch.float64, grad_fn=<KlDivBackward>)
tensor(-323.2567, dtype=torch.float64, g

ezek a tensorok a sparsity penalty értékei batch-enként.

In [30]:
roi.autoencoder.encoder.weight

Parameter containing:
tensor([[ 0.0278, -0.0151, -0.0972,  ...,  0.1100,  0.1107,  0.0876],
        [ 0.0771,  0.0222,  0.0840,  ..., -0.0344, -0.0756,  0.0668],
        [-0.0089, -0.0514,  0.0511,  ...,  0.0566, -0.0500, -0.1025],
        ...,
        [-0.0133,  0.0592, -0.0483,  ..., -0.0714, -0.0549, -0.1544],
        [ 0.0342,  0.0797,  0.0925,  ..., -0.0649,  0.0245,  0.0406],
        [ 0.0799,  0.0340, -0.0507,  ...,  0.0252,  0.1131,  0.0358]],
       dtype=torch.float64, requires_grad=True)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torchvision.transforms import Compose
import transform_classes

csv_file = 'O:/ProgrammingSoftwares/anaconda_projects/heart_contour/sa_all_1/rectangle.csv'
compose1 = Compose([transform_classes.ReScale64(),transform_classes.StandardScale(),transform_classes.ToTensor()])
compose2 = Compose([transform_classes.ReScale32(),transform_classes.ToTensor()])
ds = RoiDataset(csv_file, compose1, compose2)

roi = RoiLearn()
roi.build_model()

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(roi.model.parameters(), lr=0.1)

dataset_loader = torch.utils.data.DataLoader(ds,batch_size=32, shuffle=True,num_workers=0)