# split datasets
version: 3

info:
- split json into train.json, val.json and test.json 

author: nuno costa

In [1]:
from annotate_v5 import *
import platform 
import numpy as np
import time
import pandas as pd
from IPython.display import Image, display
import copy
import os
from shutil import copyfile
import matplotlib.pyplot as plt
from matplotlib.image import imread
from matplotlib.patches import Rectangle
import random

In [2]:
#Define root dir dependent on OS
rdir='D:/external_datasets/MOLA/annotations/' 
if str(platform.platform()).find('linux')>-1: rdir=rdir.replace('D:/','/mnt/d/')
print('OS: {}'.format(platform.platform()))
print('root dir: {}'.format(rdir))

OS: Windows-10-10.0.21292-SP0
root dir: D:/external_datasets/MOLA/annotations/


## 1. Init vars

In [3]:
train=70
val=20
test=100-(train+val)
injsonfile='coco2017_reorder_cleanclass.json'
infilename=injsonfile.split('.')[0]

In [4]:
# init json
molajson =  json.load(open(rdir+injsonfile))
for k in molajson:
    print(k, len(molajson[k]))

info 6
licenses 8
images 123287
annotations 1170251
categories 80


## 2. Import ids
#### #NOTE: work with ids and index so you can use numpy for faster operations

In [5]:
# categories id
catids=[]
cats=[]
for c in molajson['categories']:
    catids.append(c['id'])
    cats.append(c['name'])
#print(cats)

In [6]:
# annotations category_id
ann_catids=[]
ann_ids=[]
for an in tqdm(molajson['annotations']):
    ann_catids.append(an['category_id'])
    ann_ids.append(an['id'])
print(len(ann_ids))

100%|█████████████████████████████████████████████████████████████████| 1170251/1170251 [00:00<00:00, 1215215.04it/s]

1170251





In [7]:
#TEST dupplicates v1 - slow
# duplicates_l=list(set([x for x in ann_ids if ann_ids.count(x) > 1])) # duplicates l 
#TEST dupplicates v2 - fast
#from collections import Counter
#duplicates_l=[item for item, count in Counter(ann_ids).items() if count > 1]
#TEST duplicates v3 -faster
u, c = np.unique(np.array(ann_ids), return_counts=True)
duplicates_l= u[c > 1].tolist()
print(len(duplicates_l))

273469


## 3. split by annotations
#QUESTION Seeded random or not?

In [8]:
ann_catids_np=np.array(ann_catids)
train_ann_catidx=[]
val_ann_catidx=[]
test_ann_catidx=[]
for catid in tqdm(catids):
    ann_idx_np = np.where(ann_catids_np==catid)[0] #annotation index of ids
    if not ann_idx_np.any(): continue
    #print("\n>> ", catid)
    
    #assert ann_idx_np
    u, c = np.unique(ann_idx_np, return_counts=True)
    duplicates_l= u[c > 1].tolist()
    assert len(duplicates_l)==0 #assert duplicates (above is already)
    assert all([True if ann_catids[i]==catid else False for i in ann_idx_np] ) #assert index belongs to catid
    
    #parameters
    train_size=len(ann_idx_np) * train // 100 #floor division
    val_size=len(ann_idx_np) * val // 100
    test_size=len(ann_idx_np) * test // 100
    
    #select data
    random.shuffle(ann_idx_np) 
    train_ann_catidx.extend(ann_idx_np[:train_size].tolist())
    val_ann_catidx.extend(ann_idx_np[train_size+1:train_size+val_size-1].tolist())
    test_ann_catidx.extend(ann_idx_np[train_size+val_size+1:train_size+val_size+test_size].tolist())


print((len(train_ann_catidx)/len(ann_catids))*100)
print((len(val_ann_catidx)/len(ann_catids))*100)
print((len(test_ann_catidx)/len(ann_catids))*100)

100%|████████████████████████████████████████████████████████████████████████████████| 80/80 [00:01<00:00, 55.25it/s]

69.99677846889257
19.983832528235396
9.989993599663661





In [9]:
l_dup=[train_ann_catidx, val_ann_catidx,test_ann_catidx ]
for i in l_dup:
    print('original: ', len(i))
    u, c = np.unique(np.array(i), return_counts=True)
    duplicates_l= u[c > 1].tolist()
    print('duplicate: ',len(duplicates_l))

original:  819138
duplicate:  0
original:  233861
duplicate:  0
original:  116908
duplicate:  0


### 4. Save splited jsons

In [10]:
percent_idx=[train_ann_catidx,val_ann_catidx, test_ann_catidx]
percent_names=['train', 'val', 'test']

In [11]:
newjson=copy.copy(molajson)

In [12]:
annotations=copy.copy(molajson['annotations']) 
for i, percent_i in enumerate(tqdm(percent_idx)):
    #get new annotations
    newjson['annotations']=[annotations[index] for index in percent_i]
    # save
    print('\n >> SAVING {}...'.format(percent_names[i]))
    outpath=rdir+'splitann_{}/'.format(infilename)
    assure_path_exists(outpath)
    outjsonfile=outpath+'{}.json'.format(percent_names[i]) #rdir+'{}_{}.json'.format(percent_names[i],infilename)
    with open(outjsonfile, 'w') as f:
        json.dump(newjson, f)
    print("JSON SAVED : {} \n".format(outjsonfile))
    for k in molajson:
        print(k, len(newjson[k]))

  0%|                                                                                          | 0/3 [00:00<?, ?it/s]


 >> SAVING train...


 33%|███████████████████████████                                                      | 1/3 [02:09<04:19, 129.99s/it]

JSON SAVED : D:/external_datasets/MOLA/annotations/splitann_coco2017_reorder_cleanclass/train.json 

info 6
licenses 8
images 123287
annotations 819138
categories 80

 >> SAVING val...


 67%|██████████████████████████████████████████████████████                           | 2/3 [02:48<01:42, 102.65s/it]

JSON SAVED : D:/external_datasets/MOLA/annotations/splitann_coco2017_reorder_cleanclass/val.json 

info 6
licenses 8
images 123287
annotations 233861
categories 80

 >> SAVING test...


100%|██████████████████████████████████████████████████████████████████████████████████| 3/3 [03:09<00:00, 63.32s/it]

JSON SAVED : D:/external_datasets/MOLA/annotations/splitann_coco2017_reorder_cleanclass/test.json 

info 6
licenses 8
images 123287
annotations 116908
categories 80





### 5. TEST SPLIT ANNOTATIONS DUPLICATES

In [13]:
injsonfile='mola_mix_aggressive.json'
outjsonfile=rdir+'split_{}/'.format(infilename)+'test.json'
# init json
molajson =  json.load(open(outjsonfile))
for k in molajson:
    print(k, len(molajson[k]))

FileNotFoundError: [Errno 2] No such file or directory: 'D:/external_datasets/MOLA/annotations/split_coco2017_reorder_cleanclass/test.json'

In [72]:
# annotations category_id
ann_ids=[]
for an in tqdm(molajson['annotations']):
    ann_ids.append(an['id'])
print(len(ann_ids))

#TEST duplicates v3 -faster
u, c = np.unique(np.array(ann_ids), return_counts=True)
duplicates_l= u[c > 1].tolist()
print(len(duplicates_l))

100%|██████████████████████████████████████████████████████████████████| 133266/133266 [00:00<00:00, 1497403.10it/s]

133266
0



