Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I want to train the masktrackrcnn, but it occur :KeyError: "YouTubeVISDataset: 'image_id'" #814

Open
eatbreakfast111 opened this issue Dec 27, 2022 · 2 comments

Comments

@eatbreakfast111
Copy link

Hello!
I want to train the masktrackrcnn by the official youtube_vis_dataset
but it occur :KeyError: "YouTubeVISDataset: 'image_id'".
Here is my datatree
image

Traceback (most recent call last):
  File "/home/music/miniconda3/envs/mmtrack/lib/python3.8/site-packages/mmcv/utils/registry.py", line 69, in build_from_cfg
    return obj_cls(**args)
  File "/home/music/Downloads/mmtracking-master/mmtrack/datasets/youtube_vis_dataset.py", line 44, in __init__
    super().__init__(*args, **kwargs)
  File "/home/music/Downloads/mmtracking-master/mmtrack/datasets/coco_video_dataset.py", line 46, in __init__
    super().__init__(*args, **kwargs)
  File "/home/music/miniconda3/envs/mmtrack/lib/python3.8/site-packages/mmdet/datasets/custom.py", line 97, in __init__
    self.data_infos = self.load_annotations(local_path)
  File "/home/music/Downloads/mmtracking-master/mmtrack/datasets/coco_video_dataset.py", line 61, in load_annotations
    data_infos = self.load_video_anns(ann_file)
  File "/home/music/Downloads/mmtracking-master/mmtrack/datasets/coco_video_dataset.py", line 73, in load_video_anns
    self.coco = CocoVID(ann_file)
  File "/home/music/Downloads/mmtracking-master/mmtrack/datasets/parsers/coco_video_parser.py", line 22, in __init__
    super(CocoVID, self).__init__(annotation_file=annotation_file)
  File "/home/music/miniconda3/envs/mmtrack/lib/python3.8/site-packages/mmdet/datasets/api_wrappers/coco_api.py", line 23, in __init__
    super().__init__(annotation_file=annotation_file)
  File "/home/music/miniconda3/envs/mmtrack/lib/python3.8/site-packages/pycocotools/coco.py", line 86, in __init__
    self.createIndex()
  File "/home/music/Downloads/mmtracking-master/mmtrack/datasets/parsers/coco_video_parser.py", line 57, in createIndex
    imgToAnns[ann['image_id']].append(ann)
KeyError: 'image_id'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tools/train.py", line 213, in <module>
    main()
  File "tools/train.py", line 188, in main
    datasets = [build_dataset(cfg.data.train)]
  File "/home/music/miniconda3/envs/mmtrack/lib/python3.8/site-packages/mmdet/datasets/builder.py", line 82, in build_dataset
    dataset = build_from_cfg(cfg, DATASETS, default_args)
  File "/home/music/miniconda3/envs/mmtrack/lib/python3.8/site-packages/mmcv/utils/registry.py", line 72, in build_from_cfg
    raise type(e)(f'{obj_cls.__name__}: {e}')
KeyError: "YouTubeVISDataset: 'image_id'"

Thank you!

@dyhBUPT
Copy link
Collaborator

dyhBUPT commented Dec 28, 2022

It seems that there is some error in your annotation file.

You can refer to:
https://github.com/open-mmlab/mmtracking/blob/master/docs/en/dataset.md#2-convert-annotations

@eatbreakfast111
Copy link
Author

Thank you so much. It work.
I had write a script labelme2cocovid, but when i try to train the dataset, it occur this problem:
image
Here is my labelme2cocovid, could you please give me some advice?

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import os
import cv2
import numpy as np
from tqdm import tqdm
import json

############################################################
# 存在图像和json的目录
srcpath = '/home/music/Downloads/mmtracking-master/data/youtube_vis_2019/valid/JPEGImages/'
# 要保存的结果json
dstjson = 'valid.json'
############################################################

# 读取图像,解决imread不能读取中文路径的问题
def cv_imread(filePath):
    cv_img = cv2.imdecode(np.fromfile(filePath,dtype=np.uint8),-1)
    return cv_img

if __name__ == '__main__':

    # 最终的json
    dst_ann = {
    "categories": [],
    "videos": [],
    "images": [],
    "annotations": []}
    categories_list = []
    
    # 遍历目录
    videos_list = os.listdir(srcpath)
    for videoname in videos_list:
        print('================='+videoname+'=================')
        videopath = srcpath + videoname + '/'
        fileslist = os.listdir(videopath)
        #############################################################
        # 写video相关信息
        curr_video_id = len(dst_ann['videos'])+1
        dst_ann['videos'].append({"id": curr_video_id,"name": videoname})
        #############################################################
        # 遍历一个文件夹下的图像及标注
        for filename in tqdm(fileslist):
            if '.json' in filename:
                jsonpath = videopath + filename
                imgpath = jsonpath[:-4] + 'jpg'

                img = cv_imread(imgpath)
                width,height,_ = img.shape

                #############################################################
                # 写images相关信息
                curr_img_id = len(dst_ann['images'])+1
                img_info =  {
                    "file_name": filename[:-4] + 'jpg', 
                    "height": height,"width": width,
                    "id": curr_img_id,
                    "video_id": curr_video_id,
                    "frame_id": int(filename[:-5]) }
                dst_ann['images'].append(img_info)
                #############################################################

                # 读json
                with open(jsonpath,encoding="utf-8") as f:
                    data = json.load(f)
                    shapes = data['shapes']
                    for s in shapes:
                        label = s['label']
                        points = s['points']
                        curr_category_id = -1
                        #############################################################
                        # 写类别相关信息
                        if not label in categories_list:
                            curr_category_id = len(categories_list)+1
                            categories_list.append(label)
                            dst_ann['categories'].append({"id": curr_category_id,"name": label})
                        else:
                            for c in dst_ann['categories']:
                                if c['name'] == label:
                                    curr_category_id = c['id']
                                    break
                        #############################################################
                        # 求点的外接矩形
                        xmin,xmax,ymin,ymax = width,0,height,0
                        for p in points:
                            x,y = p[0], p[1]
                            xmin = min(xmin, x)
                            xmax = max(xmax,x)
                            ymin = min(ymin, y)
                            ymax = max(ymax, y)
                        # check bbox
                        if 0:
                            cv2.rectangle(img, (int(xmin),int(ymin)), (int(xmax),int(ymax)), (255, 0, 255), 3)
                            cv2.imshow('',img)
                            cv2.waitKey(-1)
                        # 计算面积
                        area = (xmax-xmin) * (ymax-ymin)

                        #############################################################
                        # 写annotations相关信息 
                        curr_ann_id = len(dst_ann['annotations'])+1
                        ann = {"id": curr_ann_id,
                        "image_id": curr_img_id,
                        "video_id": curr_video_id,
                        "category_id": curr_category_id,
                        "instance_id": -1,
                        "bbox": [xmin,ymin,xmax-xmin,ymax-ymin],
                        "area": area,
                        "mask":points,
                        "occluded": False,
                        "truncated": False,
                        "iscrowd": False,
                        "ignore": False,
                        "is_vid_train_frame": True,
                        "visibility": 1.0}
                        dst_ann['annotations'].append(ann)
                        #############################################################
    #############################################################
    # 写结果文件              
    with open(dstjson,"w") as f:
        json.dump(dst_ann, f,
              indent=2,  # 空格缩进符,写入多行
              sort_keys=False,  # 键的排序
              ensure_ascii=True)  # 显示中文

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants