In [1]:
import torch
import torch.nn as nn

import albumentations as A
from albumentations.pytorch import ToTensorV2

import tqdm
import glob
import mmcv
import numpy as np
from accelerate import Accelerator
from mmdet.structures import DetDataSample
from mmengine.structures import InstanceData
from mmdet.models.test_time_augs.merge_augs import merge_aug_results

from utils import modules, custom_dataset

import json
from ensemble_boxes import weighted_boxes_fusion

In [2]:
accelerator = Accelerator(mixed_precision=None)
device = accelerator.device

torch.set_float32_matmul_precision('high')

In [3]:
BATCH_SIZE = 1
NUM_WORKDERS = 1
IMG_PREFIX = "data/test2/"
PSEUDO_THREHSOLD = .6

MODEL_CFG_CONFIG = "config.py"

MODEL_NAME = "RTMDet"
model_path = "work_dir/RTMDet_model.pt"
tta_type = [None, "hor", 1.3, .8, "ver", 90]

In [4]:
class CustomDatset(torch.utils.data.Dataset):
    after_transform = A.Compose(
        [
            A.Normalize(),
            ToTensorV2()
        ],
    )
    
    hor = A.Compose([A.HorizontalFlip(p=1.)])
    ver = A.Compose([A.VerticalFlip(p=1.)])
    
    def __init__(
        self,
        img_prefix="data/test1/",
        fill_pad_factor=32,
        fill_pad_value=255
    ):
        """ 自定义加载数据集
        
        Args:
            img_prefix (str): 图片的根目录
            fill_pad_factor (int): 将宽高填充至倍数
            fill_pad_value (int): 填充值
        """
        self.images = sorted(glob.glob(f"{img_prefix}/*.jpg"))
        
        self.fill_pad_factor = fill_pad_factor
        self.fill_pad_value = fill_pad_value
        self.len = len(self.images)
        
        self.type = None
        
    def __getitem__(self, idx):
        img = self.images[idx]
        x = mmcv.imread(img, channel_order="rgb")
        inputs, data_sample = self.pipeline(x)
        
        return inputs, data_sample
    
    def pipeline(self, img, image_id=None):
        org_h, org_w, _ = img.shape
        h, w, c = img.shape
        scale_factor = (1., 1.)
        flip = False
        flip_direction = "y"
        if self.type == "hor":
            item = self.hor(image=img)
            img = item["image"]
        elif self.type == "ver":
            item = self.ver(image=img)
            img = item["image"]
        elif self.type is None:
            ...
        elif isinstance(self.type, float):
            new_h, new_w = int(h * self.type), int(w * self.type)
            item = A.Resize(new_h, new_w)(image=img)
            img = item["image"]
            new_h, new_w, c = img.shape
            scale_factor = (new_w / w, new_h / h)
            h, w = new_h, new_w
        elif self.type == 90:
            img = np.rot90(img, 1, (0, 1))
            h, w, c = img.shape
        else:
            raise ValueError(f"不能解释的类型：{self.type}")

        pad_h = h
        pad_w = w
        if (
            h % self.fill_pad_factor != 0 or
            w % self.fill_pad_factor != 0
        ):
            if h % self.fill_pad_factor != 0:
                pad_h = (h // self.fill_pad_factor + 1) * self.fill_pad_factor
                
            if w % self.fill_pad_factor != 0:
                pad_w = (w // self.fill_pad_factor + 1) * self.fill_pad_factor
             
            pad_img = np.full((pad_h, pad_w, c), self.fill_pad_value, dtype=img.dtype)
            pad_img[:h, :w] = img
            img = pad_img
        
        item = self.after_transform(image=img)
        img = item["image"]
        
        data_sample = DetDataSample()
        img_meta = dict(
            org_shape=(org_h, org_w),
            img_shape=(h, w),
            pad_shape=(pad_h, pad_w),
            scale_factor=scale_factor,
            image_id=image_id,
            keep_ratio=True,
            flip=flip,
            flip_direction=flip_direction,
        )

        data_sample.set_metainfo(img_meta)

        return img, data_sample
        
    def __len__(self):
        return self.len

In [5]:
class TTA:
    def __init__(self, model, dataloader):
        """ TTA 适用 batch size 为1的情况 """
        self.model = model
        self.model.eval()
        self.dataloader = dataloader
    
    @torch.no_grad()
    def __call__(self, types=None):
        """ 
        Args：
            types: 图片变换类型
                float: 图片缩放比例
                None: 原图训练
                "hor": 水平翻转
        """
        results = []
        self.dataloader.dataset.type = types
        res = []
        dataloader = tqdm.tqdm(self.dataloader)
        for batch in dataloader:
            out = self.model(batch, mode="predict")[0]
            pred = out.pred_instances
            h, w = batch["data_samples"][0].org_shape
            img_meta = dict(
                # 原图高宽
                img_shape=(h, w),
            )
            if isinstance(types, float) or types is None:
                # 缩放，在 mmdet 中会自动缩放
                ...
            elif types == "hor":
                # 水平翻转
                bbox = pred.bboxes.clone()
                pred.bboxes[:, [0, 2]] = w - bbox[:, [2, 0]]
            elif types == "ver":
                # 垂直翻转
                bbox = pred.bboxes.clone()
                pred.bboxes[:, [1, 3]] = h - bbox[:, [3, 1]]
            elif types == 90:
                ...
                # 旋转 90°
                bbox = pred.bboxes.clone()
                pred.bboxes[:, 0] = w - bbox[:, 3]
                pred.bboxes[:, 1] = bbox[:, 0]
                pred.bboxes[:, 2] = w - bbox[:, 1]
                pred.bboxes[:, 3] = bbox[:, 2]
            else:
                raise ValueError(f"不能解释的类型：{types}")
            pred.set_metainfo(img_meta)
                
            results.append(pred)
        
        self.dataloader.dataset.type = None
        return results

In [6]:
test_dataloader = torch.utils.data.DataLoader(
    CustomDatset(
        IMG_PREFIX,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKDERS,
    shuffle=False,
    collate_fn=custom_dataset.collate_fn
)

model = modules.Model(MODEL_CFG_CONFIG)
model.load_state_dict(torch.load(model_path, "cpu")["model"])

model, test_dataloader = accelerator.prepare(
    model, 
    test_dataloader
)

tta = TTA(model, test_dataloader)

All Keys Matching


In [7]:
outputs = [tta(tt) for tt in tta_type]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 1887/1887 [01:26<00:00, 21.86it/s]
100%|██████████| 1887/1887 [01:11<00:00, 26.36it/s]
100%|██████████| 1887/1887 [01:46<00:00, 17.64it/s]
100%|██████████| 1887/1887 [00:57<00:00, 32.87it/s]
100%|██████████| 1887/1887 [01:12<00:00, 26.11it/s]
100%|██████████| 1887/1887 [01:21<00:00, 23.26it/s]


In [8]:
# results = []
# for out in outputs:
#     res = []
#     for label_id in [1, 4, 3, 8, 9, 6, 5, 2, 7]:
#         r = []
#         pred_instances = out
#         for i in range(len(pred_instances.labels)):
#             if (
#                 pred_instances.labels[i].item() == label_id and
#                 pred_instances.scores[i].item() > 0.03
#             ):
#                 r.append([
#                     b.item()
#                     for b in pred_instances.bboxes[i].detach().cpu()
#                 ] + [pred_instances.scores[i].detach().cpu().item()]
#                 )
#         res.append(r)
#     results.append(res)

In [9]:
pseudo_coco = {
    "images": [],
    "annotations": [],
    "categories": [
        {
            "id": 0,
            "name": "background"
        },
        {
            "id": 1,
            "name": "knife"
        },
        {
            "id": 2,
            "name": "scissor"
        },
        {
            "id": 3,
            "name": "glassbottle"
        },
        {
            "id": 4,
            "name": "tongs"
        },
        {
            "id": 5,
            "name": "metalcup"
        },
        {
            "id": 6,
            "name": "umbrella"
        },
        {
            "id": 7,
            "name": "lighter"
        },
        {
            "id": 8,
            "name": "pressure"
        },
        {
            "id": 9,
            "name": "laptop"
        }
    ]
}
img_id = 0
ann_id = 0

In [10]:
results = []
for img_idx, i in enumerate(range(len(outputs[0]))):
    h, w = outputs[0][i].img_shape
    boxes, scores, labels = weighted_boxes_fusion(
        boxes_list=[
            (out[i].bboxes.cpu() / torch.tensor([w, h, w, h]).reshape(1, -1)).tolist()
            for out in outputs
        ],
        scores_list=[
            out[i].scores.tolist()
            for out in outputs
        ],
        labels_list=[
            out[i].labels.tolist()
            for out in outputs
        ],
        conf_type="max",
        iou_thr=.55,
        skip_box_thr=0.02,
    )
    boxes = boxes * np.array([w, h, w, h]).reshape(1, -1).tolist()
    scores = scores.tolist()
    labels = labels.tolist()
    
    cur_ann = 0
    res = []
    for label_id in [1, 4, 3, 8, 9, 6, 5, 2, 7]:
        r = []
        for i in range(len(boxes)):
            # # # # # #
            # Predict #
            # # # # # #
            if labels[i] == label_id:
                box = boxes[i]
                r.append([*boxes[i]] + [scores[i]])
                
                # # # # # #
                # Pseudo #
                # # # # #
                if scores[i] > PSEUDO_THREHSOLD:
                    box = [
                        int(boxes[i][0]),
                        int(boxes[i][1]),
                        int(boxes[i][2] - boxes[i][0]),
                        int(boxes[i][3] - boxes[i][1]),
                    ]

                    cur_ann += 1
                    pseudo_coco["annotations"].append({
                        "image_id": img_id,
                        "id": ann_id,
                        "category_id": int(labels[i]),
                        "bbox": box,
                        "area": box[2] * box[3],
                        "segmentation": [],
                        "iscrowd": 0 if "polygon" else 1,
                        "score": scores[i]
                    })
                    ann_id += 1
        res.append(r)
    results.append(res)
    
    if cur_ann > 0:
        pseudo_coco["images"].append({
            "id": img_id,
            "file_name": str(img_idx).rjust(5, "0") + ".jpg",
            "height": h,
            "width": w
        })
        
        img_id += 1



In [11]:
with open("results.json", "w") as f:
    json.dump(results, f)

In [12]:
# with open("data/pseudo_coco_from814_thre0.6.json", "w") as f:
#     json.dump(pseudo_coco, f)