## Intro
- `inference_loader_with_flag()`: 
    - For processed video (Sonosite) folders, input 3 frames t1/t2/t3, view them as L/C/R (left/center/right), and predict the needle (`x1 y1` & `x3 y3`) at C.
    - Call PK function to output **`raw_z1`** & **`raw_z3`** predicted by regression.
- `線段檢查()`: Set the `find_train_end_left` & `find_train_end_right`
- `PK()`: Regression on L & R bright & shadow area average value from `find_train_end_left` to `find_train_end_right`
- `save_pred_to_plot()`: save inference & PK result plot to folder (if prediction mask is empty, save nothing)
- `save_pred_to_json()`: save x,y prediction to json


- TODO (might not be needed)
    - process raw video (developer/prodigy管理/zipper array data for PK)crop 38mm
        - clipchamp? crop to 1080x1080
    - show_3D()
    - 


## Packages & Read Config File

In [1]:
# @Packages
import json, os, random, math
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import concurrent.futures   ## Threading
import time

## model
from model import Mask2Former
# try:
from model.memmask2former.inference_wrapper import MemInferenceWrapper
from model.memmask2former.mem_m2f import track_model_cfg
# except:
#     pass

from lib.config_helper import merge
from torch.utils.data import DataLoader
from dataset import UnlabeledDataset, Augmentation

import cv2
from sklearn.linear_model import LinearRegression
# from sklearn.metrics import explained_variance_score
from joblib import parallel_backend
import matplotlib.pyplot as plt

import torchvision.transforms as tf
from torchvision.transforms import v2
from sklearn.decomposition import PCA

from omegaconf import OmegaConf
import omegaconf

random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
np.set_printoptions(suppress=True)
torch.set_float32_matmul_precision('high')
torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm






In [2]:
## Read the user settings
from omegaconf import OmegaConf
config_PK = OmegaConf.load("./config_PK.yml")

run_only_one = False      ## Inference a single frame
preprocess_video = False  ## Assume frames are already saved as aXXXX.jpg in a folder
                          ## if True, then we need to preprocess first
## Check data
if config_PK.Data.raw_video_dir is not None:
    print('[Process Video]')
    preprocess_video = True
    data_dir = config_PK.Data.raw_video_dir
    print(type(config_PK.Data.raw_video_dir))
    if isinstance(config_PK.Data.raw_video_dir, list) or isinstance(config_PK.Data.raw_video_dir, omegaconf.listconfig.ListConfig):
        print(config_PK.Data.raw_video_dir)
    else:
        # raise NotImplementedError
        assert os.path.exists(config_PK.Data.raw_video_dir)
        ## TODO capture frame in 1 video and crop specific area into 3 LCR (or 2 LR) frames
elif config_PK.Data.sonosite_frame_dir is not None:
    print('[Process non-LCR Frames]')
    data_dir = config_PK.Data.sonosite_frame_dir
    assert os.path.exists(config_PK.Data.sonosite_frame_dir)
elif config_PK.Data.prodigy_frame_dir is not None:
    print('[Process LCR Frames]')
    data_dir = config_PK.Data.prodigy_frame_dir
    assert os.path.exists(config_PK.Data.prodigy_frame_dir+"/L")
    assert os.path.exists(config_PK.Data.prodigy_frame_dir+"/R")
    if not os.path.exists(config_PK.Data.prodigy_frame_dir+"/C"):
        ## TODO get the middle frames
        raise NotImplementedError

save_json=config_PK.User_setting.save_json   ## save prediction needles coordinates to json
save_mask=config_PK.User_setting.save_mask   ## save PK result plot

[Process Video]
<class 'str'>


## Set coordinates
If input raw prodigy video, there are two options:
1. Capture the coordinates on the first frame of the video: Run  the 1st cell below (`get_coord`) and make 3 clicks on the frame (clicking the upper-left of the 2 images and the bottom-right of the first image is enough). Then, manually set the coordinates of the 3 square area.
2. Use the default coordinate setting of `家庭資料室_Developer/Prodigy管理/zipper array data for PK豬肉打針/ultrasound_2025-06-13-15-07.mp4`: Run the 2nd cell below.

In [None]:
## get_coord
coords = []
class GetCoords:
    def __init__(self, frame1):
        self.counter = 0
        cv2.imshow("Image", frame1)
        cv2.setMouseCallback("Image", self.click_event, frame1)
        cv2.waitKey()
    def click_event(self, event, x, y, flag, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            coords.append((x, y))
            print(f"Clicked at: ({x}, {y})")
            # Optional: draw a circle
            cv2.circle(param, (x, y), 5, (0, 255, 0), -1)
            cv2.imshow("Image", param)
            self.counter += 1
            if self.counter == 3:
                print('Completed colllection of coords.')
                cv2.destroyAllWindows()
                return 0

# Load first frame from video  ------------------
cap = cv2.VideoCapture(config_PK.Data.raw_video_dir)
assert cap.isOpened(), "Error opening video file."
while True:
    ret, frame1 = cap.read()
    break
cap.release()
# -------------------------------------------------

## make 3 clicks to check the required coordinates
getcoord = GetCoords(frame1)


Clicked at: (113, 153)
Clicked at: (501, 582)
Clicked at: (623, 156)


: 

In [None]:
## Manual coordinate setting (must check for raw prodigy video)
## upper-right of the Left, Center, Right images
h, w = config_PK.Data.frame_h, config_PK.Data.frame_w  ## image size
y_up = 150
x_l = 52
x_c = x_l+w
x_r = x_c+w

## set the area of each image
y_1, y_2 = y_up, y_up+h
x_l1, x_l2 = x_l, x_l+w
x_c1, x_c2 = x_c, x_c+w
x_r1, x_r2 = x_r, x_r+w

## Model Setup
- Read config to build model
- transform to tensor

In [4]:
# @Model setting
## Set Configuration from json
if config_PK.Detection_model.name == "m2f":
    with open("config_m2f.json", "r", encoding="utf-8") as f:
        config = json.load(f)
    with open("./model/mask2former/m2f_config.json", "r", encoding="utf-8") as f:
        m2f_config = json.load(f)
    config = merge(config, m2f_config)

    # --------------------------------------------------------------------------
    # Model Initialization
    # --------------------------------------------------------------------------
    if config["Model"]["unet_backbone"]["encoder_type"] == "TransNeXt-Tiny":
        model_name = "TransNeXt-Mask2Former"
    elif config["Model"]["unet_backbone"]["encoder_type"] == "ConvNeXt":
        model_name = "ConvNeXt-Mask2Former"
    
    model = Mask2Former(config)
    anchors_pos = None
    det_head = False

elif config_PK.Detection_model.name == "memm2f":
    with open("config_memm2f.json", "r", encoding="utf-8") as f:
        config = json.load(f)
    with open("./model/memmask2former/mem_m2f_config.json", "r", encoding="utf-8") as f:
        m2f_config = json.load(f)
    config = merge(config, m2f_config)
    model_name = 'MemM2F'

    ## recheck decoder input dimension
    encoder_dim = {"ConvNeXt":[1024, 512, 256, 128], "TransNeXt-Tiny":[576, 288, 144, 72]}
    if "pixel" in config["Model"]["unet_backbone"]["decoder_type"]:
        track_model_cfg.mask_decoder.up_dims = [256,256,256]
        track_model_cfg.pixel_encoder.ms_dims = [256]*4
        track_model_cfg.embed_dim = 256
        if config["Model"]["unet_backbone"]["decoder_type"] == "pixelup":
            track_model_cfg.pixel_encoder.ms_dims[-1] = encoder_dim[config["Model"]["unet_backbone"]["encoder_type"]][0]
    else:
        track_model_cfg.mask_decoder.up_dims[0] = encoder_dim[config["Model"]["unet_backbone"]["encoder_type"]][0]
        track_model_cfg.pixel_encoder.ms_dims = encoder_dim[config["Model"]["unet_backbone"]["encoder_type"]]
        track_model_cfg.embed_dim = encoder_dim[config["Model"]["unet_backbone"]["encoder_type"]][0]

    track_eval_cfg = OmegaConf.load("./model/memmask2former/eval_config.yaml")
    model = MemInferenceWrapper(cfg=config, track_model_cfg=track_model_cfg, 
                                    track_eval_cfg=track_eval_cfg)

if config["Model"].get("dynamic_tanh"):
    from model.utils import convert_ln_to_dyt
    model = convert_ln_to_dyt('','',model)

print('\n[MODEL]:',model_name)

## Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device_name = torch.cuda.get_device_name(device) if torch.cuda.is_available() else "CPU"
print(f"[DEVICE]: {device}")

# --------------------------------------------------------------------------
# Read ckpt
# --------------------------------------------------------------------------
ckpt_path = config["Model"].get("ckpt_path")
assert os.path.exists(ckpt_path)
ckpt = torch.load(ckpt_path, map_location="cuda")  ## , weights_only=False
if "n_averaged" in ckpt.keys(): ## key names in ema_model & model are slightly different
    ema_model = torch.optim.swa_utils.AveragedModel(model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(config["Validation"]["ema"]))
    model = ema_model
model.load_state_dict(ckpt, strict=True)
print(f"[LOAD] Ckpt: {ckpt_path}")

[Convnext] from repo
[convnext unet] building pixelup_decoder...
building pixel_decoder...
ConvNeXt UNet params:  94613184

[MODEL]: ConvNeXt-Mask2Former
[DEVICE]: cuda
[LOAD] Ckpt: pretrained_weight/convnext_mask2former_cls_rB384_pixup_foc20aux1_v6imgLCR_k2.pth


In [5]:
# @Data pre-processing
valid_transform = Augmentation(crop=False, rotate=False, color_jitter=False, horizontal_flip=False, 
                               image_size=config["Model"]["image_size"])
trans_totensor = tf.ToTensor()

## Code from Radon_by_Jacky_no_PK
- 參數設定
- Flag 
- PK

In [6]:
cutNum = config_PK.PK_hyperparam.cutNum  ## maximum cuts (carriages) of the train
left_bright_weight = config_PK.PK_hyperparam.left_bright_weight
left_shadow_weight = config_PK.PK_hyperparam.left_shadow_weight
right_bright_weight = config_PK.PK_hyperparam.right_bright_weight
right_shadow_weight = config_PK.PK_hyperparam.right_shadow_weight

width, height = config_PK.Data.frame_h, config_PK.Data.frame_w  ## If this is <= 0, the code will get the size from reading a frame

In [None]:
class Flag:
    def __init__(self, cutNum=32, height=1758, width=1758):
        
        self.needle_location = [0,0,0,0]  #[x1,y1,x3,y3]
        self.height = height ## origin:1080
        self.width = width  ## origin:1920
        
        self.filename=""
        self.save_folder_name = ''
        self.found_needle = False  ## origin: hough_found_needle
        self.count1=0
        
        #要跳過幾個frame數
        self.run_only_one = False    #跑單張 frame
        self.fps_reduce = 1

        #火車(又名pk、梯度)
        self.cutNum = cutNum
        
        self.pk_regression_box=[0,0]
        self.pk_result_先左右互扣 = list(range(self.cutNum))  #保留
        self.pk_result_再左右互扣 = list(range(self.cutNum))  #保留
        
        self.carriage_x = []  ## record [Xleft_point,Xright_point] for each carriage (the ROI bbox of the train)
        self.carriage_y = list(range(self.cutNum)) ## record [Y_1_point,Y_0_point] for each carriage (the ROI bbox of the train)
        
        # ---------------------------
		# Set carriage coordinate (ROI bbox of each part of train)
		# ---------------------------
        for i in range(0,cutNum):  
            ## fixed x axis at each cut
            Xleft_point = int(round( 0 + width*(i+0.1)/cutNum,0) )
            Xright_point = int(round( 0 + width*(i+0.9)/cutNum,0) )
            self.carriage_x.append([Xleft_point,Xright_point])

        
		# region ## not used now
    	# self.save_ori_frame=False
    	# self.綠框_result = list(range(self.cutNum)) #保留 ## record flag.bright at each carriage
        # self.綠框_score = 0.0
        # self.綠框_variance = 0.0
        # self.綠框_square = 0.0
        # self.run_time_list=[]
        # self.run_time_txt=[]
        # self.find_train_threshold = 400
        # self.pk_score = 0.0  ## R^2 score of regression
        # self.pk_variance = 0.0
        # self.left_b_minus_s = []
        # self.right_b_minus_s = []
        # #XY陣列找參數，目前關閉只取第一個，blur 5 deg 14
        # self.blur = 3 # 7/3版是6 --> 4/16demo版本是3
        # #JSON
        # json_txt = '{}'
        # self.json_obj = json.loads(json_txt)
		# endregion


    def pk_regression(self):
        '''
        LinearRegression to predict the depth at each carriage.
        Store prediction in self.pk_regression_box_再左右互扣
        '''
        # try:
        #x = np.array(list(range(self.cutNum))).reshape((-1, 1))
        x = np.array(list(range(self.find_train_end_right - self.find_train_end_left +1))).reshape((-1, 1))
        #print("x=",x.shape)
        #print("flag.find_train_end_right,flag.find_train_end_left=",flag.find_train_end_right, flag.find_train_end_left)
        #print("self.pk_result=",self.pk_result_先左右互扣)
        y = np.array(self.pk_result_先左右互扣[self.find_train_end_left:self.find_train_end_right+1])
        # print('[y]', y)
        with parallel_backend('threading', n_jobs=-2):  ## n_cpu +1 +n_jobs are used for parallelizing 
            model = LinearRegression().fit(x,y)
            self.pk_regression_box_先左右互扣 = model.predict(x).tolist()

        y = np.array(self.pk_result_再左右互扣[self.find_train_end_left:self.find_train_end_right+1])
        with parallel_backend('threading', n_jobs=-2):  ## n_cpu +1 +n_jobs are used for parallelizing 
            model = LinearRegression().fit(x,y)
            self.pk_regression_box_再左右互扣 = model.predict(x).tolist()  ## predicted depth
            # self.pk_score = model.score(x, y)
        # print('[self.pk_regression_box_再左右互扣]', self.pk_regression_box_再左右互扣)

        # print("[explained_variance_score] pk_regression_box=",self.pk_regression_box, 'y', y)
        # self.pk_variance = explained_variance_score(y, self.pk_regression_box)  ## error here but not sure what this is for
        # except:
        #     print("pk_regression有問題")
        #     pass

	# region ## not used now
    def rm_dummy(self,img):
      from skimage.transform import resize

      if 判斷是否為RGB(img):
        height, width, __  = img.shape
      else:
        height, width  = img.shape
      #print("before rm_dummy img=",img.shape)
      img_new = img[int(round(width/2 - width/1.41421356/2)):int(round(width/2 + width/1.41421356/2)), int(round(width/2 - width/1.41421356/2)):int(round(width/2 + width/1.41421356/2))]
      #print("after rm_dummy img=",img_new.shape)

      return img_new.astype(np.float32)

    def run_time(self):
      #print(self.run_time_list)
      #print(self.run_time_txt)
      run_time_list_new = [None] * (len(self.run_time_list)-1)
      for i in range(1,len( self.run_time_list)):
        #print("run_time_list[i].microsecond=", self.run_time_list[i].timestamp() )
        run_time_list_new[i-1] = self.run_time_list[i].timestamp() - self.run_time_list[i-1].timestamp()
      #print( run_time_list_new)
      fig , ax = plt.subplots()
      ax.pie(run_time_list_new ,labels = self.run_time_txt[1:len( self.run_time_txt)], autopct="%.0f%%")
      ax.set_title('Total time: '+str( self.run_time_list[len( self.run_time_list)-1].timestamp() - self.run_time_list[0].timestamp()), fontsize=16)

    def save_func(self,j):
      self.save_folder_name = str("/zip2(202201月6mm 225075)autolable/theta105(20)_sinogram(sobel03x07y185_1012subrectangle075_140zero_HoughR1L40G20)train(i+2)")
      import shutil
      try:
        ################注意：清除舊的遠端
        shutil.rmtree( "/content/gdrive/MyDrive/cloud_film" + self.save_folder_name + "_" + str(j))
        #pass
      except:
        pass
      shutil.copytree('/content/dataset_auto', '/content/gdrive/MyDrive/cloud_film' + self.save_folder_name + "_" + str(j), symlinks = False, ignore = None)

      ################注意：清除舊的本地
      try:
        shutil.rmtree("/content/dataset_auto/")
      except:
        pass

    def show_3D(self, filename):
      import matplotlib.pyplot as plt
      #from mpl_toolkits.mplot3d import Axes3D
      import numpy as np

      if self.found_needle:
        size_3D = 2
        x, y, z = np.indices((100*size_3D, 100*size_3D, 100*size_3D))
        cube_left = (70*size_3D >= x) & (x >= 30*size_3D) & (y == 52*size_3D) & (60*size_3D >= z) & (z >= 20*size_3D)
        cube_right = (70*size_3D >= x) & (x >= 30*size_3D) & (y == 48*size_3D) & (60*size_3D >= z) & (z >= 20*size_3D)

        self.json_obj["pk_regression_box"] = self.pk_regression_box
        #flag.json_obj["needle_location"] = flag.needle_location
        #flag.json_obj["綠框_regression_box"] = flag.綠框_regression_box
        self.json_obj["needle_location"] = self.needle_location

        #'''
        #用 houghP 來算 3D 圖 XY
        x1 = self.needle_location[0]
        y1 = self.needle_location[1]
        x3 = self.needle_location[2]
        y3 = self.needle_location[3]
        needle_width = x3 - x1
        needle_height = y3 - y1

        Y_middle = (-(y1+y3) / 2 + (self.height /2))/100 *size_3D
        #print("Z_middle=",Z_middle)

        #用 綠框_regression 來算 3D 圖 Z
        #Z_deg = -(self.pk_regression_box_再左右互扣[-1] - self.pk_regression_box_再左右互扣[0])/200
        Z_deg = -(self.pk_regression_box_再左右互扣[-1] - self.pk_regression_box_再左右互扣[0])/300

        print("Z_deg=",Z_deg)
        X_middle = ((x1+x3)/2 - self.width/4)/50 *size_3D

        #Z_middle = (self.綠框_regression_box[-1] + self.綠框_regression_box[0])/-10 *size_3D  #0619 Y太出去了
        Z_middle = (self.pk_regression_box_再左右互扣[-1] + self.pk_regression_box_再左右互扣[0])/-10 *size_3D  #0619 Y太出去了
        print("Z_middle=",Z_middle)

        #print("(y1+y3)/2=",(y1+y3)/2)
        #print("Y_middle=",Y_middle)
        #print("XY_deg=",XY_deg)

        try:
          #slope = 200/needle_width
          #print( slope,"*x + ",flag.pk_regression_box[0]/5,"+50 -y")
          #needle = ( abs(slope*(x-50) + 50 -y) <=2 ) & (z == 50)
          #print("needle_height / needle_width=",needle_height / needle_width)
          XY_deg = needle_height / needle_width / -2
          #print("Z_deg=",Z_deg)
          #needle = ( abs(XY_deg*(x - 50*size_3D -X_middle) + 50*size_3D + Y_middle -y) <=2 ) & ( abs(Z_deg*(x - 40*size_3D - X_middle) + 50*size_3D + Z_middle - z) <= 1)
          needle = ( abs(Z_deg*(x - 50*size_3D -X_middle) + 50*size_3D + Z_middle -y) <=2 ) & ( abs(XY_deg*(x - 40*size_3D - X_middle) + 45*size_3D + Y_middle - z) <= 1) #0619 Z太高了
        except:
          pass

        try:
          if self.found_needle == False:
            voxelarray = cube_left | cube_right
          else:
            voxelarray = cube_left | cube_right | needle
          colors = np.empty(voxelarray.shape, dtype=object)
          colors[needle] = 'red'
        except:
          voxelarray = cube_left | cube_right
          colors = np.empty(voxelarray.shape, dtype=object)

        colors[cube_left] = 'blue'
        colors[cube_right] = 'green'
        fig = plt.figure(figsize=(15,15))
        #fig = plt.figure(figsize=(5,5))
        ax = fig.add_subplot(projection='3d')
        ax.voxels(voxelarray, facecolors=colors)

        ax.set_axis_off()
        #angle=90
        angle=30
        plt.gca().invert_xaxis()
        #ax.view_init(30, angle)
        #ax.view_init(20, angle)
        ax.azim = 30
        ax.dist = 10
        ax.elev = 20
        fig.savefig(filename)
        plt.close()
        fig.clear()

        show_3D_img = cv2.imread(filename)
        show_3D_img = self.rm_dummy(show_3D_img)
        show_3D_img = self.rm_dummy(show_3D_img)
        show_3D_img = self.rm_dummy(show_3D_img)
        show_3D_img = self.rm_dummy(show_3D_img)
        cv2.imwrite(filename, show_3D_img, [cv2.IMWRITE_JPEG_QUALITY, 95])

        #'''
      else:
        shutil.copy('/content/img0000_3D_blank.jpg',filename)
	# endregion

In [None]:
def set_train_end_idx(flag):
    x1 = flag.needle_location[0]
    y1 = flag.needle_location[1]
    x3 = flag.needle_location[2]
    y3 = flag.needle_location[3]

    ## the closest carriage that match the prediction endpoints
    flag.x1_cut_idx = int(round(x1*flag.cutNum/flag.width - 0.5,0))
    flag.x3_cut_idx = int(round(x3*flag.cutNum/flag.width - 0.5,0))
    ## deal with too short needle
    if flag.x3_cut_idx == flag.x1_cut_idx:
        if flag.x3_cut_idx == flag.cutNum-1:  ## right side
            flag.x1_cut_idx -= 1
        elif flag.x3_cut_idx == 0: ## left side
            flag.x3_cut_idx += 1
        elif y3 > y1:
            flag.x3_cut_idx += 1
        else:
            flag.x1_cut_idx += 1
    
    flag.find_train_end_left = min(flag.x1_cut_idx, flag.x3_cut_idx)
    flag.find_train_end_right = max(flag.x1_cut_idx, flag.x3_cut_idx)
    # print(f'[left right end] {flag.find_train_end_left, flag.find_train_end_right}')
    return

# region ## not used now
def set_train_expand_end_idx(flag, img, find_train_center_i):
    x1 = flag.needle_location[0]
    x3 = flag.needle_location[2]
    ################################### 往左延伸 find_train_end_left = i #############################################
    flag.find_train_end_left = 0
    for i in range(find_train_center_i, 0, -1):               #find_train_center_i = i
        Xleft_point,Xright_point = flag.carriage_x[i]
        Y_1_point,Y_0_point = flag.carriage_y[i]
        # Y_1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1 - 13) )
        # Y_0_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1) + 13)
        #Y_m1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width/2)/flag.cutNum*i - needle_height/flag.cutNum*1 + 200) )
        #火車、線段檢查，用不到Y_m1_point。    pk才需要Y_m1_point去畫陰影區域  這裡的Y_m1_point是畫圖展現pk陰影法用的

        #接下來算差值取平方
        ################################### A.Needle location內 判斷是否為針​(有無黑點)
        if Xleft_point > min(x1,x3):      # >x1 or >x1_ 才對，變數是否有更新過？
            flag.find_train_threshold = 1200
            if (( flag.綠框_result[i] - flag.綠框_result[i+1] ) ** 2 >= flag.find_train_threshold       #i, i+1
                and flag.綠框_result[i] < 30 ) \
                or (( flag.綠框_result[i] - flag.綠框_result[i+2] ) ** 2 >= 3.6 * flag.find_train_threshold       #i, i+2,   若綠框是等差級數，應該是4倍threshold
                and flag.綠框_result[i] < 30 ) \
                or flag.綠框_result[i] < 16:                                  #綠框_result[i]平均象素亮度會與Y point高度範圍有關
                #if (A&A) or C
                cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_1_point), (255, 0, 0), 2)
                flag.find_train_end_left = i
                #加上這不是針 needleOrFake=false, 不啟動pk
                # print(f'[{Xleft_point} > x1] [R cv2.rectangle] blue')
                break
            else:
                #在center圖上畫畫
                cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_1_point), (0, 0, 255), 2)
                # print(f'[{Xleft_point} > x1] [R cv2.rectangle] red')
        ################################### #B.Needle location外 是針就寬容找端點
        else:
            flag.find_train_threshold = 800
            if (( flag.綠框_result[i] - flag.綠框_result[i+1] ) ** 2 >= flag.find_train_threshold        #i, i+1
                and flag.綠框_result[i] < 60 ) \
                or flag.綠框_result[i] <= 30:
                #if (A&A) or C
                cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_1_point), (18, 153, 255), 2) #BGR 橙色
                flag.find_train_end_left = i
                # print(f'[{Xleft_point} > x1] [R cv2.rectangle] org')
                break
            else:
                #在center圖上畫畫
                cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_1_point), (0, 255, 255), 2)
                # print(f'[{Xleft_point} <= x1] [R cv2.rectangle] yellow')


    ################################### 往右延伸 flag.find_train_end_right = i #############################################
    flag.find_train_end_right = flag.cutNum
    for i in range(find_train_center_i, flag.cutNum):           #find_train_center_i = i
        Xleft_point,Xright_point = flag.carriage_x[i]
        Y_1_point,Y_0_point = flag.carriage_y[i]
        # Y_1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1 - 13) )
        # Y_0_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1) + 13 )
        #Y_m1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width/2)/flag.cutNum*i - needle_height/flag.cutNum*1 + 200) )

        #接下來算差值取平方
        ################################### A.Needle location內 判斷是否為針​
        if Xright_point < max(x1,x3): # <x3 or <x3_ 才對，變數是否有更新過？
            flag.find_train_threshold = 1200
            if (( flag.綠框_result[i] - flag.綠框_result[i-1] ) ** 2 >= flag.find_train_threshold      #i, i-1
                and flag.綠框_result[i] < 30 ) \
                or (( flag.綠框_result[i] - flag.綠框_result[i-2] ) ** 2 >= 3.6 * flag.find_train_threshold      #i, i-2,   若綠框是等差級數，應該是4倍threshold
                and flag.綠框_result[i] < 30 ) \
                or flag.綠框_result[i] < 16:
                #if (A&A) or C
                cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_1_point), (255, 0, 0), 2)  #BGR 藍色
                flag.find_train_end_right = i
                #加上這不是針 needleOrFake=false, 不啟動pk
                break
            else:
                #在center圖上畫畫
                cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_1_point), (0, 0, 255), 2) #BGR 紅色
        ################################### #B.Needle location外 是針就寬容找端點
        else:
            flag.find_train_threshold = 800
            if (( flag.綠框_result[i] - flag.綠框_result[i-1] ) ** 2 >= flag.find_train_threshold     #i, i-1
                and flag.綠框_result[i] < 60 ) \
                or flag.綠框_result[i] <= 30:
                #if (A&A) or C
                cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_1_point), (18, 153, 255), 2) #BGR 橙色
                flag.find_train_end_right = i
                break
            else:
                #在center圖上畫畫
                cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_1_point), (0, 255, 255), 2) #BGR 黃色

    return img

def 線段檢查(flag, img, expand_len=False):
    """
    Set the find_train_end_left & find_train_end_right & brightness of each carriage
    - Input: 
        - flag
        - current_frame_center ((L+R)/2)
        - expand_len: Expand the length of train by pixel average value change.
                        If false, directly use model prediction as the 2 ends of the train. 
    - Return:
        - OutputImg_線段檢查: show where find_train_end_left & find_train_end_right are if expand_len
    """

    x1 = flag.needle_location[0]
    y1 = flag.needle_location[1]
    x3 = flag.needle_location[2]
    y3 = flag.needle_location[3]
    needle_width = x3 - x1  #此二行是否可負值？yes
    needle_height = y3 - y1
    find_train_center = False

    # ---------------------------
    # Set the carriage idx of train start & end
    # ---------------------------
    set_train_end_idx(flag)
    
    # print(f"[線段檢查] 畫出柱狀體 (x1,y1),(x3,y3) {(x1, y1)}{(x3, y3)}")
    # 火車來了 在原圖上畫畫(線段檢查先一律算出0-31，火車要正確端點)
    temp = y1 - needle_height/needle_width*x1 - needle_height/flag.cutNum*1
    for i in range(flag.find_train_end_left, flag.find_train_end_right+1):
        # ---------------------------
        # Carriage coordinate (ROI bbox of each part of train)
        # ---------------------------
        #Y_1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1 + needle_height/flag.cutNum*4) )
        #固定Y1=Y0+15 pixel，後方的Y_1_point也都一併修改
        Y_1_point = int(round( temp + needle_height/needle_width*(flag.width)/flag.cutNum*i - 13) )
        Y_0_point = int(round( temp + needle_height/needle_width*(flag.width)/flag.cutNum*i) + 13 )
        
        #Y_m1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1 + 200) )
        #火車、線段檢查，用不到Y_m1_point。    pk才需要Y_m1_point去畫陰影區域 這裡的Y_m1_point是畫圖展現pk陰影法用的

        flag.carriage_y[i] = [Y_1_point,Y_0_point] ###

        if expand_len: ## default False
            Xleft_point, Xright_point = flag.carriage_x[i]
            # ---------------------------
            #某一節車廂的平均亮度
            # ---------------------------
            temp_area = img[Y_1_point:Y_0_point, Xleft_point: Xright_point]
            bright = np.average(temp_area)
            # if np.isnan(bright):
            #    print('[Warn][line check] bright is nan, convert to 0')
            #    bright = 0
            #未來可以把pk串在這後面，以免重複計算亮區。不過，線段檢查火車是算0-31的亮區，pk只需要特定eg.4-18亮區(Y_1_point)，以及下方暗區(Y_m1_point)。
            #用get_result,把sum/height=result_bright, result_shadow, return result_bright/2, result_bright - result_shadow
            flag.綠框_result[i] = bright *255  ## only used for expand_len 
            
            # ---------------------------
            #算出火車中心點
            # ---------------------------
            if Xleft_point > (x3+x1)/2 and find_train_center == False:   #某節車廂的左界>needle location中點則停下來find_train_center_i = i
                find_train_center = True
                find_train_center_i = i

    if expand_len:  ## default False
        # ---------------------------
        # Set the carriage idx of train start & end by expanding from center
        # ---------------------------
        img = set_train_expand_end_idx(flag, img, find_train_center_i)

    # #################################################################畫底線在 960 上面 (1920/2)
    # OutputImg_green = np.zeros((int(round(flag.height/4)), int(round(flag.width))), dtype=np.float32)
    # OutputImg_green = cv2.cvtColor(OutputImg_green, cv2.COLOR_GRAY2RGB)

    # # print("線段檢查 [flag.find_train_end_left, find_train_end_right] ",flag.find_train_end_left,flag.find_train_end_right)
    # #for i in range(flag.find_train_end_left, flag.find_train_end_right):
    # for i in range(0, flag.cutNum):
    #     Xleft_point,Xright_point = flag.carriage_x[i]
    #     cv2.rectangle(OutputImg_green, (int(round(Xleft_point)),int(round(flag.height/8)) ), (int(round(Xright_point)), int(round(flag.height/8+flag.綠框_result[i]*-2 )) ), (255, 255, 255), -1)
    
    return img #, OutputImg_green.astype(np.float32)
# endregion

#左右二張圖 OutputImg_PK, OutputImg_pk_regression_先左右互扣, OutputImg_pk_regression_再左右互扣 = PK(flag, current_frame.copy())
def PK(flag, img, return_plot=False, set_train_idx_point=False):
    """
    Compare the left & right frame.
    Originally,
    - Input:
        - flag
        - img: left & right image concat together
    - Return: 
        - OutputImg_PK, 
        - OutputImg_pk_regression_先左右互扣
        - OutputImg_pk_regression_再左右互扣
    """
    # print("[PK] 畫出柱狀體")
    x1 = flag.needle_location[0]
    y1 = flag.needle_location[1]
    x3 = flag.needle_location[2]
    y3 = flag.needle_location[3]
    needle_width = x3 - x1  #此二行是否可負值？
    needle_height = y3 - y1

    if set_train_idx_point:
        # ---------------------------
        # Set the carriage idx of train start & end
        # ---------------------------
        set_train_end_idx(flag)

    if len(img.shape) == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    #for i in range(0,flag.cutNum):
    for i in range(flag.find_train_end_left, flag.find_train_end_right+1):
        #Y_1_point = int(round( y1 + needle_height/flag.cutNum*i - needle_height/flag.cutNum*1 + needle_height/flag.cutNum*4) )
        #Y_0_point = int(round( y1 + needle_height/flag.cutNum*i - needle_height/flag.cutNum*1) )
        #Y_m1_point = int(round( y1 + needle_height/flag.cutNum*i - needle_height/flag.cutNum*1 - needle_height/flag.cutNum*32))
        Xleft_point,Xright_point = flag.carriage_x[i]
        
        Y_m1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1 + 200) )
        if set_train_idx_point:
            Y_1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1 - 13) )
            Y_0_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1) + 10 )
            flag.carriage_y[i] = [Y_1_point,Y_0_point]
        else:
            Y_1_point, Y_0_point = flag.carriage_y[i]
            Y_0_point = Y_0_point - 3

        #先bright-shadow，再左右互扣 (same result as 先左右互扣)
        # ---------------------------
        # needle bright area (L) - bottom shadow area (L)
        # ---------------------------
        ###################left
        #########################紅框pk亮度(亮區佔70%pk)
        temp_area = img[Y_1_point:Y_0_point, Xleft_point: Xright_point]  ## LB
        bright = np.average(temp_area)
        
        #########################白框pk暗度(暗區佔30%pk)
        temp_area = img[Y_0_point:Y_m1_point, Xleft_point: Xright_point]  ## LS
        shadow = np.average(temp_area)
        ###################亮度-暗度
        ## 0.7LB - 0.3LS
        left_result = (left_bright_weight*bright - left_shadow_weight*shadow) / 1  #值太大，人工縮小一點

        # ---------------------------
        # needle bright area (R) - bottom shadow area (R)
        # ---------------------------
        ###################right
        #########################紅框pk亮度(亮區佔70%pk)
        temp_area = img[Y_1_point:Y_0_point, Xleft_point+int(round(flag.width)): Xright_point+int(round(flag.width))]  ## RB
        bright = np.average(temp_area)
        #########################白框pk暗度(暗區佔30%pk)
        temp_area = img[Y_0_point:Y_m1_point, Xleft_point+int(round(flag.width)): Xright_point+int(round(flag.width))]  ## RS
        shadow = np.average(temp_area)
        ###################亮度-暗度
        ## 0.7RB - 0.3RS
        right_result = (right_bright_weight*bright - right_shadow_weight*shadow) / 1   #值太大，人工縮小一點

        # ##############################左右相減亮為正 左右相減暗為正 (之前是0.7bright-0.3shadow，再左右互扣。目前是先左右互扣，再bright-shadow)
        # ## 0.7LB - 0.3LS - 0.7RB + 0.3RS
        flag.pk_result_再左右互扣[i] = (left_result - right_result) *255
        flag.pk_result_先左右互扣[i] = (left_result - right_result) *255

        if return_plot:
            #在原圖上畫畫
            #left
            cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_1_point), (0, 0, 255), 2)
            cv2.rectangle(img, (Xleft_point,Y_0_point), (Xright_point, Y_m1_point), (255, 255, 255), 2)   #mark掉，就可以變透明的
            #right
            cv2.rectangle(img, (Xleft_point+int(round(flag.width)),Y_0_point), (Xright_point+int(round(flag.width)), Y_1_point), (0, 0, 255), 2)
            cv2.rectangle(img, (Xleft_point+int(round(flag.width)),Y_0_point), (Xright_point+int(round(flag.width)), Y_m1_point), (255, 255, 255), 2)   #mark掉，就可以變透明的


    #要等前面算完一次，再畫出regression = y1_regression, y3_regression
    flag.pk_regression()
    ################################################################# PK #########################################
    #畫底線在 960 上面 (1920/2)
    OutputImg_pk_regression_先左右互扣 = np.zeros((int(round(flag.height/4)), int(round(flag.width))), dtype=np.float32)
    OutputImg_pk_regression_先左右互扣 = cv2.cvtColor(OutputImg_pk_regression_先左右互扣, cv2.COLOR_GRAY2RGB)
    OutputImg_pk_regression_再左右互扣 = OutputImg_pk_regression_先左右互扣.copy()

    if return_plot:
        j=0
        for i in range(flag.find_train_end_left, flag.find_train_end_right+1):
            Xleft_point,Xright_point = flag.carriage_x[i]
            Y_1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1 - 13) )
            Y_0_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1) + 13 )
            Y_m1_point = int(round( y1-needle_height/needle_width*x1 + needle_height/needle_width*(flag.width)/flag.cutNum*i - needle_height/flag.cutNum*1 + 200) )
            #   print("flag.pk_result_先左右互扣[i]=",flag.pk_result_先左右互扣[i])
            cv2.rectangle(OutputImg_pk_regression_先左右互扣, (int(round(Xleft_point)),int(round(flag.height/8)) ), (int(round(Xright_point)), int(round(flag.height/8+flag.pk_result_先左右互扣[i]*-2 )) ), (255, 255, 255), -1)
            #   print("flag.pk_result_再左右互扣[i]=",flag.pk_result_再左右互扣[i])
            cv2.rectangle(OutputImg_pk_regression_再左右互扣, (int(round(Xleft_point)),int(round(flag.height/8)) ), (int(round(Xright_point)), int(round(flag.height/8+flag.pk_result_再左右互扣[i]*-2 )) ), (255, 255, 255), -1)
            j=j+1

        # try:
        #   #cv2.line(OutputImg_PK, (0,int(round(flag.height/8 + flag.pk_regression_box[0]*-3 )) ), (int(round(Xright_point)), int(round(flag.height/8 + flag.pk_regression_box[-1]*-3 ))), (0,0,255), 20, cv2.LINE_AA)
        cv2.line(OutputImg_pk_regression_先左右互扣, 
                (int(round(0 + flag.width*(flag.find_train_end_left+0.1)/flag.cutNum,0 )) ,
                int(round(flag.height/8 + flag.pk_regression_box_先左右互扣[0]*-3 )) ),   
                (int(round(Xright_point)), 
                int(round(flag.height/8 + flag.pk_regression_box_先左右互扣[-1]*-3 ))), (0,0,255), 20, cv2.LINE_AA)
        cv2.line(OutputImg_pk_regression_再左右互扣, 
                (int(round(0 + flag.width*(flag.find_train_end_left+0.1)/flag.cutNum,0 )) ,
                int(round(flag.height/8 + flag.pk_regression_box_再左右互扣[0]*-3 )) ), 
                (int(round(Xright_point)), 
                    int(round(flag.height/8 + flag.pk_regression_box_再左右互扣[-1]*-3 ))), (0,0,255), 20, cv2.LINE_AA)
        # except:
        #   pass

    return img.astype(np.float32), OutputImg_pk_regression_先左右互扣.astype(np.float32), OutputImg_pk_regression_再左右互扣


## Other Utils
- Plot
- Postprocessing
- Inference

In [None]:
def save_pred_to_plot(image_C, pred_mask, pred_line, gt_mask, gt_line, 
              OutputImg_PK, OutputImg_pk_regression_先左右互扣, OutputImg_pk_regression_再左右互扣,
              fname="pred.png"):
    """
    Plot model prediction (if prediction is not None) & PK result
    - First title: the coordinates
    - Row 1: Center image with prediction dot & pred mask
    - Row 2~4: image PK on Left & Right images
    """
    ## do not save image if no pred
    if len(np.unique(pred_mask)) <= 1:
        return
    row, col, scale = 4, 1, 3
    fig, axes = plt.subplots(row, col, gridspec_kw={'height_ratios': [3, 3, 1,1]}, figsize=(3*col*scale, row*scale))

    ## Row 1: image with prediction dot & pred mask
    image_pred = np.concatenate([image_C, pred_mask], axis =-1)
    axes[0].imshow((image_pred * 255).astype(np.uint8), cmap='gray')
    if pred_line is not None and len(np.unique(pred_mask)) > 1:
        axes[0].scatter([pred_line[0], pred_line[2]], [pred_line[1], pred_line[3]], c='red', s=10, label='pred')  # Plot dots
        axes[0].set_title(', '.join(str(int(x)) for x in pred_line))
    if gt_line is not None:
        axes[0].scatter([gt_line[0], gt_line[2]], [gt_line[1], gt_line[3]], c='green', s=10, label='gt')  # Plot dots

    ## Row 2~4: image PK on Left & Right
    cv2.putText(OutputImg_PK,"10",(0,int(round(OutputImg_PK.shape[0]*7/8)) ) ,cv2.FONT_HERSHEY_SIMPLEX,2, (255, 255, 255), 8, cv2.LINE_AA)
    cv2.putText(OutputImg_pk_regression_先左右互扣,"10-1:first",(0,int(round(OutputImg_pk_regression_先左右互扣.shape[0]*7/8)) ) ,cv2.FONT_HERSHEY_SIMPLEX,2, (255, 255, 255), 8, cv2.LINE_AA)
    cv2.putText(OutputImg_pk_regression_再左右互扣,"10-2:after",(0,int(round(OutputImg_pk_regression_再左右互扣.shape[0]*7/8)) ) ,cv2.FONT_HERSHEY_SIMPLEX,2, (255, 255, 255), 8, cv2.LINE_AA)
    if OutputImg_PK is not None:
        # OutputImg_線段檢查 = cv2.cvtColor(OutputImg_線段檢查, cv2.COLOR_BGR2RGB)
        OutputImg_PK = cv2.cvtColor(OutputImg_PK, cv2.COLOR_BGR2RGB)
        OutputImg_pk_regression_先左右互扣 = cv2.cvtColor(OutputImg_pk_regression_先左右互扣, cv2.COLOR_BGR2RGB)
        OutputImg_pk_regression_再左右互扣 = cv2.cvtColor(OutputImg_pk_regression_再左右互扣, cv2.COLOR_BGR2RGB)
        # axes[1].imshow(OutputImg_線段檢查[:1759,:1759].astype('uint8'))
        axes[1].imshow(OutputImg_PK.astype('uint8'))
        axes[2].imshow(OutputImg_pk_regression_先左右互扣.astype('uint8'))
        axes[3].imshow(OutputImg_pk_regression_再左右互扣.astype('uint8'))
        axes[1].set_title('OutputImg_PK')
        axes[2].set_title('OutputImg_pk_first')
        axes[3].set_title('OutputImg_pk_after')
    
    for r in range(row):
        axes[r].axis('off')
    plt.margins(0,0)
    # plt.show()
    plt.savefig(fname, bbox_inches='tight')
    plt.close(fig)

def save_pred_to_json(fnames, pred_coord_label, img_size, save_folder_dir=None):
    if save_folder_dir is not None:
        filename = os.path.basename(fnames[0]).replace(".jpg", ".json")
        json_path = os.path.join(save_folder_dir, filename)
    else:
        json_path = fnames[0].replace(".jpg", ".json")
    if os.path.exists(json_path):
        with open(json_path, "r", encoding="utf-8") as f:
            json_dict = json.load(f)
        # for label_id in range(len(json_dict["shapes"])):
        #     if json_dict["shapes"][label_id]["label"] == "model_pred": ## no need to add model label again
        #         break
        #     if label_id == len(json_dict["shapes"]) -1:  ## "model_pred" not found in labels
        #         json_dict["shapes"].append(pred_coord_label)
        # if len(json_dict["shapes"]) == 0:  ## "model_pred" not found in labels
        json_dict["shapes"].append(pred_coord_label)
        json_dict["imageHeight"], json_dict["imageWidth"] = img_size, img_size 
    else:
        json_dict = {
            "version": "4.4.1",
            "flags": {},
            "shapes": [pred_coord_label],
            "imagePath": fnames[0][-9:],
            "imageData": None,
            "imageHeight": img_size,
            "imageWidth": img_size
            }
    with open(json_path, "w") as json_file:
        json.dump(json_dict, json_file, indent=4)


In [None]:
## Postprocess method 1: PCA
def mask_to_line(pred_idx, img_size, flag):
    with parallel_backend('threading', n_jobs=-2):
        pca = PCA(n_components=1)
        # Find the line from the points in the binary mask using PCA
        # pred_line = mask2Line(pred_idx, pca)
        mask_2d = np.stack([pred_idx[0], pred_idx[1]], axis=1)

        # Fit PCA on the target mask points
        pca.fit(mask_2d)
        mask_1d = pca.transform(mask_2d)
        mask_2d_new = pca.inverse_transform(mask_1d)
        mask_2d_new = sorted(mask_2d_new, key=lambda x: x[0])
    
    ## clamp to startup x >= 0
    for i in range(len(mask_2d_new)):
        x_i = mask_2d_new[i][1]
        if x_i >= 0:
            break
    mask_2d_new = mask_2d_new[i:]

    mask_line_end_points = [
        [ mask_2d_new[0][1],   ## x0
        mask_2d_new[0][0],   ## y0
        ],  
        [ mask_2d_new[-1][1],  ## x1
        mask_2d_new[-1][0],  ## y1
        ],
    ]
    mask_line_end_points = np.clip(mask_line_end_points, 0, img_size).tolist()
    if mask_line_end_points[0][0] == mask_line_end_points[1][0]:
        mask_line_end_points[1][0] += 1
    if mask_line_end_points[0][1] == mask_line_end_points[1][1]:
        mask_line_end_points[1][1] += 1
    pred_coord_label = {"label": "needle",
                        "points": mask_line_end_points,
                        "group_id": None,
                        "shape_type": "rectangle",
                        "flags":{}}
    flag.needle_location = [ mask_2d_new[0][1],  mask_2d_new[0][0],  mask_2d_new[-1][1],  mask_2d_new[-1][0] ]  # [x1,y1,x3,y3]
    return pred_coord_label  ## for json

## Postprocess method 2: fit a rotate-able minimum area rectangle
## default (slightly faster than mask_to_line by PCA)
def min_rect_2_line(mask: np.ndarray, flag):
    """
    Fit a minimum area bounding rectangle to a binary mask,
    extract the 2 middle points on the shorter edges.
    NOTE: this may introduce bias if mask width is unstable. 

    Args:
        mask: Binary mask tensor of shape [H, W], dtype=torch.uint8
    Returns:
        shorter_edge_midpoints : [x0,y0,x1,y1]
    """
    # Convert PyTorch tensor to NumPy
    img_size = max(mask.shape[-1], mask.shape[-2])

    # Find contours
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours) == 0:
        raise ValueError("No contours found in the mask.")
    contours_combined = np.vstack(contours)  ## stack a list of contours
    hull = cv2.convexHull(contours_combined)
    # Get the minimum area bounding rectangle
    # rect = cv2.minAreaRect(contours[0])
    rect = cv2.minAreaRect(hull)
    # (cx, cy), (width, height), angle = rect  # Center, size (width, height), and rotation angle
    # Get the four corner points of the rectangle
    box = cv2.boxPoints(rect)
    # box = np.int64(np.round(box))  # Convert to integer
    
    ## pair the coords on short edge together
    dist_list, pair = [], [1,2,3]
    for c in pair:
        dist = pow(box[0][0]-box[c][0],2)+ pow(box[0][1]-box[c][1],2)
        dist_list.append(dist)
    min_dist_idx = dist_list.index(min(dist_list))+1
    pair.remove(min_dist_idx)
    pair = [0,min_dist_idx] + pair
    
    x0,y0 = (box[0][0] + box[pair[1]][0]) / 2, (box[0][1] + box[pair[1]][1]) / 2
    x1,y1 = (box[pair[2]][0] + box[pair[3]][0]) / 2, (box[pair[2]][1] + box[pair[3]][1]) / 2
    
    shorter_edge_midpoints = [x0,y0,x1,y1]
    shorter_edge_midpoints = [max(min(x, img_size), 0) for x in shorter_edge_midpoints]  ## clamp
    
    if shorter_edge_midpoints[0] == shorter_edge_midpoints[2]:
        shorter_edge_midpoints[2] += 1
    if shorter_edge_midpoints[1] == shorter_edge_midpoints[3]:
        shorter_edge_midpoints[3] += 1
    mask_line_end_points = [
        [ shorter_edge_midpoints[0],   ## x0
        shorter_edge_midpoints[1],   ## y0
        ],  
        [ shorter_edge_midpoints[2],  ## x1
        shorter_edge_midpoints[3],  ## y1
        ],
    ]
    pred_coord_label = {"label": "needle",
                        "points": mask_line_end_points,
                        "group_id": None,
                        "shape_type": "rectangle",
                        "flags":{}}
    flag.needle_location = shorter_edge_midpoints  # [x1,y1,x3,y3]
    return pred_coord_label  ## for json


In [None]:
## Inference helpers (including model prediction & PK)
## NOTE: only `inference_LCR_image_with_flag` is updated to the newest version

## input is torch dataloader
def inference_loader_with_flag(model, device, loader, flag, save_json=False, save_mask=False):
    '''
    - Get 3 frames t1/t2/t3 from loader, view them as L/C/R (left/center/right), and model predicts the needle at C
    - PK on L & R frame based on predicted needle
    '''

    with torch.no_grad():
        # parameters for iteration over the buffer
        time_window = 3 #run.config["Time Window"]

        # # total steps for evaluation
        for buffer in loader:  # loader getitem() returns a buffer
            if flag.count1 % flag.fps_reduce == 0:
                # Image data
                images = buffer["images"][:, :time_window, :, :].to(device)  ## [N=1, T=3, h, w]
                fnames = buffer["img_path"][1]   ## [N]
                origin_img_size = buffer["origin_img_size"]
                img_size = origin_img_size[0].item()
                origin_image = buffer["origin_images"][:, :time_window, :, :]
                image_array_L = origin_image[:,0,:,:].squeeze()
                image_array_C = origin_image[:,1,:,:].squeeze().cpu().numpy()  ## [H,W]
                image_array_R = origin_image[:,2,:,:].squeeze()
                image_array_LR = torch.concatenate([image_array_L, image_array_R], axis=-1).cpu().numpy()
                image_array_LR = cv2.cvtColor(image_array_LR, cv2.COLOR_GRAY2RGB)  ##[H,W*2,3]
                
                # Forward pass
                pred_masks = model(images)  # [N=1, 3, H, W]
                if isinstance(pred_masks, dict):  ## Mask2Former-based model
                    output_dict = pred_masks
                    pred_masks = output_dict["pred_masks"]

                if output_dict["pred_class"].item() < 0.5:
                    pred_masks.zero_()

                # new ver. -----------------------
                pred_mask_resize = pred_masks#.unsqueeze(0)
                if pred_mask_resize.shape[-1] != img_size:
                    pred_mask_resize = v2.functional.resize(pred_mask_resize, (img_size,img_size), interpolation=tf.InterpolationMode.BILINEAR, antialias=True)
                # pred_mask_array = (pred_mask_resize > 0.5).detach().to(dtype=torch.uint8)
                # pred_mask_array= pred_mask_array.squeeze().cpu().numpy()

                pred_mask_tensor = (pred_mask_resize > 0.5).to(dtype=torch.uint8).squeeze()
                pred_mask = pred_mask_tensor.detach().cpu().numpy()
                if (pred_mask_tensor != 0).sum() < 2:  # -----------------------

                ## old ver. # -----------------------
                # ## Resize back to image size
                # pred_mask_resize = pred_masks[0].unsqueeze(0)
                # pred_mask_resize = v2.functional.resize(pred_mask_resize, (img_size,img_size), interpolation=tf.InterpolationMode.BILINEAR, antialias=True)
                # pred_mask_array = (pred_mask_resize > 0.5).detach().to(dtype=torch.uint8)
                # pred_mask_array= pred_mask_array.squeeze().cpu().numpy()

                # pred_mask_tensor = pred_mask_resize.squeeze()
                # pred_mask_tensor = (pred_mask_tensor > 0.5).to(dtype=torch.uint8)
                # pred_mask = pred_mask_tensor.detach().cpu().numpy()

                # ## pred mask to line
                # pred_idx = np.where(pred_mask_array > 0.5)  ## coordinates
                # if len(pred_idx[0]) <= 1:  # -----------------------
                    flag.found_needle = False
                    print(f'idx {flag.count1} Needle not found\n------------')
                else:
                    flag.found_needle = True
                    # pred_coord_label = mask_to_line(pred_idx, img_size, flag)
                    pred_coord_label = min_rect_2_line(pred_mask, flag)

                    ## save json
                    if save_json:
                        save_pred_to_json(fnames, pred_coord_label, img_size, save_folder_dir=flag.save_folder_name)
                
                    # 線段檢查(current_frame_center)，PK(current_frame)
                    # current_frame_gray_blur = cv2.blur(image_array_LR, (flag.blur, flag.blur))
                    # OutputImg_線段檢查 = 線段檢查(flag, current_frame_gray_blur.copy())
                    (OutputImg_PK, 
                     OutputImg_pk_regression_先左右互扣, 
                     OutputImg_pk_regression_再左右互扣) = PK(flag, image_array_LR.copy(), return_plot=save_mask, set_train_idx_point=True)

                    ## save mask
                    if save_mask:
                        flag.filename = os.path.basename(fnames[0]).replace(".jpg", '_pred.png')
                        gt_mask, gt_line = None,None
                        save_pred_to_plot(image_array_C, pred_mask_array, flag.needle_location, gt_mask , gt_line, 
                                            OutputImg_PK, OutputImg_pk_regression_先左右互扣, OutputImg_pk_regression_再左右互扣,
                                            fname=os.path.join(flag.save_folder_name, flag.filename))
                    
                    ## Final Estimation
                    x1, y1, x3, y3 = flag.needle_location
                    # print(f'[cut] {flag.x1_cut_idx} ~ {flag.x3_cut_idx}  [train_end_left] {flag.find_train_end_left}')
                    
                    if flag.x1_cut_idx < flag.x3_cut_idx:
                        x1y1_regression_depth = flag.pk_regression_box_再左右互扣[flag.x1_cut_idx-flag.find_train_end_left]
                        x3y3_regression_depth = flag.pk_regression_box_再左右互扣[flag.x3_cut_idx-1-flag.find_train_end_left]
                    else:
                        x1y1_regression_depth = flag.pk_regression_box_再左右互扣[flag.x1_cut_idx-1-flag.find_train_end_left]
                        x3y3_regression_depth = flag.pk_regression_box_再左右互扣[flag.x3_cut_idx-flag.find_train_end_left]
                    print(f'idx {flag.count1}')
                    print(f'[x1 y1 raw_z1] {x1:.4f}, {y1:.4f}, {x1y1_regression_depth:.4f}')
                    print(f'[x3 y3 raw_z3] {x3:.4f}, {y3:.4f}, {x3y3_regression_depth:.4f}')
                    print('------------')

            flag.count1 = flag.count1+1
            # pbar.update(1)
            # if flag.count1 > 100:
            #     return

## input is Sonosite folder
def inference_folder_with_flag(model, device, data_dir, flag, save_json=False, save_mask=False):
    '''
    - Get 3 frames t1/t2/t3 from loader, view them as L/C/R (left/center/right), and model predicts the needle at C
    - PK on L & R frame based on predicted needle
    '''
    file_names = sorted(os.listdir(data_dir))
    image_names = [f for f in file_names if f[0] == "a" and f.endswith(".jpg")]
    lcr_queue = []

    with torch.no_grad():
        # parameters for iteration over the buffer
        time_window = 3 #run.config["Time Window"]

        # # total steps for evaluation
        for f in image_names:
            image = Image.open(os.path.join(data_dir, f)).convert("L")
            ## preprocess
            image_tensor = trans_totensor(image)  ## [1,H,W]
            img_size = image_tensor.shape[-1]
            
            lcr_queue.append(image_tensor)
            if len(lcr_queue) == 3 and flag.count1 % flag.fps_reduce == 0:
                images = torch.stack(lcr_queue, dim=1)
                images = v2.functional.resize(images, (config["Model"]["image_size"], config["Model"]["image_size"]), 
                                                interpolation=tf.InterpolationMode.BILINEAR, antialias=True)
                images = images.to(device)
                fnames = [f]

                image_array_L = lcr_queue[0].squeeze()
                image_array_C = lcr_queue[1].squeeze()  ## [H,W]
                image_array_R = lcr_queue[2].squeeze()
                image_array_LR = torch.concatenate([image_array_L, image_array_R], axis=-1).cpu().numpy()
                image_array_LR = cv2.cvtColor(image_array_LR, cv2.COLOR_GRAY2RGB)

                # Forward pass
                pred_masks = model(images)  # [N=1, 3, h, w]
                if isinstance(pred_masks, dict):  ## Mask2Former-based model
                    output_dict = pred_masks
                    pred_masks = output_dict["pred_masks"]
                    if output_dict["pred_class"].item() < 0.5:
                        pred_masks.zero_()

                ## new ver. remove pred_idx (rely on .sum) ## -------------
                # ## Resize back to image size  
                pred_mask_resize = pred_masks#.unsqueeze(0)
                if pred_mask_resize.shape[-1] != img_size:
                    pred_mask_resize = v2.functional.resize(pred_mask_resize, (img_size,img_size), interpolation=tf.InterpolationMode.BILINEAR, antialias=True)
                # pred_mask_array = (pred_mask_resize > 0.5).detach().to(dtype=torch.uint8)
                # pred_mask_array= pred_mask_array.squeeze().cpu().numpy()

                pred_mask_tensor = (pred_mask_resize > 0.5).to(dtype=torch.uint8).squeeze()
                pred_mask = pred_mask_tensor.detach().cpu().numpy()
                if (pred_mask_tensor != 0).sum() < 2:
                    flag.found_needle = False
                    print(f'idx {flag.count1} Needle not found\n------------')
                else:
                    flag.found_needle = True
                    # pred_coord_label = mask_to_line(pred_idx, img_size, flag)
                    pred_coord_label = min_rect_2_line(pred_mask, flag) ## this is faster

                    ## save json
                    if save_json:
                        save_pred_to_json(fnames, pred_coord_label, img_size, save_folder_dir=flag.save_folder_name)
                
                    # 線段檢查(current_frame_center)，PK(current_frame)
                    # current_frame_gray_blur = cv2.blur(image_array_LR, (flag.blur, flag.blur))
                    # OutputImg_線段檢查 = 線段檢查(flag, current_frame_gray_blur.copy())
                    (OutputImg_PK, 
                     OutputImg_pk_regression_先左右互扣, 
                     OutputImg_pk_regression_再左右互扣) = PK(flag, image_array_LR.copy(), return_plot=save_mask, set_train_idx_point=True)

                    ## save mask
                    if save_mask:
                        flag.filename = os.path.basename(fnames[0]).replace(".jpg", '_pred.png')
                        gt_mask, gt_line = None,None
                        save_pred_to_plot(image_array_C, pred_mask, flag.needle_location, gt_mask , gt_line, 
                                            OutputImg_PK, OutputImg_pk_regression_先左右互扣, OutputImg_pk_regression_再左右互扣,
                                            fname=os.path.join(flag.save_folder_name, flag.filename))
                    
                    ## Final Estimation
                    x1, y1, x3, y3 = flag.needle_location
                    # print(f'[cut] {flag.x1_cut_idx} ~ {flag.x3_cut_idx}  [train_end_left] {flag.find_train_end_left}')
                    
                    if flag.x1_cut_idx < flag.x3_cut_idx:
                        x1y1_regression_depth = flag.pk_regression_box_再左右互扣[flag.x1_cut_idx-flag.find_train_end_left]
                        x3y3_regression_depth = flag.pk_regression_box_再左右互扣[flag.x3_cut_idx-1-flag.find_train_end_left]
                    else:
                        x1y1_regression_depth = flag.pk_regression_box_再左右互扣[flag.x1_cut_idx-1-flag.find_train_end_left]
                        x3y3_regression_depth = flag.pk_regression_box_再左右互扣[flag.x3_cut_idx-flag.find_train_end_left]
                    print(f'idx {flag.count1}')
                    print(f'[x1 y1 raw_z1] {x1:.4f}, {y1:.4f}, {x1y1_regression_depth:.4f}')
                    print(f'[x3 y3 raw_z3] {x3:.4f}, {y3:.4f}, {x3y3_regression_depth:.4f}')
                    print('------------')

            if len(lcr_queue) == 3:
                flag.count1 = flag.count1+1
                lcr_queue.pop(0)  ## remove the first frame
            # pbar.update(1)
            # if flag.count1 > 100:
            #     return

## input is processed Prodigy video (crop into 3 videos)
def inference_LCRfolder_with_flag(model, device, data_dir, flag, save_json=False, save_mask=False):
    '''
    - Get same file name image from folder L/C/R (left/center/right), and model predicts the needle at C
    - PK on L & R frame based on predicted needle
    '''
    with torch.no_grad():
        lcr_dir_path = [data_dir+"/L", data_dir+"/C", data_dir+"/R"]

        file_names = sorted(os.listdir(data_dir+"/L"))
        image_names = [f for f in file_names if f[0] == "a" and f.endswith(".jpg")]  # ["a0001.jpg", "a0002.jpg", ...]
        
        for f in image_names:
            fname_list, lcr_queue = [],[]
            for dir_path in lcr_dir_path:
                img = Image.open(os.path.join(dir_path, f)).convert("L")
                img_tensor = trans_totensor(img)
                lcr_queue.append(img_tensor)
                fname_list.append(os.path.join(dir_path, f))
                img.close()
            # print(fname_list)
            img_size = img_tensor.shape[-1]
            
            if flag.count1 % flag.fps_reduce == 0:
                images = torch.stack(lcr_queue, dim=1)
                images = v2.functional.resize(images, (config["Model"]["image_size"], config["Model"]["image_size"]), 
                                                interpolation=tf.InterpolationMode.BILINEAR, antialias=True)
                images = images.to(device)
                fnames = [f]

                ## get numpy array for PK
                image_array_L = lcr_queue[0].squeeze()
                image_array_C = lcr_queue[1].squeeze()  ## [H,W]
                image_array_R = lcr_queue[2].squeeze()
                image_array_LR = torch.concatenate([image_array_L, image_array_R], axis=-1).cpu().numpy()
                image_array_LR = cv2.cvtColor(image_array_LR, cv2.COLOR_GRAY2RGB)  ## 0.~1.
                # print('[image_array_LR]', torch.max(image_array_LR),torch.min(image_array_LR), len(torch.unique(image_array_LR)))

                # Forward pass
                pred_masks = model(images)  # [N=1, 3, h, w]
                if isinstance(pred_masks, dict):  ## Mask2Former-based model
                    output_dict = pred_masks
                    pred_masks = output_dict["pred_masks"]
                    if output_dict["pred_class"].item() < 0.5:
                        pred_masks.zero_()

                ## new ver. remove pred_idx (rely on .sum) ## -------------
                # ## Resize back to image size  
                pred_mask_resize = pred_masks#.unsqueeze(0)
                if pred_mask_resize.shape[-1] != img_size:
                    pred_mask_resize = v2.functional.resize(pred_mask_resize, (img_size,img_size), interpolation=tf.InterpolationMode.BILINEAR, antialias=True)
                # pred_mask_array = (pred_mask_resize > 0.5).detach().to(dtype=torch.uint8)
                # pred_mask_array= pred_mask_array.squeeze().cpu().numpy()

                pred_mask_tensor = (pred_mask_resize > 0.5).to(dtype=torch.uint8).squeeze()
                pred_mask = pred_mask_tensor.detach().cpu().numpy()
                
                if (pred_mask_tensor != 0).sum() < 2: #len(pred_idx[0]) <= 1:  ## -------------
                    flag.found_needle = False
                    print(f'idx {flag.count1} Needle not found\n------------')
                else:
                    flag.found_needle = True
                    # pred_coord_label = mask_to_line(pred_idx, img_size, flag)
                    pred_coord_label = min_rect_2_line(pred_mask, flag)

                    ## save json
                    if save_json:
                        save_pred_to_json(fnames, pred_coord_label, img_size, save_folder_dir=flag.save_folder_name)
                
                    # 線段檢查(current_frame_center)，PK(current_frame)
                    # current_frame_gray_blur = cv2.blur(image_array_LR, (flag.blur, flag.blur))
                    # OutputImg_線段檢查 = 線段檢查(flag, current_frame_gray_blur.copy())
                    (OutputImg_PK, 
                     OutputImg_pk_regression_先左右互扣, 
                     OutputImg_pk_regression_再左右互扣) = PK(flag, image_array_LR.copy(), return_plot=save_mask, set_train_idx_point=True)

                    ## save mask
                    if save_mask:
                        flag.filename = os.path.basename(fnames[0]).replace(".jpg", '_pred.png')
                        gt_mask, gt_line = None,None
                        save_pred_to_plot(image_array_C, pred_mask, flag.needle_location, gt_mask , gt_line, 
                                            OutputImg_PK, OutputImg_pk_regression_先左右互扣, OutputImg_pk_regression_再左右互扣,
                                            fname=os.path.join(flag.save_folder_name, flag.filename))
                    
                    ## Final Estimation
                    x1, y1, x3, y3 = flag.needle_location
                    # print(f'[cut] {flag.x1_cut_idx} ~ {flag.x3_cut_idx}  [train_end_left] {flag.find_train_end_left}')
                    
                    if flag.x1_cut_idx < flag.x3_cut_idx:
                        x1y1_regression_depth = flag.pk_regression_box_再左右互扣[flag.x1_cut_idx-flag.find_train_end_left]
                        x3y3_regression_depth = flag.pk_regression_box_再左右互扣[flag.x3_cut_idx-1-flag.find_train_end_left]
                    else:
                        x1y1_regression_depth = flag.pk_regression_box_再左右互扣[flag.x1_cut_idx-1-flag.find_train_end_left]
                        x3y3_regression_depth = flag.pk_regression_box_再左右互扣[flag.x3_cut_idx-flag.find_train_end_left]
                    print(f'idx {flag.count1}')
                    print(f'[x1 y1 raw_z1] {x1:.4f}, {y1:.4f}, {x1y1_regression_depth:.4f}')
                    print(f'[x3 y3 raw_z3] {x3:.4f}, {y3:.4f}, {x3y3_regression_depth:.4f}')
                    print('------------')

                    # plt.figure(figsize=(4, 4))
                    # plt.imshow((image_array_C*255), cmap='gray', vmin=0, vmax=255) # vmin/vmax for 8-bit grayscale
                    # plt.plot([x1,x3], [y1,y3], 'ro', markersize=8) # 'ro' for red circles, markersize for dot size
                    # plt.tight_layout()
                    # plt.margins(x=0, y=0)
                    # plt.axis('off') # Turn off axis ticks and labels if not needed for the image
                    # plt.show()

            if len(lcr_queue) == 3:
                flag.count1 = flag.count1+1
                lcr_queue.pop(0)  ## remove the first frame
            # pbar.update(1)
            # if flag.count1 > 100:
            #     return

## input is online frame cropped from raw Prodigy video
def inference_LCR_image_with_flag(model, device, lcr_image, flag, save_json=False, save_mask=False):
    '''
    - Input list of images L/C/R (left/center/right)
    - Model predicts the needle at C
    - PK on L & R frame based on predicted needle
    '''
    with torch.no_grad():
        # for f in image_names:
        if flag.count1 % flag.fps_reduce == 0:
            lcr_queue = []
            image_array_L = lcr_image[0][:, :, ::-1].astype(np.float32) / 255.0
            image_array_R = lcr_image[2][:, :, ::-1].astype(np.float32) / 255.0  ## BGR to RGB
            for idx in range(len(lcr_image)):
                lcr_image[idx] = cv2.cvtColor(lcr_image[idx], cv2.COLOR_BGR2GRAY )  ## [H,W,3]->[H,W]
                lcr_image[idx] = lcr_image[idx].astype(np.float32) / 255.0
                ## GPT said CV2 resize is faster than torch resize
                img_tensor = cv2.resize(lcr_image[idx], (config["Model"]["image_size"], config["Model"]["image_size"]), interpolation=cv2.INTER_LINEAR)
                img_tensor = torch.from_numpy(img_tensor).unsqueeze(0) # [1,H,W]
                lcr_queue.append(img_tensor)
            img_size = lcr_image[idx].shape[-1]

            images = torch.stack(lcr_queue, dim=1)
            # images = v2.functional.resize(images, (config["Model"]["image_size"], config["Model"]["image_size"]), 
            #                                 interpolation=tf.InterpolationMode.BILINEAR, antialias=True)
            images = images.to(device)
            fnames = [f]

            ## get numpy array for PK
            image_array_C = lcr_image[1].squeeze()  ## [H,W]
            image_array_LR = np.concatenate([image_array_L, image_array_R], axis=1)

            # Forward pass----------------------------------------------------
            pred_masks = model(images)  # [N=1, 3, h, w]
            if isinstance(pred_masks, dict):  ## Mask2Former-based model
                output_dict = pred_masks
                pred_masks = output_dict["pred_masks"]
                if output_dict["pred_class"].item() < 0.5:
                    pred_masks.zero_()

            ## new ver. remove pred_idx (rely on .sum) ## ---------------------
            # ## Resize back to image size  
            pred_mask_resize = pred_masks#.unsqueeze(0)
            if pred_mask_resize.shape[-1] != img_size:
                pred_mask_resize = v2.functional.resize(pred_mask_resize, (img_size,img_size), interpolation=tf.InterpolationMode.BILINEAR, antialias=True)

            pred_mask_tensor = (pred_mask_resize > 0.5).to(dtype=torch.uint8).squeeze()
            pred_mask = pred_mask_tensor.detach().cpu().numpy()
            
            if (pred_mask_tensor != 0).sum() < 2: #len(pred_idx[0]) <= 1:  ## -------------
                flag.found_needle = False
                print(f'idx {flag.count1} Needle not found\n------------')
            else:
                ## Post process mask into endpoints ----------------------
                # print('[image_array_LR]', np.max(image_array_LR),np.min(image_array_LR), len(np.unique(image_array_LR)))
                flag.found_needle = True
                # pred_coord_label = mask_to_line(pred_idx, img_size, flag)
                pred_coord_label = min_rect_2_line(pred_mask, flag)   ######################

                ## Final Estimation
                x1, y1, x3, y3 = flag.needle_location
                # print(f'[cut] {flag.x1_cut_idx} ~ {flag.x3_cut_idx}  [train_end_left] {flag.find_train_end_left}')
                # plt.figure(figsize=(4, 4))
                # plt.imshow((image_array_C*255), cmap='gray', vmin=0, vmax=255) # vmin/vmax for 8-bit grayscale
                # plt.plot([x1,x3], [y1,y3], 'ro', markersize=8) # 'ro' for red circles, markersize for dot size
                # plt.tight_layout()
                # plt.margins(x=0, y=0)
                # plt.axis('off') # Turn off axis ticks and labels if not needed for the image
                # plt.show()

                ## Save json ---------------------------------------------
                if save_json:
                    save_pred_to_json(fnames, pred_coord_label, img_size, save_folder_dir=flag.save_folder_name)
            
                ## PK to estimate depth (z1 z3) -------------------------- ###########################
                (OutputImg_PK, 
                    OutputImg_pk_regression_先左右互扣, 
                    OutputImg_pk_regression_再左右互扣) = PK(flag, image_array_LR.copy(), return_plot=save_mask, set_train_idx_point=True)

                if flag.x1_cut_idx < flag.x3_cut_idx:
                    x1y1_regression_depth = flag.pk_regression_box_再左右互扣[flag.x1_cut_idx-flag.find_train_end_left]
                    x3y3_regression_depth = flag.pk_regression_box_再左右互扣[flag.x3_cut_idx-1-flag.find_train_end_left]
                else:
                    x1y1_regression_depth = flag.pk_regression_box_再左右互扣[flag.x1_cut_idx-1-flag.find_train_end_left]
                    x3y3_regression_depth = flag.pk_regression_box_再左右互扣[flag.x3_cut_idx-flag.find_train_end_left]
                print(f'idx {flag.count1}')
                print(f'[x1 y1 raw_z1] {x1:.4f}, {y1:.4f}, {x1y1_regression_depth:.4f}')
                print(f'[x3 y3 raw_z3] {x3:.4f}, {y3:.4f}, {x3y3_regression_depth:.4f}')
                print('------------')

                ## Save mask ---------------------------------------------
                if save_mask:
                    flag.filename = os.path.basename(fnames[0]).replace(".jpg", '_pred.png')
                    gt_mask, gt_line = None,None
                    save_pred_to_plot(image_array_C, pred_mask, flag.needle_location, gt_mask , gt_line, 
                                        OutputImg_PK, OutputImg_pk_regression_先左右互扣, OutputImg_pk_regression_再左右互扣,
                                        fname=os.path.join(flag.save_folder_name, flag.filename))

        flag.count1 = flag.count1+1
    return flag   


## Main function

Output example:
```
idx <t>
[x1 y1 raw_z1] (7.598571007578613, 250.50608134443533, -0.23366086184978663)
[x3 y3 raw_z3] (184.2970108538728, 331.3508378681138, -11.399152440950273)
```

Explanation:

In `inference_loader_with_flag()`, the model predicts x1y1 & x3y3 for frame t, and PK regresssion estimates the depth (denoted as `raw_z1` & `raw_z3`).

`raw_z1` & `raw_z3` can be scaled manually in 3D construction.

In [1]:
## track model
if __name__ == "__main__":
    count = 0
    ## Set the image raw size
    if not (width > 0 and height > 0):
        for f in os.listdir(data_dir):
            if f.lower().endswith(".jpg"):
                image = Image.open(os.path.join(data_dir, f))
                width, height = image.size
                break
    flag = Flag(cutNum=cutNum, height=height, width=width)
    
    ## User Settings
    flag.run_only_one = run_only_one
    if save_json or save_mask:
        flag.save_folder_name = data_dir + '_PK'
        if not os.path.exists(flag.save_folder_name):
            os.makedirs(flag.save_folder_name)
        print(f'[Path] Results are saved to {flag.save_folder_name}')
    else:
        print('...No extra files are saved')

    ## Reset model ----------------------------------
    frame_count = 0
    model.eval()
    model.compile() ### torch.jit.script will cause error so just use compile  
    model.to(device)
    flag.count1 = 0

    ### Reset Memory for track model -----------------
    if isinstance(model, MemInferenceWrapper):
        model.clear_memory()
        print(f"[Model reset memory] memory_engaged: {model.memory_engaged()}")

    #主迴圈開始
    ## Inference on a folder of frames ---------------------
    if not preprocess_video and not flag.run_only_one:
        video_dataset = UnlabeledDataset(data_dir, transform=valid_transform,
                                        time_window=3,
                                        buffer_num_sample=1)
        
        ## Test batch size 1
        loader = DataLoader(video_dataset, batch_size=1, shuffle=False, drop_last=False, 
                            num_workers=4, persistent_workers=True, pin_memory=True)
        print(f"[Data] video length:{len(video_dataset)}")

        ## Similar to evaluate function but writes points to json file or plots PK result
        # inference_loader_with_flag(model, device, loader, flag, save_json, save_mask)
        # inference_folder_with_flag(model, device, data_dir, flag, save_json, save_mask)
        inference_LCRfolder_with_flag(model, device, data_dir, flag, save_json, save_mask)

    # region ## Inference on 1 image (not verified!)
    # elif flag.run_only_one:
    #     for i in range(0,1):
    #         current_frame = cv2.imread('/content/frames/a0000.jpg')

    #         flag.filename='img0000'
    #         flag, OutputImg_vconcat = get_sinogram_img(flag, current_frame)
    #         cv2.imwrite(os.path.join(flag.save_folder_name, flag.filename + '_vconcat.jpg'), OutputImg_vconcat, [cv2.IMWRITE_JPEG_QUALITY, 95])
    #         flag.run_time()
    # endregion

    ## Inference on raw prodigy video -------------------------------
    else:
        ## 3 video paths (left, center, right)
        if (isinstance(config_PK.Data.raw_video_dir, list) or 
            isinstance(config_PK.Data.raw_video_dir, omegaconf.listconfig.ListConfig)):  
            cap_l = cv2.VideoCapture(config_PK.Data.raw_video_dir[0])
            cap_c = cv2.VideoCapture(config_PK.Data.raw_video_dir[1])
            cap_r = cv2.VideoCapture(config_PK.Data.raw_video_dir[2])
            assert cap_l.isOpened(), "Error opening video file"

            target_fps = 15
            source_fps = cap_l.get(cv2.CAP_PROP_FPS)
            print(f"[VIDEO] Cut at fps {target_fps}")
            original_fps = cap_l.get(cv2.CAP_PROP_FPS)
            frame_interval = int(round(original_fps / target_fps))        

            while True:
                ret_l, frame_l = cap_l.read()
                ret_c, frame_c = cap_c.read()
                ret_r, frame_r = cap_r.read()
                if not ret_l:
                    break
                
                if frame_count % frame_interval == 0:
                    ## Threading: https://stackoverflow.com/a/58829816
                    with concurrent.futures.ThreadPoolExecutor() as executor:
                        inference_thread = executor.submit(inference_LCR_image_with_flag, 
                                                model, device, [frame_l, frame_c, frame_r], 
                                                        flag, save_json, save_mask)
                        flag = inference_thread.result()
                    # flag = inference_LCR_image_with_flag(model, device, [frame_l, frame_c, frame_r], 
                    #                                     flag, save_json, save_mask)
                frame_count += 1

            cap_l.release()
            cap_c.release()
            cap_r.release()
            cv2.destroyAllWindows()
        
        ## Default for Prodigy video: Single video path, need to crop each frame into 3 frames
        else:
            cap = cv2.VideoCapture(config_PK.Data.raw_video_dir)
            assert cap.isOpened(), "Error opening video file"

            target_fps = config_PK.User_setting.target_fps
            print(f"[VIDEO] Cut at fps {target_fps}")
            original_fps = cap.get(cv2.CAP_PROP_FPS)  ## Prodigy video:30
            frame_interval = int(round(original_fps / target_fps))        

            ## capture frames --------------------------------------
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                
                if frame_count % frame_interval == 0:
                    ## crop to 3 frames --------------------------------------
                    # print('[frame]', frame.shape)
                    frame_l = frame[y_1:y_2, x_l1:x_l2]
                    frame_c = frame[y_1:y_2, x_c1:x_c2]
                    frame_r = frame[y_1:y_2, x_r1:x_r2]

                    ## TODO remove black edge (and resize)
                    
                    ## Threading --------------------------------------------
                    ## https://stackoverflow.com/a/58829816
                    with concurrent.futures.ThreadPoolExecutor() as executor:
                        inference_thread = executor.submit(inference_LCR_image_with_flag, 
                                                model, device, [frame_l, frame_c, frame_r], 
                                                flag, save_json, save_mask)
                        flag = inference_thread.result()
                    # flag = inference_LCR_image_with_flag(model, device, [frame_l, frame_c, frame_r], 
                    #                                     flag, save_json, save_mask)
                frame_count += 1

            cap.release()
            cv2.destroyAllWindows()


NameError: name 'width' is not defined

In [None]:
## track model (frames are saved already)
# TORCH_COMPILE_DEBUG=1
# TORCHDYNAMO_VERBOSE=1
if __name__ == "__main__":
    count = 0
    ## Set the image raw size
    if not (width > 0 and height > 0):
        for f in os.listdir(data_dir):
            if f.lower().endswith(".jpg"):
                image = Image.open(os.path.join(data_dir, f))
                width, height = image.size
                break
    flag = Flag(cutNum=cutNum, height=height, width=width)
    
    ## @Settings
    flag.run_only_one = run_only_one
    if save_json or save_mask:
        flag.save_folder_name = data_dir + '_PK'
        if not os.path.exists(flag.save_folder_name):
            os.makedirs(flag.save_folder_name)
        print(f'[Path] Results are saved to {flag.save_folder_name}')
    else:
        print('...No extra files are saved')

    ## Reset model ----------------------------------
    frame_count = 0
    model.eval()
    model.compile() ### torch.jit.script will cause error so just use compile  
    model.to(device)
    flag.count1 = 0

    ## Reset Memory for track model -----------------
    if isinstance(model, MemInferenceWrapper):
        model.clear_memory()
        print(f"[Model reset memory] memory_engaged: {model.memory_engaged()}")

    #主迴圈開始
    ## Inference on a folder of frames ---------------------
    if not preprocess_video and not flag.run_only_one:
        video_dataset = UnlabeledDataset(data_dir, transform=valid_transform,
                                        time_window=3,
                                        buffer_num_sample=1)
        
        ## Test batch size 1
        loader = DataLoader(video_dataset, batch_size=1, shuffle=False, drop_last=False, 
                            num_workers=4, persistent_workers=True, pin_memory=True)
        print(f"[Data] video length:{len(video_dataset)}")

        ## Similar to evaluate function but writes points to json file or plots PK result
        # inference_loader_with_flag(model, device, loader, flag, save_json, save_mask)
        # inference_folder_with_flag(model, device, data_dir, flag, save_json, save_mask)
        inference_LCRfolder_with_flag(model, device, data_dir, flag, save_json, save_mask)

    
    # region ## Inference on 1 image (not verified!)
    # elif flag.run_only_one:
    #     for i in range(0,1):
    #         current_frame = cv2.imread('/content/frames/a0000.jpg')

    #         flag.filename='img0000'
    #         flag, OutputImg_vconcat = get_sinogram_img(flag, current_frame)
    #         cv2.imwrite(os.path.join(flag.save_folder_name, flag.filename + '_vconcat.jpg'), OutputImg_vconcat, [cv2.IMWRITE_JPEG_QUALITY, 95])
    #         flag.run_time()
    # endregion

    ## Inference on raw video -------------------------------
    else:
        if (isinstance(config_PK.Data.raw_video_dir, list) or 
            isinstance(config_PK.Data.raw_video_dir, omegaconf.listconfig.ListConfig)):  ## 3 video paths (left center right)
            for video_file_path in  config_PK.Data.raw_video_dir:
                os.makedirs(video_file_path[:-4], exist_ok=True) ## remove ".mp4" from file path
        else: ## single video, need to crop into 3 videos
            raise NotImplementedError

        cap_l = cv2.VideoCapture(config_PK.Data.raw_video_dir[0])
        cap_c = cv2.VideoCapture(config_PK.Data.raw_video_dir[1])
        cap_r = cv2.VideoCapture(config_PK.Data.raw_video_dir[2])
        assert cap_l.isOpened(), "Error opening video file"

        target_fps = 15
        source_fps = cap_l.get(cv2.CAP_PROP_FPS)
        print(f"[VIDEO] Cut at fps {target_fps}")
        original_fps = cap_l.get(cv2.CAP_PROP_FPS)
        frame_interval = int(round(original_fps / target_fps))        

        while True:
            ret_l, frame_l = cap_l.read()
            ret_c, frame_c = cap_c.read()
            ret_r, frame_r = cap_r.read()
            if not ret_l:
                break
            frame_l = cv2.cvtColor(frame_l, cv2.COLOR_BGR2GRAY )
            frame_c = cv2.cvtColor(frame_c, cv2.COLOR_BGR2GRAY )
            frame_r = cv2.cvtColor(frame_r, cv2.COLOR_BGR2GRAY )  ## [H,W,3]->[H,W]
            if frame_count % frame_interval == 0:
                flag = inference_LCR_image_with_flag(model, device, [frame_l, frame_c, frame_r], 
                                                    flag, save_json, save_mask)
            frame_count += 1

        cap_l.release()
        cap_c.release()
        cap_r.release()
        cv2.destroyAllWindows()
        # print(f"Done. Saved {saved_count} frames to {output_dir}.")

## 0.3742913  -1.4015927
##idx 10
# [x1 y1 raw_z1] 520.0276, 513.0443, -2.4962
# [x3 y3 raw_z3] 737.9263, 377.2767, -3.4129

...No extra files are saved
[Data] video length:0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))


idx 0 Needle not found
------------
idx 1 Needle not found
------------
idx 2 Needle not found
------------
idx 3 Needle not found
------------
idx 4 Needle not found
------------
idx 5 Needle not found
------------
idx 6 Needle not found
------------
idx 7 Needle not found
------------
idx 8 Needle not found
------------
idx 9 Needle not found
------------
idx 10
[x1 y1 raw_z1] 520.0276, 513.0443, -2.4819
[x3 y3 raw_z3] 737.9263, 377.2767, -3.4153
------------
idx 11
[x1 y1 raw_z1] 586.3554, 468.4446, -3.0919
[x3 y3 raw_z3] 743.9955, 366.9930, -5.1439
------------
idx 12
[x1 y1 raw_z1] 555.5448, 480.2503, -0.9539
[x3 y3 raw_z3] 833.6484, 313.7425, -4.0454
------------
idx 13
[x1 y1 raw_z1] 545.6659, 483.4122, 0.5394
[x3 y3 raw_z3] 902.9113, 280.3625, -5.1032
------------
idx 14
[x1 y1 raw_z1] 487.7260, 526.9653, -2.5146
[x3 y3 raw_z3] 811.1499, 324.4329, -2.4435
------------
idx 15
[x1 y1 raw_z1] 469.0780, 523.6935, -3.6492
[x3 y3 raw_z3] 916.7518, 263.4539, 1.1273
------------
idx 16

KeyboardInterrupt: 

In [10]:
## multi experiment to check time
EXP = 10
sum_time = 0
import time
## track model
if __name__ == "__main__":
    for _ in range(EXP):
        start_time = time.perf_counter()
        count = 0
        ## Set the image raw size
        if not (width > 0 and height > 0):
            for f in os.listdir(data_dir):
                if f.lower().endswith(".jpg"):
                    image = Image.open(os.path.join(data_dir, f))
                    width, height = image.size
                    break
        flag = Flag(cutNum=cutNum, height=height, width=width)
        
        ## @Settings
        flag.run_only_one = run_only_one
        if save_json or save_mask:
            flag.save_folder_name = data_dir + '_PK'
            if not os.path.exists(flag.save_folder_name):
                os.makedirs(flag.save_folder_name)
            print(f'[Path] Results are saved to {flag.save_folder_name}')
        else:
            print('...No extra files are saved')


        #主迴圈開始
        ## Inference on a folder of frames
        if not preprocess_video and not flag.run_only_one:
            assert os.path.exists(data_dir)
            video_dataset = UnlabeledDataset(data_dir, transform=valid_transform,
                                            time_window=3,
                                            buffer_num_sample=1)
            
            ## Test batch size 1
            loader = DataLoader(video_dataset, batch_size=1, shuffle=False, drop_last=False, 
                                num_workers=4, persistent_workers=True, pin_memory=True)
            print(f"[Data] video length:{len(video_dataset)}")

            ## Similar to evaluate function but writes points to json file or plots PK result
            inference_loader_with_flag(model, device, loader, flag, save_json, save_mask)
            # inference_folder_with_flag(model, device, data_dir, flag, save_json, save_mask)
            # inference_LCRfolder_with_flag(model, device, data_dir, flag, save_json, save_mask)
            
        end_time = time.perf_counter()
        sum_time += (end_time-start_time)

print(f'[Avg time of {EXP} expirements] {(sum_time/EXP) :.3f} sec', )

...No extra files are saved
[Data] video length:638


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))


idx 0 Needle not found
------------
idx 1 Needle not found
------------
idx 2 Needle not found
------------
idx 3 Needle not found
------------
idx 4 Needle not found
------------
idx 5 Needle not found
------------
idx 6 Needle not found
------------
idx 7
[x1 y1 raw_z1] 264.9110, 227.9296, 7.2728
[x3 y3 raw_z3] 436.1849, 278.7472, 1.0620
------------
idx 8
[x1 y1 raw_z1] 197.9044, 214.7325, 1.5551
[x3 y3 raw_z3] 476.5899, 291.8687, 5.4290
------------
idx 9 Needle not found
------------
idx 10
[x1 y1 raw_z1] 13.4763, 181.8716, 0.8038
[x3 y3 raw_z3] 529.0816, 326.1342, 8.7666
------------
idx 11
[x1 y1 raw_z1] 4.5746, 178.1512, 2.4627
[x3 y3 raw_z3] 767.4755, 415.2524, 13.9871
------------
idx 12
[x1 y1 raw_z1] 0.7733, 175.7189, 4.6175
[x3 y3 raw_z3] 840.8597, 460.5142, 11.3680
------------
idx 13
[x1 y1 raw_z1] 4.0711, 177.8661, 2.8958
[x3 y3 raw_z3] 874.2724, 475.2766, 7.8651
------------
idx 14
[x1 y1 raw_z1] 2.6121, 178.1589, 4.3357
[x3 y3 raw_z3] 871.9572, 469.1530, 0.9856
------

### Process Utilization Memo
- Sonosite 638 frames, convnext m2f
    - pure inference +parallel_backend in mask2line: 74.9s
    - inf+PK w/o plot:
        * power(最佳效能): 89.7s
        * +parallel_backend in regression: 83.8s (critical)
        * min_rect_2_line & remove duplicate .cpu(): 65.7s (critical)
        * ''  + dynamic tanh: 69.7s (no help)
        * model.compile(): 60.3s (critical)
            * mode="reduce-overhead":96.0s
            * mode="max-autotune":fail

        * battery(平衡) +parallel_backend in regression & mask2line: 103.7s

        * place data under local demo folder: not helpful
        * Direct Image.open without Dataset: 109.9s

    - inf w/o PK & plot: 51.8s
    - 10 main function avg time:
        * min_rect_2_line & remove duplicate .cpu(): 59.75s
        * ''  + dynamic tanh: 59.83s
        * model.compile(): 51.1s

- Prodigy video (1327 frames)
    - online capture 3 videos: 175.6s
    - online capture 3 videos, compile(): 163.7s (critical)
    - online capture 3 videos, compile(), no resize: 412.0s (poor)
    - online capture 3 videos, compile(), cv2 resize instead of torch: 167.2s
    - online capture 3 videos, compile(), remove duplicate gray2rgb: 122.3s (critical!!!)
        - TIP! directly use bgr to rgb in inference function, rather than use bgr to gray to rgb
        - now only applied in inference_LCR_image_with_flag()
    - online capture 3 videos, compile(), remove duplicate gray2rgb,ThreadPoolExecutor: 105.5s(critical!!!)
        - TODO: need to check memory sequence if using mem_m2f (tracking model)
    

In [None]:
!nvcc -V
# !nvidia-smi
import torch, os
torch.__version__
print("is_available", torch.cuda.is_available())
# print(torch.cuda.get_arch_list())
# print(os.environ.get('TORCH_CUDA_ARCH_LIST'))
# print(torch.__path__ )
print(torch.version.cuda)

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:23:50_PST_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0
True
12.8


## Source Code Archive
no need to run below code

In [None]:
### video to frames
cap = cv2.VideoCapture(config_PK.Data.raw_video_dir)
fps = cap.get(cv2.CAP_PROP_FPS)

#for j in range(0,2):    #range是做的總frame跨距(30fps)，fps_reduce每n個做一次
for j in range(0,100):   # range設很大會自動停止
    flag.count1 = 100*j  # 100 frame存一個資料夾，共200 frame
    print("j=",j)
    for i in range(100*j*flag.fps_reduce, 100*j*flag.fps_reduce+100*flag.fps_reduce):    #range是做的總frame跨距(30fps)，fps_reduce每n個做一次
        ret, current_frame = cap.read()
        if ret:
            count=count+1
            flag.height, flag.width, __  = current_frame.shape
            #if count % flag.fps_reduce == 1:               #要跳過幾個frame數在參數區  取餘數
            if True:                             #flag.fps_reduce = 1全做
                print("count1=",flag.count1)
                flag.filename='img{:04}'.format(flag.count1)
                if True:
                #if 115 == flag.count1 or flag.count1 == 75:
                    flag, OutputImg_vconcat = get_sinogram_img(flag, current_frame)
                    cv2.imwrite('/content/dataset_auto/'+ flag.filename + '_vconcat.jpg', OutputImg_vconcat, [cv2.IMWRITE_JPEG_QUALITY, 95])
                    ##每次存檔都在資料夾名稱加數字(3D每100迴圈要幾小時，就要存檔一次)
                flag.count1=flag.count1+1
        else:
            break
    flag.save_func(j) #要打開，這樣每個for迴圈都會存檔   改參數想重跑，不想覆蓋舊的就要關掉?


In [11]:
1327/87.2

15.21788990825688