In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

import config
from utils.load_utils import load_image

In [None]:
DATAPATH = config.SCHULTHESS_DATAPATH
test_datapath = os.path.join(DATAPATH, 'test')
train_datapath = os.path.join(DATAPATH, 'train')

In [None]:
filenames = []
for root, dirs, files in os.walk(test_datapath):
    for file in files:
        if file.endswith('.png'):
            filenames.append(os.path.join(root, file))
   

In [None]:
os.listdir(test_datapath)

In [None]:
filenames[0]
os.path.basename(os.path.dirname(filenames[0]))

In [None]:
ex = filenames[2]

img = load_image(ex)
width, height = img.size
print(f"Width: {width}, Height: {height}")
print(f"Image Mode: {img.mode}")

In [None]:
plt.imshow(img, cmap='gray')
plt.axis('off')

In [None]:
img = np.array(img)
img2 = np.array(img)

_, thres = cv2.threshold(img, 100, 255, cv2.THRESH_BINARY)

plt.imshow(thres, cmap='gray')
plt.axis('off')

In [None]:
thres = cv2.bitwise_not(thres)

row_sums = np.sum(thres, axis=1)
joint_row = np.argmax(row_sums)

center_col = img.shape[1] // 2

cv2.circle(img, (center_col, joint_row), 10, (0, 0, 255), -1)
cv2.putText(img, "Joint Center", (center_col + 10, joint_row),
            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

# Plot
plt.figure(figsize=(6, 6))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title(f"Estimated Joint Center: ({center_col}, {joint_row})")
plt.axis("off")
plt.show()

In [None]:
col_sums = np.sum(thres, axis=0)
col_sums.shape

In [None]:
print(col_sums[0:5])
print(col_sums[500:505])
print(col_sums[1000:1005])

In [None]:
np.argmin(col_sums)

In [None]:
row_sums = np.sum(thres, axis=1)
joint_row = np.argmax(row_sums)

#center_col = img.shape[1] // 2
center_col = np.argmin(col_sums)

cv2.circle(img, (center_col, joint_row), 10, (0, 0, 255), -1)
cv2.putText(img, "Joint Center", (center_col + 10, joint_row),
            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

# Plot
plt.figure(figsize=(6, 6))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title(f"Estimated Joint Center: ({center_col}, {joint_row})")
plt.axis("off")
plt.show()

In [None]:
cv2.circle(img, (center_col, joint_row), 10, (0, 0, 255), -1)
cv2.putText(img, "Joint Center", (center_col + 10, joint_row),
            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

# Plot
plt.figure(figsize=(6, 6))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
#plt.title(f"Estimated Joint Center: ({center_col}, {joint_row})")
plt.axis("off")
plt.show()

In [None]:
joint_center_x = center_col
joint_center_y = joint_row

### Get Cut-Out of the Image

First visualize with bounding box.

In [None]:
box_width = 600
box_height = 600

top_left = (joint_center_x - box_width // 2, joint_center_y - box_height // 2)
bottom_right = (joint_center_x + box_width // 2, joint_center_y + box_height // 2)

cv2.circle(img, (center_col, joint_row), 10, (0, 0, 255), -1)
cv2.putText(img, "Joint Center", (center_col + 10, joint_row),
            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
cv2.rectangle(img, top_left, bottom_right, (255, 0, 0), 2)
# Plot
plt.figure(figsize=(6, 6))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title(f"Estimated Joint Center: ({center_col}, {joint_row})")
plt.axis("off")
plt.show()

In [None]:
cv2.circle(img, (center_col, joint_row), 10, (0, 0, 255), -1)
cv2.putText(img, "Joint Center", (center_col + 10, joint_row),
            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
cv2.rectangle(img, top_left, bottom_right, (255, 0, 0), 2)
# Plot
plt.figure(figsize=(6, 6))
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
#plt.title(f"Estimated Joint Center: ({center_col}, {joint_row})")
plt.axis("off")
plt.show()

In [None]:
box_width = 600
box_height = 600

top_left = (joint_center_x - box_width // 2, joint_center_y - box_height // 2)
bottom_right = (joint_center_x + box_width // 2, joint_center_y + box_height // 2)

cv2.rectangle(img2, top_left, bottom_right, (255, 0, 0), 2)
# Plot
plt.figure(figsize=(6, 6))
plt.imshow(cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
plt.title(f"Estimated Joint Center: ({center_col}, {joint_row})")
plt.axis("off")
plt.show()

In [None]:
top = max(joint_center_y - box_height // 2, 0)
bottom = min(joint_center_y + box_height // 2, img.shape[0])
left = max(joint_center_x - box_width // 2, 0)
right = min(joint_center_x + box_width // 2, img.shape[1])

roi = img2[top:bottom, left:right]

# Resize to 224x224
roi_resized = cv2.resize(roi, (224, 224), interpolation=cv2.INTER_AREA)
plt.imshow(cv2.cvtColor(roi_resized, cv2.COLOR_BGR2RGB))

In [None]:
plt.imshow(cv2.cvtColor(roi_resized, cv2.COLOR_BGR2RGB))
plt.axis("off")

In [None]:
SAVEPATH = os.path.join(DATAPATH, f"{box_width}_{box_height}_joint_center")

# Check images

In [None]:
def show_img_cropping(ex, t=100, box_width=600, box_height=600):
    img = load_image(ex)
    width, height = img.size
    img_cv = np.array(img)

    _, thres = cv2.threshold(img_cv, t, 255, cv2.THRESH_BINARY)
    thres_inv = cv2.bitwise_not(thres)

    row_sums = np.sum(thres_inv, axis=1)
    joint_row = np.argmax(row_sums)
    while joint_row > 650 or joint_row<300 :
        row_sums[joint_row] = 0
        joint_row = np.argmax(row_sums)
        # print(joint_row)
    # center_col = img_cv.shape[1] // 2
    # print(center_col)
    col_sums = np.sum(thres_inv, axis=0)
    # print(col_sums.shape)
    # print(col_sums[center_col])
    center_col = np.argmin(col_sums[512-50:512+50])+512-50
    # while center_col > 550 or center_col<450 :
    #     col_sums[center_col] = 0
    #     center_col = np.argmin(col_sums)
        # print(center_col)

    annotated = cv2.cvtColor(img_cv.copy(), cv2.COLOR_GRAY2BGR)
    cv2.circle(annotated, (center_col, joint_row), 10, (0, 0, 255), -1)
    cv2.putText(annotated, "Joint Center", (center_col + 10, joint_row),
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)


    top_left = (center_col - box_width // 2, joint_row - box_height // 2)
    bottom_right = (center_col + box_width // 2, joint_row + box_height // 2)
    cv2.rectangle(annotated, top_left, bottom_right, (255, 0, 0), 2)

    fig, axes = plt.subplots(1, 3, figsize=(15, 6))

    axes[0].imshow(img_cv, cmap="gray")
    axes[0].set_title("Original")
    axes[0].axis("off")

    axes[1].imshow(thres, cmap="gray")
    axes[1].set_title("Thresholded")
    axes[1].axis("off")

    axes[2].imshow(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
    axes[2].set_title("Annotated with Joint Center")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
kl = "4"
for files in os.listdir(os.path.join(test_datapath, kl)):
    if files.endswith('.png'):
        ex = os.path.join(test_datapath, kl, files)
        show_img_cropping(ex, t=100)