In [14]:
# mainly consisting of 1 steps
# 1. undistort RGB/depth

import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import tqdm

%matplotlib inline

os.chdir("/home/data/workspace/heqi/monogastroendo/")

assert os.path.isfile("c3vd_data/matrix.npy")
matrix = np.load("c3vd_data/matrix.npy")
distortion = np.load("c3vd_data/distortion.npy")

class preprocessor():
    def __init__(self, matrix, distortion) -> None:
        self.matrix = matrix
        self.distortion = distortion
    
    def run(self, img):
        img_undistort = cv2.undistort(img, matrix, distCoeffs=distortion)
        return img_undistort

def test():
    pp = preprocessor(matrix, distortion)
    img = cv2.imread("c3vd_data/cfhq190l_10x10mm_checkerboard_images/frames/0.tiff")
    plt.imshow(pp.run(img))
    print(img.shape)

def read_from_root(root):
    img_list = []
    for rt, folder, files in os.walk(root):
        for file in files:
            if 'color.png' in file or 'depth.tiff' in file:
                img_list.append([rt, file])
    return img_list

def save_undistort_to_root(src):
    pp = preprocessor(matrix, distortion)
    for dir, file in tqdm.tqdm(src):
        des = "rect_" + dir
        os.makedirs(des, exist_ok=True)
        # if undistorted file is aleady there, skip it
        if os.path.exists(os.path.join(des, file)):
            continue
        img = cv2.imread(os.path.join(dir, file), -1)
        try:
            img = pp.run(img)
            cv2.imwrite(os.path.join(des, file), img)
        except:
            print(dir, file)
        

import threading
class MyThread(threading.Thread ):
    def __init__(self, img_list):
        threading.Thread.__init__(self)
        self.img_list = img_list
        
    def run (self):
        save_undistort_to_root(self.img_list)

def multi_thread_save_undistort_to_root(src, n_thr=3):
    threads=[]
    block_size = np.ceil(len(src) / n_thr).astype(int)
    for i in range(n_thr):
        thread=MyThread(src[i*block_size:(i+1)*block_size])
        threads.append(thread)
        thread.start()
    for t in threads:
        t.join()

# test()

img_list = read_from_root(root="c3vd_data/")
# save_undistort_to_root(src=img_list)
multi_thread_save_undistort_to_root(img_list, n_thr=3)


  0%|          | 0/6744 [00:00<?, ?it/s]
 24%|██▍       | 1617/6744 [00:00<00:00, 16162.12it/s]
100%|██████████| 6742/6742 [00:00<00:00, 37537.63it/s]
100%|██████████| 6744/6744 [00:00<00:00, 36507.65it/s]
100%|██████████| 6744/6744 [00:00<00:00, 30356.62it/s]
