In [23]:
!pip install requests # for retrieiving data

Collecting requests
  Using cached requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting charset-normalizer<4,>=2 (from requests)
  Using cached charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl.metadata (33 kB)
Collecting idna<4,>=2.5 (from requests)
  Downloading idna-3.8-py3-none-any.whl.metadata (9.9 kB)
Collecting urllib3<3,>=1.21.1 (from requests)
  Using cached urllib3-2.2.2-py3-none-any.whl.metadata (6.4 kB)
Collecting certifi>=2017.4.17 (from requests)
  Downloading certifi-2024.8.30-py3-none-any.whl.metadata (2.2 kB)
Using cached requests-2.32.3-py3-none-any.whl (64 kB)
Downloading certifi-2024.8.30-py3-none-any.whl (167 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m167.3/167.3 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hUsing cached charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl (122 kB)
Downloading idna-3.8-py3-none-any.whl (66 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.9/66.9 kB[0m [31m2

In [1]:
from tandv.track.common import read_pickle, TrainingStats
from tandv.viz import exp_hist,scalar_global_heatmap,scalar_line,interactive
import pickle,requests

## Loading Example Data

In [2]:
# Download sample data from s3
lf_url = 'https://graphcore-research-public.s3.eu-west-1.amazonaws.com/tandv/sampleLogFrame.pkl'
ts_url = 'https://graphcore-research-public.s3.eu-west-1.amazonaws.com/tandv/sampleTrainingStats.pkl'
lf = read_pickle(lf_url)
res = requests.get(ts_url)
res.raise_for_status()
tstats = pickle.loads(res.content)


## Interactive Exponent Histogram
Pass `exp_hist`function into `interactive` along with valid arguements to generate an initial plot.

You can then use the various widgets to query along different dimensions (step, layer, tensor type, etc..) and subsequently generate visualizations.

In [3]:
interactive(exp_hist,
    df=lf,
    layer='output',
    tt = 'Activation',
    step = 0,
    dtype_annotation = 'float8_e4m3fn'
)

Output(layout=Layout(width='100%'), outputs=({'output_type': 'display_data', 'data': {'text/plain': "HBox(chil…

Output(layout=Layout(overflow='scroll hidden', width='1500px'), outputs=({'output_type': 'display_data', 'data…

## Interactive Scalar Line 
Pass `scalar_line`function into `interactive` along with valid arguements to generate an initial plot.

You can then use the various widgets to query along different dimensions (scalar_metric, layer, tensor type, etc..) and subsequently generate visualizations.

In [30]:
interactive(
    scalar_line,
    df=lf,
    layer='output',
    tt='Gradient',
    scalar_metric='std',
)

Output(layout=Layout(width='100%'), outputs=({'output_type': 'display_data', 'data': {'text/plain': "HBox(chil…

Output(layout=Layout(overflow='scroll hidden', width='1500px'), outputs=({'output_type': 'display_data', 'data…

## Interactive Global Heatmap 
Pass `scalar_global_heatmap`function into `interactive` along with valid arguements to generate an initial plot.

You can then use the various widgets to query along different dimensions (scalar_metric, layer, tensor type, etc..) and subsequently generate visualizations.

You can also click on the heatmap patches and it will generate an `exp_hist` vizualization along those query dimensions (step, tensor_type and layer)

In [31]:
interactive(
    scalar_global_heatmap,
    df=lf,
    tt='Optimiser_State.exp_avg',
    scalar_metric='rm2',
    inc=50,
)

Output(layout=Layout(width='100%'), outputs=({'output_type': 'display_data', 'data': {'text/plain': "HBox(chil…

Output(layout=Layout(overflow='scroll hidden', width='1500px'), outputs=({'output_type': 'display_data', 'data…

# Cross Referencing with Training Stats

Query exponent histograms from the loss curve

In [32]:
interactive(exp_hist,
    train_stats=tstats,
    df=lf,
    layer=[n for n in lf.metadata.name.unique().tolist() if 'layers.5.feed_forward.w' in n],
    tt = 'Activation',
    step = 0,
    dtype_annotation = 'float8_e4m3fn',
    col_wrap = 3
)

Output(layout=Layout(width='100%'), outputs=({'output_type': 'display_data', 'data': {'text/plain': "HBox(chil…

Output(layout=Layout(overflow='scroll hidden', width='1500px'), outputs=({'output_type': 'display_data', 'data…

  return f"$2^{{{int(np.log2(value))}}}$"


Cross referencing loss curve(s) with scalar statistics of various tensors in the network

In [9]:
interactive(
    scalar_line,
    train_stats=tstats,
    mouse_sensitivity=20000,
    df=lf,
    layer=[n for n in lf.metadata.name.unique().tolist() if 'layers.5.feed_forward.w' in n],
    tt='Gradient',
    scalar_metric=['std','rm2'],
    col_wrap = 3
)

Output(layout=Layout(width='100%'), outputs=({'output_type': 'display_data', 'data': {'text/plain': "HBox(chil…

Output(layout=Layout(overflow='scroll hidden', width='1500px'), outputs=({'output_type': 'display_data', 'data…