# split datasets
version: 1

info:
- split json into train.json, val.json and test.json 

author: nuno costa

In [1]:
from annotate_v5 import *
import platform 
import numpy as np
import time
import pandas as pd
from IPython.display import Image, display
import copy
import os
from shutil import copyfile
import matplotlib.pyplot as plt
from matplotlib.image import imread
from matplotlib.patches import Rectangle
import random

#Define root dir dependent on OS
rdir='D:/external_datasets/MOLA/annotations/' 
if str(platform.platform()).find('linux')>-1: rdir=rdir.replace('D:/','/mnt/d/')
print('OS: {}'.format(platform.platform()))
print('root dir: {}'.format(rdir))

OS: Windows-10-10.0.20279-SP0
root dir: D:/external_datasets/MOLA/annotations/


## 1. Init vars

In [2]:
train=70
val=20
test=100-(train+val)
injsonfile='mola_fix_equal.json'
infilename=injsonfile.split('.')[0]

In [3]:
# init json
molajson =  json.load(open(rdir+injsonfile))
for k in molajson:
    print(k, len(molajson[k]))

info 5
licenses 9
categories 1261
videos 1488
images 177936
tracks 8132
segment_info 0
annotations 1338002
datasets 2


## 2. Import ids
#### #NOTE: work with ids and index so you can use numpy for faster operations

In [4]:
# categories id
catids=[]
for c in molajson['categories']:
    catids.append(c['id'])

In [5]:
# annotations category_id
ann_catids=[]
ann_ids=[]
for an in tqdm(molajson['annotations']):
    ann_catids.append(an['category_id'])
    ann_ids.append(an['id'])
print(len(ann_ids))

100%|████████████████████████████████████████████████████████████████| 1338002/1338002 [00:00<00:00, 1406837.80it/s]

1338002





In [6]:
#TEST dupplicates v1 - slow
# duplicates_l=list(set([x for x in ann_ids if ann_ids.count(x) > 1])) # duplicates l 
#TEST dupplicates v2 - fast
#from collections import Counter
#duplicates_l=[item for item, count in Counter(ann_ids).items() if count > 1]
#TEST duplicates v3 -faster
u, c = np.unique(np.array(ann_ids), return_counts=True)
duplicates_l= u[c > 1].tolist()
print(len(duplicates_l))

0


## 3. split annotations
#QUESTION Seeded random or not?

In [59]:
ann_catids_np=np.array(ann_catids)
train_ann_catidx=[]
val_ann_catidx=[]
test_ann_catidx=[]
for catid in tqdm(catids):
    ann_idx_np = np.where(ann_catids_np==catid)[0] #annotation index of ids
    if not ann_idx_np.any(): continue
    #print("\n>> ", catid)
    
    #assert ann_idx_np
    u, c = np.unique(ann_idx_np, return_counts=True)
    duplicates_l= u[c > 1].tolist()
    assert len(duplicates_l)==0 #assert duplicates (above is already)
    assert all([True if ann_catids[i]==catid else False for i in ann_idx_np] ) #assert index belongs to catid
    
    #parameters
    train_size=len(ann_idx_np) * train // 100 #floor division
    val_size=len(ann_idx_np) * val // 100
    test_size=len(ann_idx_np) * test // 100
    
    #select data
    random.shuffle(ann_idx_np) 
    train_ann_catidx.extend(ann_idx_np[:train_size].tolist())
    val_ann_catidx.extend(ann_idx_np[train_size+1:train_size+val_size-1].tolist())
    test_ann_catidx.extend(ann_idx_np[train_size+val_size+1:train_size+val_size+test_size].tolist())
    
    """
    OLD  #BUG . np.random.choice() duplicates
    remain_idx_np=ann_idx_np.copy() #start 100%
    #train
    train_idx_np = np.random.choice(remain_idx_np, train_size)
    train_ann_catidx.extend(train_idx_np.tolist())
    remain_idx_np=remain_idx_np[~np.in1d(remain_idx_np,train_idx_np)]
    #val
    val_idx_np = np.random.choice(remain_idx_np, val_size)
    val_ann_catidx.extend(val_idx_np.tolist())
    remain_idx_np=remain_idx_np[~np.in1d(remain_idx_np,val_idx_np)]
    #test
    test_idx_np = np.random.choice(remain_idx_np, test_size)
    test_ann_catidx.extend(test_idx_np.tolist())
    remain_idx_np=remain_idx_np[~np.in1d(remain_idx_np,test_idx_np)]
    """
    
    

print((len(train_ann_catidx)/len(ann_catids))*100)
print((len(val_ann_catidx)/len(ann_catids))*100)
print((len(test_ann_catidx)/len(ann_catids))*100)

100%|██████████████████████████████████████████████████████████████████████████| 1261/1261 [00:03<00:00, 323.67it/s]

69.98823619097729
19.93248141632075
9.960074798094471





In [60]:
l_dup=[train_ann_catidx, val_ann_catidx,test_ann_catidx ]
for i in l_dup:
    print('original: ', len(i))
    u, c = np.unique(np.array(i), return_counts=True)
    duplicates_l= u[c > 1].tolist()
    print('duplicate: ',len(duplicates_l))

original:  936444
duplicate:  0
original:  266697
duplicate:  0
original:  133266
duplicate:  0


### 4. Save splited jsons

In [8]:
percent_idx=[train_ann_catidx,val_ann_catidx, test_ann_catidx]
percent_names=['train', 'val', 'test']

In [9]:
annotations=copy.copy(molajson['annotations']) #reset annotations
for i, percent_i in enumerate(tqdm(percent_idx)):
    #get new annotations
    molajson['annotations']=[annotations[index] for index in percent_i]
    # save
    print('\n >> SAVING {}...'.format(percent_names[i]))
    outjsonfile=rdir+'split_{}/'.format(infilename)+'{}.json'.format(percent_names[i]) #rdir+'{}_{}.json'.format(percent_names[i],infilename)
    with open(outjsonfile, 'w') as f:
        json.dump(molajson, f)
    print("JSON SAVED : {} \n".format(outjsonfile))
    for k in molajson:
        print(k, len(molajson[k]))

  0%|                                                                                          | 0/3 [00:00<?, ?it/s]


 >> SAVING train...


 33%|███████████████████████████                                                      | 1/3 [02:33<05:06, 153.09s/it]

JSON SAVED : D:/external_datasets/MOLA/annotations/train.json 

info 5
licenses 9
categories 1261
videos 1488
images 177936
tracks 8132
segment_info 0
annotations 936444
datasets 2

 >> SAVING val...


 67%|██████████████████████████████████████████████████████                           | 2/3 [03:22<02:01, 121.99s/it]

JSON SAVED : D:/external_datasets/MOLA/annotations/val.json 

info 5
licenses 9
categories 1261
videos 1488
images 177936
tracks 8132
segment_info 0
annotations 267463
datasets 2

 >> SAVING test...


100%|██████████████████████████████████████████████████████████████████████████████████| 3/3 [03:50<00:00, 76.84s/it]

JSON SAVED : D:/external_datasets/MOLA/annotations/test.json 

info 5
licenses 9
categories 1261
videos 1488
images 177936
tracks 8132
segment_info 0
annotations 133641
datasets 2





### 5. TEST SPLIT ANNOTATIONS DUPLICATES

In [5]:
injsonfile='mola_mix_aggressive.json'
outjsonfile=rdir+'split_{}/'.format(infilename)+'test.json'
# init json
molajson =  json.load(open(outjsonfile))
for k in molajson:
    print(k, len(molajson[k]))

info 5
licenses 9
categories 1261
videos 1488
images 177936
tracks 8132
segment_info 0
annotations 133641
datasets 2


In [6]:
a = molajson['annotations']

In [10]:
#compare
comp=[]
for i,d1 in enumerate(tqdm(a)):
    for ii,d2 in enumerate(a[i+1:]): 
        comp.append(d1 == d2)
        if d1 == d2:
            print(">> d1: ", d1, i)
            print(">> d2: ", d2, i+1+ii)
            print(">> diff: ", 1+ii)
    if i==20:break

  0%|                                                                           | 4/133641 [00:00<2:11:10, 16.98it/s]

>> d1:  {'segmentation': [[261.3, 350.1, 250.41, 324.93, 272.86, 286.15, 285.11, 266.43, 304.16, 258.94, 307.56, 229.01, 321.17, 216.76, 336.81, 225.61, 346.34, 235.81, 346.34, 268.47, 374.23, 285.47, 371.51, 318.81, 398.72, 350.1, 384.44, 377.32, 344.98, 355.55, 320.49, 350.1, 305.52, 335.82, 288.51, 321.53, 287.83, 309.29, 276.27, 324.93, 272.18, 352.83]], 'num_keypoints': 12, 'area': 11012.1809, 'iscrowd': 0, 'keypoints': [317, 269, 2, 328, 261, 2, 311, 259, 2, 342, 258, 2, 0, 0, 0, 362, 290, 2, 290, 272, 2, 328, 345, 2, 266, 317, 2, 316, 308, 2, 269, 339, 2, 367, 367, 2, 315, 348, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'image_id': 750, 'bbox': [250.41, 216.76, 148.31, 160.56], 'category_id': 1, 'id': 920879, 'dataset': 1} 1
>> d2:  {'segmentation': [[261.3, 350.1, 250.41, 324.93, 272.86, 286.15, 285.11, 266.43, 304.16, 258.94, 307.56, 229.01, 321.17, 216.76, 336.81, 225.61, 346.34, 235.81, 346.34, 268.47, 374.23, 285.47, 371.51, 318.81, 398.72, 350.1, 384.44, 377.32, 344.98, 355.5

  0%|                                                                          | 10/133641 [00:00<2:08:31, 17.33it/s]

>> d1:  {'segmentation': [[357.49, 235.88, 358.27, 238.98, 358.14, 239.76, 356.33, 239.63, 355.94, 241.57, 355.68, 242.73, 354.13, 243.89, 353.61, 245.58, 353.74, 246.87, 352.32, 247.0, 353.09, 248.42, 351.41, 249.07, 349.86, 251.27, 348.57, 261.48, 348.44, 263.17, 351.41, 263.42, 353.35, 262.39, 353.87, 259.8, 353.09, 257.35, 354.13, 256.44, 355.55, 255.53, 356.2, 256.83, 355.94, 258.9, 355.94, 261.61, 356.2, 264.46, 358.27, 265.49, 359.43, 263.81, 359.95, 261.1, 359.43, 258.77, 358.66, 256.96, 359.69, 254.63, 362.28, 253.08, 363.83, 251.91, 364.6, 250.49, 366.93, 248.68, 365.77, 245.32, 364.99, 242.08, 363.7, 241.31, 362.92, 241.18, 363.18, 240.01, 362.54, 239.63, 362.54, 238.46, 363.83, 238.46, 362.41, 236.91, 361.11, 235.62, 359.82, 235.62]], 'area': 273.03360000000026, 'iscrowd': 0, 'image_id': 98627, 'bbox': [348.44, 235.62, 18.49, 29.87], 'category_id': 1, 'id': 568605, 'dataset': 1} 5
>> d2:  {'segmentation': [[357.49, 235.88, 358.27, 238.98, 358.14, 239.76, 356.33, 239.63, 355

  0%|                                                                          | 16/133641 [00:00<2:05:11, 17.79it/s]

>> d1:  {'segmentation': [[597.98, 135.59, 597.12, 128.24, 598.41, 124.79, 602.3, 123.92, 606.62, 126.52, 607.48, 135.59, 606.19, 137.32, 614.4, 143.8, 613.53, 154.17, 613.53, 159.78, 612.67, 162.37, 610.51, 163.24, 609.21, 164.97, 608.35, 183.54, 607.92, 192.62, 601.87, 193.48, 600.57, 202.12, 599.71, 205.58, 604.89, 207.31, 603.16, 209.03, 594.96, 209.03, 594.52, 202.99, 594.52, 195.64, 594.09, 191.32, 591.07, 194.34, 589.77, 204.28, 589.34, 209.47, 581.13, 207.74, 583.72, 199.53, 584.59, 187.86, 588.04, 168.42, 588.04, 164.1, 587.61, 159.35, 585.02, 156.76, 585.45, 151.14, 587.18, 144.66, 590.2, 139.91, 595.39, 136.02, 597.12, 134.29]], 'area': 1671.5567000000008, 'iscrowd': 0, 'image_id': 69623, 'bbox': [581.13, 123.92, 33.27, 85.55], 'category_id': 1, 'id': 329351, 'dataset': 1} 14
>> d2:  {'segmentation': [[597.98, 135.59, 597.12, 128.24, 598.41, 124.79, 602.3, 123.92, 606.62, 126.52, 607.48, 135.59, 606.19, 137.32, 614.4, 143.8, 613.53, 154.17, 613.53, 159.78, 612.67, 162.37, 61

  0%|                                                                          | 20/133641 [00:01<2:11:17, 16.96it/s]

>> d1:  {'segmentation': [[347.64, 388.87, 338.66, 432.64, 334.73, 445.55, 309.48, 444.98, 319.58, 414.68, 324.63, 394.48, 326.31, 383.26, 323.51, 368.11, 323.51, 362.49, 333.61, 366.42, 342.59, 377.08, 345.95, 382.7, 347.64, 385.5, 347.64, 386.62]], 'num_keypoints': 8, 'area': 1726.499, 'iscrowd': 0, 'keypoints': [324, 393, 1, 326, 391, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 332, 402, 2, 327, 405, 2, 333, 387, 2, 0, 0, 0, 330, 372, 2, 0, 0, 0, 330, 427, 2, 321, 428, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'image_id': 3493, 'bbox': [309.48, 362.49, 38.16, 83.06], 'category_id': 1, 'id': 926218, 'dataset': 1} 19
>> d2:  {'segmentation': [[347.64, 388.87, 338.66, 432.64, 334.73, 445.55, 309.48, 444.98, 319.58, 414.68, 324.63, 394.48, 326.31, 383.26, 323.51, 368.11, 323.51, 362.49, 333.61, 366.42, 342.59, 377.08, 345.95, 382.7, 347.64, 385.5, 347.64, 386.62]], 'num_keypoints': 8, 'area': 1726.499, 'iscrowd': 0, 'keypoints': [324, 393, 1, 326, 391, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 332, 402, 2, 327, 40


