In [None]:
import pydicom
import pandas as pd
from vertebra_4_point import *
from vertebra_centrioid_model import *

heat_img_height = 384
heat_img_width = 224
input_height = 768
input_width = 448

In [None]:
pose_model_4 = StackedHourglassNetwork_4(input_shape=(768,448,3), num_stack=4, num_residual=1,num_heatmap=69,num_seg = 17)
pose_model_4.load_weights('./test_with_seg.h5')

In [None]:
pose_model = StackedHourglassNetwork(input_shape=(768,448,3), num_stack=4, num_residual=1,num_heatmap=17,num_seg = 17)
pose_model.load_weights('./test_with_with_centrioid.h5')

In [None]:

# for centrioid model
def find_max_coordinates(heatmaps):
    
    flatten_heatmaps = tf.reshape(heatmaps, (heat_img_height*heat_img_width, 17))
    indices = tf.math.argmax(flatten_heatmaps, axis=0)
    # after flatten, each 64 values represent one row in original heatmap
    y = tf.cast(indices / heat_img_width, dtype=tf.int64)
    x = indices - heat_img_width * y
    return tf.stack([x, y], axis=1).numpy()

# for centrioid model
def extract_keypoints_from_heatmap(heatmaps):
    max_keypoints = find_max_coordinates(heatmaps)
    # pad the heatmap so that we don't need to deal with borders
    padded_heatmap = np.pad(heatmaps, [[1,1],[1,1],[0,0]])
    adjusted_keypoints = []
    for i, keypoint in enumerate(max_keypoints):
        # since we've padded the heatmap, the max keypoint should increment by 1
        max_y = keypoint[1]+1
        max_x = keypoint[0]+1
        # the patch is the 3x3 grid around the max keypoint location
        patch = padded_heatmap[max_y-1:max_y+2, max_x-1:max_x+2, i]
        # assign 0 to max location
        patch[1][1] = 0
        # and the next largest value is the largest neigbour we are looking for
        index = np.argmax(patch)
        # find out the location of it relative to center
        next_y = index // 3
        next_x = index - next_y * 3
        delta_y = (next_y - 1) / 4
        delta_x = (next_x - 1) / 4
        # we can then add original max keypoint location with this offset
        adjusted_keypoint_x = keypoint[0] + delta_x
        adjusted_keypoint_y = keypoint[1] + delta_y
        adjusted_keypoints.append((adjusted_keypoint_x, adjusted_keypoint_y))
    # we do need to clip the value to make sure there's no keypoint out of border, just in case.
    
    normalized_keypoints = []
    for j in adjusted_keypoints:
        norm_x = np.clip(j[0],0,heat_img_width) / heat_img_width
        norm_y = np.clip(j[1],0,heat_img_height) / heat_img_height
        normalized_keypoints.append((norm_x, norm_y))
    
    return normalized_keypoints

In [None]:
# for 4 point model
def find_max_coordinates_4(heatmaps):
    
    flatten_heatmaps = tf.reshape(heatmaps, (heat_img_height*heat_img_width, 69))
    indices = tf.math.argmax(flatten_heatmaps, axis=0)
    # after flatten, each 64 values represent one row in original heatmap
    y = tf.cast(indices / heat_img_width, dtype=tf.int64)
    x = indices - heat_img_width * y
    return tf.stack([x, y], axis=1).numpy()

# for 4 point model
def extract_keypoints_from_heatmap_4(heatmaps):
    max_keypoints = find_max_coordinates_4(heatmaps)
    # pad the heatmap so that we don't need to deal with borders
    padded_heatmap = np.pad(heatmaps, [[1,1],[1,1],[0,0]])
    adjusted_keypoints = []
    for i, keypoint in enumerate(max_keypoints):
        # since we've padded the heatmap, the max keypoint should increment by 1
        max_y = keypoint[1]+1
        max_x = keypoint[0]+1
        # the patch is the 3x3 grid around the max keypoint location
        patch = padded_heatmap[max_y-1:max_y+2, max_x-1:max_x+2, i]
        # assign 0 to max location
        patch[1][1] = 0
        # and the next largest value is the largest neigbour we are looking for
        index = np.argmax(patch)
        # find out the location of it relative to center
        next_y = index // 3
        next_x = index - next_y * 3
        delta_y = (next_y - 1) / 4
        delta_x = (next_x - 1) / 4
        # we can then add original max keypoint location with this offset
        adjusted_keypoint_x = keypoint[0] + delta_x
        adjusted_keypoint_y = keypoint[1] + delta_y
        adjusted_keypoints.append((adjusted_keypoint_x, adjusted_keypoint_y))
    # we do need to clip the value to make sure there's no keypoint out of border, just in case.
    
    normalized_keypoints = []
    for j in adjusted_keypoints:
        norm_x = np.clip(j[0],0,heat_img_width) / heat_img_width
        norm_y = np.clip(j[1],0,heat_img_height) / heat_img_height
        normalized_keypoints.append((norm_x, norm_y))
    
    return normalized_keypoints



In [None]:
def draw_keypoints_on_image(image, keypoints, true_keypoints, dcm_ID,index=None):
    fig = plt.figure(figsize = (12,7))
    plt.imshow(image, cmap = 'gray')
   
    joints = []
    for i, joint in enumerate(keypoints):
        joint_x = joint[0] * image.shape[1]
        joint_y = joint[1] * image.shape[0]
        if index is not None and index != i:
            continue
        plt.scatter(joint_x, joint_y, s=1, c='red', marker='o')
    
    
    for i, joint in enumerate(true_keypoints):
        joint_x = joint[0]
        joint_y = joint[1]
        if index is not None and index != i:
            continue
        plt.scatter(joint_x, joint_y, s=1, c='yellow', marker='o')
    
    
    #plt.savefig(os.path.join('/home/u8227385/val_spine_graph', dcm_ID+'.jpg'))
    plt.show()

def draw_true_on_image(image, true_keypoints, dcm_ID,index=None):
    fig = plt.figure(figsize = (12,7))
    plt.imshow(image, cmap = 'gray')
   
    
    
    for i, joint in enumerate(true_keypoints):
        joint_x = joint[0]
        joint_y = joint[1]
        if index is not None and index != i:
            continue
        plt.scatter(joint_x, joint_y, s=1, c='yellow', marker='o')
    
    
    plt.savefig(os.path.join('/Users/linweichen/Desktop/NTU_ortho/Spine/spine_examination', dcm_ID+'.jpg'))
    plt.show()

In [None]:
def calculate_angle(line1, line2):
    vector1 = (line1[1][0] - line1[0][0], line1[1][1] - line1[0][1])
    vector2 = (line2[1][0] - line2[0][0], line2[1][1] - line2[0][1])
    
    dot_product = vector1[0] * vector2[0] + vector1[1] * vector2[1]
    magnitude1 = math.sqrt(vector1[0]**2 + vector1[1]**2)
    magnitude2 = math.sqrt(vector2[0]**2 + vector2[1]**2)
    
    cos_angle = dot_product / (magnitude1 * magnitude2)
    angle = math.degrees(math.acos(cos_angle))
    
    return angle

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import quad
import math

def polynomial_curve_fit(x, y, degree):
    coefficients = np.polyfit(x, y, degree)
    polynomial = np.poly1d(coefficients)
    return polynomial

def calculate_area_difference(x, y, poly_func):
    segment_linear_areas = []
    segment_poly_areas = []

    for i in range(len(x) - 1):
        # Linear area (trapezoid)
        segment_linear_area = 0.5 * (y[i] + y[i + 1]) * (x[i + 1] - x[i])
        segment_linear_area = abs(segment_linear_area)
        segment_linear_areas.append(segment_linear_area)

        # Polynomial area (integration)
        segment_poly_area, _ = quad(poly_func, x[i], x[i + 1])
        segment_poly_areas.append(abs(segment_poly_area))

    # Calculate total area differences
    total_difference = sum(abs(lp - pp) for lp, pp in zip(segment_linear_areas, segment_poly_areas))
    return total_difference

def find_best_polynomial_degree(x, y, max_degree=7):
    best_degree = None
    min_area_difference = float('inf')
    r_squared_values = []

    for degree in range(max_degree + 1):
        # Fit polynomial curve
        poly_func = polynomial_curve_fit(x, y, degree)

        # Calculate area difference
        area_difference = calculate_area_difference(x, y, poly_func)

        # Calculate R^2
        y_mean = np.mean(y)
        ss_total = sum((yi - y_mean)**2 for yi in y)
        ss_residual = sum((yi - poly_func(xi))**2 for xi, yi in zip(x, y))
        r_squared = 1 - (ss_residual / ss_total)
        r_squared_values.append((degree, r_squared))

        # Update the best degree based on area difference
        if area_difference < min_area_difference:
            min_area_difference = area_difference
            best_degree = degree

    return best_degree, r_squared_values

def find_apex_and_limits(points, poly_func):
    first_derivative = poly_func.deriv(m=1)
    second_derivative = poly_func.deriv(m=2)
    
    # Find roots of the first derivative (apex candidates)
    critical_points = np.roots(first_derivative)
    critical_points = critical_points[np.isreal(critical_points)].real  # Keep only real roots
    
    apexes = {}
    for cp in critical_points:
        closest_key = min(points.keys(), key=lambda k: abs(points[k][0] - cp))
        if closest_key not in ['t1', 'l5']:  # Skip if apex is 't1' or 'l5'
            apexes[closest_key] = cp
    
    # Find roots of the second derivative (inflection points)
    inflection_points = np.roots(second_derivative)
    inflection_points = inflection_points[np.isreal(inflection_points)].real
    
    # Determine upper and lower limits
    results = {}
    for apex_key, apex_x in apexes.items():
        left_inflection = max([ip for ip in inflection_points if ip < apex_x], default=None)
        right_inflection = min([ip for ip in inflection_points if ip > apex_x], default=None)
        
        upper_limit = min(points.keys(), key=lambda k: abs(points[k][0] - left_inflection)) if left_inflection else 't1'
        lower_limit = min(points.keys(), key=lambda k: abs(points[k][0] - right_inflection)) if right_inflection else 'l5'
        
        # Skip cases where apex overlaps with limits
        if upper_limit == apex_key:
            upper_limit = None
        if lower_limit == apex_key:
            lower_limit = None
        
        if upper_limit is not None and lower_limit is not None:
            results[apex_key] = {
                "apex_x": apex_x,
                "upper_limit": upper_limit,
                "lower_limit": lower_limit
            }
    
    return results


def print_apex_and_limits(results):
    for apex_key, data in results.items():
        # Skip output if apex is the same as upper or lower limit
        if data['upper_limit'] == apex_key or data['lower_limit'] == apex_key:
            continue

        #print(f"Apex: {apex_key} at x = {data['apex_x']}")
        print(f"Apex: {apex_key}")
        print(f"Upper Endplate: {data['upper_limit']}")
        print(f"Lower Endplate: {data['lower_limit']}")
        
        if data['angle'] is not None:
            print(f"Cobb angle: {data['angle']} degrees")
        print('')
        



    


def comparing_true_and_pred_ntuh(dcm_ID):

    img = pydicom.dcmread(dcm_ID + '.dcm', force=True).pixel_array 
    df = pd.read_csv(dcm_ID + '_new_with_center.csv', sep = ',')

    lst = list(df.columns)
    true_lst = []
    # from 69 to last
    for k in range(69,86,1):
        x = df[lst[2*k]][0]
        y = df[lst[2*k+1]][0]
        true_lst.append((x,y))
    height, width = img.shape
    
    dcm = pydicom.dcmread(dcm_ID + '.dcm', force=True)

    dcm_img = dcm.pixel_array

    height, width = dcm_img.shape
    scale_x = heat_img_width/width
    scale_y = heat_img_height/height   
    resize_params = dcm.get(0x00280030).value
    abs_lst = []

    ##
    new_img = 255 - ((dcm_img - dcm_img.min()) / (dcm_img.max() - dcm_img.min())*255).astype(np.uint8)
    raw_img = cv2.resize(new_img, (448,768),interpolation=cv2.INTER_CUBIC)
    color_img = cv2.cvtColor(raw_img,  cv2.COLOR_GRAY2BGR)/255

    inputs_neg = tf.expand_dims(color_img, 0)


    outputs = pose_model([inputs_neg], training=False)
        ##

    heatmap = tf.squeeze(outputs[3], axis=0).numpy()

    kp = extract_keypoints_from_heatmap(heatmap)

    joints = []
    for i, joint in enumerate(kp):
        joint_x = joint[0] * width
        joint_y = joint[1] * height
        joints.append((joint_x,joint_y))


    points = {'t1':[],'t2':[],'t3':[],'t4':[],'t5':[],'t6':[],'t7':[],'t8':[],'t9':[],'t10':[],'t11':[],'t12':[],'l1':[],'l2':[],'l3':[],'l4':[],'l5':[]}
    count = 0
    for i in points:
        points[i].append(true_lst[count][1])
        points[i].append(true_lst[count][0])

        count+=1

    x = np.array([v[0] for v in points.values()])
    y = np.array([v[1] for v in points.values()])
    
    plt.figure(figsize = (12,7))
    plt.plot(x,y)
    plt.xlim(0,height)
    plt.ylim(0,width)

    # Find the best polynomial degree
    best_degree, r_squared_values = find_best_polynomial_degree(x, y, max_degree=7)
    print('Predicted Result')
    print(f"Best polynomial degree: {best_degree}")
    # Fit polynomial curve
    
    poly_func = polynomial_curve_fit(x, y, best_degree)

    # Plot the first derivative
    first_derivative = poly_func.deriv(m=1)
    x_vals = np.linspace(min(x), max(x), 500)
    y_vals = first_derivative(x_vals)

    results = find_apex_and_limits(points, poly_func)
   
    draw_true_on_image(dcm_img, true_lst, dcm_ID,index=None)
    draw_keypoints_on_image(dcm_img, kp, true_lst, dcm_ID,index=None)
    
    ## for 4 point and then calculate angle
    outputs_4 = pose_model_4([inputs_neg], training=False)


    heatmap_4 = tf.squeeze(outputs_4[3], axis=0).numpy()

    kp_4 = extract_keypoints_from_heatmap_4(heatmap_4)

    joints = []
    for i, joint in enumerate(kp_4):
        joint_x = joint[0] * width
        joint_y = joint[1] * height
        joints.append((joint_x,joint_y))


    new_xy = {'t1':[],'t2':[],'t3':[],'t4':[],'t5':[],'t6':[],'t7':[],'t8':[],'t9':[],'t10':[],'t11':[],'t12':[],'l1':[],'l2':[],'l3':[],'l4':[],'l5':[]}
    count = 0
    for i in new_xy:
        new_xy[i].append((joints[4*count][0], joints[4*count][1]))
        new_xy[i].append((joints[4*count+1][0], joints[4*count+1][1]))
        new_xy[i].append((joints[4*count+2][0], joints[4*count+2][1]))
        new_xy[i].append((joints[4*count+3][0], joints[4*count+3][1]))
        count+=1    
    
    for result in results:
        
        upper_limit_coords = None
        upper_limit_coords = new_xy[results[result]["upper_limit"]][:2]


        lower_limit_coords = None
        lower_limit_coords = new_xy[results[result]["lower_limit"]][2:]

        # Calculate the angle between the lines if limits are found
        angle = None
        if upper_limit_coords and lower_limit_coords:
            angle = calculate_angle(upper_limit_coords, lower_limit_coords)

        results[result]['angle'] = angle
    
    
    
    print_apex_and_limits(results)
    
    print('....')
    return results





In [None]:
# demo
os.chdir('folder to your dcm file')
ID = ""
results = comparing_true_and_pred_ntuh(ID)