## LSM Fine-Tuned Activity Recognition Confusion Matrix
##### Colab Kernel (Brainframe CPU)
##### Dataset (Electrodes)

Grants command for Access on Demand (AoD):

https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-electrodes-deid-colab-jobs&reason=b%2F314799341

### About This Notebook:


In [None]:
# @title Imports

from google3.learning.deepmind.xmanager2.client import xmanager_api
import matplotlib as mpl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import collections
import numpy as np

from google3.pyglib import gfile
import os
import tensorflow as tf
import itertools
import tensorflow as tf

from typing import Sequence

In [None]:
import os
import re
import tempfile
import warnings
import collections
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

from google3.learning.deepmind.xmanager2.client import xmanager_api
from google3.pyglib import gfile
from google3.pyglib.function_utils import memoize
from matplotlib import font_manager

import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.lines import Line2D
from matplotlib.ticker import FixedLocator  # Import for the fix
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import LogLocator
from matplotlib.ticker import FuncFormatter
#@title Google Sans Import

# Import Google font family
_GOOGLE_SANS_PATH = (
    'google3/third_party/googlefonts/api/googlerestricted/googlesans/'
)

@memoize.Memoize()
def import_google3_fonts(font_path: str) -> None:
  """Import fonts stored in google3 into Matplotlib for use in Colab.

  Args:
    font_path: google3 path to either a directory that contains .ttf fonts or to
      a specific .ttf font file.
  """
  if gfile.IsDirectory(font_path):
    # Create a temp directory as a destination for copied font files.
    tmp_dir = tempfile.mkdtemp()
    # Copy font files from google3 to temp dir.
    gfile.RecursivelyCopyDir(font_path, tmp_dir, overwrite=True)
    # Add font files in directory to matplotlib font_manager.
    font_files = font_manager.findSystemFonts(fontpaths=tmp_dir)
  else:
    # Assume the path points to a file if it's not a directory.
    # Copy ttf file from google3 to temp location.
    tmp_file = tempfile.NamedTemporaryFile(suffix='.ttf')
    tmp_file.close()
    gfile.Copy(font_path, tmp_file.name)
    font_files = [tmp_file.name]

  # Add fonts to default font manager.
  for font_file in font_files:
    font_manager.fontManager.addfont(font_file)


def import_default_google_fonts() -> None:
  """Register a set of default fonts (Roboto, Google Sans) with Matplotlib."""
  # Prepend google_src to google3 paths.
  import_google3_fonts(os.path.join('/google_src/head/depot', _GOOGLE_SANS_PATH))


# Import and register Google fonts with Matplotlib so we can use them.
import_default_google_fonts()
#@title Set up Plot Settings

pd.set_option('display.max_rows', None)  # Show all rows
pd.set_option('display.max_columns', None)  # Show all columns
# Suppress specific warning
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 18
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'
plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE-5)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')
mpl.rcParams['font.family'] = 'Google Sans'



In [None]:
# @title CM Plotting Fn

def confusion_matrix_fig(
    confusion_matrix: tf.Tensor, labels: Sequence[str], scale: float = 0.8
) -> plt.Figure:
  """Returns a matplotlib plot of the given confusion matrix.

  Forked from:
  google3/fitbit/research/sensor_algorithms/training/logging/
  confusion_matrix_logging.py

  Args:
      confusion_matrix: Confusion matrix as 2D numpy array.
      labels: List of class names, will be used as axis labels.
      scale: Scale for the image size.
  """
  label_totals = np.sum(confusion_matrix, axis=1, keepdims=True)
  prediction_totals = np.sum(confusion_matrix, axis=0, keepdims=True)

  cm_normalized = np.nan_to_num(confusion_matrix / label_totals)

  num_labels = len(labels)
  longest_label = max([len(label) for label in labels])

  # Guesstimating an appropriate size.
  image_size = scale * (num_labels + (longest_label / 8.0))

  fig = plt.figure(
      figsize=(image_size, image_size), facecolor='w', edgecolor='k', dpi=100)
  ax = fig.add_subplot(1, 1, 1)
  ax.imshow(cm_normalized, cmap='Blues')

  tick_marks = np.arange(num_labels)

  # ax.set_xlabel('Predicted')
  ax.set_xticks(tick_marks)
  x_labels = (
      f'{label} ({int(count):,})'
      for label, count in zip(labels, prediction_totals[0, :])
  )
  ax.set_xticklabels(x_labels, rotation=-90, ha='center')
  ax.xaxis.set_label_position('bottom')
  ax.xaxis.tick_bottom()

  # ax.set_ylabel('True Label')
  ax.set_yticks(tick_marks)
  y_labels = (
      f'{label} ({int(count):,})'
      for label, count in zip(labels, label_totals[:, 0])
  )
  ax.set_yticklabels(y_labels, va='center')
  ax.yaxis.set_label_position('left')
  ax.yaxis.tick_left()

  for row_idx, col_idx in itertools.product(
      range(confusion_matrix.shape[0]), range(confusion_matrix.shape[1])
  ):
    text_color = 'white' if cm_normalized[row_idx, col_idx] >= 0.5 else 'black'
    if confusion_matrix[row_idx, col_idx] == 0:
      text_str = '.'
    else:
      text_str = (
          f'{cm_normalized[row_idx,col_idx]:2.0%}\n'
          f'({int(confusion_matrix[row_idx, col_idx]):,})'
      )
    ax.text(
        col_idx,
        row_idx,
        text_str,
        horizontalalignment='center',
        verticalalignment='center',
        color=text_color,
    )

  fig.set_tight_layout(True)
  plt.tight_layout()
  plt.subplots_adjust(bottom=0.1)  # Make space for the legend at the bottom
  plt.savefig("/tmp/confusion_matrix.pdf", bbox_inches='tight', format="pdf")
  plt.show()
  %download_file /tmp/confusion_matrix.pdf
  return fig

In [None]:
# @title LSMv1 Confusion Matrix

XID = 126268296
WID = 1

step = 300
file_name = os.path.join('/cns/dz-d/home/xliucs/lsm/xm/', str(XID), str(WID))
cm_file_name = os.path.join(file_name, f'valid_confusion_matrix_{step}.npy')
cm_labels_file_name = os.path.join(file_name, f'valid_confusion_matrix_labels_{step}.npy')

print('Reading CM File:', cm_file_name)
with gfile.Open(cm_file_name, 'rb') as f:
  cm = np.load(f)

print('Reading CM Labels File:', cm_labels_file_name)
with gfile.Open(cm_labels_file_name, 'rb') as f:
  cm_labels = np.load(f)
cm_labels = np.where(cm_labels == 'Strength training', 'Strength\nTraining', cm_labels)
confusion_matrix_fig(cm, cm_labels, scale=1.2);

In [None]:
# @title LSMv2 Confusion Matrix

XID = 163102248
WID = 1
WORKING_DIR = '/namespace/fitbit-medical-sandboxes/jg/partner/encrypted/chr-ards-fitbit-prod-research/deid/exp/dmcduff/ttl=52w/xm/'
step = 1000

plt.rc('font', size=25)  # Adjust the value as needed
plt.rcParams['xtick.labelsize'] = 25 # Increase x tick label size
plt.rcParams['ytick.labelsize'] = 25 # Increase y tick label size

file_name = os.path.join(WORKING_DIR, str(XID), str(WID))
cm_file_name = os.path.join(file_name, f'valid_confusion_matrix_{step}.npy')
cm_labels_file_name = os.path.join(file_name, f'valid_confusion_matrix_labels_{step}.npy')

print('Reading CM File:', cm_file_name)
with gfile.Open(cm_file_name, 'rb') as f:
  cm = np.load(f)

print('Reading CM Labels File:', cm_labels_file_name)
with gfile.Open(cm_labels_file_name, 'rb') as f:
  cm_labels = np.load(f)
cm_labels = np.where(cm_labels == 'Strength training', 'Strength\nTraining', cm_labels)

# cm_labels = np.where(cm_labels == 'Indoor climbing', 'Indoor\nclimbing', cm_labels)


confusion_matrix_fig(cm, cm_labels, scale=1.2);