<a href="https://colab.research.google.com/github/halldm2000/NOAA-AI-2020-TUTORIAL/blob/master/taylor_series_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Functions for generating and plotting the data**

In [1]:
import torch
!mkdir -p images

def generate_data(nterms,  noise=0.2, rand_seed= 1):
  # generate a noisy taylor series

  torch.manual_seed(rand_seed) 
  x = torch.linspace(-1,1,100)                   
  y = sum(torch.randn(1)*x**(i) for i in range(nterms) ) + noise * torch.randn_like(x)
  return x,y

def plot_data(x,y,pred,i,loss):
  # plot training data points and the model prediction

  terms = [f'${w[i]:+.2f}x^{i}$ ' for i in range(len(w))]
  eqn = 'y ='+ ''.join(terms)
  
  plt.figure(figsize=(5,4),dpi=1.5*72)
  plt.plot(x,y,'.')
  plt.plot(x,pred, linewidth=3)
  plt.title(f"epoch={i}, loss = {loss:.4f}\n")
  plt.ylim(y.min(),y.max())
  plt.text(0.5, 1.02, eqn, transform=plt.gca().transAxes, fontsize=8,horizontalalignment='center')
  plt.savefig(f'./images/img_{i:03d}',bbox_inches='tight');
  plt.show()


**Fit Model to Data**

In [None]:
import torch
import matplotlib.pyplot as plt

# DATA
x,y = generate_data(nterms = 6, noise=0.1, rand_seed=5)

# MODEL
w = torch.zeros(6, requires_grad = True)                  
def model(x): return sum(w[i]*x**i for i in range(len(w))) 

# OPTIMIZER
optimizer = torch.optim.Adam(params = [w], lr=2e-2)

# TRAIN
for epoch in range(100):

  prediction = model(x)

  optimizer.zero_grad()
  loss = (prediction-y).abs().mean()
  loss.backward()
  optimizer.step()

  with torch.no_grad():
    plot_data(x,y,prediction,epoch,loss.item())


**Convert figures into a movie**

In [3]:
!ffmpeg -loglevel warning -i ./images/img_%03d.png -vf scale=1280:-2 -pix_fmt yuv420p -y out.mp4
!rm ./images/img_*.png

**Embed the movie in the notebook**

In [4]:
from IPython.display import HTML
from base64 import b64encode
mp4 = open('out.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()

HTML("""
<video width=800 controls><source src="%s" type="video/mp4"></video>
""" % data_url)