In [None]:
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Note: This file may have been modified by ByteDance Inc.
"""

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import collections
import os
import re
import time
import sys
from IPython import display

if not sys.version_info >= (3, 0):
  sys.stdout.write("This notebook requires Python 3.x\n")
  sys.exit(1)

# Set this to the directory where logs from your experiments reside.
LOG_DIR="/tmp/sc_paper/logs"

def set_matplotlib_params(font_size=8, params=None):
  """Resets matplotlib defaults to nicer defaults, esp. for publications.

  Args:
    font_size: Default font size.
    params: A dict of other params to update, e.g. {'figure.figsize': [7, 5]}.
  """
  matplotlib.rcdefaults()

  # Define our own defaults:
  default_params = {
      'axes.labelsize': font_size,
      'axes.titlesize': font_size,
      'font.size': font_size,
      'legend.fontsize': font_size,
      'xtick.labelsize': font_size,
      'ytick.labelsize': font_size,
      'font.family': 'sans-serif',
      'pdf.fonttype': 42,  # Avoid Type 3 fonts in publication plots
      'ps.fonttype': 42,  # Avoid Type 3 fonts in publication plots
      'legend.frameon': False,
  }
  if params:
    default_params.update(params)
  matplotlib.rcParams.update(default_params)

set_matplotlib_params(10,
    {'legend.handlelength': 3.0,
     'figure.figsize': [6, 4],  
     'lines.linewidth': 1.5,
    })

def get_mean_and_std(x): # x is a list of lists
  """Computes mean and std vectors from matrices.
  
  Args:
    x: a list of lists of numbers. Inner lists must have same lengths.
  
  Returns:
    A dictionary with entries 'mean' and 'std', containing lists of the mean
    values and standard deviations, computed along axis 0.
  """  
  return {'mean': np.mean(x, axis=0).tolist(), 'std': np.std(x, axis=0).tolist()}


In [13]:
def find_key_val(lines, key, num_results=1):
  """Get "key=value" values from a list of lines.
  
  Args:
    lines: a list of strings containing key=value pairs
    key: a string key to search for in key=value expressions
    num_results: Expected number of results. This function fails if less or more
      lines starting with "key=" are found
  """
  lines = [l for l in lines if l.startswith(key)]    
  assert(len(lines) == num_results)
  if num_results == 1:
    return lines[0].lstrip(key)
  else:
    return [l.lstrip(prefix) for l  in lines]

def parse_logfile(f, results, abort_function):
  """Parses a log file from an experiment run.
  
  Args:
    f: log file path
    results: existing dictionary where results will be stored. See below for
      details on the contents of this dictionary.
    abort_function: function that takes the log file contents as list of lines
      and returns true if this file - based on the contents - should be ignored.
      This can be used to filter out files that do not match certain criteria,
      e.g. log files that do not contain a certain configuration such as
      "bias=0.5".
      
  Returns:
    A new entry in results, stored in results[replica], where replica is the
    replica number extraced from the log file. Any configuration is run
    num_replica times, and the i-th such replica will store "replica=i" in the
    log file. The contents are log file dependent:
      
    results[replica][learning_rate][mode]
    for mode==iid:
      results[learning_rate][mode]['avg'] = [...]
      results[learning_rate][mode]['raw'][test_group] = [...]
    for mode==sep:
      results[learning_rate][mode]['avg'] = [...]
      results[learning_rate][mode]['raw'][train_group][test_group] = [...]
    for mode==sc:
      results[learning_rate][mode]['sc']: [...] average over matrix
      results[learning_rate][mode]['pl']: [...] average over matrix diagonal
      results[learning_rate][mode]['raw'][train_group][test_group] = [...]
    # Yingxiang add fourier mode
    for mode==fourier:
      results[learning_rate][mode]['fourier']: [...] average over matrix
  """
      
  def get_losses(lines, num_groups, num_examples_per_day_per_group):
    """Extracts losses and corresponding #examples from logs. Ignore epochs."""
    num_examples=[[] for i in range(0, num_groups)]
    losses=[[] for i in range(0, num_groups)]
    r = r'^day (\d+), group (\d+): trained on (\d+) examples, loss=([\d.]+)$'
    for l in lines:
      m = re.search(r, l)
      if m:
        day = int(m.group(1))
        group = int(m.group(2))
        num_examples[group].append(int(day)*num_examples_per_day_per_group+int(m.group(3)))
        losses[group].append(float(m.group(3)))
    return num_examples, losses

  def get_sc_accuracies(lines, num_groups, num_days, num_examples_per_day):    
    num_examples_per_day_per_group = num_examples_per_day / num_days
    accuracies={}
    r = r' (\d+) on (\d+): day (\d+), group (\d+): num_train_examples (\d+) \(dt=\d+s\): num correct: \d+/\d+ \(([\d.]+)\)$'
    for l in lines:
      m=re.search(r, l)
      if m:
        train_group = int(m.group(1))
        test_group = int(m.group(2))
        day = int(m.group(3))
        assert test_group == int(m.group(4))
        accuracy = float(m.group(6))
        if not train_group in accuracies:
          accuracies[train_group] = {}
        if not test_group in accuracies[train_group]:
          accuracies[train_group][test_group] = {}
        accuracies[train_group][test_group][day] = accuracy
    res = {'raw': {}}
    for trg in accuracies:
      res['raw'][trg] = {}
      for tsg in accuracies[trg]:
        res['raw'][trg][tsg] = [accuracies[trg][tsg][d] for d in accuracies[trg][tsg]]
    res['pl'] = [np.mean([accuracies[g][g][d] for g in range(num_groups)]) for d in range(num_days)]

    return res

  def get_oco_accuracies(lines, num_groups, num_days, num_examples_per_day):    
    num_examples_per_day_per_group = num_examples_per_day / num_days
    accuracies={}
    r = r'oco (\d+) on (\d+): day (\d+), group (\d+): num_train_examples (\d+) \(dt=\d+s\): num correct: \d+/\d+ \(([\d.]+)\)$'
    for l in lines:
      m=re.search(r, l)
      if m:
        train_group = int(m.group(1))
        test_group = int(m.group(2))
        day = int(m.group(3))
        assert test_group == int(m.group(4))
        accuracy = float(m.group(6))
        if not train_group in accuracies:
          accuracies[train_group] = {}
        if not test_group in accuracies[train_group]:
          accuracies[train_group][test_group] = {}
        accuracies[train_group][test_group][day] = accuracy
    res = {'raw': {}}
    for trg in accuracies:
      res['raw'][trg] = {}
      for tsg in accuracies[trg]:
        res['raw'][trg][tsg] = [accuracies[trg][tsg][d] for d in accuracies[trg][tsg]]
    res['pl'] = [np.mean([accuracies[g][g][d] for g in range(num_groups)]) for d in range(num_days)]
    return res

  def get_sep_accuracies(lines, num_groups, num_days, num_examples_per_day):    
    num_examples_per_day_per_group = num_examples_per_day / num_days
    accuracies={}
    r = r'sep (\d+) on (\d+): day (\d+), group (\d+): num_train_examples (\d+) \(dt=\d+s\): num correct: \d+/\d+ \(([\d.]+)\)$'
    for l in lines:
      m=re.search(r, l)
      if m:
        train_group = int(m.group(1))
        test_group = int(m.group(2))
        day = int(m.group(3))
        assert test_group == int(m.group(4))
        accuracy = float(m.group(6))
        if not train_group in accuracies:
          accuracies[train_group] = {}
        if not test_group in accuracies[train_group]:
          accuracies[train_group][test_group] = {}
        accuracies[train_group][test_group][day] = accuracy        
    res = {'raw': {}}
    for trg in accuracies:
      res['raw'][trg] = {}
      for tsg in accuracies[trg]:
        res['raw'][trg][tsg] = [accuracies[trg][tsg][d] for d in accuracies[trg][tsg]]
    res['avg'] = [np.mean([accuracies[g][g][d] for g in range(num_groups)]) for d in range(num_days)]
    return res

  # Open the log file, check for integrity, and extract the configuration from
  # this run (learning rate, replica number, etc.).
  with open(f) as f:
    lines = [l.rstrip('\n') for l in f]
  assert lines[-1] == 'END_MARKER'
  if abort_function(lines):
    return
  learning_rate=float(find_key_val(lines, 'lr='))
  vocab_size=int(find_key_val(lines, 'vocab_size='))
  mode=find_key_val(lines, 'mode=')
  num_examples_per_day=int(find_key_val(lines, 'num_train_examples_per_day='))
  num_days=int(find_key_val(lines, 'num_days='))
  num_groups=int(find_key_val(lines, 'num_groups='))
  replica=int(find_key_val(lines, 'replica='))
  batch_size=int(find_key_val(lines, 'batch_size='))

  def get_fourier_accuracies(lines, num_groups, num_days, num_examples_per_day):    
    num_examples_per_day_per_group = num_examples_per_day / num_days
    accuracies={}
    r = r'fourier (\d+) on (\d+): day (\d+), group (\d+): num_train_examples (\d+) \(dt=\d+s\): num correct: \d+/\d+ \(([\d.]+)\)$'
    for l in lines:
      m=re.search(r, l)
      if m:
        train_group = int(m.group(1))
        test_group = int(m.group(2))
        day = int(m.group(3))
        assert test_group == int(m.group(4))
        accuracy = float(m.group(6))
        if not train_group in accuracies:
          accuracies[train_group] = {}
        if not test_group in accuracies[train_group]:
          accuracies[train_group][test_group] = {}
        accuracies[train_group][test_group][day] = accuracy
    res = {'raw': {}}
    for trg in accuracies:
      res['raw'][trg] = {}
      for tsg in accuracies[trg]:
        res['raw'][trg][tsg] = [accuracies[trg][tsg][d] for d in accuracies[trg][tsg]]
    res['pl'] = [np.mean([accuracies[g][g][d] for g in range(num_groups)]) for d in range(num_days)]
    return res

  def get_tod_accuracies(lines, num_groups, num_days, num_examples_per_day):    
    num_examples_per_day_per_group = num_examples_per_day / num_days
    accuracies={}
    r = r'time (\d+) on (\d+): day (\d+), group (\d+): num_train_examples (\d+) \(dt=\d+s\): num correct: \d+/\d+ \(([\d.]+)\)$'
    for l in lines:
      m=re.search(r, l)
      if m:
        train_group = int(m.group(1))
        test_group = int(m.group(2))
        day = int(m.group(3))
        assert test_group == int(m.group(4))
        accuracy = float(m.group(6))
        if not train_group in accuracies:
          accuracies[train_group] = {}
        if not test_group in accuracies[train_group]:
          accuracies[train_group][test_group] = {}
        accuracies[train_group][test_group][day] = accuracy
    res = {'raw': {}}
    for trg in accuracies:
      res['raw'][trg] = {}
      for tsg in accuracies[trg]:
        res['raw'][trg][tsg] = [accuracies[trg][tsg][d] for d in accuracies[trg][tsg]]
    res['pl'] = [np.mean([accuracies[g][g][d] for g in range(num_groups)]) for d in range(num_days)]
    return res

  # Extract the results from this run, depending on what mode was used.
  if not replica in results:
    results[replica] = {}
  if not learning_rate in results[replica]:
    results[replica][learning_rate] = {}
  if mode == 'pluralistic':
    results[replica][learning_rate][mode] = get_sc_accuracies(lines, num_groups, num_days, num_examples_per_day)
  elif mode == 'sep':
    results[replica][learning_rate][mode] = get_sep_accuracies(lines, num_groups, num_days, num_examples_per_day)
  elif mode == 'fourier':
    results[replica][learning_rate][mode] = get_fourier_accuracies(lines, num_groups, num_days, num_examples_per_day)
  elif mode == 'nline-learning':
    mode = 'online-learning'
    results[replica][learning_rate][mode] = get_oco_accuracies(lines, num_groups, num_days, num_examples_per_day)
  elif mode == 'time-feature':
    results[replica][learning_rate][mode] = get_tod_accuracies(lines, num_groups, num_days, num_examples_per_day)
  else:
    print(mode)
    raise ValueError('unknown mode %s' % mode)


In [None]:
# Parse log files, plot results.
files = [f for f in os.listdir(LOG_DIR) if f.endswith('.log')]
# Only use the results from runs that use a data bias of 0.5.
for bias in [0.7]:
  r = {}
  for i, f in enumerate(files):    
    parse_logfile(os.path.join(LOG_DIR, f), r, lambda lines: float(find_key_val(lines, 'bias='))!=bias)
  # days = range(1, len(r[0][list(r[0].keys())[0]]['iid']['avg']) + 1)
  days = range(1, 16)

  # Holds results with average and std values from across replicas.
  ravg = {}
  for lr in sorted(r[0]):
    ravg[lr] = {}
    ravg[lr]['pluralistic'] = {}
    ravg[lr]['pluralistic']['pl'] = get_mean_and_std([r[rep][lr]['pluralistic']['pl'] for rep in r])  # semi-cyclic plurastic-averaging
    ravg[lr]['fourier'] = {}
    ravg[lr]['fourier']['pl'] = get_mean_and_std([r[rep][lr]['fourier']['pl'] for rep in r])  # align with evaluation method of pluralistic approach
    ravg[lr]['oco'] = {}
    ravg[lr]['oco']['pl'] = get_mean_and_std([r[rep][lr]['online-learning']['pl'] for rep in r])  # align with evaluation method of pluralistic approach
    ravg[lr]['time'] = {}
    ravg[lr]['time']['pl'] = get_mean_and_std([r[rep][lr]['time-feature']['pl'] for rep in r])
    num_groups = len(r[0][lr]['pluralistic']['raw'])
    num_days = len(r[0][lr]['pluralistic']['raw'][0][0])
    # plot test accuracy as a function of days, for the four different modes.
    plt.figure()
    plt.errorbar(days, ravg[lr]['pluralistic']['pl']['mean'], ravg[lr]['pluralistic']['pl']['std'], label='Pluralistic', color='red')
    plt.errorbar(days, ravg[lr]['fourier']['pl']['mean'], ravg[lr]['fourier']['pl']['std'], label='Fourier Learning', color='purple')
    plt.errorbar(days, ravg[lr]['oco']['pl']['mean'], ravg[lr]['oco']['pl']['std'], label='Online Learning', color='blue')
    plt.errorbar(days, ravg[lr]['time']['pl']['mean'], ravg[lr]['time']['pl']['std'], label='Time-Feature', color='orange')
    plt.legend(loc='lower right')
    plt.ylim([0.5, 0.8])
    plt.xlabel('Day')
    plt.ylabel('Test Accuracy')
    plt.grid()
    plt.savefig("./fig/{lr}.eps".format(lr=lr))
plt.show()