# Responsável por mostrar o funcionamento de baixo nível da metodologia

Importação obrigatória das bibliotecas utilizadas

In [2]:
import argparse
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import cv2
import numpy as np
import torch.nn.init
import matplotlib.pyplot as plt
import scipy.ndimage
# from skimage import measure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import nibabel as nb
#conversor utilizado para realizar a leitura dos dicoms
from dicom_to_nifti import converter

ModuleNotFoundError: No module named 'nibabel'

In [None]:
use_cuda = torch.cuda.is_available()

In [None]:
parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation')
parser.add_argument('--scribble', action='store_true', default=False, help='use scribbles')
parser.add_argument('--nChannel', metavar='N', default=100, type=int, help='number of channels')
parser.add_argument('--maxIter', metavar='T', default=50, type=int, help='number of maximum iterations')
parser.add_argument('--minLabels', metavar='minL', default=8, type=int, help='minimum number of labels')
parser.add_argument('--lr', metavar='LR', default=0.1, type=float, help='learning rate')
parser.add_argument('--nConv', metavar='M', default=3, type=int, help='number of convolutional layers')
parser.add_argument('--stepsize_con', metavar='CON', default=1, type=float, help='step size for continuity loss')
parser.add_argument('--stepsize_scr', metavar='SCR', default=1, type=float, help='step size for scribble loss')
parser.add_argument('--visualize', metavar='1 or 0', default=1, type=int, help='visualization flag')
# parser.add_argument('--input', metavar='FILENAME', default=r'D:\Users\paulo\PycharmProjects\pytorch-unsupervised-segmentation-tip\imagens\3.png', help='input image file name', required=False)
parser.add_argument('--input', metavar='FILENAME',
                    default=r'E:\PycharmProjects\pythonProject\imagens\Dsc32909.jpg',
                    help='input image file name', required=False)
parser.add_argument('--stepsize_sim', metavar='SIM', default=1, type=float, help='step size for similarity loss', required=False)

args = parser.parse_args()

In [None]:
def plotarHistograma(exame):
    # plt.hist(exame.flatten(), bins=80, color='c')
    plt.hist(exame.flatten(), color='c')
    plt.xlabel("Hounsfield Units (HU)")
    plt.ylabel("Frequency")
    plt.show()

In [None]:
folder_dcm = r"E:\PycharmProjects\pythonProject\exame\CQ500CT257\Unknown Study\CT 0.625mm"
nifti_file = r"E:\PycharmProjects\pythonProject\exame\CQ500CT257.nii.gz"

In [None]:
vol = converter(folder_dcm, nifti_file)

In [None]:
fileName = 'exame_linha_210_'+ str(args.minLabels) + '.nii.gz'
img = nb.Nifti1Image(vol.T, np.eye(4))  
nb.save(img, os.path.join('build', fileName))

In [None]:
exame1 = vol.reshape(256,512,512)

In [None]:
plotarHistograma(exame1)

In [None]:
# CNN model
class MyNet(nn.Module):
    def __init__(self, input_dim):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, args.nChannel, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(args.nChannel)
        self.conv2 = nn.ModuleList()
        self.bn2 = nn.ModuleList()
        for i in range(args.nConv - 1):
            self.conv2.append(nn.Conv2d(args.nChannel, args.nChannel, kernel_size=3, stride=1, padding=1))
            self.bn2.append(nn.BatchNorm2d(args.nChannel))
        self.conv3 = nn.Conv2d(args.nChannel, args.nChannel, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(args.nChannel)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.bn1(x)

        for i in range(args.nConv - 1):
            x = self.conv2[i](x)
            x = F.relu(x)
            x = self.bn2[i](x)
        x = self.conv3(x)
        x = self.bn3(x)
        return x

In [None]:
data = torch.from_numpy(np.array([exame1.astype('float32') / 255.]))

In [None]:
if use_cuda:
    data = data.cuda()
data = Variable(data)

In [None]:
model = MyNet(1)
print(model)
if use_cuda:
    model.cuda()
model.train()

In [None]:
# similarity loss definition
loss_fn = torch.nn.CrossEntropyLoss()
# continuity loss definition
loss_hpy = torch.nn.L1Loss(size_average=True)
loss_hpz = torch.nn.L1Loss(size_average=True)

In [None]:
HPy_target = torch.zeros(512 - 1, 512, args.nChannel)
HPz_target = torch.zeros(512, 512 - 1, args.nChannel)
if use_cuda:
    HPy_target = HPy_target.cuda()
    HPz_target = HPz_target.cuda()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

In [None]:
label_colours = np.random.randint(255, size=(100, 3))

In [None]:
parou = False

In [None]:
for batch_idx in range(args.maxIter):
    if parou:
        break
    for slice in range(256):
        data1 = exame1[slice, :, :]
        data = torch.from_numpy(data1.reshape(1,1,512,512).astype('float32'))
        if use_cuda:
            data = data.cuda()
        data = Variable(data)


        # forwarding
        optimizer.zero_grad()
        output1 = model(data)[0]

        # plt.imshow(output1[0,:,:].data.cpu().numpy())
        # plt.show()

        output = output1.permute(1, 2, 0).contiguous().view(-1, args.nChannel)

        # plt.imshow(output.data.cpu().numpy())
        # plt.show()

        outputHP = output.reshape((data.shape[2], data.shape[3], args.nChannel))
        HPy = outputHP[1:, :, :] - outputHP[0:-1, :, :]
        HPz = outputHP[:, 1:, :] - outputHP[:, 0:-1, :]

        # continuity loss definition
        lhpy = loss_hpy(HPy, HPy_target)
        lhpz = loss_hpz(HPz, HPz_target)

        ignore, target = torch.max(output, 1)
        im_target = target.data.cpu().numpy()

        # plt.imshow(im_target.reshape(191, 194))
        # plt.show()

        nLabels = len(np.unique(im_target))


        if args.visualize:
            im_target_rgb = np.array([label_colours[c % args.nChannel] for c in im_target])
            im_target_rgb = im_target_rgb.reshape(512,512,3).astype(np.uint8)

            im_target_rgb = cv2.resize(im_target_rgb, (600, 600))
            data2 = cv2.resize(data1, (600, 600))
            cv2.imshow("output", im_target_rgb)
            cv2.imshow("original", data2)
            cv2.waitKey(10)
        loss = args.stepsize_sim * loss_fn(output, target) + args.stepsize_con * (lhpy + lhpz)
       

        loss.backward()
        optimizer.step()

#         torch.save(model.state_dict(), 'results/model.pth')
#         torch.save(optimizer.state_dict(), 'results/optimizer.pth')

        print(batch_idx, '/', args.maxIter, '|', ' label num :', nLabels, ' | loss :', loss.item())

        if nLabels <= args.minLabels:
            print("nLabels", nLabels, "reached minLabels", args.minLabels, ".")
            parou = True
            break

In [None]:
folder_dcm = r"E:\PycharmProjects\pythonProject\exame\CQ500CT420\Unknown Study\CT 0.625mm"
nifti_file = r"E:\PycharmProjects\pythonProject\exame\CQ500CT420.nii.gz"

In [None]:
exame_teste = converter(folder_dcm, nifti_file)

In [None]:
exame1_teste = exame_teste.reshape(256,512,512)
nifti_teste = np.ones((256, 512, 512), dtype=np.uint8) # dummy data in numpy matrix

In [None]:
for slice in range(256):
    data1 = exame1_teste[slice, :, :]
    data_teste = torch.from_numpy(data1.reshape(1, 1, 512, 512).astype('float32'))
    if use_cuda:
        data_teste = data_teste.cuda()
    data_teste = Variable(data_teste)
    output_teste = model(data_teste)[0]
    output = output_teste.permute(1, 2, 0).contiguous().view(-1, args.nChannel)
    ignore, target = torch.max(output, 1)
    im_target = target.data.cpu().numpy()



    im_target_rgb = np.array([label_colours[c % args.nChannel] for c in im_target])
    im_target_rgb = im_target_rgb.reshape(512, 512, 3).astype(np.uint8)

    nifti_teste[slice, :, :] = im_target.reshape(512, 512).astype(np.uint8)

    im_target_rgb = cv2.resize(im_target_rgb, (600, 600))
    data2 = cv2.resize(data1, (600, 600))
    cv2.imshow("output", im_target_rgb)
    cv2.imshow("original", data2)
    cv2.waitKey(10)

In [None]:
fileName = 'teste_segmentation_'+ str(args.minLabels) + '.nii.gz'
img = nb.Nifti1Image(nifti_teste.T, np.eye(4))  # Save axis for data (just identity)
img.header.get_xyzt_units()
img.to_filename(os.path.join('build',fileName))  # Save as NiBabel file