<a href="https://colab.research.google.com/github/haonan2012/Dive2DeepLearning/blob/master/1_LinearRegression_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Utilities

In [None]:
def add_to_class(Class):
  """Register functions as methods in created class."""
  def wrapper(obj):
    setattr(Class, obj.__name__, obj)
  return wrapper

In [None]:
import inspect

class HyperParameters:
  """The base class of hyperparameters."""
  def save_hyperparameters(self, ignore=[]):
    """Save function arguments into class attributes."""
    frame = inspect.currentframe().f_back
    _, _, _, local_vars = inspect.getargvalues(frame)
    self.hparams = {k:v for k, v in local_vars.items()
            if k not in set(ignore+['self']) and not k.startswith('_')}
    for k, v in self.hparams.items():
      setattr(self, k, v)

In [None]:
import collections
class ProgressBoard(HyperParameters):
  """The board that plots data points in animation."""
  def __init__(self, xlabel=None, ylabel=None, xlim=None,
        ylim=None, xscale='linear', yscale='linear',
        ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
        fig=None, axes=None, figsize=(3.5, 2.5), display=True):
    self.save_hyperparameters()
  def draw(self, x, y, label, every_n=1):
    Point = collections.namedtuple('Point', ['x', 'y'])
    if not hasattr(self, 'raw_points'):
      self.raw_points = collections.OrderedDict()
      self.data = collections.OrderedDict()
    if label not in self.raw_points:
      self.raw_points[label] = []
      self.data[label] = []
    points = self.raw_points[label]
    line = self.data[label]
    points.append(Point(x, y))
    if len(points) != every_n:
      return
    mean = lambda x: sum(x) / len(x)
    line.append(Point(mean([p.x for p in points]),
            mean([p.y for p in points])))
    points.clear()
    if not self.display:
      return
    d2l.use_svg_display()
    if self.fig is None:
      self.fig = d2l.plt.figure(figsize=self.figsize)
    plt_lines, labels = [], []
    for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):
      plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v],
        linestyle=ls, color=color)[0])
      labels.append(k)
    axes = self.axes if self.axes else d2l.plt.gca()
    if self.xlim: axes.set_xlim(self.xlim)
    if self.ylim: axes.set_ylim(self.ylim)
    if not self.xlabel: self.xlabel = self.x
    axes.set_xlabel(self.xlabel)
    axes.set_ylabel(self.ylabel)
    axes.set_xscale(self.xscale)
    axes.set_yscale(self.yscale)
    axes.legend(plt_lines, labels)
    display.display(self.fig)
    display.clear_output(wait=True)