# Megascale Network Performance Analysis

Provided an Xprof profile, this notebook will generate several graphs and metrics that can help in identifying and understanding potential performance issues.

**Use:** Upload one or more xprof profiles to use as input to this Colab.

## 1) Install dependencies and generate required data structures
The cells below must be run before jumping to other sections.

In [None]:
# @title Install necessary packages

import shutil

if shutil.which("pip") is None:
  print("pip is not installed. Skipping package installation.")
else:
  # The nightly release of JAX is required until the official release contains the necessary ProfileData API.
  !pip install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  !pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
  !pip install git+https://github.com/jax-ml/jax
  print("Finished package installation.")

In [None]:
# @title Define helper functions

from IPython import display
import jax
import matplotlib.pyplot as plt
import pandas as pd


def bytes_to_human(n, precision=2):
  """Convert bytes to a human-readable string (e.g., B, KiB, MiB, GiB)."""
  if n < 0:
    return "-NaN"

  # Define the units and the base for the conversion
  units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"]
  base = 1024

  # Special case for bytes
  if n < base:
    return f"{n} {units[0]}"

  # Find the correct unit and divide by the base
  for i, unit in enumerate(units):
    if n < base ** (i + 1):
      return f"{n / (base**i):.{precision}f} {unit}"

  # Handle extremely large numbers
  return f"{n / base**(len(units)-1):.{precision}f} {units[-1]}"

### Upload xprof profile(s)

Using the "Files" button in the vertical menu on the left hand side. Upload one or more xprof profiles.

In [None]:
# @title Generate a DataFrame for each profile.

profile_paths = '/content/example_profile_1.pb'  # @param {'type': 'string', isTemplate: true}
labels = 'profile1'  # @param {'type':'string', isTemplate: true}

profile_paths = profile_paths.split(',')
labels = labels.split(',')
assert len(profile_paths) == len(labels)

dfs = []
for i, xplane_path in enumerate(profile_paths):
  plane = jax.profiler.ProfileData.from_file(xplane_path).find_plane_with_name(
      '/host:CPU'
  )
  rows = []
  for line in plane.lines:
    if not line.name.startswith('MegascaleEM2_Worker'):
      continue
    for event in line.events:
      if event.name == 'MegaScale: Communication Transport Receive':
        stats = dict(event.stats)
        source_id = f'{stats["dcn_source_slice_id"]}-{stats["dcn_source_per_slice_device_id"]}'
        destination_id = f'{stats["dcn_destination_slice_id"]}-{stats["dcn_destination_per_slice_device_id"]}'
        latency_us = stats['duration_us']
        start_ns = event.start_ns
        end_ns = start_ns + latency_us * 1000
        timestamp = pd.to_datetime(end_ns, unit='ns')

        rows.append([
            latency_us,
            stats['payload_size_bytes'],
            source_id,
            destination_id,
            start_ns,
            end_ns,
            timestamp,
        ])
  df = pd.DataFrame(
      rows,
      columns=[
          'latency_us',
          'bytes',
          'src',
          'dst',
          'start_ns',
          'end_ns',
          'timestamp',
      ],
  )
  df.set_index('timestamp', inplace=True)
  df.attrs['label'] = labels[i]
  dfs.append(df)

## 2) Examine transfer latencies

Check for outliers or persistently high latency. Small transfers should have lower latency than large ones.

Possible sources of high latency are network slowdowns or individual host slowness.

In [None]:
# @title Network transfer latency

for _, df in enumerate(dfs):
  plt.figure(figsize=(10, 6))
  series_names = df['bytes'].unique()
  # Sort series_names numerically
  series_names_sorted = sorted(series_names)
  for i, series_name in enumerate(series_names_sorted):
    series_data = df[df['bytes'] == series_name]
    plt.scatter(
        series_data.index,
        series_data['latency_us'] / 1000,
        label=f'{bytes_to_human(series_name)}',
        s=20,
    )

  plt.title(f'Network transfer latency over time for {df.attrs.get("label")}')
  plt.xlabel('Time')
  plt.ylabel('Latency (ms)')
  plt.legend(title='Transfer Size')
  plt.grid(True)
  plt.tight_layout()
  plt.show()

## 3) Examine the distribution of transfer sizes.

Sanity check the transfer size distribution. Generally we want fewer larger transfers, not a high number of small ones.

In [None]:
# @title Distribution of transfer sizes

for _, df in enumerate(dfs):
  grouped = df.groupby('bytes')
  count_by_bytes = grouped.size()

data = {
    'Buffer size': [],
    'Count': [],
    'Percentage': [],
}
for key, value in count_by_bytes.items():
  percentage = (value / count_by_bytes.sum()) * 100
  data['Buffer size'].append(bytes_to_human(key))
  data['Count'].append(value)
  data['Percentage'].append(f'{percentage:.2f}')

display.display(pd.DataFrame(data))

## 4) Examine inflight transfer count over time

This indicates how many pending collectives there are at a given point in time throughout the profiling time window. If this chart is spiky or remains consistently high then the program may not be well optimized for compute/communication overlap.

In [None]:
# @title Inflight transfers by size.

agg_window_ms = 100

for i, df in enumerate(dfs):
  grouped_by_size = df.groupby('bytes')

  legend_labels = []
  resampled_series_list = []

  for size, group_df in grouped_by_size:
    # Create time series for the start and end of each transfer for this size.
    # At the start of a transfer, the inflight count increases by 1.
    # At the end of a transfer, the inflight count decreases by 1.
    inflight_changes = pd.concat([
        pd.Series(1, index=pd.to_datetime(group_df['start_ns'], unit='ns')),
        pd.Series(-1, index=pd.to_datetime(group_df['end_ns'], unit='ns')),
    ]).sort_index()

    # Handle duplicate timestamps.
    inflight_changes = (
        inflight_changes.groupby(inflight_changes.index).sum().sort_index()
    )

    # Calculate the cumulative sum to get the number of inflight transfers over time for this size.
    cumulative_inflight = inflight_changes.cumsum()

    # Resample the cumulative inflight count to the desired time window and take the max.
    inflight_count_per_window = (
        cumulative_inflight.resample(f'{agg_window_ms}ms').max().fillna(0)
    )

    # Add the series.
    resampled_series_list.append(inflight_count_per_window)
    legend_labels.append(f'{bytes_to_human(int(size))}')

  # Concatenate all resampled series into a single DataFrame and reindex to a common time index.
  combined_df = pd.concat(resampled_series_list, axis=1).fillna(0)
  time_labels_datetime = combined_df.index
  inflight_data_for_stacking = (
      combined_df.values.T
  )  # Transpose to get data in the correct shape for stackplot.

  # Plot the results as a stacked area chart.
  plt.figure(figsize=(12, 5))
  plt.stackplot(
      time_labels_datetime, inflight_data_for_stacking, labels=legend_labels
  )
  plt.title(f'Inflight Transfers By Size For {labels[i]}')
  plt.xlabel('Time')
  plt.ylabel('Inflight transfers')
  plt.legend(title='Transfer Size (bytes)')
  plt.grid(True)
  plt.tight_layout()
  plt.show()

## 5) Examine network throughput

Optionally, you may enter the per-task maximum bandwidth for the platform to see it in the graph. For example, if you're using TPU v6e with two tasks per machine then the per-task bandwidth is 4 NICs per machine * 200 Gbps per NIC / 2 tasks per machine = 400 Gbps.

If the throughput is consistently near the platform's theoretical max then this indicates that this workload is network-bound.

In [None]:
# @title Network transfer throughput over time.

max_bandwidth_gbps = 100  # @param {'type':'number', isTemplate: true}

agg_window_ms = 1  # @param {'type':'number', isTemplate: true}

for _, df in enumerate(dfs):
  df['average_bandwidth_gbps'] = (
      (df['bytes'] * 8) / (df['latency_us'] * 1e-6) / 1e9
  )

  # Create a time series for the start and end of each transfer.
  # At the start of a transfer, bandwidth increases by average_bandwidth_gbps.
  # At the end of a transfer, bandwidth decreases by average_bandwidth_gbps.
  bandwidth_changes = pd.concat([
      pd.Series(
          df['average_bandwidth_gbps'].values,
          index=pd.to_datetime(df['start_ns'], unit='ns'),
      ),
      pd.Series(
          -df['average_bandwidth_gbps'].values,
          index=pd.to_datetime(df['end_ns'], unit='ns'),
      ),
  ]).sort_index()

  # Handle duplicate timestamps.
  bandwidth_changes = (
      bandwidth_changes.groupby(bandwidth_changes.index).sum().sort_index()
  )

  # Calculate the cumulative bandwidth over time.
  cumulative_bandwidth = bandwidth_changes.cumsum()

  # To get a correct time-weighted average, we first create a dense time series
  # by upsampling and forward-filling, then we downsample and take the mean.
  # Note: The upsampling frequency should be high enough to capture the data's dynamics.
  dense_bandwidth = cumulative_bandwidth.resample(
      f'{agg_window_ms/100}ms'
  ).ffill()
  average_bandwidth_per_window = (
      dense_bandwidth.resample(f'{agg_window_ms}ms').mean().fillna(0)
  )

  # Plot the results
  plt.figure(figsize=(12, 5))
  plt.plot(
      average_bandwidth_per_window.index, average_bandwidth_per_window.values
  )

  if max_bandwidth_gbps > 0:
    plt.axhline(
        y=max_bandwidth_gbps,
        color='r',
        linestyle='--',
        label=f'Platform Max Bandwidth ({max_bandwidth_gbps} Gbps)',
    )
  plt.title(f'Network Throughput (Gbps) for {df.attrs.get("label")}')
  plt.xlabel('Time')
  plt.ylabel('Throughput (Gbps)')
  plt.grid(True)
  plt.tight_layout()
  plt.show()

## 6) Examine {source, destination} pairs for slow transfers

Returns the source and destination global device ID ({slice}-{per_slice_device_id}) pairs for transfers that occur after `min_start_time_ns` and that take longer than `min_latency_us`.

This data can help identify bad network cards, overloaded network switches, sub-optimal sharding, etc.

In [None]:
# @title Long tail host pairs

min_start_time_ns = 0  # @param
min_latency_us = 1000  # @param

for i, df in enumerate(dfs):
  long_tail = (
      df.loc[df['start_ns'] > min_start_time_ns]
      .loc[df['latency_us'] > min_latency_us]
      .replace({'src': i, 'dst': i})
      .groupby(['src', 'dst'])
      .size()
      .reset_index(name='counts')
  )
  print(f'Long tail host src-dest pair for {labels[i]}')
  display.display(long_tail)