In [2]:
import os
import json
import glob
import shutil
from math import dist
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
try:
    from mmdet.apis import inference_detector, init_detector
    has_mmdet = True
except (ImportError, ModuleNotFoundError):
    has_mmdet = False

from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import adapt_mmdet_pipeline
import sys
sys.path.append(os.path.abspath("../code_base"))
from _info_ import ear_types, degrees, acupoints_name, cm, rotation_angles
from _common_ import angles, split_xy_xyv
from prediction import pred_csv
# from ..code_base._info_ import ear_types, degrees, acupoints_name, cm, rotation_angles
# from ..code_base._common_ import angles, split_xy_xyvm
# from ..code_base.prediction import pred_csv

In [7]:
class rotation_pred_csv(pred_csv):
    def __init__(self):
        super().__init__("MAT_inpainting")
        self.rotation_img = "../rotation_img"
        self.rotation_csv = "../rotation_csv"
    def read_csv(self, fpath1):
        df = pd.read_csv(fpath1, index_col = 0)
        return df

    def process_one_image(self, img, detector , pose_estimator):
        bboxes = None
        if detector is not None:
            det_result = inference_detector(detector, img)
            pred_instance = det_result.pred_instances.cpu().numpy()
            bboxes = np.concatenate(
                (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
            bboxes = bboxes[np.logical_and(pred_instance.labels == 0,
                                           pred_instance.scores > 0.3)]
            bboxes = bboxes[nms(bboxes, 0.3), :4]
            
        pose_results = inference_topdown(pose_estimator, img, bboxes)
        data_samples = merge_data_samples(pose_results)
        pred_instances = data_samples.get('pred_instances', None)
        pred_instances_list = split_instances(pred_instances)
        kpts = pred_instances_list[0]["keypoints"]
        return kpts
    
    
    def generate(self, has_detector = True):
        if has_detector == True:
            detector = init_detector(self.det_cfg, self.det_ckpt, device="cuda:0")
            detector.cfg = adapt_mmdet_pipeline(detector.cfg)
        else:
            detector = None
    
        for ear_type in ear_types:
            names = os.listdir(os.path.join("..", self.data_folder, ear_type, "model_save"))
            for name in names:
                rtmpose_cfg = os.path.join(self.kpt_cfg.format(et = ear_type))
                rtmpose_ckp = glob.glob(os.path.join("..", self.data_folder, ear_type, "model_save", name,"best*.pth"))[0]

                pose_estimator = init_pose_estimator(
                rtmpose_cfg,
                rtmpose_ckp,
                device="cuda:0",
                cfg_options=dict(
                    model=dict(test_cfg=dict(output_heatmaps=False))))
                for deg in degrees:
                    for angle in rotation_angles:

                        imgs = os.listdir(os.path.join(self.rotation_img, ear_type, name, deg, str(angle), "img"))
                        imgs = sorted(imgs, key = lambda s : int(os.path.splitext(os.path.basename(s))[0]), reverse = False)


                        df = pd.DataFrame()
                        total_count = 0
                        for img in imgs:
                            image = os.path.join(os.path.join(self.rotation_img, ear_type, name, deg, str(angle), "img", img))
                            kpts = self.process_one_image(image, detector, pose_estimator)
                            kpts_pd = {}
                            for i in range(0, len(kpts)):
                                x, y  = kpts[i]
                                kpts_pd[i] = str(x)+str(",")+str(y)
                            ser = pd.DataFrame(data=kpts_pd, index = [total_count])
                            df = pd.concat([df, ser])
                            total_count += 1
                        df.to_csv(os.path.join(self.rotation_csv, ear_type, name, deg, str(angle), "pred.csv"))
    


                    

In [8]:
A = rotation_pred_csv()

In [9]:
A.generate()

Loads checkpoint by local backend from path: ../mmdetection/work_dirs/rtmdet_nano_320-8xb32_coco-ear/epoch_120.pth
Loads checkpoint by local backend from path: ..\MAT_inpainting\free\model_save\cather\best_EPE_epoch_20.pth
Loads checkpoint by local backend from path: ..\MAT_inpainting\free\model_save\daniel\best_EPE_epoch_30.pth
Loads checkpoint by local backend from path: ..\MAT_inpainting\free\model_save\dominic\best_EPE_epoch_10.pth
Loads checkpoint by local backend from path: ..\MAT_inpainting\free\model_save\jack\best_EPE_epoch_10.pth
Loads checkpoint by local backend from path: ..\MAT_inpainting\free\model_save\jakaria\best_EPE_epoch_30.pth
Loads checkpoint by local backend from path: ..\MAT_inpainting\free\model_save\jimmy\best_EPE_epoch_10.pth
Loads checkpoint by local backend from path: ..\MAT_inpainting\free\model_save\vicky\best_EPE_epoch_10.pth
Loads checkpoint by local backend from path: ..\MAT_inpainting\free\model_save\wayne\best_EPE_epoch_40.pth
Loads checkpoint by loca