In [1]:
import pandas as pd
import glob
import os

from utils import remap, cats_to_list

In [2]:
def generate_inference_df(input_path='runs/detect/predict/labels/', conf_threshold=0.25):
    '''Generates dataframe of information from infrence output files.
    Args:
        input_path: (string) path to prediction labels files
        conf_thrreshold: (float) minimum confidence threshold for valid detection
    Returns:
        df: (pd.DataFrame) dataframe of inference output
    '''
    out = {}
    filelist = glob.glob(input_path + '*.txt')

    cat_df = pd.read_json('category_key.json')
    shallow = cat_df[cat_df.shallow_species == True]['index'].to_list()
    mapper = cat_df[['id', 'index']].to_dict()['id']

    for i, file in enumerate(filelist):
        with open(file, 'r') as f:

            cats = []
            conf = []
            location = []
            weak_shallow = 0
            strong_shallow = 0
            no_detection = 0

            for line in f.readlines():

                category, x, y, w, h, conf_value = line.split(' ')
                category = int(category)
                conf_value = float(conf_value)
                loc = tuple([float(i) for i in [x, y, w, h]])

                if category in shallow:
                    weak_shallow = 1    # weakly shallow if there is a shallow detection at any confidence
                    if conf_value >= conf_threshold:
                        strong_shallow = 1  # strongly shallow if a high conf shallow detection

                if (category not in cats) and (conf_value >= conf_threshold): # dedup and add to list
                    cats.append(category)
                    conf.append(conf_value)
                    location.append(loc)
            
            cats = remap(cats, mapper)

            if len(cats) == 0:
                no_detection = 1

        out[i] = {'id': os.path.basename(file)[:-4],
                  'categories': cats,
                  'location': location,
                  'conf': conf,
                  'weak_shallow': weak_shallow,
                  'strong_shallow': strong_shallow,
                  'no_detection': no_detection,
                  #   'osd': osd,
                  }

    df = pd.DataFrame.from_dict(out, orient='index')
    return df

In [4]:
# pd.read_json('supercat_key.json')

In [5]:
# pd.read_json('category_key.json')

In [84]:
cat_df = pd.read_json('category_key.json')
shallow = cat_df[cat_df.shallow_species == True]['id'].to_list()
shallow

[10, 51, 61, 103, 104, 105, 116, 119, 133, 160, 214, 259, 260, 274]

In [3]:
def generate_inference_df_sup(input_path='runs/detect/predict_superL/labels/', conf_threshold=0.15):
    '''Generates dataframe of information from infrence output files.
    Args:
        input_path: (string) path to prediction labels files
        conf_thrreshold: (float) minimum confidence threshold for valid detection
    Returns:
        df: (pd.DataFrame) dataframe of inference output
    '''
    out = {}
    filelist = glob.glob(input_path + '*.txt')

    cat_df = pd.read_json('category_key.json')
    scat_df = pd.read_json('supercat_key.json')
    shallow = cat_df[cat_df.shallow_species == True]['id'].to_list()
    mapper = scat_df[['top_category_id', 'supercat_id']].to_dict()['top_category_id']
    mapper = {str(k): mapper[k] for k in mapper.keys()}

    for i, file in enumerate(filelist):
        with open(file, 'r') as f:

            cats = []
            supercats = []
            conf = []
            location = []
            weak_shallow = 0
            strong_shallow = 0
            no_detection = 0

            for line in f.readlines():

                supercat, x, y, w, h, conf_value = line.split(' ')
                category = remap(supercat, mapper)[0]
                supercat = int(supercat)
                conf_value = float(conf_value)
                loc = tuple([float(i) for i in [x, y, w, h]])

                if category in shallow:
                    weak_shallow = 1    # weakly shallow if there is a shallow detection at any confidence
                    if conf_value >= conf_threshold:
                        strong_shallow = 1  # strongly shallow if a high conf shallow detection

                if (category not in cats) and (conf_value >= conf_threshold): # dedup and add to list
                    cats.append(category)
                    supercats.append(supercat)
                    conf.append(conf_value)
                    location.append(loc)
            

            if len(cats) == 0:
                no_detection = 1

        out[i] = {'id': os.path.basename(file)[:-4],
                  'supercategory': supercats,
                  'categories_s': cats,
                  'location_s': location,
                  'conf_s': conf,
                  'weak_shallow_s': weak_shallow,
                  'strong_shallow_s': strong_shallow,
                  'no_detection_s': no_detection
                  }

    df = pd.DataFrame.from_dict(out, orient='index')
    return df

In [97]:
df_sup = generate_inference_df_sup(conf_threshold=0.25)
df_sup

Unnamed: 0,id,supercategory,categories_s,location_s,conf_s,weak_shallow_s,strong_shallow_s,no_detection_s
0,9cdaaebd-acd7-403c-829c-af5ba519c8a7,[1],[37],"[(0.914333, 0.298541, 0.0802597, 0.0408547)]",[0.407764],0,0,0
1,49254842-f712-4875-ad09-57ebf8508bd1,"[7, 1, 2]","[214, 37, 119]","[(0.435158, 0.679759, 0.0353025, 0.0544208), (...","[0.850874, 0.750873, 0.472209]",1,1,0
2,c7f61e23-a0c1-4687-ad8a-b65a1db37fdb,[],[],[],[],0,0,1
3,67d55379-18ca-40ec-b9da-6aa7117e4e1a,[0],[160],"[(0.0714349, 0.58565, 0.114581, 0.127358)]",[0.857191],1,1,0
4,e6fdae2f-86ee-46c7-8c75-fba6a35b8f0d,[1],[37],"[(0.265701, 0.731048, 0.115989, 0.105286)]",[0.826337],0,0,0
...,...,...,...,...,...,...,...,...
10739,90508480-7df0-450f-bb1e-80462661cd7b,[],[],[],[],1,0,1
10740,01230b6e-f53c-431a-850c-b1cbe877f08b,[2],[119],"[(0.696914, 0.572213, 0.10094, 0.115592)]",[0.729477],1,1,0
10741,6af67476-63ea-4ac1-afc3-b79b3ef019bf,"[1, 6]","[37, 203]","[(0.858225, 0.395525, 0.0630371, 0.0601834), (...","[0.433666, 0.359158]",0,0,0
10742,5b9257ee-ee54-4936-a192-12facfa04be7,"[2, 1, 3, 0]","[119, 37, 10, 160]","[(0.305819, 0.701176, 0.0425898, 0.062373), (0...","[0.755491, 0.514154, 0.424596, 0.37873]",1,1,0


In [164]:
df_cat = generate_inference_df('runs/detect/predict40m/labels/', 0.5)
df_cat

Unnamed: 0,id,categories,location,conf,weak_shallow,strong_shallow,no_detection
0,9cdaaebd-acd7-403c-829c-af5ba519c8a7,[],[],[],0,0,1
1,49254842-f712-4875-ad09-57ebf8508bd1,"[219, 52, 218]","[(0.752803, 0.512912, 0.0655638, 0.0840544), (...","[0.68369, 0.591417, 0.535339]",0,0,0
2,c7f61e23-a0c1-4687-ad8a-b65a1db37fdb,[],[],[],0,0,1
3,67d55379-18ca-40ec-b9da-6aa7117e4e1a,[160],"[(0.0715589, 0.585966, 0.112564, 0.121046)]",[0.801862],1,1,0
4,e6fdae2f-86ee-46c7-8c75-fba6a35b8f0d,[51],"[(0.267414, 0.730147, 0.110058, 0.102909)]",[0.561465],1,1,0
...,...,...,...,...,...,...,...
10739,90508480-7df0-450f-bb1e-80462661cd7b,[],[],[],0,0,1
10740,01230b6e-f53c-431a-850c-b1cbe877f08b,[120],"[(0.695897, 0.572669, 0.100813, 0.11052)]",[0.832999],0,0,0
10741,6af67476-63ea-4ac1-afc3-b79b3ef019bf,[],[],[],0,0,1
10742,5b9257ee-ee54-4936-a192-12facfa04be7,"[119, 52]","[(0.306548, 0.701045, 0.0442856, 0.061971), (0...","[0.829782, 0.689188]",1,1,0


In [165]:
df = df_cat.merge(df_sup, on='id')
df

Unnamed: 0,id,categories,location,conf,weak_shallow,strong_shallow,no_detection,supercategory,categories_s,location_s,conf_s,weak_shallow_s,strong_shallow_s,no_detection_s
0,9cdaaebd-acd7-403c-829c-af5ba519c8a7,[],[],[],0,0,1,[1],[37],"[(0.914333, 0.298541, 0.0802597, 0.0408547)]",[0.407764],0,0,0
1,49254842-f712-4875-ad09-57ebf8508bd1,"[219, 52, 218]","[(0.752803, 0.512912, 0.0655638, 0.0840544), (...","[0.68369, 0.591417, 0.535339]",0,0,0,"[7, 1, 2]","[214, 37, 119]","[(0.435158, 0.679759, 0.0353025, 0.0544208), (...","[0.850874, 0.750873, 0.472209]",1,1,0
2,c7f61e23-a0c1-4687-ad8a-b65a1db37fdb,[],[],[],0,0,1,[],[],[],[],0,0,1
3,67d55379-18ca-40ec-b9da-6aa7117e4e1a,[160],"[(0.0715589, 0.585966, 0.112564, 0.121046)]",[0.801862],1,1,0,[0],[160],"[(0.0714349, 0.58565, 0.114581, 0.127358)]",[0.857191],1,1,0
4,e6fdae2f-86ee-46c7-8c75-fba6a35b8f0d,[51],"[(0.267414, 0.730147, 0.110058, 0.102909)]",[0.561465],1,1,0,[1],[37],"[(0.265701, 0.731048, 0.115989, 0.105286)]",[0.826337],0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10739,90508480-7df0-450f-bb1e-80462661cd7b,[],[],[],0,0,1,[],[],[],[],1,0,1
10740,01230b6e-f53c-431a-850c-b1cbe877f08b,[120],"[(0.695897, 0.572669, 0.100813, 0.11052)]",[0.832999],0,0,0,[2],[119],"[(0.696914, 0.572213, 0.10094, 0.115592)]",[0.729477],1,1,0
10741,6af67476-63ea-4ac1-afc3-b79b3ef019bf,[],[],[],0,0,1,"[1, 6]","[37, 203]","[(0.858225, 0.395525, 0.0630371, 0.0601834), (...","[0.433666, 0.359158]",0,0,0
10742,5b9257ee-ee54-4936-a192-12facfa04be7,"[119, 52]","[(0.306548, 0.701045, 0.0442856, 0.061971), (0...","[0.829782, 0.689188]",1,1,0,"[2, 1, 3, 0]","[119, 37, 10, 160]","[(0.305819, 0.701176, 0.0425898, 0.062373), (0...","[0.755491, 0.514154, 0.424596, 0.37873]",1,1,0


In [180]:
def detect_osd(row):
    shallow = (row.strong_shallow, row.strong_shallow_s, row.weak_shallow, row.weak_shallow_s)
    # no detections from either model - must be osd
    if row.no_detection and row.no_detection_s:
        row['osd'] = 1.0
        row['categories'] = [52] # setting to the most common deep object

    if row.no_detection and not row.no_detection_s:
        row['categories'] = row.categories_s
    # nothing detected by cat, something detected by super
    if shallow == (0,0,0,0):
        row['osd'] = 0.9
    if shallow == (0,0,0,1):
        row['osd'] = 0.7
    if shallow == (0,1,0,1):
        row['osd'] = 0.5
    
    if shallow == (0,0,1,0):
        row['osd'] = 0.4
    if shallow == (0,1,0,1):
        row['osd'] = 0.4
    if shallow == (0,0,1,1):
        row['osd'] = 0.3
    if shallow == (1,0,1,0):
        row['osd'] = 0.2
    if shallow == (0,1,1,1):
        row['osd'] = 0.1
    if shallow == (1,0,1,1):
        row['osd'] = 0.1
    if shallow == (1,1,1,1):
        row['osd'] = 0.0
        

    return row

In [None]:
import stat


stat ss w ws
0 0 0 0 0.9 
0 0 0 1 0.8



0 1 0 1 0.5
0 0 1 1 0.4
0 1 1 1 0.2
1 0 1 1
1 1 1 1 


In [182]:
out_df = df.apply(detect_osd, axis=1)

In [183]:
out_df[out_df['osd'].isnull()]

Unnamed: 0,id,categories,location,conf,weak_shallow,strong_shallow,no_detection,supercategory,categories_s,location_s,conf_s,weak_shallow_s,strong_shallow_s,no_detection_s,osd


In [143]:
def select_top(lst):
    return f'[{lst[0]}]'


In [185]:
def format_cat(lst):
    if len(lst) == 1:
        return f'[{lst[0]}]'
    else:
        return ' '.join([str(x) for x in lst])

In [184]:
out = out_df[['id', 'categories', 'osd']].copy()
out

Unnamed: 0,id,categories,osd
0,9cdaaebd-acd7-403c-829c-af5ba519c8a7,[37],1.0
1,49254842-f712-4875-ad09-57ebf8508bd1,"[219, 52, 218]",0.3
2,c7f61e23-a0c1-4687-ad8a-b65a1db37fdb,[52],1.0
3,67d55379-18ca-40ec-b9da-6aa7117e4e1a,[160],0.0
4,e6fdae2f-86ee-46c7-8c75-fba6a35b8f0d,[51],0.0
...,...,...,...
10739,90508480-7df0-450f-bb1e-80462661cd7b,[52],0.7
10740,01230b6e-f53c-431a-850c-b1cbe877f08b,[120],0.3
10741,6af67476-63ea-4ac1-afc3-b79b3ef019bf,"[37, 203]",1.0
10742,5b9257ee-ee54-4936-a192-12facfa04be7,"[119, 52]",0.0


In [186]:
# out.categories = out.categories.apply(select_top)
out.categories = out.categories.apply(format_cat)
out

Unnamed: 0,id,categories,osd
0,9cdaaebd-acd7-403c-829c-af5ba519c8a7,[37],1.0
1,49254842-f712-4875-ad09-57ebf8508bd1,219 52 218,0.3
2,c7f61e23-a0c1-4687-ad8a-b65a1db37fdb,[52],1.0
3,67d55379-18ca-40ec-b9da-6aa7117e4e1a,[160],0.0
4,e6fdae2f-86ee-46c7-8c75-fba6a35b8f0d,[51],0.0
...,...,...,...
10739,90508480-7df0-450f-bb1e-80462661cd7b,[52],0.7
10740,01230b6e-f53c-431a-850c-b1cbe877f08b,[120],0.3
10741,6af67476-63ea-4ac1-afc3-b79b3ef019bf,37 203,1.0
10742,5b9257ee-ee54-4936-a192-12facfa04be7,119 52,0.0


In [187]:
out.to_csv('submission_26.csv', index=False)

In [120]:
out_df.isnull().head(20)

Unnamed: 0,categories,categories_s,conf,conf_s,id,location,location_s,no_detection,no_detection_s,osd,strong_shallow,strong_shallow_s,supercategory,weak_shallow,weak_shallow_s
0,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
1,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
2,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
3,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
4,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
5,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
6,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
7,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
8,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
9,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False


In [5]:
# df.to_json('runs/predict.json')

In [6]:
pd.read_json('runs/predict133m.json')

Unnamed: 0,id,categories,location,conf,weak_shallow,strong_shallow,no_detection
0,9cdaaebd-acd7-403c-829c-af5ba519c8a7,[37],"[[0.39912000000000003, 0.428155, 0.450171, 0.8...",[0.44142699999999996],1,1,0
1,49254842-f712-4875-ad09-57ebf8508bd1,"[219, 218, 52]","[[0.970996, 0.800569, 0.058007699999999995, 0....","[0.849691, 0.601008, 0.491423]",0,0,0
2,c7f61e23-a0c1-4687-ad8a-b65a1db37fdb,[],[],[],0,0,1
3,67d55379-18ca-40ec-b9da-6aa7117e4e1a,[160],"[[0.8186859999999999, 0.389267, 0.110638, 0.09...",[0.7810469999999999],0,0,0
4,e6fdae2f-86ee-46c7-8c75-fba6a35b8f0d,"[51, 52, 283]","[[0.264021, 0.7309869999999999, 0.114717, 0.10...","[0.488928, 0.318465, 0.275636]",0,0,0
...,...,...,...,...,...,...,...
10739,90508480-7df0-450f-bb1e-80462661cd7b,[1],"[[0.6215609999999999, 0.0895358, 0.06382059999...",[0.48776899999999995],0,0,0
10740,01230b6e-f53c-431a-850c-b1cbe877f08b,[88],"[[0.697519, 0.569529, 0.0993145, 0.11152799999...",[0.8101659999999999],0,0,0
10741,6af67476-63ea-4ac1-afc3-b79b3ef019bf,[],[],[],0,0,1
10742,5b9257ee-ee54-4936-a192-12facfa04be7,"[9, 119, 88, 160, 51]","[[0.360134, 0.615629, 0.0440992, 0.110397], [0...","[0.7970959999999999, 0.793121, 0.634007, 0.584...",0,0,0


In [4]:
pd.read_json('runs/predict40m.json')

Unnamed: 0,id,categories,location,conf,weak_shallow,strong_shallow,no_detection
0,9cdaaebd-acd7-403c-829c-af5ba519c8a7,[],[],[],0,0,1
1,49254842-f712-4875-ad09-57ebf8508bd1,"[219, 52, 218, 242]","[[0.752803, 0.5129119999999999, 0.065563799999...","[0.68369, 0.591417, 0.535339, 0.29825599999999...",0,0,0
2,c7f61e23-a0c1-4687-ad8a-b65a1db37fdb,[],[],[],0,0,1
3,67d55379-18ca-40ec-b9da-6aa7117e4e1a,[160],"[[0.0715589, 0.585966, 0.112564, 0.121046]]",[0.801862],1,1,0
4,e6fdae2f-86ee-46c7-8c75-fba6a35b8f0d,[51],"[[0.267414, 0.730147, 0.11005799999999999, 0.1...",[0.561465],1,1,0
...,...,...,...,...,...,...,...
10739,90508480-7df0-450f-bb1e-80462661cd7b,[],[],[],0,0,1
10740,01230b6e-f53c-431a-850c-b1cbe877f08b,[120],"[[0.695897, 0.572669, 0.100813, 0.110520000000...",[0.8329989999999999],0,0,0
10741,6af67476-63ea-4ac1-afc3-b79b3ef019bf,[205],"[[0.750776, 0.728356, 0.0532152, 0.105277]]",[0.432376],0,0,0
10742,5b9257ee-ee54-4936-a192-12facfa04be7,"[119, 52, 160, 9]","[[0.306548, 0.7010449999999999, 0.0442856, 0.0...","[0.8297819999999999, 0.689188, 0.4957209999999...",1,1,0
