In [10]:
import numpy as np
import pandas as pd
import glob
import os
import tqdm
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [11]:
def cm_to_iou(cm):
    tps = np.diag(cm)
    fps = cm.sum(axis=0) - tps
    fns = cm.sum(axis=1) - tps
    iou = tps / (tps + fps + fns)
    iou = iou[~np.isnan(iou)]
    return iou.squeeze()

def cm_to_miou(cm):
    iou = cm_to_iou(cm)
    return iou.mean()

def read_metrics(metrics_path):
    with open(metrics_path, 'r') as f:
        lines = f.readlines()
        miou = float(lines[-1].split(':')[-1].strip())
    return miou

def add2dict(metrics_dict, label_set, seg_method, lulc_type, d3_type, res, refine_type, dataset, trajectory, cm, miou, jitter, jitter_type):
    metrics_dict['dataset'].append(dataset)
    metrics_dict['trajectory'].append(trajectory)
    metrics_dict['label_set'].append(label_set)
    metrics_dict['seg_method'].append(seg_method)
    metrics_dict['lulc_type'].append(lulc_type)
    metrics_dict['d3_type'].append(d3_type)
    metrics_dict['res'].append(res)
    metrics_dict['refine_type'].append(refine_type)
    metrics_dict['cm'].append(cm)
    # metrics_dict['iou'].append(iou)
    metrics_dict['miou'].append(miou)

    metrics_dict['jitter'].append(jitter)
    metrics_dict['jitter_type'].append(jitter_type)


In [12]:
common_idx2names = {
    0 : 'water',
    1 : 'trees',
    2 : 'low_vegetation',
    3 : 'built',
    4 : 'ground',
    5 : 'sky',
}

more_common_idx2names = {
    0 : 'water',
    1 : 'vegetation',
    2 : 'built',
    3 : 'ground',
    4 : 'sky',
}

most_common_idx2names = {
    0 : 'water',
    1 : 'ground',
    2 : 'sky',
}

In [13]:
cm_paths = glob.glob('./jitter_outputs/*/*/*/*/*/*/*/*/*/confusion_matrix.npy')

In [14]:
metrics_dict = {
    'dataset' : [],
    'trajectory' : [],
    'label_set' : [],
    'seg_method' : [],
    'lulc_type' : [],
    'd3_type' : [],
    'res' : [],
    'refine_type' : [],
    'cm' : [],
    # 'iou' : [],
    'miou' : [],
    'jitter' : [],
    'jitter_type' : [],
}

for cm_filepath in tqdm.tqdm(cm_paths):
    iou_path = cm_filepath.replace('confusion_matrix.npy', 'metrics.txt')
    miou = read_metrics(iou_path)

    label_set, seg_method, lulc_type, jitter, d3_type, res, refine_type, dataset, trajectory = cm_filepath.split(os.path.sep)[-10:-1]
    cm = np.load(cm_filepath)
    iou = cm_to_iou(cm)
    miou = iou.mean()
    gps_jitter, altitude_jitter, imu_jitter = map(float, jitter.split('_'))
    if gps_jitter == 0 and altitude_jitter == 0:
        jitter_type = 'imu'
        jitter = imu_jitter
    elif gps_jitter == 0 and imu_jitter == 0:
        jitter_type = 'altitude'
        jitter = altitude_jitter
    elif imu_jitter == 0 and altitude_jitter == 0:
        jitter_type = 'gps'
        jitter = gps_jitter

    if (jitter_type == 'gps' or jitter_type == 'altitude') and jitter > 10:
        continue
    
    add2dict(metrics_dict, label_set, seg_method, lulc_type, d3_type, res, refine_type, dataset, trajectory, cm, miou, jitter, jitter_type)
    



invalid value encountered in divide

100%|██████████| 792/792 [00:00<00:00, 5954.06it/s]


In [15]:
df = pd.DataFrame.from_dict(metrics_dict)

In [16]:
unique_ids = ['label_set', 'jitter_type', 'jitter']
grouped_df = df.groupby(unique_ids)

In [17]:
counts = grouped_df['miou'].transform('count')
cm_sum = grouped_df['cm'].transform('sum')
trajectory_avg_miou = grouped_df['miou'].transform('mean')
trajectory_miou_std = grouped_df['miou'].transform('std')

df['trajectory_count'] = counts
df['cm_sum'] = cm_sum
df['trajectory_avg_miou'] = trajectory_avg_miou
df['trajectory_miou_std'] = trajectory_miou_std
df['miou (total)'] = df.apply(lambda x: cm_to_miou(x['cm_sum']), axis=1)

df.sort_values(by=['jitter'], inplace=True)

In [18]:
fig = px.line(df, x='jitter', y='trajectory_avg_miou',
              color='label_set', facet_col='jitter_type', facet_col_spacing=0.04, markers=True, title=None,
              color_discrete_map={
                    'common': '#FF7912',
                    'more_common': '#0091FF',
                    'most_common': '#2E5C80'
              },
              category_orders={'jitter_type': ['altitude', 'gps', 'imu']},
              symbol='label_set',
              symbol_map={
                    'common': 'circle',
                    'more_common': 'circle',
                    'most_common': 'x',
              }
              )
fig.update_xaxes(matches=None)

for anno in fig['layout']['annotations']:
    anno['text'] = ''

fig.update_xaxes(range=[0, 11], row=1, col=1)
fig.update_xaxes(range=[0, 11], row=1, col=2)
fig.update_yaxes(range=[0.3, 0.85])
fig.update_xaxes(title_text=None, nticks=6)
fig.update_xaxes(title_text=None, nticks=6, tickfont=dict(size=16), tickangle=0, tickformat=".1f", row=1, col=3)
fig.update_layout(
    title=None,
    legend=dict(
        title=None,
        orientation="h",
        yanchor="top",
        y=-0.2,
        xanchor="right",
        x=1,
        # hide

    ),
    showlegend=False,
    template='simple_white',
    width=700,
    height=200,
    margin=dict(l=20, r=20, t=10, b=0),
    font=dict(size=16)
)
fig.add_vline(x=2, line_width=1, line_dash="dot", annotation_text="2m", annotation_position='right', annotation_font=dict(size=16), line_color="black", row=1, col=1)
fig.add_vline(x=2, line_width=1, line_dash="dot", annotation_text="2m", annotation_position='right', annotation_font=dict(size=16), line_color="black", row=1, col=2)
fig.add_vline(x=1.8, line_width=1, line_dash="dot", annotation_text="1.8&deg;", annotation_font=dict(size=16), annotation_position='right', line_color="black", row=1, col=3)
fig.update_layout(yaxis_title=None)
fig.update_layout(xaxis_title=None)
fig.show()
fig.write_image('jitter/jitter_miou.pdf')
fig.write_image('jitter/jitter_miou.pdf')