In [1]:
#@title Install kfac_jax

# need version of kfac_jax from after September 13 due to:
# https://github.com/google-deepmind/kfac-jax/issues/142

!pip uninstall kfac-jax -y
!pip install git+https://github.com/google-deepmind/kfac-jax@d3643a1ad85cd34ce9ec096a64c5a44708743217

Found existing installation: kfac-jax 0.0.5
Uninstalling kfac-jax-0.0.5:
  Successfully uninstalled kfac-jax-0.0.5
Collecting git+https://github.com/google-deepmind/kfac-jax@d3643a1ad85cd34ce9ec096a64c5a44708743217
  Cloning https://github.com/google-deepmind/kfac-jax (to revision d3643a1ad85cd34ce9ec096a64c5a44708743217) to /tmp/pip-req-build-l8inly0w
  Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/kfac-jax /tmp/pip-req-build-l8inly0w
  Running command git rev-parse -q --verify 'sha^d3643a1ad85cd34ce9ec096a64c5a44708743217'
  Running command git fetch -q https://github.com/google-deepmind/kfac-jax d3643a1ad85cd34ce9ec096a64c5a44708743217
  Running command git checkout -q d3643a1ad85cd34ce9ec096a64c5a44708743217
  Resolved https://github.com/google-deepmind/kfac-jax to commit d3643a1ad85cd34ce9ec096a64c5a44708743217
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: kfac-jax
  Building wheel for kfac-ja

'\nfrom colabtools import adhoc_import\nwith adhoc_import.Google3Head():\n  import kfac_jax\n#'

In [2]:
#@title Import Jax libraries

import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import random

import kfac_jax

import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors

import os
import shutil
import time


In [3]:

x__params = [125e6, 350e6, 760e6, 1.3e9, 2.7e9, 6.7e9, 13e9, 175e9]
x__compute = [2.60, 7.42, 1.58e1, 2.75e1, 5.52e1, 1.39e2, 2.68e2, 3.64e3]
x__upstream_loss = [2.75, 2.47, 2.3, 2.25, 2.1, 2.03, 1.97, 1.79]

#x = x__params

# This data is from "Table H.1: Scores for every task, setting and model that we investigate in this paper" of GPT-3 arxiv paper: https://arxiv.org/abs/2005.14165
columns = ["Name", "Metric", "Split", "Fine-tune SotA", "K", "Zero-Shot", "One-Shot", "Few-Shot", "(test, server)"]
data = [
["HellaSwag", "acc", "dev", 85.6, 20, [33.7, 43.6, 51.0, 54.7, 62.8, 67.4, 70.9, 78.9], [33.0, 42.9, 50.5, 53.5, 61.9, 66.5, 70.0, 78.1], [33.5, 43.1, 51.3, 54.9, 62.9, 67.3, 71.3, 79.3], "None"],
["LAMBADA", "acc", "test", 68.0, 15, [42.7, 54.3, 60.4, 63.6, 67.1, 70.3, 72.5, 76.2], [22.0, 47.1, 52.6, 58.3, 61.1, 65.4, 69.0, 72.5], [22.0, 40.4, 63.2, 57.0, 78.1, 79.1, 81.3, 86.4], "None"],
["LAMBADA", "ppl", "test", 8.63, 15, [18.6, 9.09, 6.53, 5.44, 4.60, 4.00, 3.56, 3.00], [165.0, 11.6, 8.29, 6.46, 5.53, 4.61, 4.06, 3.35], [165.0, 27.6, 6.63, 7.45, 2.89, 2.56, 2.56, 1.92], "None"],
["StoryCloze", "acc", "test", 91.8, 70, [63.3, 68.5, 72.4, 73.4, 77.2, 77.7, 79.5, 83.2], [62.3, 68.7, 72.3, 74.2, 77.3, 78.7, 79.7, 84.7], [62.3, 70.2, 73.9, 76.1, 80.2, 81.2, 83.0, 87.7], "None"],
["NQs", "acc", "test", 44.5, 64, [0.64, 1.75, 2.71, 4.40, 6.01, 5.79, 7.84, 14.6], [1.19, 3.07, 4.79, 5.43, 8.73, 9.78, 13.7, 23.0], [1.72, 4.46, 7.89, 9.72, 13.2, 17.0, 21.0, 29.9], "None"],
["TriviaQA", "acc", "dev", 68.0, 64, [4.15, 7.61, 14.0, 19.7, 31.3, 38.7, 41.8, 64.3], [4.19, 12.9, 20.5, 26.5, 35.9, 44.4, 51.3, 68.0], [6.96, 16.3, 26.5, 32.1, 42.3, 51.6, 57.5, 71.2], 71.2],
["WebQs", "acc", "test", 45.5, 64, [1.77, 3.20, 4.33, 4.63, 7.92, 7.73, 8.22, 14.4], [2.56, 6.20, 8.51, 9.15, 14.5, 15.1, 19.0, 25.3], [5.46, 12.6, 15.9, 19.6, 24.8, 27.7, 33.5, 41.5], "None"],
["Ro→En 16", "BLEU-mb", "test", 39.9, 64, [2.08, 2.71, 3.09, 3.15, 16.3, 8.34, 20.2, 19.9], [0.55, 15.4, 23.0, 26.3, 30.6, 33.2, 35.6, 38.6], [1.25, 20.7, 25.8, 29.2, 33.1, 34.8, 37.0, 39.5], "None"],
["Ro→En 16", "BLEU-sb", "test", "None", 64, [2.39, 3.08, 3.49, 3.56, 16.8, 8.75, 20.8, 20.9], [0.65, 15.9, 23.6, 26.8, 31.3, 34.2, 36.7, 40.0], [1.40, 21.3, 26.6, 30.1, 34.3, 36.2, 38.4, 41.3], "None"],
["En→Ro 16", "BLEU-mb", "test", 38.5, 64, [2.14, 2.65, 2.53, 2.50, 3.46, 4.24, 5.32, 14.1], [0.35, 3.30, 7.89, 8.72, 13.2, 15.1, 17.3, 20.6], [1.25, 5.90, 9.33, 10.7, 14.3, 16.3, 18.0, 21.0], "None"],
["En→Ro 16", "BLEU-sb", "test", "None", 64, [2.61, 3.11, 3.07, 3.09, 4.26, 5.31, 6.43, 18.0], [0.55, 3.90, 9.15, 10.3, 15.7, 18.2, 20.8, 24.9], [1.64, 7.40, 10.9, 12.9, 17.2, 19.6, 21.8, 25.8], "None"],
["Fr→En 14", "BLEU-mb", "test", 35.0, 64, [1.81, 2.53, 3.47, 3.13, 20.6, 15.1, 21.8, 21.2], [1.28, 15.9, 23.7, 26.3, 29.0, 30.5, 30.2, 33.7], [4.98, 25.5, 28.5, 31.1, 33.7, 34.9, 36.6, 39.2], "None"],
["Fr→En 14", "BLEU-sb", "test", "None", 64, [2.29, 2.99, 3.90, 3.60, 21.2, 15.5, 22.4, 21.9], [1.50, 16.3, 24.4, 27.0, 30.0, 31.6, 31.4, 35.6], [5.30, 26.2, 29.5, 32.2, 35.1, 36.4, 38.3, 41.4], "None"],
["En→Fr 14", "BLEU-mb", "test", 45.6, 64, [1.74, 2.16, 2.73, 2.15, 15.1, 8.82, 12.0, 25.2], [0.49, 8.00, 14.8, 15.9, 20.3, 23.3, 24.9, 28.3], [4.08, 14.5, 19.3, 21.5, 24.9, 27.3, 29.5, 32.6], "None"],
["En→Fr 14", "BLEU-sb", "test", 45.9, 64, [2.44, 2.75, 3.54, 2.82, 19.3, 11.4, 15.3, 31.3], [0.81, 10.0, 18.2, 19.3, 24.7, 28.3, 30.1, 34.1], [5.31, 18.0, 23.6, 26.1, 30.3, 33.3, 35.5, 39.9], "None"],
["De→En 16", "BLEU-mb", "test", 40.2, 64, [2.06, 2.87, 3.41, 3.63, 21.5, 17.3, 23.0, 27.2], [0.83, 16.2, 22.5, 24.7, 28.2, 30.7, 33.0, 30.4], [3.25, 22.7, 26.2, 29.2, 32.7, 34.8, 37.3, 40.6], "None"],
["De→En 16", "BLEU-sb", "test", "None", 64, [2.39, 3.27, 3.85, 4.04, 22.5, 18.2, 24.4, 28.6], [0.93, 17.1, 23.4, 25.8, 29.2, 31.9, 34.5, 32.1], [3.60, 23.8, 27.5, 30.5, 34.1, 36.5, 39.1, 43.0], "None"],
["En→De 16", "BLEU-mb", "test", 41.2, 64, [1.70, 2.27, 2.31, 2.43, 12.9, 8.66, 10.4, 24.6], [0.50, 7.00, 12.9, 13.1, 18.3, 20.9, 22.5, 26.2], [3.42, 12.3, 15.4, 17.1, 20.9, 23.0, 26.6, 29.7], "None"],
["En→De 16", "BLEU-sb", "test", 41.2, 64, [2.09, 2.65, 2.75, 2.92, 13.7, 9.36, 11.0, 25.3], [0.54, 7.40, 13.4, 13.4, 18.8, 21.7, 23.3, 27.3], [3.78, 12.9, 16.1, 17.7, 21.7, 24.1, 27.7, 30.9], "None"],
["Winograd", "acc", "test", 93.8, 7, [66.3, 72.9, 74.7, 76.9, 82.4, 85.7, 87.9, 88.3], [63.4, 68.5, 72.9, 76.9, 82.4, 84.6, 86.1, 89.7], [63.4, 67.4, 73.6, 76.9, 84.3, 85.4, 82.4, 88.6], "None"],
["Winogrande", "acc", "dev", 84.6, 50, [52.0, 52.1, 57.4, 58.7, 62.3, 64.5, 67.9, 70.2], [51.3, 53.0, 58.3, 59.1, 61.7, 65.8, 66.9, 73.2], [51.3, 52.6, 57.5, 59.1, 62.6, 67.4, 70.0, 77.7], "None"],
["PIQA", "acc", "dev", 77.1, 50, [64.6, 70.2, 72.9, 75.1, 75.6, 78.0, 78.5, 81.0], [64.3, 69.3, 71.8, 74.4, 74.3, 76.3, 77.8, 80.5], [64.3, 69.4, 72.0, 74.3, 75.4, 77.8, 79.9, 82.3], 82.8],
["ARC (Challenge)", "acc", "test", 78.5, 50, [26.6, 29.5, 31.8, 35.5, 38.0, 41.4, 43.7, 51.4], [25.5, 30.2, 31.6, 36.4, 38.4, 41.5, 43.1, 53.2], [25.5, 28.4, 32.3, 36.7, 39.5, 43.7, 44.8, 51.5], "None"],
["ARC (Easy)", "acc", "test", 92.0, 50, [43.6, 46.5, 53.0, 53.8, 58.2, 60.2, 63.8, 68.8], [42.7, 48.2, 54.6, 55.9, 60.3, 62.6, 66.8, 71.2], [42.7, 51.0, 58.1, 59.1, 62.1, 65.8, 69.1, 70.1], "None"],
["OpenBookQA", "acc", "test", 87.2, 100, [35.6, 43.2, 45.2, 46.8, 53.0, 50.4, 55.6, 57.6], [37.0, 39.8, 46.2, 46.4, 53.4, 53.0, 55.8, 58.8], [37.0, 43.6, 48.0, 50.6, 55.6, 55.2, 60.8, 65.4], "None"],
["Quac", "f1", "dev", 74.4, 5, [21.2, 26.8, 31.0, 30.1, 34.7, 36.1, 38.4, 41.5], [21.1, 26.9, 31.9, 32.3, 37.4, 39.0, 40.6, 43.4], [21.6, 27.6, 32.9, 34.2, 38.2, 39.9, 40.9, 44.3], "None"],
["RACE-h", "acc", "test", 90.0, 10, [35.2, 37.9, 40.1, 40.9, 42.4, 44.1, 44.6, 45.5], [34.3, 37.7, 40.0, 42.0, 43.8, 44.3, 44.6, 45.9], [34.3, 37.0, 40.4, 41.4, 42.3, 44.7, 45.1, 46.8], "None"],
["RACE-m", "acc", "test", 93.1, 10, [42.1, 47.2, 52.1, 52.3, 54.7, 54.4, 56.7, 58.4], [42.3, 47.3, 51.7, 55.2, 56.1, 54.7, 56.9, 57.4], [42.3, 47.0, 52.7, 53.0, 55.6, 55.4, 58.1, 58.1], "None"],
["SQuADv2 em", "em", "dev", 90.7, 16, [22.6, 32.8, 33.9, 43.1, 43.6, 45.4, 49.0, 52.6], [25.1, 37.5, 37.9, 47.9, 47.9, 51.1, 56.0, 60.1], [27.5, 40.5, 39.2, 53.5, 50.0, 56.6, 62.6, 64.9], "None"],
["SQuADv2 f1", "f1", "dev", 93.0, 16, [28.3, 40.2, 41.4, 50.3, 51.0, 52.7, 56.3, 59.5], [30.1, 43.6, 44.1, 54.0, 54.1, 57.1, 61.8, 65.4], [32.1, 45.5, 44.9, 58.7, 55.9, 62.1, 67.7, 69.8], "None"],
["CoQA", "f1", "dev", 90.7, 5, [34.5, 55.0, 61.8, 65.3, 71.1, 72.8, 76.3, 81.5], [30.6, 52.1, 61.6, 66.1, 71.8, 75.1, 77.9, 84.0], [31.1, 52.0, 62.7, 66.8, 73.2, 77.3, 79.9, 85.0], "None"],
["DROP", "f1", "dev", 89.1, 20, [9.40, 13.6, 14.4, 16.4, 19.7, 17.0, 24.0, 23.6], [11.7, 18.1, 20.9, 23.0, 26.4, 27.3, 29.2, 34.3], [12.9, 18.7, 24.0, 25.6, 29.7, 29.7, 32.3, 36.5], "None"],
["BoolQ", "acc", "dev", 91.0, 32, [49.7, 60.3, 58.9, 62.4, 67.1, 65.4, 66.2, 60.5], [52.6, 61.7, 60.4, 63.7, 68.4, 68.7, 69.0, 76.7], [43.1, 60.6, 62.0, 64.1, 70.3, 70.0, 70.2, 77.5], 76.4],
["CB", "acc", "dev", 96.9, 32, [0.00, 32.1, 8.93, 19.6, 19.6, 28.6, 19.6, 46.4], [55.4, 53.6, 53.6, 48.2, 57.1, 33.9, 55.4, 64.3], [42.9, 58.9, 53.6, 69.6, 67.9, 60.7, 66.1, 82.1], 75.6],
["CB", "f1", "dev", 93.9, 32, [0.00, 29.3, 11.4, 17.4, 22.4, 25.1, 20.3, 42.8], [60.1, 39.8, 45.6, 37.5, 45.7, 28.5, 44.6, 52.5], [26.1, 40.4, 32.6, 48.3, 45.7, 44.6, 46.0, 57.2], 52.0],
["Copa", "acc", "dev", 94.8, 32, [66.0, 68.0, 73.0, 77.0, 76.0, 80.0, 84.0, 91.0], [62.0, 64.0, 66.0, 74.0, 76.0, 82.0, 86.0, 87.0], [67.0, 64.0, 72.0, 77.0, 83.0, 83.0, 86.0, 92.0], 92.0],
["RTE", "acc", "dev", 92.5, 32, [47.7, 49.8, 48.4, 56.0, 46.6, 55.2, 62.8, 63.5], [53.1, 47.3, 49.5, 49.5, 54.9, 54.9, 56.3, 70.4], [52.3, 48.4, 46.9, 50.9, 56.3, 49.5, 60.6, 72.9], 69.0],
["WiC", "acc", "dev", 76.1, 32, [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00], [50.0, 50.3, 50.3, 49.2, 49.4, 50.3, 50.0, 48.6], [49.8, 55.0, 53.0, 53.0, 51.6, 53.1, 51.1, 55.3], 49.4],
["WSC", "acc", "dev", 93.8, 32, [59.6, 56.7, 65.4, 61.5, 66.3, 60.6, 64.4, 65.4], [58.7, 58.7, 60.6, 62.5, 66.3, 60.6, 66.3, 69.2], [58.7, 60.6, 54.8, 49.0, 62.5, 67.3, 75.0, 75.0], 80.1],
["MultiRC", "acc", "dev", 62.3, 32, [4.72, 9.65, 12.3, 13.6, 14.3, 18.4, 24.2, 27.6], [4.72, 9.65, 12.3, 13.6, 14.3, 18.4, 24.2, 27.6], [6.09, 11.8, 16.8, 20.8, 24.7, 23.8, 25.0, 32.5], 30.5],
["MultiRC", "f1a", "dev", 88.2, 32, [57.0, 59.7, 60.4, 59.9, 60.0, 64.5, 71.4, 72.9], [57.0, 59.7, 60.4, 59.9, 60.0, 64.5, 71.4, 72.9], [45.0, 55.9, 64.2, 65.4, 69.5, 66.4, 69.3, 74.8], 75.4],
["ReCoRD", "acc", "dev", 92.5, 32, [70.8, 78.5, 82.1, 84.1, 86.2, 88.6, 89.0, 90.2], [69.8, 77.0, 80.7, 83.0, 85.9, 88.0, 88.8, 90.2], [69.8, 77.2, 81.3, 83.1, 86.6, 87.9, 88.9, 89.0], 90.2],
["ReCoRD", "f1", "dev", 93.3, 32, [71.9, 79.2, 82.8, 85.2, 87.3, 89.5, 90.4, 91.0], [70.7, 77.8, 81.6, 83.9, 86.8, 88.8, 89.7, 91.2], [70.7, 77.9, 82.1, 84.0, 87.5, 88.8, 89.8, 90.1], 91.1],
["SuperGLUE", "average", "dev", 89.0, "None", [40.6, 47.4, 46.8, 49.6, 50.1, 52.3, 54.4, 58.2], [54.4, 55.1, 56.7, 57.8, 61.2, 59.7, 64.3, 68.9], [50.2, 56.2, 56.8, 60.0, 64.3, 63.6, 66.9, 73.2], 71.8],
["ANLI R1", "acc", "test", 73.8, 50, [33.4, 34.2, 33.4, 33.4, 34.2, 32.3, 33.2, 34.6], [32.1, 31.6, 31.9, 34.6, 30.6, 31.6, 32.7, 32.0], [32.1, 32.5, 30.9, 32.5, 33.5, 33.1, 33.3, 36.8], "None"],
["ANLI R2", "acc", "test", 50.7, 50, [33.2, 31.9, 33.3, 33.3, 33.8, 33.5, 33.5, 35.4], [35.7, 33.7, 33.2, 32.7, 32.7, 33.9, 33.9, 33.9], [35.7, 33.8, 32.1, 31.4, 32.6, 33.3, 32.6, 34.0], "None"],
["ANLI R3", "acc", "test", 48.3, 50, [33.6, 34.0, 33.8, 33.4, 35.3, 34.8, 34.4, 34.5], [35.0, 32.6, 33.0, 33.9, 34.1, 33.1, 32.5, 35.1], [35.0, 34.4, 35.1, 36.0, 32.7, 33.9, 34.5, 40.2], "None"],
["2D+", "acc", "n/a", "None", 50, [0.70, 0.65, 0.70, 0.85, 1.10, 2.54, 15.4, 76.9], [2.00, 0.55, 3.15, 4.00, 12.1, 19.6, 73.0, 99.6], [2.00, 4.10, 3.50, 4.50, 8.90, 11.9, 55.5, 100.0], "None"],
["2D-", "acc", "n/a", "None", 50, [1.25, 1.25, 1.25, 1.25, 1.60, 7.60, 12.6, 58.0], [1.15, 0.95, 1.45, 1.95, 3.85, 11.5, 44.6, 86.4], [1.15, 1.45, 2.25, 2.70, 7.35, 13.6, 52.4, 98.9], "None"],
["3D+", "acc", "n/a", "None", 50, [0.10, 0.10, 0.05, 0.10, 0.10, 0.25, 1.40, 34.2], [0.15, 0.00, 0.10, 0.30, 0.45, 0.95, 15.4, 65.5], [0.15, 0.45, 0.30, 0.55, 0.75, 0.90, 8.40, 80.4], "None"],
["3D-", "acc", "n/a", "None", 50, [0.05, 0.05, 0.05, 0.05, 0.05, 0.45, 1.35, 48.3], [0.05, 0.15, 0.25, 0.30, 0.55, 1.60, 6.15, 78.7], [0.05, 0.10, 0.15, 0.35, 0.65, 1.05, 9.20, 94.2], "None"],
["4D+", "acc", "n/a", "None", 50, [0.05, 0.05, 0.00, 0.00, 0.05, 0.05, 0.15, 4.00], [0.00, 0.00, 0.10, 0.00, 0.00, 0.10, 0.80, 14.0], [0.00, 0.05, 0.05, 0.00, 0.15, 0.15, 0.40, 25.5], "None"],
["4D-", "acc", "n/a", "None", 50, [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.10, 7.50], [0.00, 0.00, 0.00, 0.00, 0.05, 0.00, 0.50, 14.0], [0.00, 0.05, 0.00, 0.00, 0.10, 0.05, 0.40, 26.8], "None"],
["5D+", "acc", "n/a", "None", 50, [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.65], [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 3.45], [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 9.30], "None"],
["5D-", "acc", "n/a", "None", 50, [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.80], [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.05, 3.75], [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 9.90], "None"],
["2Dx", "acc", "n/a", "None", 50, [2.20, 2.25, 2.65, 2.10, 2.55, 5.80, 6.15, 19.8], [1.35, 2.35, 3.35, 2.35, 4.75, 9.15, 11.0, 27.4], [1.35, 2.90, 2.70, 2.85, 4.25, 6.10, 7.05, 29.2], "None"],
["1DC", "acc", "n/a", "None", 50, [1.25, 2.95, 2.75, 0.05, 0.30, 2.35, 0.75, 9.75], [1.90, 2.80, 2.85, 3.65, 6.45, 9.15, 8.20, 14.3], [1.70, 2.15, 3.90, 5.75, 6.20, 7.60, 9.95, 21.3], "None"],
["Cycled Letters", "acc", "n/a", "None", 100, [0.62, 0.71, 2.85, 0.00, 0.63, 1.35, 2.58, 3.66], [1.67, 4.36, 5.68, 6.46, 6.25, 9.41, 15.1, 21.7], [4.63, 9.27, 10.7, 14.5, 16.7, 21.9, 27.7, 37.9], "None"],
["Anagrams 1", "acc", "n/a", "None", 100, [0.10, 0.14, 0.40, 0.00, 0.27, 0.69, 1.16, 2.28], [0.21, 0.61, 1.12, 1.27, 1.60, 2.72, 3.72, 8.62], [0.50, 1.27, 2.13, 3.05, 3.81, 5.49, 8.38, 15.1], "None"],
["Anagrams 2", "acc", "n/a", "None", 100, [0.81, 1.21, 2.69, 0.01, 1.71, 3.75, 4.53, 8.91], [1.19, 2.62, 4.70, 4.77, 6.97, 10.2, 14.6, 25.9], [1.94, 4.80, 7.59, 9.87, 12.6, 18.9, 25.6, 39.7], "None"],
["Symbol Insertion", "acc", "n/a", "None", 100, [0.00, 0.00, 0.10, 0.00, 0.05, 0.42, 0.89, 8.26], [0.03, 0.05, 0.57, 1.18, 1.67, 3.46, 6.62, 45.4], [0.11, 0.28, 2.19, 4.18, 6.61, 11.0, 27.3, 67.2], "None"],
["Reversed Words", "acc", "n/a", "None", 100, [0.00, 0.01, 0.01, 0.01, 0.02, 0.03, 0.03, 0.09], [0.02, 0.01, 0.01, 0.00, 0.05, 0.07, 0.11, 0.48], [0.00, 0.05, 0.00, 0.17, 0.24, 0.30, 0.42, 0.44], "None"],
["SAT Analogies", "acc", "n/a", "None", 20, [35.6, 39.0, 45.2, 44.1, 50.0, 49.2, 52.7, 53.7], [30.5, 41.2, 43.1, 46.5, 55.1, 54.3, 53.5, 59.1], [30.5, 40.4, 42.8, 40.6, 48.4, 51.9, 53.5, 65.2], "None"],
]

news = [76, 61, 68, 62, 62, 60, 55, 52]

df = pd.DataFrame(data, columns = columns)

print(df["Metric"].unique())

['acc' 'ppl' 'BLEU-mb' 'BLEU-sb' 'f1' 'em' 'f1a' 'average']


In [4]:
#@title The functional forms

def count_params(params):
  return sum([jnp.prod(jnp.arratarget(x.shape))
              for x in jax.tree_util.tree_leaves(params)])

def bnsl_dim(x, n_breaks, name):
  #out = nn.Dense(1)(x) + nn.Dense(1, use_bias=False)(nn.softplus(nn.Dense(n_breaks)(x)))
  out = nn.Dense(1)(x) + nn.Dense(1, use_bias=False)(nn.softplus(type(name, (nn.Dense,), {})(n_breaks)(x)))
  return out

class FunctionalForm(nn.Module):
  """BNSL operating in log-log space (i.e. inputs (and children of inputs) are never explicitly raised to a power)."""
  n_breaks: int

  @nn.compact
  def __call__(self, x, batch_train):
    x = jnp.log(x)

    offset = 1e-16

    x_train, y_train = batch_train
    x_train, y_train = jnp.log(x_train), jnp.log(y_train + offset)

    x_mean = jnp.mean(x_train, axis=0, keepdims=True)
    x_std = jnp.std(x_train, axis=0, keepdims=True)
    x_std = jnp.where(x_std == 0., 1., x_std)
    x = (x - x_mean)/x_std

    y_mean = jnp.mean(y_train)
    u = jnp.ones_like(x[:,0]) * y_mean

    eps_2 = jnp.log(1e-20)

    eps_array = jnp.ones_like(x) * eps_2

    x0 = jnp.zeros_like(x)

    a = type("irr_ent", (nn.Dense,), {})(1, use_bias=True)(x0)

    _mbnsl = jnp.concatenate([
          bnsl_dim(x, self.n_breaks, 'x'),
          a,
          eps_array,
      ], axis=-1)

    mbnsl = jax.scipy.special.logsumexp(_mbnsl, axis=-1, keepdims=False)

    u = mbnsl + u

    return jnp.exp(mbnsl)

def functional_form__sle(params, _model, batch, batch_train):
    x, y = batch
    y_pred = _model.apply(params, x, batch_train)
    pred, target = jnp.log(y_pred), jnp.log(y)
    return jnp.square(pred - target)

def functional_form__standard_le(params, _model, batch, batch_train):
    x, y = batch
    y_pred = _model.apply(params, x, batch_train)
    pred, target = jnp.log(y_pred), jnp.log(y)

    error = (pred - target) ** 2
    err_mu = jnp.mean(error)
    std_err = jnp.sqrt(err_mu + jnp.std(error) / (len(pred)**0.5)) - jnp.sqrt(err_mu)

    return std_err


In [None]:
#@title Fits the functional form

try:
  shutil.rmtree("plots")
except:
  pass
os.mkdir("plots")

data_dict = {
    "name": [],
    "rmsle_train": [],
    "rmsle_extrap": [],
    "rmsle_all": [],
    "standard_le_train": [],
    "standard_le_all": [],
}

for index, row in df.iterrows():
    n = str(row["Name"])
    m = str(row["Metric"])
    for t in ["Zero-Shot", "One-Shot", "Few-Shot"]:
        x = jnp.array(x__params)
        y = jnp.array(row[t])

        # ['acc' 'ppl' 'BLEU-mb' 'BLEU-sb' 'f1' 'em' 'f1a' 'average']

        if row["Metric"] != "ppl":
            y = (100.0 - y)/100
        else:
            y = jnp.log(y)

        if row["Metric"] == 'acc':
            metric = "Error Rate"
        elif row["Metric"] == 'ppl':
            metric = "Cross-Entropy"
        elif row["Metric"] == 'BLEU-mb':
            metric = "1 - BLEU-mb normalized"
        elif row["Metric"] == 'BLEU-sb':
            metric = "1 - BLEU-sb normalized"
        elif row["Metric"] == 'f1':
            metric = "1 - f1 normalized"
        elif row["Metric"] == 'em':
            metric = "1 - em normalized"
        elif row["Metric"] == 'f1a':
            metric = "1 - f1a normalized"
        elif row["Metric"] == 'average':
            metric = "average"

        title = str(n)+"  "+str(t)
        plot_name = str(n)+"  "+str(metric)+"  "+str(t)
        plot_name = plot_name.replace(" ", "_")

        x_train = []
        x_test = []

        y_train = []
        y_test = []

        #x_split = 750
        #x_split = 550

        x_split = 100e9

        for _x, _y in zip(x, y):
          if _x < x_split:
            x_train.append(_x)
            y_train.append(_y)
          else:
            x_test.append(_x)
            y_test.append(_y)

        x_train = jnp.expand_dims(jnp.array(x_train), -1)
        x_test = jnp.expand_dims(jnp.array(x_test), -1)

        y_train = jnp.array(y_train)
        y_test = jnp.array(y_test)

        x_all = jnp.expand_dims(jnp.array(x), -1)
        y_all = y

        n_breaks = 1

        model = FunctionalForm(n_breaks=n_breaks)
        params = model.init(random.PRNGKey(21), x_train, (x_train, y_train))

        def loss(params, batch, offset=1e-16):

          x, y = batch
          y_pred = model.apply(params, x, batch)

          pred, target = jnp.log(offset + y_pred), jnp.log(offset + y)

          y_mean = jnp.mean(target)
          y_std = jnp.std(target)
          y_std = jnp.where(y_std == 0., 1., y_std)

          target = target / y_std
          pred = pred / y_std

          kfac_jax.register_squared_error_loss(pred, target)

          return jnp.mean(jnp.square(pred - target))
        loss_and_grad = jax.value_and_grad(loss, argnums=0)

        # using a second order optimizer makes training way faster
        optimizer = kfac_jax.Optimizer(
          value_and_grad_func=loss_and_grad,
          l2_reg=0.0, # 1e-4,
          value_func_has_aux=False,
          value_func_has_state=False,
          value_func_has_rng=False,
          use_adaptive_learning_rate=True,
          use_adaptive_momentum=True,
          use_adaptive_damping=True,
          initial_damping=1.0,
          multi_device=False,
        )

        rng = random.PRNGKey(0)
        rng, init_rng = random.split(rng)
        opt_state = optimizer.init(params, init_rng, (x_train, y_train))

        start_time = time.time()

        # Fits in <1min on 1 GPU
        best_params, best_loss = params, 1e6

        for j in range(int(1e1)):
          params = model.init(random.PRNGKey(j), x_train, (x_train, y_train))
          rng = random.PRNGKey(j)
          rng, init_rng = random.split(rng)
          opt_state = optimizer.init(params, init_rng, (x_train, y_train))
          for i in range(int(1e3)):
            if i == 0:
              #print("j =", j)
              pass
            rng, step_rng = jax.random.split(rng)
            params, opt_state, stats = optimizer.step(
                params, opt_state, step_rng, batch=(x_train, y_train), global_step_int=i)
            if jnp.isnan(stats['loss']):
              break
            if (stats['loss'] < best_loss) and (stats['loss'] > 1e-12):
              best_params = jax.tree_map(lambda x: jnp.array(x), params)
              best_loss = stats['loss']
              if i % 25 == 0:
                sle_train = functional_form__sle(best_params, model, (x_train, y_train), (x_train, y_train))
                sle_test = functional_form__sle(best_params, model, (x_test, y_test), (x_train, y_train))
                #print(f"{jnp.sqrt(jnp.mean(sle_train)):.5e}", f"{jnp.sqrt(jnp.mean(sle_test)):.5e}", f"{best_loss:.5e}", "   ", f"{(time.time() - start_time):.5e}", "", j, i)

        bc = best_params
        print()
        print(title)
        print(bc)
        print("time:", time.time() - start_time)

        sle_all = functional_form__sle(bc, model, (x_all, y_all), (x_train, y_train))
        sle_train = functional_form__sle(bc, model, (x_train, y_train), (x_train, y_train))
        sle_test = functional_form__sle(bc, model, (x_test, y_test), (x_train, y_train))

        standard_le_all = functional_form__standard_le(bc, model, (x_all, y_all), (x_train, y_train))
        standard_le_train = functional_form__standard_le(bc, model, (x_train, y_train), (x_train, y_train))

        print("rmsle all:   ", jnp.sqrt(jnp.mean(sle_all)))
        print("rmsle train: ", jnp.sqrt(jnp.mean(sle_train)))
        print("rmsle test:  ", jnp.sqrt(jnp.mean(sle_test)))

        print("standard_le_all:   ", standard_le_all)
        print("standard_le_train: ", standard_le_train)

        data_dict["name"].append(plot_name)
        data_dict["rmsle_train"].append(jnp.sqrt(jnp.mean(sle_train)))
        data_dict["rmsle_extrap"].append(jnp.sqrt(jnp.mean(sle_test)))
        data_dict["rmsle_all"].append(jnp.sqrt(jnp.mean(sle_all)))
        data_dict["standard_le_train"].append(standard_le_train)
        data_dict["standard_le_all"].append(standard_le_all)

        points = 1024
        x_tile = jnp.expand_dims(jnp.logspace(-1, 15, points), -1)

        pred = model.apply(bc, x_tile, (x_train, y_train))

        plt.plot(x_train, y_train, 'o', color='black', markersize=9.1)
        plt.plot(x_test, y_test, 'o', color=[0.0, 0.835, 0.0], markersize=9.37, markerfacecolor=[0.0, 1.0, 0.0])
        plt.plot(x_tile, pred, color=[1.0, 0.1275, 0.1275], linewidth=2.15, alpha=1.0)

        plt.xlim(x.min()*.865,x.max()*1.175)
        plt.ylim(y.min()*.9,y.max()*1.05)

        plt.title(title, fontsize=16)
        plt.xlabel("Number of Model Parameters", fontsize=12)
        plt.ylabel(metric, fontsize=12)
        plt.tick_params(axis='both', labelsize=12)

        #"""
        plt.xscale('log')
        plt.yscale('log')
        #"""

        plt.savefig('plots/'+plot_name+'.png', bbox_inches='tight')

        plt.show()

df = pd.DataFrame(data_dict)
df.to_csv("error.csv")

!zip -r plots.zip plots/

In [None]:
#df = pd.DataFrame(data_dict)
#df.to_csv("error.csv")

In [None]:
#!zip -r plots.zip plots/