In [1]:
import sys 
sys.path.append('..')

import os
import torch
import torch.nn.functional as F
import argparse

from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

import scipy.sparse as sp
import numpy as np

from sklearn.metrics import mean_squared_error
from utils.utils import scipy_to_torch_sparse, genMatrixes, genMatrixesLH, genMatrixesOne
from utils.dataLoader import LandmarksDataset, ToTensor, ToTensorLH, Rescale, RandomScale, AugColor, Rotate

from models.hybridDoubleSkip import Hybrid as DoubleSkip
from models.hybridSkip import Hybrid as Skip
from models.hybrid import Hybrid
from models.hybridNoPool import Hybrid as HybridNoPool

from models.chebConv import Pool
from skimage.metrics import hausdorff_distance as hd



In [3]:
img_path = os.path.join("/home/hpcpin1/Seg_Project/Datasets/Cephalometric/Test", 'Images')
label_path = os.path.join("/home/hpcpin1/Seg_Project/Datasets/Cephalometric/Test", 'landmarks')
test_dataset = LandmarksDataset(img_path=img_path,
                                 label_path=label_path,
                                 transform = transforms.Compose([
                                             Rescale(1024),
                                             ToTensor()])
                                 )

device = 'cpu'

In [5]:
A, AD, D, U = genMatrixesOne(19)

A = sp.csc_matrix(A).tocoo()
AD = sp.csc_matrix(AD).tocoo()
D = sp.csc_matrix(D).tocoo()
U = sp.csc_matrix(U).tocoo()

D_ = [D.copy()]
U_ = [U.copy()]
A_ = [A.copy(), A.copy(), A.copy(), AD.copy(), AD.copy(), AD.copy()]

N1 = A.shape[0]
N2 = AD.shape[0]

config = {}
config['n_nodes'] = config['n_nodes'] = [N1, N1, N1, N2, N2, N2]

A_t, D_t, U_t = ([scipy_to_torch_sparse(x).to(device) for x in X] for X in (A_, D_, U_))

config['latents'] = 64
config['inputsize'] = 1024

f = 32
config['filters'] = [2, f, f, f, f//2, f//2, f//2]
config['skip_features'] = f

config['window'] = (3,3)

config['K'] = 6
config['l1'] = 6
config['l2'] = 5

double = DoubleSkip(config.copy(), D_t, U_t, A_t).to(device)
double.load_state_dict(torch.load("/home/hpcpin1/Seg_Project/Training/cephalo_28071330/bestMSE.pt", map_location=device))
double.eval()
print('Model loaded')

6-5
Model loaded


In [7]:
from utils.fun import drawOrgans

model_list = [double]
model_names = ['2-SC Layers 6-5']

i_ =[10, 13, 44]

fig = plt.figure(figsize=(3 * (len(model_list)+2), 9), dpi= 100)

c = 0

for i in i_:
    with torch.no_grad():
        sample = test_dataset[i]

        data, target = sample['seg'], sample['landmarks']
        data = torch.unsqueeze(data, 0).to(device)
        target =  target.reshape(-1).numpy()

        draw = data.cpu().numpy()[0,0,:,:]

        ax = plt.subplot(3, len(model_list) + 2, 1 + c * (len(model_list) + 2))
        plt.axis('off')
        plt.xlim(1, 1024)
        plt.ylim(1024, 1)
        
        target = np.clip(target, 0, 1)
        
        drawOrgans(ax, target[:38] * 1024, None, draw.copy())
        if c == 0:
            plt.title("Ground Truth", fontsize = 20)
        
        ax = plt.subplot(3, len(model_list) + 2, 2 + c * (len(model_list) + 2))
        
        data_MA = test_dataset.images[i].replace('Datasets/JSRT/Test/landmarks', "Results/MultiAtlas/JSRT/labels/output_points").replace(".png", ".npy")
        data_MA = np.load(data_MA)[:38]
        
        plt.axis('off')
        plt.xlim(1, 1024)
        plt.ylim(1024, 1)
        drawOrgans(ax, data_MA, None, draw.copy())
        if c == 0:
            plt.title("MultiAtlas", fontsize = 20)
        
        for j in range(0, len(model_list)):
            output = model_list[j](data)
            if len(output) > 1:
                output = output[0]
            output = output.cpu().numpy().reshape(-1) 
            output = np.clip(output, 0, 1)[:38]
            ax = plt.subplot(3, len(model_list) + 2, j + 3 + c * (len(model_list) + 2))
            plt.axis('off')
            drawOrgans(ax, output * 1024, None, draw.copy())
            if c == 0:
                plt.title(model_names[j], fontsize = 20)
            
            plt.xlim(1, 1024)
            plt.ylim(1024, 1)
        
        c += 1
        
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.05, hspace=0)      
plt.savefig('figs/s2g.png', bbox_inches = 'tight', dpi = 200)      
plt.savefig('figs/s2g.pdf', bbox_inches = 'tight', dpi = 200)

NameError: name 'plt' is not defined