In [115]:
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import rc

from lowAltitude_classification.gsd_utils import papermode

In [116]:
font_size = 8
papermode(plt, font_size, has_latex=True)
rc('text.latex', preamble=r'\usepackage{color}')
matplotlib.use('ps')

In [117]:
# data = pd.read_csv('results/scaling-test-METRICS.csv')
# data = pd.read_csv('results/scaling-test-METRICS_version2.csv')
data = pd.read_csv('results/scaling-test-METRICS_versionFinal2.csv')

num_images = 143065
data['F1'] = data['F1'] * 100

factors = []
fractions = []
for idx, row in data.iterrows():
    fraction = str(row['factor']).replace('.', '/')
    factor = eval(fraction)
    factors.append(factor)
    fractions.append(fraction)
data['factor'] = factors
data['fraction'] = fractions

data['num_images'] = (num_images * data['factor']).astype(int)
data = data.sort_values('num_images')

data

Unnamed: 0,factor,F1,pAcc,fraction,num_images
10,0.001953,20.95,0.3403,1/512,279
9,0.003906,27.97,0.4052,1/256,558
8,0.007812,32.23,0.4428,1/128,1117
7,0.015625,34.89,0.4703,1/64,2235
6,0.03125,35.12,0.4791,1/32,4470
5,0.0625,36.33,0.4755,1/16,8941
4,0.125,38.04,0.4728,1/8,17883
3,0.25,37.69,0.4891,1/4,35766
2,0.5,37.82,0.486,1/2,71532
1,0.75,38.93,0.4936,3/4,107298


In [128]:
width = 3.5
# height = width / 1.618
height = 1.5
fig, ax = plt.subplots(figsize=(width, height))

ax.plot(data['num_images'], data['F1'], color='teal', marker='o')
ax.set_xlabel('Pre-training Set Size (\# images)', labelpad=0)
ax.set_ylabel(r"$F1$ Score on \hspace{2.2em} (\%)")
ax.text(-0.105, 0.75, r"$D_{test}^{drone}$", fontsize=font_size, color="blue", transform=ax.transAxes, ha='center',
        va='center', rotation=90)
ax.set_xscale('log', base=2)
ax.grid(True, which="major", ls="--", color='gray')
ax.set_xlim(200, num_images + 50000)
ax.set_ylim(15, 50)

# xticks bottom
bottom_labels = data['num_images'].tolist()
ax.set_xticks(data['num_images'], bottom_labels, rotation=45, ha='right', rotation_mode="anchor")
ax.tick_params(axis='x', pad=0)

# xticks top
ax2 = ax.twiny()
ax2.set_xscale('log', base=2)
ax2.set_xlim(ax.get_xlim())
top_labels = []
for i, (n, f) in enumerate(zip(data['num_images'], data['fraction'])):
    top_label = f
    if f == '1/1':
        top_label = '1'
    if f == '3/4':
        top_label = r'$\frac{3}{4}$'
    top_labels.append(top_label)
ax2.set_xticks(data['num_images'], top_labels)

# Lines for other datasets and methods
ax.text(1_200, 43, 'Pseudo-labels', color='chocolate', ha='right', backgroundcolor='white', bbox=dict(facecolor='white', alpha=1, edgecolor='none', boxstyle='round,pad=0.'))
ax.hlines(40.69, 200, 1e6, linestyles='dashed', color='chocolate', label='Pseudo-labels')
ax.text(180_000, 21.0, 'Supervised', color='black', ha='right', backgroundcolor='white', bbox=dict(facecolor='white', alpha=1, edgecolor='none', boxstyle='round,pad=0'))
ax.hlines(26.83, 200, 1e6, linestyles='dashed', color='black', label='Supervised')
ax.text(19_500, 43, r'Soltani \textit{et al.} [33]', color='green', ha='right', backgroundcolor='white', bbox=dict(facecolor='white', alpha=1, edgecolor='none', boxstyle='round,pad=0.'))
ax.vlines(20_523, 15, 50, linestyles='dashed', color='green', label='Soltani et al. [35]')

fig.subplots_adjust(top=0.835, bottom=0.305, left=0.12, right=0.99)

fig.savefig('results/M2F_scaling.pdf')
fig.savefig('results/M2F_scaling.png')

fig.show()

  fig.show()
