In [None]:
import os
import re
import tempfile
import warnings
import collections
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

from google3.learning.deepmind.xmanager2.client import xmanager_api
from google3.pyglib import gfile
from google3.pyglib.function_utils import memoize
from matplotlib import font_manager

import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.lines import Line2D
from matplotlib.ticker import FixedLocator  # Import for the fix
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import LogLocator
from matplotlib.ticker import FuncFormatter

In [None]:
#@title Google Sans Import

# Import Google font family
_GOOGLE_SANS_PATH = (
    'google3/third_party/googlefonts/api/googlerestricted/googlesans/'
)

@memoize.Memoize()
def import_google3_fonts(font_path: str) -> None:
  """Import fonts stored in google3 into Matplotlib for use in Colab.

  Args:
    font_path: google3 path to either a directory that contains .ttf fonts or to
      a specific .ttf font file.
  """
  if gfile.IsDirectory(font_path):
    # Create a temp directory as a destination for copied font files.
    tmp_dir = tempfile.mkdtemp()
    # Copy font files from google3 to temp dir.
    gfile.RecursivelyCopyDir(font_path, tmp_dir, overwrite=True)
    # Add font files in directory to matplotlib font_manager.
    font_files = font_manager.findSystemFonts(fontpaths=tmp_dir)
  else:
    # Assume the path points to a file if it's not a directory.
    # Copy ttf file from google3 to temp location.
    tmp_file = tempfile.NamedTemporaryFile(suffix='.ttf')
    tmp_file.close()
    gfile.Copy(font_path, tmp_file.name)
    font_files = [tmp_file.name]

  # Add fonts to default font manager.
  for font_file in font_files:
    font_manager.fontManager.addfont(font_file)


def import_default_google_fonts() -> None:
  """Register a set of default fonts (Roboto, Google Sans) with Matplotlib."""
  # Prepend google_src to google3 paths.
  import_google3_fonts(os.path.join('/google_src/head/depot', _GOOGLE_SANS_PATH))


# Import and register Google fonts with Matplotlib so we can use them.
import_default_google_fonts()

In [None]:
#@title Set up Plot Settings

pd.set_option('display.max_rows', None)  # Show all rows
pd.set_option('display.max_columns', None)  # Show all columns
# Suppress specific warning
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 12
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'
plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE-5)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')
mpl.rcParams['font.family'] = 'Google Sans'

def log_float_formatter(y, pos):
    return f'{y:.2f}'


In [None]:
# Create the figure and axes for the subplots
def log_float_formatter(y, pos):
    return f'{y:.2f}'

#@title Data vs. Subject Scaling

map_results = {
    'num_of_shots': [5, 10, 15, 20],
    'lsm_ft': [18.558182325716704, 18.594820292849427, 21.030589082598738, 24.356139254039988],
    # 'lsm_conv_lp': [18.22190621877023, 20.018072182267236, 20.018072182267236, 22.597159887334872],
    'lsm_lp': [14.712633201006353, 14.804711532807987, 15.451279084073391, 17.756282961336474],
    'supervised_baseline': [12.781144280731132, 13.56088290669249, 13.222288420289424, 15.153588158587056]
}


df = pd.DataFrame(map_results)

fig, axes = plt.subplots(1, 1, figsize=(3, 3), sharex=True, dpi=100)
# Create a color palette using the "Blues" color map
palette = sns.color_palette("Blues", n_colors=3)

# Plot the lines with the "Blues" palette
sns.lineplot(
    data=map_results,
    x='num_of_shots',
    y='lsm_ft',
    color=palette[2],
    marker='o',
    label='LSM + FT',
    markersize=5,
    linestyle='--',
)

# sns.lineplot(
#     data=map_results,
#     x='num_of_shots',
#     y='lsm_conv_lp',
#     color=palette[2],
#     marker='o',
#     label='LSM + Conv Probe',
#     markersize=5,
#     linestyle='--',
# )

sns.lineplot(
    data=map_results,
    x='num_of_shots',
    y='lsm_lp',
    color=palette[1],
    marker='o',
    label='LSM + Linear Probe',
    markersize=5,
    linestyle='--',
)

sns.lineplot(
    data=map_results,
    x='num_of_shots',
    y='supervised_baseline',
    color=palette[0],
    marker='o',
    label='Baseline',
    markersize=5,
    linestyle='--',
)
axes.set_ylabel('mAP')
axes.set_xlabel('Number of Shots')
plt.legend(frameon=False, fontsize=MEDIUM_SIZE-2)
plt.tight_layout()
plt.savefig("/tmp/few_shot_lsm_map.pdf", bbox_inches='tight', format="pdf")
%download_file /tmp/few_shot_lsm_map.pdf
plt.show()


In [None]:
accuracy_results = {
    'num_of_shots': [5, 10, 15, 20],
    'lsm_ft': [19.374616799503563, 30.104230533405858, 36.57265481298695, 51.16492949109407],
    # 'lsm_conv_lp': [12.660944206004702, 20.570202329852677, 27.86633966890623, 27.86633966890623, 36.879215205384156],
    'lsm_lp': [12.231759656648611, 20.079705702017144, 21.030042918448487, 22.256284488037323],
    'supervised_baseline': [10.33108522378592, 16.400980993250645, 16.27835683629176, 18.5469037400311]
}

fig, axes = plt.subplots(1, 1, figsize=(3.3, 3), sharex=True, dpi=100)
# Create a color palette using the "Blues" color map
palette = sns.color_palette("Blues", n_colors=3)

# Plot the lines with the "Blues" palette
sns.lineplot(
    data=accuracy_results,
    x='num_of_shots',
    y='lsm_ft',
    color=palette[2],
    marker='o',
    label='LSM + FT',
    markersize=5,
    linestyle='--',
)

# sns.lineplot(
#     data=accuracy_results,
#     x='num_of_shots',
#     y='lsm_conv_lp',
#     color=palette[2],
#     marker='o',
#     label='LSM + Conv Probe',
#     markersize=5,
#     linestyle='--',
# )

sns.lineplot(
    data=accuracy_results,
    x='num_of_shots',
    y='lsm_lp',
    color=palette[1],
    marker='o',
    label='LSM + Linear Probe',
    markersize=5,
    linestyle='--',
)

sns.lineplot(
    data=accuracy_results,
    x='num_of_shots',
    y='supervised_baseline',
    color=palette[0],
    marker='o',
    label='Baseline',
    markersize=5,
    linestyle='--',
)
axes.set_ylabel('Accuracy (%)')
axes.set_xlabel('Number of Shots')
plt.legend(frameon=False, fontsize=MEDIUM_SIZE-2)
plt.tight_layout()
plt.savefig("/tmp/few_shot_lsm_accuracy.pdf", bbox_inches='tight', format="pdf")
%download_file /tmp/few_shot_lsm_accuracy.pdf
plt.show()
