# RNN Applications
- Time series predication
- Language modeling (Text generation)
- Text sentiment Analysis
- Named entity recongnition
- Translation
- Speech recognition
- Music Composition

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import sys

### Define different cells using nn module

In [2]:
# output is hidden_size
cell = nn.RNN(input_size = 4, hidden_size = 2, batch_first = True)
cell

RNN(4, 2, batch_first=True)

In [3]:
cell = nn.LSTM(input_size = 4, hidden_size = 2, batch_first = True)
cell

LSTM(4, 2, batch_first=True)

In [4]:
cell = nn.GRU(input_size = 4, hidden_size = 2, batch_first = True)
cell

GRU(4, 2, batch_first=True)

In [5]:
# inputs is (batch_size, seq_len, input_size) with batch_first = True 
# hidden is (num_layers, batch_size, hidden_size)

In [6]:
# feed letters as inputs
# One hot encoding
h = [1, 0, 0, 0]
e = [0, 1, 0, 0]
l = [0, 0, 1, 0]
o = [0, 0, 0, 1]

In [7]:
# batch_size, seq_len, input_size
inputs = torch.autograd.Variable(torch.Tensor([[h]]))
inputs.shape

torch.Size([1, 1, 4])

In [8]:
# num_layers, batch_size, hidden_size (Initialize randome)
hidden = torch.autograd.Variable(torch.randn(1, 1, 2))
hidden.shape

torch.Size([1, 1, 2])

In [9]:
out , hidden = cell(inputs, hidden)
# Expact two values as output 
print(out) 
print(hidden)

tensor([[[-1.3292,  0.8165]]], grad_fn=<TransposeBackward0>)
tensor([[[-1.3292,  0.8165]]], grad_fn=<ViewBackward>)


### Build RNN for multiple inputs

In [10]:
# batch_size, seq_len, input_size
cell = nn.RNN(input_size = 4, hidden_size = 2, batch_first = True)
inputs = torch.autograd.Variable(torch.Tensor([[h, e, l, l, o]]))
inputs.shape

torch.Size([1, 5, 4])

In [11]:
# num_layers, batch_size, hidden_size (Initialize randome)
hidden = torch.autograd.Variable(torch.randn(1, 1, 2))
out , hidden = cell(inputs, hidden)
out

tensor([[[ 0.5726,  0.7125],
         [-0.5578, -0.2074],
         [ 0.3379,  0.4796],
         [-0.5314,  0.0656],
         [ 0.5346, -0.1308]]], grad_fn=<TransposeBackward0>)

### Build Multiple Batch input

In [12]:
# batch_size, seq_len, input_size
cell = nn.RNN(input_size = 4, hidden_size = 2, batch_first = True)
inputs = torch.autograd.Variable(torch.Tensor([[h, e, l, l, o],
                                               [e, o, l, l, l],
                                               [l, l, e, e, l]]))
# Notice batch_first = True
# batch_size = 3
# seq_len = 5
# input_size = 4
inputs.shape

torch.Size([3, 5, 4])

In [13]:
# num_layers, batch_size, hidden_size (Initialize randome)
hidden = torch.autograd.Variable(torch.randn(1, 3, 2))
out , hidden = cell(inputs, hidden)
out

tensor([[[ 0.3644, -0.3014],
         [-0.3402, -0.0756],
         [ 0.3476, -0.6521],
         [ 0.5901, -0.8711],
         [-0.0508, -0.7759]],

        [[-0.4725,  0.0927],
         [-0.5096, -0.1364],
         [ 0.3390, -0.6280],
         [ 0.5838, -0.8674],
         [ 0.6595, -0.9117]],

        [[ 0.2894, -0.5575],
         [ 0.5621, -0.8534],
         [-0.1338, -0.3973],
         [-0.3892,  0.0837],
         [ 0.2920, -0.5933]]], grad_fn=<TransposeBackward0>)

### Train the network

We want to predict the following sequence:
- h --> i
- i --> h
- h --> e
- e --> l
- l --> l
- l --> o

There are five **input**  letters: h i e l  o 

There are fine **output** letters: h i e l  o 

Design our Loss function using **cross entropy**


In [14]:
torch.manual_seed(777)  # reproducibility
#            0    1    2    3    4
idx2char = ['h', 'i', 'e', 'l', 'o']

# Teach hihell -> ihello
x_data = [0, 1, 0, 2, 3, 3]   # hihell
one_hot_lookup = [[1, 0, 0, 0, 0],  # 0
                  [0, 1, 0, 0, 0],  # 1
                  [0, 0, 1, 0, 0],  # 2
                  [0, 0, 0, 1, 0],  # 3
                  [0, 0, 0, 0, 1]]  # 4

y_data = [1, 0, 2, 3, 3, 4]    # ihello
x_one_hot = [one_hot_lookup[x] for x in x_data]

# As we have one batch of samples, we will change them to variables only once
# inputs = Variable(torch.Tensor(x_one_hot))
# labels = Variable(torch.LongTensor(y_data))

In [15]:
# cell = nn.RNN(input_size = 5, hidden_size = 5, batch_first = True)
# inputs = torch.autograd.Variable(torch.Tensor(x_one_hot))
# labels = torch.autograd.Variable(torch.LongTensor(y_data))
# Notice batch_first = True
# batch_size = 1
# seq_len = 6
# input_size = 5
# inputs.shape

In [16]:
# As we have one batch of samples, we will change them to variables only once
inputs = Variable(torch.Tensor(x_one_hot))
labels = Variable(torch.LongTensor(y_data))

In [17]:
num_classes = 5
input_size = 5  # one-hot size
hidden_size = 5  # output from the RNN. 5 to directly predict one-hot
batch_size = 1   # one sentence
sequence_length = 1  # One by one
num_layers = 1  # one-layer rnn

In [18]:
class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.rnn = nn.RNN(input_size=input_size,
                          hidden_size=hidden_size, batch_first=True)

    def forward(self, hidden, x):
        # Reshape input (batch first)
        x = x.view(batch_size, sequence_length, input_size)

        # Propagate input through RNN
        # Input: (batch, seq_len, input_size)
        # hidden: (num_layers * num_directions, batch, hidden_size)
        out, hidden = self.rnn(x, hidden)
        return hidden, out.view(-1, num_classes)

    def init_hidden(self):
        # Initialize hidden and cell states
        # (num_layers * num_directions, batch, hidden_size)
        return Variable(torch.zeros(num_layers, batch_size, hidden_size))

In [19]:
model = Model()
print(model)
# Use cross entrpy loss and find the error
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr =0.1)

Model(
  (rnn): RNN(5, 5, batch_first=True)
)


In [23]:
# Train the model
for epoch in range(100):
    optimizer.zero_grad()
    loss = 0
    hidden = model.init_hidden()
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple RNN\n",
    "\n",
    "In ths notebook, we're going to train a simple RNN to do **time-series prediction**. Given some set of input data, it should be able to generate a prediction for the next time step!\n",
    "<img src='assets/time_prediction.png' width=40% />\n",
    "\n",
    "> * First, we'll create our data\n",
    "* Then, define an RNN in PyTorch\n",
    "* Finally, we'll train our network and see how it performs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import resources and create data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAEyCAYAAADA/hjIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAG0xJREFUeJzt3X+QVOW95/HPlxnGMQFMBUg2MupQudwlSAFem8EGUvY6msVbG4hZ3YC4XpIodZOwJtkkJbgp40ol5F7dtcrVJOtNLC+J8cfqqqxFyroZ7YjaRpqo2QAhNQFcB1JhLhrE3OAww3f/OM1kZuxhzjDndD/T/X5VTZ053c885zvPnD6fPk/3nDZ3FwAACMeEahcAAAAGI5wBAAgM4QwAQGAIZwAAAkM4AwAQGMIZAIDAEM4AAASGcAYAIDCEMwAAgWms1oanTZvmra2t1do8AAAVtWPHjn929+lx2lYtnFtbW1UsFqu1eQAAKsrMXovblmltAAACQzgDABAYwhkAgMBU7TVnAED4jh8/rq6uLh07dqzapYwbzc3Namlp0cSJE0+7D8IZADCsrq4uTZ48Wa2trTKzapcTPHfX4cOH1dXVpZkzZ552P0xrAwCGdezYMU2dOpVgjsnMNHXq1DHPNBDOAIBTIphHJ4nxGjGczexeMztkZr8a5n4zszvNrNPMfmlmfzXmqgAAqGNxzpzvk7TsFPdfLmlW6WutpO+OvSwAo1EoSJs2RcuwOwVGb/HixYn3uX//fv34xz9OvN+kjPiGMHd/1sxaT9FkhaTN7u6SXjSz95nZh9z9dwnVCOAUCgWpvV3q6ZGamqSODimbDbFT4PS88MILifd5MpyvvvrqxPtOQhKvOc+Q9PqA9a7Sbe9iZmvNrGhmxe7u7gQ2DSCfjzK0ry9a5vOhdoq6kfCsy6RJkyRJ+XxeuVxOV155pWbPnq3Vq1crOi+MLgl94403qq2tTW1tbers7JQkrVmzRo888si7+lq/fr22bdumBQsW6I477hh229u3b9e8efN07Ngx/fGPf9T555+vX/2q7Ku8iUriX6nKvfLt5Rq6+z2S7pGkTCZTtg1QywqFKOdyueRORHM5qamxTz0npKZGKZdrSKTTQsNS5U8sUa7heWVzubH3KaUzAAhLyrMuL7/8snbu3Kmzzz5bS5Ys0fPPP6+lS5dKkqZMmaKXXnpJmzdv1pe+9CU9+eSTw/bz7W9/W7fffvsp20jSwoULtXz5cn3961/Xn/70J11zzTWaO3duYr/PcJII5y5J5wxYb5F0MIF+gZqS1jErq4I6fIPyWqKcP6+sNkkaW8cFZdVuHeqRqclcHWoYY49iqrxelJt1SfDv3NbWppaWFknSggULtH///v5wXrVqVf/yy1/+cmLbvPnmm7Vw4UI1NzfrzjvvTKzfU0liWnuLpGtL79q+SNIRXm8G3i21meJ8Xtm+57TBv6Vs33OJdJzPSz29DerzCerpbWCqHPHlctGTr4aGaJnUrEvJGWec0f99Q0ODent7+9cH/gvTye8bGxt14sQJSdEFQnp6eka9zTfeeENvv/22jh49WrErpcX5V6oHJBUk/Wsz6zKzz5rZ35rZ35aabJW0V1KnpH+Q9PnUqgXGsdSOWSl0nEqtKR+0EYhsNpoV2bix4rMjDz30UP8yW9pua2urduzYIUl64okndPz4cUnS5MmTdfTo0f6fPXDggNrb28v2u3btWm3cuFGrV6/WjTfemOav0C/Ou7VXjXC/S/pCYhUBNerkMSvxl1xT6DiVWlMbAAQnm63K3/edd97RokWLdOLECT3wwAOSpOuvv14rVqxQW1ub2tvb9d73vleSNG/ePDU2Nmr+/Plas2aNPvrRj6qx8d2RuHnzZjU2Nurqq69WX1+fFi9erKefflqXXHJJqr+LnXynW6VlMhkvFotV2TYAIJ7du3frIx/5SLXLGFFra6uKxaKmTZt2Wj9/11136dxzz9Xy5csTqafcuJnZDnfPxPl5PvgCAFD31q1bV+0SBiGcAQDj3v79+6tdQqL44AsAAAJDOAPD4HrVyUvt16/zcUXtYVobKIPrVScvtV+/zscVtYkzZ6AMrledvDQvwlLP44raRDgDZXARjuSNp4uwIBx/+MMf9J3vfKci28rn86l8AtbpYFobKIOLcCRvPF2EBeE4Gc6f/3z8i0+6u9xdEyaM7vwzn89r0qRJqXx+9Kid/CUq/XXhhRc6ACBsu3btGvXPvPCC+7e+FS3H6lOf+pQ3Nzf7/Pnz/atf/aofPXrUL7nkEr/gggt87ty5/vjjj7u7+759+3z27Nn+uc99zhcsWOD79+/373//+z5r1iy/+OKL/brrrvMvfOEL7u5+6NAh/+QnP+mZTMYzmYw/99xzvm/fPv/gBz/oZ599ts+fP9+fffbZYWtaunSpv/zyy/3rixcv9ldffXVQm3LjJqnoMTOScAYADGu04fzCC+5nnune0BAtxxrQ+/bt8/PPP79//fjx437kyBF3d+/u7vYPf/jDfuLECd+3b5+bmRcKBXd3P3DggJ933nl++PBh7+np8aVLl/aH86pVq3zbtm3u7v7aa6/57Nmz3d39G9/4ht92220j1nTffff5F7/4RXd337Nnj5fLs7GGM9PaAIDEpPyJkXJ33XTTTXr22Wc1YcIEHThwQL///e8lSeedd54uuugiSdJLL72kiy++WO9///slSVdddZV+85vfSJJ++tOfateuXf19vvXWW4M+BGMkV111lTZu3KjbbrtN9957r9asWZPQb/dnhDMAIDEn35938j/bkn5/3v3336/u7m7t2LFDEydOVGtra//HOJ78UAspCvHhnDhxQoVCQWeeeeZp1fCe97xHl112mZ544gk9/PDDSuNzIni3NgAgMUl/YuTQj3Y8cuSIPvCBD2jixIl65pln9Nprr5X9uba2Nv3sZz/Tm2++qd7eXj366KP9933sYx/TXXfd1b/+yiuvlN3WY489pg0bNpTt/7rrrtMNN9yghQsX9p+dJ4lwBgAkKpuVNmxIZjp76tSpWrJkiebOnauvfe1rWr16tYrFojKZjO6//37Nnj277M/NmDFDN910kxYtWqRLL71Uc+bM0VlnnSVJuvPOO1UsFjVv3jzNmTNH3/ve9yRJH//4x/XYY49pwYIF2rZtm377299qypQpZfu/8MILNWXKFH36058e+y9ZBh8ZCQAY1nj5yMhy3n77bU2aNEm9vb264oor9JnPfEZXXHFF7J+/5pprdMcdd2j69Onvuu/gwYPK5XL69a9/XfZftsb6kZGcOQMAatItt9yiBQsWaO7cuZo5c6Y+8YlPjOrnf/SjH5UN5s2bN2vRokX65je/Oer/pY6LN4QBAGrS7bffnkq/1157ra699tpU+j6JM2cAwClV6+XP8SqJ8SKcURP4eMf6xt8/Pc3NzTp8+DABHZO76/Dhw2pubh5TP0xrY9zj4x3rG3//dLW0tKirq0vd3d3VLmXcaG5uVktLy5j6IJwx7qVyRaK0L3OExPD3T9fEiRM1c+bMapdRd5jWxrjHxzvWN/7+qEX8nzNqQqGQwicGptIp0sDfH+PBaP7PmXAGAKACuAgJAADjGOEMAEBgCGcAAAJDOAMAEBjCGQCAwBDOAAAEhnAGACAwhDMAAIEhnAEACAzhDABAYAhnAAACQzgDABAYwhkAgMAQzgAABIZwBgAgMLHC2cyWmdkeM+s0s/Vl7j/XzJ4xs5fN7Jdm9tfJl4paUChImzZFy/HRMepVKrsU+yliahypgZk1SLpb0mWSuiRtN7Mt7r5rQLOvS3rY3b9rZnMkbZXUmkK9GMcKBam9XerpkZqapI4OKZsNuWPUq1R2KfZTjEKcM+c2SZ3uvtfdeyQ9KGnFkDYuaUrp+7MkHUyuRNSKfD46LvX1Rct8PvSOUa9S2aXYTzEKccJ5hqTXB6x3lW4b6BZJ15hZl6Kz5v+USHWoKblcdMLQ0BAtc7nQO0a9SmWXYj/FKIw4rS3JytzmQ9ZXSbrP3f+bmWUl/dDM5rr7iUEdma2VtFaSzj333NOpF+NYNhvN5OXz0XEpsRm91DpGvUpll2I/xSiY+9CcHdIgCttb3P3fltY3SJK7bxrQZqekZe7+eml9r6SL3P3QcP1mMhkvFotj/w0AABgHzGyHu2fitI0zrb1d0iwzm2lmTZJWStoypM3/k9Re2vhHJDVL6o5fMgAAOGnEcHb3XknrJD0labeid2XvNLNbzWx5qdlXJF1vZq9KekDSGh/plBwAAJQV5zVnuftWRW/0GnjbzQO+3yVpSbKlAQBQn7hCGAAAgSGcAQAIDOEMAEBgCGcAAAJDOAMAEBjCGQCAwBDOAAAEhnAGACAwhDMAAIEhnAEACAzhDABAYAhnAAACQzgDABAYwhkAgMAQzgAABIZwBgAgMIQzhlUoSJs2RcuwOwXGh9R2fx5XNaex2gUgTIWC1N4u9fRITU1SR4eUzYbYKTA+pLb787iqSZw5o6x8Pnqs9/VFy3w+1E6B8SG13Z/HVU0inFFWLhc9CW9oiJa5XKidAuNDars/j6uaZO5elQ1nMhkvFotV2TbiKRSiJ+G5XIKzZKl0CowPqe3+PK7GBTPb4e6ZWG0JZwAA0jeacGZaGwCAwBDOAAAEhnAGACAwhDMAAIEhnAEACAzhDABAYAhnAAACQzgDABAYwhkAgMAQzgAABIZwBgAgMIQzAACBIZwBAAgM4QwAQGAIZwAAAkM4AwAQGMIZAIDAEM4AAASGcAYAIDCxwtnMlpnZHjPrNLP1w7T5D2a2y8x2mtmPky0TAID60ThSAzNrkHS3pMskdUnabmZb3H3XgDazJG2QtMTd3zSzD6RVMAAAtS7OmXObpE533+vuPZIelLRiSJvrJd3t7m9KkrsfSrZMAADqR5xwniHp9QHrXaXbBvpLSX9pZs+b2YtmtqxcR2a21syKZlbs7u4+vYoBAKhxccLZytzmQ9YbJc2SlJO0StL3zex97/oh93vcPePumenTp4+2VgyjUJA2bYqW46NjAElK5aHK47+qRnzNWdGZ8jkD1lskHSzT5kV3Py5pn5ntURTW2xOpEsMqFKT2dqmnR2pqkjo6pGw25I4BJCmVhyqP/6qLc+a8XdIsM5tpZk2SVkraMqTN45L+jSSZ2TRF09x7kywU5eXz0eOnry9a5vOhdwwgSak8VHn8V92I4ezuvZLWSXpK0m5JD7v7TjO71cyWl5o9Jemwme2S9Iykr7n74bSKxp/lctET24aGaJnLhd4xgCSl8lDl8V915j705ePKyGQyXiwWq7LtWlMoRE9sc7mEZ55S6xhAklJ5qPL4T5yZ7XD3TKy2hDMAAOkbTThz+U4AAAJDOAMAEBjCGQCAwBDOAAAEhnAGACAwhDMAAIEhnAEACAzhDABAYAhnAAACQzgDABAYwhkAgMAQzgAABIZwBgAgMIQzAACBIZwBAAgM4QwAQGAIZwAAAkM4AwAQGMIZAIDAEM4AAASGcAYAIDCEMwAAgSGcAQAIDOEMAEBgCGcAAAJDOAMAEBjCGQCAwBDOFVYoSJs2RcuwOwVQzzhWVVdjtQuoJ4WC1N4u9fRITU1SR4eUzYbYKYB6xrGq+jhzrqB8Ptov+/qiZT4faqcA6hnHquojnCsol4ueMDY0RMtcLtROAdQzjlXVZ+5elQ1nMhkvFotV2XY1FQrRE8ZcLsEZnVQ6BVDPOFYlz8x2uHsmVlvCGQCA9I0mnJnWBgAgMIQzAACBIZwBAAgM4QwAQGAIZwAAAkM4AwAQmFjhbGbLzGyPmXWa2fpTtLvSzNzMYr1VHAAAvNuI4WxmDZLulnS5pDmSVpnZnDLtJku6QdLPky4SAIB6EufMuU1Sp7vvdfceSQ9KWlGm3UZJfy/pWIL1AQBQd+KE8wxJrw9Y7yrd1s/MLpB0jrs/mWBtAADUpTjhbGVu67/mp5lNkHSHpK+M2JHZWjMrmlmxu7s7fpUAANSROOHcJemcAestkg4OWJ8saa6kvJntl3SRpC3l3hTm7ve4e8bdM9OnTz/9qgEAqGFxwnm7pFlmNtPMmiStlLTl5J3ufsTdp7l7q7u3SnpR0nJ351MtAAA4DSOGs7v3Slon6SlJuyU97O47zexWM1uedoEAANSbxjiN3H2rpK1Dbrt5mLa5sZcFAED94gphAAAEhnAGACAwhDMAAIEhnAEACAzhDABAYAhnAAACQzgDABAYwhkAgMAQzgAABIZwBgAgMIQzAACBIZwBAAgM4QwAQGAIZwAAAkM4AwAQGML5FAoFadOmaBl2pwAQvtQOfzV4XG2sdgGhKhSk9napp0dqapI6OqRsNsROASB8qR3+avS4ypnzMPL56G/d1xct8/lQOwWA8KV2+KvR4yrhPIxcLnoS1tAQLXO5UDsFgPCldvir0eOquXtVNpzJZLxYLFZl23EVCtGTsFwuwVmSVDoFgPCldvgbJ8dVM9vh7plYbQlnAADSN5pwZlobAIDAEM4AAASGcAYAIDCEMwAAgSGcAQAIDOEMAEBgCGcAAAJDOAMAEBjCGQCAwBDOAAAEhnAGACAwhDMAAIEhnAEACAzhDABAYAhnAAACQzgDABAYwhkAgMAQzgAABIZwBgAgMIQzAACBiRXOZrbMzPaYWaeZrS9z/382s11m9ksz6zCz85IvFQCA+jBiOJtZg6S7JV0uaY6kVWY2Z0izlyVl3H2epEck/X3ShQIAUC/inDm3Sep0973u3iPpQUkrBjZw92fc/V9Kqy9Kakm2TAAA6keccJ4h6fUB612l24bzWUk/KXeHma01s6KZFbu7u+NXCQBAHYkTzlbmNi/b0OwaSRlJt5W7393vcfeMu2emT58ev0oAAOpIY4w2XZLOGbDeIung0EZmdqmk/yLpYnd/J5nyAACoP3HOnLdLmmVmM82sSdJKSVsGNjCzCyT9T0nL3f1Q8mUCAFA/Rgxnd++VtE7SU5J2S3rY3Xea2a1mtrzU7DZJkyT9LzN7xcy2DNMdAAAYQZxpbbn7Vklbh9x284DvL024LgAA6hZXCAMAIDCEMwAAgSGcAQAITE2Ec6EgbdoULcdHxwCApKRyqK7y8T/WG8JCVihI7e1ST4/U1CR1dEjZbMgdAwCSksqhOoDj/7g/c87no/Hr64uW+XzoHQMAkpLKoTqA4/+4D+dcLnpi09AQLXO50DsGACQllUN1AMd/cy97mezUZTIZLxaLifRVKERPbHK5hGceUusYAJCUVA7VKXRqZjvcPROrbS2EMwAAoRtNOI/7aW0AAGoN4QwAQGAIZwAAAkM4AwAQGMIZAIDAEM4AAASGcAYAIDCEMwAAgSGcAQAIDOEMAEBgCGcAAAJDOAMAEBjCGQCAwBDOAAAEhnAGACAwhDMAAIEhnAEACAzhDABAYAhnAAACQzgDABAYwhkAgMAQzgAABIZwBgAgMIQzAACBIZwBAAgM4QwAQGAIZwAAAkM4AwAQGMIZAIDAEM4AAASGcAYAIDCxwtnMlpnZHjPrNLP1Ze4/w8weKt3/czNrTbpQAADqxYjhbGYNku6WdLmkOZJWmdmcIc0+K+lNd/8LSXdI+rukCz2lQkHatClaAgAwRtWOlcYYbdokdbr7XkkyswclrZC0a0CbFZJuKX3/iKS7zMzc3ROstbxCQWpvl3p6pKYmqaNDymZT3ywAoDaFECtxprVnSHp9wHpX6baybdy9V9IRSVOHdmRma82saGbF7u7u06t4qHw+GsG+vmiZzyfTLwCgLoUQK3HC2crcNvSMOE4bufs97p5x98z06dPj1DeyXC56atPQEC1zuWT6BQDUpRBiJc60dpekcwast0g6OEybLjNrlHSWpDcSqXAk2Ww055DPRyPIlDYAYAxCiJU44bxd0iwzmynpgKSVkq4e0maLpL+RVJB0paSnK/J680nZLKEMAEhMtWNlxHB2914zWyfpKUkNku51951mdqukortvkfQDST80s05FZ8wr0ywaAIBaFufMWe6+VdLWIbfdPOD7Y5KuSrY0AADqE1cIAwAgMIQzAACBIZwBAAgM4QwAQGAIZwAAAkM4AwAQGMIZAIDAWCUv5DVow2bdkl6rysbDM03SP1e7iIAwHoMxHoMxHoMxHoOFPB7nuXusD5aoWjjjz8ys6O6ZatcRCsZjMMZjMMZjMMZjsFoZD6a1AQAIDOEMAEBgCOcw3FPtAgLDeAzGeAzGeAzGeAxWE+PBa84AAASGM2cAAAJDOAMAEBjCuYLMbJmZ7TGzTjNbX+b+NWbWbWavlL6uq0adlWBm95rZITP71TD3m5ndWRqrX5rZX1W6xkqKMR45MzsyYN+4uVy7WmFm55jZM2a228x2mtkXy7Spm30k5njUzT5iZs1m9pKZvVoaj/9aps0ZZvZQaf/4uZm1Vr7S09dY7QLqhZk1SLpb0mWSuiRtN7Mt7r5rSNOH3H1dxQusvPsk3SVp8zD3Xy5pVulrkaTvlpa16j6dejwkaZu7/7vKlFN1vZK+4u6/MLPJknaY2T8NebzU0z4SZzyk+tlH3pF0ibu/bWYTJT1nZj9x9xcHtPmspDfd/S/MbKWkv5P0qWoUezo4c66cNkmd7r7X3XskPShpRZVrqhp3f1bSG6doskLSZo+8KOl9ZvahylRXeTHGo664++/c/Rel749K2i1pxpBmdbOPxByPulH6m79dWp1Y+hr67uYVkv6x9P0jktrNzCpU4pgRzpUzQ9LrA9a7VP7B9e9LU3SPmNk5lSktSHHHq55kS9N4PzGz86tdTKWUpiMvkPTzIXfV5T5yivGQ6mgfMbMGM3tF0iFJ/+Tuw+4f7t4r6YikqZWt8vQRzpVT7hnb0Gd6/0dSq7vPk/RT/flZXz2KM1715BeKrss7X9L/kPR4leupCDObJOlRSV9y97eG3l3mR2p6HxlhPOpqH3H3PndfIKlFUpuZzR3SZFzvH4Rz5XRJGngm3CLp4MAG7n7Y3d8prf6DpAsrVFuIRhyveuLub52cxnP3rZImmtm0KpeVqtJriY9Kut/d/3eZJnW1j4w0HvW4j0iSu/9BUl7SsiF39e8fZtYo6SyNo5eOCOfK2S5plpnNNLMmSSslbRnYYMjrZcsVva5Ur7ZIurb0jtyLJB1x999Vu6hqMbN/dfL1MjNrU/TYPVzdqtJT+l1/IGm3u//3YZrVzT4SZzzqaR8xs+lm9r7S92dKulTSr4c02yLpb0rfXynpaR9HV93i3doV4u69ZrZO0lOSGiTd6+47zexWSUV33yLpBjNbruidmW9IWlO1glNmZg9IykmaZmZdkr6h6E0dcvfvSdoq6a8ldUr6F0mfrk6llRFjPK6U9Dkz65X0J0krx9OB5jQskfQfJf3f0uuKknSTpHOlutxH4oxHPe0jH5L0j6X/gpkg6WF3f3LI8fQHkn5oZp2Kjqcrq1fu6HH5TgAAAsO0NgAAgSGcAQAIDOEMAEBgCGcAAAJDOAMAEBjCGQCAwBDOAAAE5v8D/l6MKUP+AowAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(8,5))\n",
    "\n",
    "# how many time steps/data pts are in one batch of data\n",
    "seq_length = 20\n",
    "\n",
    "# generate evenly spaced data pts\n",
    "time_steps = np.linspace(0, np.pi, seq_length + 1)\n",
    "data = np.sin(time_steps)\n",
    "data.resize((seq_length + 1, 1)) # size becomes (seq_length+1, 1), adds an input_size dimension\n",
    "\n",
    "x = data[:-1] # all but the last piece of data\n",
    "y = data[1:] # all but the first\n",
    "\n",
    "# display the data\n",
    "plt.plot(time_steps[1:], x, 'r.', label='input, x') # x\n",
    "plt.plot(time_steps[1:], y, 'b.', label='target, y') # y\n",
    "\n",
    "plt.legend(loc='best')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Define the RNN\n",
    "\n",
    "Next, we define an RNN in PyTorch. We'll use `nn.RNN` to create an RNN layer, then we'll add a last, fully-connected layer to get the output size that we want. An RNN takes in a number of parameters:\n",
    "* **input_size** - the size of the input\n",
    "* **hidden_dim** - the number of features in the RNN output and in the hidden state\n",
    "* **n_layers** - the number of layers that make up the RNN, typically 1-3; greater than 1 means that you'll create a stacked RNN\n",
    "* **batch_first** - whether or not the input/output of the RNN will have the batch_size as the first dimension (batch_size, seq_length, hidden_dim)\n",
    "\n",
    "Take a look at the [RNN documentation](https://pytorch.org/docs/stable/nn.html#rnn) to read more about recurrent layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_size, output_size, hidden_dim, n_layers):\n",
    "        super(RNN, self).__init__()\n",
    "        \n",
    "        self.hidden_dim=hidden_dim\n",
    "\n",
    "        # define an RNN with specified parameters\n",
    "        # batch_first means that the first dim of the input and output will be the batch_size\n",
    "        self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)\n",
    "        \n",
    "        # last, fully-connected layer\n",
    "        self.fc = nn.Linear(hidden_dim, output_size)\n",
    "\n",
    "    def forward(self, x, hidden):\n",
    "        # x (batch_size, seq_length, input_size)\n",
    "        # hidden (n_layers, batch_size, hidden_dim)\n",
    "        # r_out (batch_size, time_step, hidden_size)\n",
    "        batch_size = x.size(0)\n",
    "        \n",
    "        # get RNN outputs\n",
    "        r_out, hidden = self.rnn(x, hidden)\n",
    "        # shape output to be (batch_size*seq_length, hidden_dim)\n",
    "        r_out = r_out.view(-1, self.hidden_dim)  \n",
    "        \n",
    "        # get final output \n",
    "        output = self.fc(r_out)\n",
    "        \n",
    "        return output, hidden\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check the input and output dimensions\n",
    "\n",
    "As a check that your model is working as expected, test out how it responds to input data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input size:  torch.Size([1, 20, 1])\n",
      "Output size:  torch.Size([20, 1])\n",
      "Hidden state size:  torch.Size([2, 1, 10])\n"
     ]
    }
   ],
   "source": [
    "# test that dimensions are as expected\n",
    "test_rnn = RNN(input_size=1, output_size=1, hidden_dim=10, n_layers=2)\n",
    "\n",
    "# generate evenly spaced, test data pts\n",
    "time_steps = np.linspace(0, np.pi, seq_length)\n",
    "data = np.sin(time_steps)\n",
    "data.resize((seq_length, 1))\n",
    "\n",
    "test_input = torch.Tensor(data).unsqueeze(0) # give it a batch_size of 1 as first dimension\n",
    "print('Input size: ', test_input.size())\n",
    "\n",
    "# test out rnn sizes\n",
    "test_out, test_h = test_rnn(test_input, None)\n",
    "print('Output size: ', test_out.size())\n",
    "print('Hidden state size: ', test_h.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Training the RNN\n",
    "\n",
    "Next, we'll instantiate an RNN with some specified hyperparameters. Then train it over a series of steps, and see how it performs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RNN(\n",
      "  (rnn): RNN(1, 32, batch_first=True)\n",
      "  (fc): Linear(in_features=32, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# decide on hyperparameters\n",
    "input_size=1 \n",
    "output_size=1\n",
    "hidden_dim=32\n",
    "n_layers=1\n",
    "\n",
    "# instantiate an RNN\n",
    "rnn = RNN(input_size, output_size, hidden_dim, n_layers)\n",
    "print(rnn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss and Optimization\n",
    "\n",
    "This is a regression problem: can we train an RNN to accurately predict the next data point, given a current data point?\n",
    "\n",
    ">* The data points are coordinate values, so to compare a predicted and ground_truth point, we'll use a regression loss: the mean squared error.\n",
    "* It's typical to use an Adam optimizer for recurrent models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MSE loss and Adam optimizer with a learning rate of 0.01\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(rnn.parameters(), lr=0.01) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the training function\n",
    "\n",
    "This function takes in an rnn, a number of steps to train for, and returns a trained rnn. This function is also responsible for displaying the loss and the predictions, every so often.\n",
    "\n",
    "#### Hidden State\n",
    "\n",
    "Pay close attention to the hidden state, here:\n",
    "* Before looping over a batch of training data, the hidden state is initialized\n",
    "* After a new hidden state is generated by the rnn, we get the latest hidden state, and use that as input to the rnn for the following steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the RNN\n",
    "def train(rnn, n_steps, print_every):\n",
    "    \n",
    "    # initialize the hidden state\n",
    "    hidden = None      \n",
    "    \n",
    "    for batch_i, step in enumerate(range(n_steps)):\n",
    "        # defining the training data \n",
    "        time_steps = np.linspace(step * np.pi, (step+1)*np.pi, seq_length + 1)\n",
    "        data = np.sin(time_steps)\n",
    "        data.resize((seq_length + 1, 1)) # input_size=1\n",
    "\n",
    "        x = data[:-1]\n",
    "        y = data[1:]\n",
    "        \n",
    "        # convert data into Tensors\n",
    "        x_tensor = torch.Tensor(x).unsqueeze(0) # unsqueeze gives a 1, batch_size dimension\n",
    "        y_tensor = torch.Tensor(y)\n",
    "\n",
    "        # outputs from the rnn\n",
    "        prediction, hidden = rnn(x_tensor, hidden)\n",
    "\n",
    "        ## Representing Memory ##\n",
    "        # make a new variable for hidden and detach the hidden state from its history\n",
    "        # this way, we don't backpropagate through the entire history\n",
    "        hidden = hidden.data\n",
    "\n",
    "        # calculate the loss\n",
    "        loss = criterion(prediction, y_tensor)\n",
    "        # zero gradients\n",
    "        optimizer.zero_grad()\n",
    "        # perform backprop and update weights\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # display loss and predictions\n",
    "        if batch_i%print_every == 0:        \n",
    "            print('Loss: ', loss.item())\n",
    "            plt.plot(time_steps[1:], x, 'r.') # input\n",
    "            plt.plot(time_steps[1:], prediction.data.numpy().flatten(), 'b.') # predictions\n",
    "            plt.show()\n",
    "    \n",
    "    return rnn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss:  0.3019888401031494\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAEkhJREFUeJzt3X+M5Hddx/HXq9suNLbShFtj096xVY/EE1HaybWTJjJmC177x90fVnMlCCXIGU1FIzFpFQu2ho0QJUEbyikNpcH+sBqyNkeqLp2Q2Ll6c/wo3NWa5RRuaZMutRYNwnLH2z++s2WYzt58d/e73+98P/N8JJuZ78xnZt7f/X6/r/3OZ2c+H0eEAABpOa/qAgAAxSPcASBBhDsAJIhwB4AEEe4AkCDCHQASRLgDQIIIdwBIEOEOAAk6v6oX3rFjR8zOzlb18gBQS8ePH/9mRMyMaldZuM/Ozqrb7Vb18gBQS7a/lqcd3TIAkCDCHQASRLgDQIIIdwBIEOEOAAkaGe6277H9nO2vrHO/bX/E9pLtJ21fWXyZAICNyHPm/glJ+85x//WSdvd+Dkn66NbLAsZUpyPNz2eXwBgb+Tn3iPic7dlzNDkg6ZORzdd31PYlti+NiGcLqhEYD52ONDcnra5K09PS4qLUbFZdFTBUEX3ul0k63be83LvtZWwfst213V1ZWSngpYEStdtZsJ89m12221VXBKyriHD3kNuGzrodEYcjohERjZmZkd+eBYq3lW6VVis7Y5+ayi5brfJrAHIqYviBZUk7+5Yvl/RMAc8LFGur3SrNZvaYdjsL9s10ydC1g5IUcea+IOltvU/NXCPpRfrbMZaK6FZpNqXbbtt8INO1g5KMPHO3fb+klqQdtpclvU/SBZIUEXdLOiLpBklLkr4t6R3bVSywJWvdKmtnzZvtVql7DZgIzj7kUr5GoxGMConSdTpb61ZJpQbUlu3jEdEY2Y5wB4D6yBvuDD8AAAki3FEvfIyQ3wFyqWwmJmDD+BghvwPkxpk76oOPEfI7QG6EO+qjqG+I1hm/A+REtwzqo4hviNYdvwPkxEchAaBG+CgkAEwwwh0AEkS4A0CCCHcASBDhDgAJItwBIEGEO8rFuCjVYxtMBL7EhPIwLkr12AYTgzN3lIdxUarHNpgYhDvKw7go1WMbTAy6ZVAexkWpHttgYjC2DADUCGPLAMAEI9wBIEGEOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBucLd9j7bT9tesn3rkPt32X7M9hdsP2n7huJLxVhgLHCwD9TCyIHDbE9JukvSmyQtSzpmeyEiTvY1e6+khyLio7b3SDoiaXYb6kWVGAsc7AO1kefMfa+kpYg4FRGrkh6QdGCgTUj60d71V0l6prgSMTYYCxzsA7WRZ8jfyySd7ltelnT1QJv3S/pH278t6UckXVdIdRgva2OBr521MRb45GEfqI084e4htw2OE3yTpE9ExJ/Zbkq6z/brIuL7P/RE9iFJhyRp165dm6kXVWIscLAP1MbI8dx7Yf3+iPil3vJtkhQR831tTkjaFxGne8unJF0TEc+t97yM5w4AG1fkeO7HJO22fYXtaUkHJS0MtPm6pLneC/+0pFdKWtlYyQCAoowM94g4I+kWSY9KekrZp2JO2L7D9v5es/dIepftL0m6X9LNUdUUTwCAfHOoRsQRZR9v7L/t9r7rJyVdW2xpAIDN4huqAJAgwh0AEkS4A0CCCHcASBDhDgAJItwBIEGEOwAkiHAHgAQR7gCQIMIdABJEuANAggj3ScP8l6ga+2Apcg0chkQw/yWqxj5YGs7cJwnzX6Jq7IOlIdwnydr8l1NTzH+JarAPloZumUnC/JeoGvtgaUbOobpdmEMVADauyDlUAQA1Q7gDQIIIdwBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJAgwh0AEkS4A0CCCHcASBDhDgAJItwBIEGEOwAkKFe4295n+2nbS7ZvXafNr9o+afuE7b8ptkwAwEaMnKzD9pSkuyS9SdKypGO2FyLiZF+b3ZJuk3RtRLxg+8e2q2AAwGh5ztz3SlqKiFMRsSrpAUkHBtq8S9JdEfGCJEXEc8WWCQDYiDzhfpmk033Ly73b+r1W0mtt/4vto7b3DXsi24dsd213V1ZWNlcxAGCkPOHuIbcNzs13vqTdklqSbpL017YvedmDIg5HRCMiGjMzMxutFZLU6Ujz89klMIk4BnLJM0H2sqSdfcuXS3pmSJujEfE9Sf9h+2llYX+skCqR6XSkuTlpdTWbOX5xkQmGMVk4BnLLc+Z+TNJu21fYnpZ0UNLCQJtPS/pFSbK9Q1k3zakiC4WyGeNXV6WzZ7PLdrvqioBycQzkNjLcI+KMpFskPSrpKUkPRcQJ23fY3t9r9qik522flPSYpN+PiOe3q+iJ1WplZytTU9llq1V1RUC5OAZyc8Rg93k5Go1GdLvdSl671jqd7Gyl1eLtKCbThB8Dto9HRGNkO8IdAOojb7gz/AAAJIhwB4AEEe4AkCDCHQASRLgDQIIIdwBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJAgwh0AEkS4A0CCCHcASBDhDgAJItwBIEGEOwAkiHAHgAQR7gCQIMIdABJEuJet05Hm57NLAOWbkGPw/KoLmCidjjQ3J62uStPT0uKi1GxWXRUwOSboGOTMvUztdrZTnT2bXbbbVVcETJYJOgYJ9zK1WtnZwtRUdtlqVV0RMFkm6BikW6ZMzWb2NrDdznaqRN8OAmNrgo5BR0QlL9xoNKLb7Vby2gBQV7aPR0RjVDu6ZQAgQYQ7ACSIcAeABOUKd9v7bD9te8n2redod6PtsD2yPwgAsH1GhrvtKUl3Sbpe0h5JN9neM6TdxZLeLemJoosEAGxMnjP3vZKWIuJURKxKekDSgSHt7pT0QUnfKbA+AMAm5An3yySd7lte7t32EttvkLQzIh4psDYAwCblCXcPue2lD8fbPk/ShyW9Z+QT2Ydsd213V1ZW8lcJANiQPOG+LGln3/Llkp7pW75Y0usktW3/p6RrJC0M+6dqRByOiEZENGZmZjZfNQDgnPKE+zFJu21fYXta0kFJC2t3RsSLEbEjImYjYlbSUUn7I4KvnwJARUaGe0SckXSLpEclPSXpoYg4YfsO2/u3u0AAwMblGjgsIo5IOjJw2+3rtG1tvSwAwFbwDVUASBDhDgAJItwBIEGEOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBhPtGdTrS/Hx2CWDy1CQDcg0chp5OR5qbk1ZXpelpaXFRajarrgpAWWqUAZy5b0S7nW3Us2ezy3a76ooAlKlGGUC4b0Srlf21nprKLlutqisCUKYaZQDdMhvRbGZvw9rtbKOO6dsxANukRhngiBjdahs0Go3odpmJDwA2wvbxiHjZHNWD6JYBgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHQASRLgDQIII9w2qySQsACYcQ/5uQI0mYQEw4XKdudveZ/tp20u2bx1y/+/ZPmn7SduLtl9TfKnVq9EkLAAm3Mhwtz0l6S5J10vaI+km23sGmn1BUiMiXi/pYUkfLLrQcVDEJCx06wAoQ55umb2SliLilCTZfkDSAUkn1xpExGN97Y9KemuRRY6LrU7CQrcOgLLkCffLJJ3uW16WdPU52r9T0meG3WH7kKRDkrRr166cJY6XZnPzgTysW4dwByZLp1POLH15wt1Dbhs6N5/tt0pqSHrjsPsj4rCkw1I2zV7OGpOx1q2zduZexdy6W92xytoxge1S5TFQ5rv3POG+LGln3/Llkp4ZbGT7Okl/KOmNEfHdYspLSxFz61a5YxWxY/LHBXUO160+vsx373nC/Zik3bavkPQNSQclvaW/ge03SPqYpH0R8VzhVSZkK906Ve9YW3181QfW2nNU/cel6hoI1+qOgTLfvY8M94g4Y/sWSY9KmpJ0T0ScsH2HpG5ELEj6kKSLJP2tbUn6ekTs376yJ1PVO9ZWH1/1gVV1sIxDDVU/vup9oOpjoIh373nl+hJTRByRdGTgttv7rl9XcF3bps5v66vesbb6+KoPrKqDZRxqqPrxVe8DVR8Da89RSvZERCU/V111VZTt8ccjLrwwYmoqu3z88dJL2LLHH4/4wAfqWXvE1uvfyuO3uv2L2H+qrqHqx689R1X7QAqU9ZiMzFhnbcvXaDSi2+2W+prz89IfvTd09vvW1HmhO//Euu22UktAxarurx6HGqp+/MTb4i/Q9vGIaIxsN0nh3jn8Zc39xk9qVRdoWt/T4se+quahny21BgATrIB/3OQN94kaFbL5/CNaPO/NulO3a/G8N6v5/CNVlwRgkpQ4QNVkjQrZaqn5ijvVXD3a+2/Mh6quCMAkKfGzkLUL9y11V5X5OSQAGFRiBtWqz52BtwBMuiT73BlPHQDyqVW4FzGeOgBMglr1udNlDgD51CrcpRK/ugsANVarbhkAQD6EOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHQASRLgDQIIIdwBIEOEOAAki3AEgQYQ7ACSIcAeABOUKd9v7bD9te8n2rUPuf4XtB3v3P2F7tuhCAQD5jQx321OS7pJ0vaQ9km6yvWeg2TslvRARPyXpw5L+tOhCAQD55Tlz3ytpKSJORcSqpAckHRhoc0DSvb3rD0uas+3iyuzT6Ujz89klAGCo83O0uUzS6b7lZUlXr9cmIs7YflHSqyV9s4giX9LpSHNz0uqqND0tLS5KzWahLwEAKchz5j7sDDw20Ua2D9nu2u6urKzkqe+HtdtZsJ89m1222xt/DgCYAHnCfVnSzr7lyyU9s14b2+dLepWk/xp8oog4HBGNiGjMzMxsvNpWKztjn5rKLlutjT8HAEyAPN0yxyTttn2FpG9IOijpLQNtFiS9XVJH0o2SPhsRLztz37JmM+uKabezYKdLBgCGGhnuvT70WyQ9KmlK0j0RccL2HZK6EbEg6eOS7rO9pOyM/eC2VdxsEuoAMEKeM3dFxBFJRwZuu73v+nck/UqxpQEANotvqAJAggh3AEgQ4Q4ACSLcASBBhDsAJMjb8XH0XC9sr0j6WiUvXqwdKnqYheqwLuOJdRlPVa3LayJi5LdAKwv3VNjuRkSj6jqKwLqMJ9ZlPI37utAtAwAJItwBIEGE+9YdrrqAArEu44l1GU9jvS70uQNAgjhzB4AEEe455Zgk/GbbK7a/2Pv59SrqHMX2Pbafs/2Vde637Y/01vNJ21eWXWNeOdalZfvFvm1y+7B248D2TtuP2X7K9gnbvzOkTS22Tc51qcW2sf1K2/9q+0u9dfnjIW1eYfvB3nZ5wvZs+ZUOERH8jPhRNtTxVyX9hKRpSV+StGegzc2S/rLqWnOsyy9IulLSV9a5/wZJn1E2u9Y1kp6ouuYtrEtL0iNV15lzXS6VdGXv+sWS/n3IPlaLbZNzXWqxbXq/64t61y+Q9ISkawba/Jaku3vXD0p6sOq6I4Iz95zyTBJeCxHxOQ2ZJavPAUmfjMxRSZfYvrSc6jYmx7rURkQ8GxGf713/H0lPKZubuF8ttk3OdamF3u/6f3uLF/R+Bv9ReUDSvb3rD0uasz1s6tFSEe75DJskfNjO+su9t8sP29455P46yLuuddHsvaX+jO2fqbqYPHpv69+g7CyxX+22zTnWRarJtrE9ZfuLkp6T9E8Rse52iYgzkl6U9Opyq3w5wj2fPBOA/4Ok2Yh4vaR/1g/+ktdNrsnOa+Lzyr6q/XOS/kLSpyuuZyTbF0n6O0m/GxHfGrx7yEPGdtuMWJfabJuIOBsRP69s/ui9tl830GQstwvhns/IScIj4vmI+G5v8a8kXVVSbUXLMyF6LUTEt9beUkc2m9gFtndUXNa6bF+gLAw/FRF/P6RJbbbNqHWp27aRpIj4b0ltSfsG7nppu9g+X9KrNAbdhYR7Pi9NEm57Wtk/TRb6Gwz0fe5X1s9YRwuS3tb7ZMY1kl6MiGerLmozbP/4Wt+n7b3K9vfnq61quF6dH5f0VET8+TrNarFt8qxLXbaN7Rnbl/SuXyjpOkn/NtBsQdLbe9dvlPTZ6P13tUq55lCddJFvkvB3294v6Yyyv9o3V1bwOdi+X9knFXbYXpb0PmX/JFJE3K1srtwbJC1J+rakd1RT6Wg51uVGSb9p+4yk/5N0cBwOunVcK+nXJH25178rSX8gaZdUu22TZ13qsm0ulXSv7Sllf4AeiohHBo79j0u6z/aSsmP/YHXl/gDfUAWABNEtAwAJItwBIEGEOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEjQ/wNYq0vqkvMD/QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss:  0.05813450366258621\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAEwhJREFUeJzt3X+wXGddx/H3l5SUGaGY0F+RcE2VohTHAb0WrkxLpm1sqQ5pFbCgEMZihgH+0RGNU4c/2nESYBRGZZQY0ICDgGBthAqkF4MyXrA3AkLLlKSl0LSZpp3WHx2cdpp+/WNPhntvzuZu7tm7e84+79fMnd29++w+57ln95Mnz3nOcyIzkSSV5Wnj3gBJ0ugZ/pJUIMNfkgpk+EtSgQx/SSqQ4S9JBTL8JalAhr8kFcjwl6QCnTHuDejn7LPPzk2bNo17MySpUw4ePPhwZp6zXLnWhv+mTZuYn58f92ZIUqdExHcHKeewjyQVyPCXpAIZ/pJUIMNfkgpk+EtSgQx/SSrQRIb/3Bzs3Nm7lSSdrLXz/Fdqbg4uvxyeeALWroXZWZiZGfdWSVK7DKXnHxFXRcRdEXE4InbUPH9mRHy8ev4rEbFpGPXWOXCgF/zHj/duDxxYrZokqbsah39ErAHeD7wSuAh4XURctKTY9cCjmfl84L3Au5rW28/mzbD2jOOsieOsPeM4mzevVk2S1F3D6PlfDBzOzHsy8wngY8DWJWW2Anur+58ELo+IGELdJ5lhjtm8nJt4J7N5OTM48C9JSw1jzP+5wH0LHh8BXtqvTGY+GRH/DTwHeHhhoYjYDmwHmJqaWtnWHDjAzPEvMZNfhONreuM+DvpL6oi5uV5sbd68utE1jPCv68HnCsqQmbuB3QDT09MnPT+QzZt7R3pPHPF13EdSR4xywsowhn2OAM9b8Hgj8EC/MhFxBvBs4JEh1H2ymZneX+ymm5zqI6lTRjlhZRg9/9uBCyPiAuB+4Drg9UvK7AO2AXPAq4EvZObKevaDmJkx9CV1zigHLhqHfzWG/3bgc8Aa4EOZeUdE3AjMZ+Y+4IPARyLiML0e/3VN65WkSXNi4GIUY/6xmh3wJqanp9OLuUjS6YmIg5k5vVy5iVzeQZJ0aoa/JBXI8JekAhn+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUCGvyQVyPCXpAIZ/pJUIMNfkgpk+EtSgQx/SSqQ4V9jbg527uzdStIkGsZlHCfKKC+gLEnjYs9/iVFeQFmSxsXwX+LEBZTXrFn9CyhL0rg47LPEKC+gLGnyzM11Iz8M/xozM+3eaZLaqUvHDB32kaQh6dIxQ8NfkoakS8cMHfap05VBO0mtMpRjhiPKH8N/qS4N2klqnUbHDEeYPw77LNWlQTtJk2WE+WP4L9WlQTtJk2WE+eOwz1JO9Jc0LiPMn8jMVXvzJqanp3N+fn7cmyFJnRIRBzNzerlyDvtIUoEahX9ErI+I/RFxqLpdV1PmxRExFxF3RMR/RsSvNqlTktRc057/DmA2My8EZqvHS30feGNmvgi4CnhfRPxww3olSQ00Df+twN7q/l7gmqUFMvPbmXmouv8AcAw4p2G9kqQGmob/eZl5FKC6PfdUhSPiYmAtcHef57dHxHxEzD/00EMNN02S1M+yUz0j4jbg/JqnbjidiiJiA/ARYFtmPlVXJjN3A7uhN9vndN5fkjS4ZcM/M6/o91xEPBgRGzLzaBXux/qUOwv4DPAHmfnlFW+tJGkomg777AO2Vfe3AbcsLRARa4GbgQ9n5t81rE+SNARNw38XsCUiDgFbqsdExHRE7KnKvBa4FHhTRHyt+nlxw3olSQ14hu8qcEVoSeMy6Bm+ru0zZK4ILakLXN5hyFwRWlIXGP5D5orQkrrAYZ8hc0VoqdtKOWZn+K+CRpdxkzQ2JR2zc9hHkiolHbMz/CWpUtIxO4d9JKlS0jE7w1+SFijlmJ3DPpJUIMNfkgpk+EtSgQx/SSqQ4b8a5uZg587eraSydOT772yfYSvpFEFJi3Xo+2/Pf9hKOkVQ0mId+v4b/sNW0imCkhbr0PffYZ9hK+kUQUmLdej772UcJWmCDHoZR4d9WqgjkwUkdZjDPi3TockCkjrMnn/LdGiygKQOM/xbpkOTBSR1mMM+LdOhyQJSK5VyDd6mDP8WKmU9cWnYPGY2OId9JE0Mj5kNzvCXNDE8ZjY4h30kTQyPmQ3O8Jc0UTxmNphGwz4RsT4i9kfEoep23SnKnhUR90fEnzWpU5LUXNMx/x3AbGZeCMxWj/u5Cfhiw/okSUPQNPy3Anur+3uBa+oKRcTPAucBn29YnyRpCJqG/3mZeRSguj13aYGIeBrwR8A7GtYlSRqSZQ/4RsRtwPk1T90wYB1vBW7NzPsiYrm6tgPbAaampgZ8e0nS6Vo2/DPzin7PRcSDEbEhM49GxAbgWE2xGeCSiHgr8ExgbUQ8lpknHR/IzN3Abuit5z9oIyRJp6fpVM99wDZgV3V7y9ICmflrJ+5HxJuA6brglySNTtMx/13Alog4BGypHhMR0xGxp+nGaWW8GIyk5XgZxwnjwlZS2byMY6Fc2ErSIAz/CePCVpIG4do+E8aFrSQNwvCfQC5spS7zSlyjYfi3kZ9+FaoVExYK+f4Z/m3Tik+/NB51ExZG+vEv6PvnAd+2cbqOCjb2CQsFff/s+bfNiU//iZ6H03VUkLFPWCjo++dJXm1UyJij1Eod//4NepKX4S9JE8QzfCVJfRn+klQgw1+SCmT46yQuCS1NPqd6apGCznGRimbPX4sUdI6LVDTDX4uM/QxLSSPhsI8WGfsZlpJGwvDXSVwSWk10/ATZYhj+kobGCQPd4Zi/pKFxwkB3GP6ShsYJA93hsI+koXHCQHcY/pKGygkD3eCwjyQVyPDX0Lk2kNR+DvtoqJzqJ3WDPX8NlVP9pG4w/DVUTvWTusFhHw2VU/2kbmgU/hGxHvg4sAm4F3htZj5aU24K2AM8D0jg6sy8t0ndOoUxL67iVL9u6/zaPJ1vwGg07fnvAGYzc1dE7Kge/15NuQ8Df5iZ+yPimcBTDetVPx5xVQOd//h0vgGj03TMfyuwt7q/F7hmaYGIuAg4IzP3A2TmY5n5/Yb1qh+PuKqBzn98Ot+A0Wka/udl5lGA6vbcmjIvAP4rIv4+Ir4aEe+JiDUN61U/HnFVA53/+HS+AaOz7LBPRNwGnF/z1A2nUcclwEuA79E7RvAm4IM1dW0HtgNMTU0N+PZaZAKOuDpkOz6d//h0vgGjE5m58hdH3AVszsyjEbEBOJCZP7GkzMuAXZm5uXr8BuBlmfm2U7339PR0zs/Pr3jb1E0O2UrNRMTBzJxerlzTYZ99wLbq/jbglpoytwPrIuKc6vFlwJ0N69WEcshWGo2m4b8L2BIRh4At1WMiYjoi9gBk5nHgd4DZiPgGEMBfNqxXE8ohW2k0Gg37rCaHfcrlmL+0coMO+3iGr1rHk8Sa8R9PDcLwlyaIB8w1KBd208Qp+XoCHjDXoOz5a6KU3vM9ccD8RPs9YK5+DH9NlLqeb0nh7zlOGpThr4kyCT3fpgdsPWCuQRj+mijD6PmOc7ZM6cNWGh3DXxOnSc93GOHb5B+P0oetNDqGv7RA0/Bt+o/HJAxbqRsMf2mBpuHb9B8PD9hqVAx/aYGm4TuMnrsHbDUKhr+0RJPwteeurjD8pSGz564ucHkHSSqQ4a+Tlbw4jsbPz99IOOyjxTzLSOPk529k7PlrMZeF1Dj5+RsZw1+LeR1FjZOfv5Fx2EeLOVdR4+Tnb2S8hq8kTZBBr+HrsI8kFcjwl6QCGf6SVCDDX5IKZPhLUoEMf0kqkOEvSQUy/CWpQIa/JBXI8JekAjUK/4hYHxH7I+JQdbuuT7l3R8QdEfGtiPiTiIgm9UqSmmna898BzGbmhcBs9XiRiPh54OXATwM/Bfwc8IqG9UqSGmga/luBvdX9vcA1NWUSeAawFjgTeDrwYMN6JUkNNA3/8zLzKEB1e+7SApk5B/wzcLT6+VxmfqthvZKkBpZdzz8ibgPOr3nqhkEqiIjnAy8ENla/2h8Rl2bmv9SU3Q5sB5iamhrk7dVGc3Oux14y938nLBv+mXlFv+ci4sGI2JCZRyNiA3Cspti1wJcz87HqNf8EvAw4KfwzczewG3rr+Q/WBLWK12Atm/u/M5oO++wDtlX3twG31JT5HvCKiDgjIp5O72Cvwz6Tymuwls393xlNw38XsCUiDgFbqsdExHRE7KnKfBK4G/gG8HXg65n5jw3rVVt5Ddayuf87w8s4avgc8y2b+3+sBr2Mo+EvSRPEa/hKkvoy/CWpQIa/JBXI8JekAhn+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUCGv9pnbg527uzdavT8+xdh2fX8pZFyPfjx8u9fDHv+ahfXgx8v//7FMPzVLq4HP17+/YvhsI/aZWamN9TgevDj4d+/GK7nL0kTxPX8JUl9Gf6SVCDDX5IKZPhLUoEMf0kqkOGvyVP68gSlt18DcZ6/JkvpyxOU3n4NzJ6/JkvpyxOU3n4NzPDXZCl9eYLS26+BOeyjyVL68gSlt18Dc3kHSZogLu8gSerL8JekAjUK/4h4TUTcERFPRUTf/2ZExFURcVdEHI6IHU3qlFbduOfJj7t+FaHpAd9vAr8MfKBfgYhYA7wf2AIcAW6PiH2ZeWfDuqXhG/c8+XHXr2I06vln5rcy865lil0MHM7MezLzCeBjwNYm9UqrZhjz5Jv03J2nrxEZxVTP5wL3LXh8BHjpCOqVTt+JefInet6nO0++ac+9af3SgJYN/4i4DTi/5qkbMvOWAeqImt/Vzi+NiO3AdoCpqakB3loasqbz5Ot67qfzHs7T14gsG/6ZeUXDOo4Az1vweCPwQJ+6dgO7oTfPv2G90srMzKw8dIfRc29SvzSgUQz73A5cGBEXAPcD1wGvH0G90ujZc1dHNAr/iLgW+FPgHOAzEfG1zLwyIn4E2JOZV2fmkxHxduBzwBrgQ5l5R+Mtl9rKnrs6oFH4Z+bNwM01v38AuHrB41uBW5vUJUkaHs/wlaQCGf6SVCDDX5IKZPhLUoEMf0kqUGsv5hIRDwHfHfd21DgbeHjcGzEktqWdbEs7daUtP5qZ5yxXqLXh31YRMT/IVXK6wLa0k21pp0lqCzjsI0lFMvwlqUCG/+nbPe4NGCLb0k62pZ0mqS2O+UtSiez5S1KBDP8FImJNRHw1Ij5dPf7XiPha9fNARPxDn9cdX1Bu32i3ul5NWy6PiP+otvFLEfH8Pq/7/Yg4HBF3RcSVo93qeitpS0Rsioj/W7Bf/mL0W36ymrZcVrXlmxGxNyJqF1uMiG0Rcaj62Tbara7XoC1t/L7cGxHfqLZpvvrd+ojYX/3N90fEuj6vbd2+GUhm+lP9AL8NfBT4dM1znwLe2Od1j41725drC/Bt4IXV/bcCf13zmouArwNnAhcAdwNrOtqWTcA3x73tp2oLvc7XfcALquduBK6vec164J7qdl11f10X21I918bvy73A2Ut+925gR3V/B/CuruybQX7s+VciYiPwi8CemueeBVwG1Pb826ZPWxI4q7r/bOqvprYV+FhmPp6Z3wEOAxev5rYup0FbWqemLc8BHs/Mb1eP9wO/UvPSK4H9mflIZj5albtqtbf3VBq0pUu2Anur+3uBa2rKtG7fDMrw/4H3Ab8LPFXz3LXAbGb+T5/XPiMi5iPiyxFR9wEZtbq2vBm4NSKOAG8AdtW87rn0em8nHKl+N04rbQvABdWwxBcj4pJV3s5BLG3Lw8DTI+LEiUOvZvElT0/own4ZtC3Qvu8L9DoUn4+Ig9W1xAHOy8yjANXtuTWva+O+GYjhD0TELwHHMvNgnyKvA/72FG8xlb0z/14PvC8ifnzY2zioU7Tlt4CrM3Mj8FfAH9e9vOZ3Y5sO1rAtR+ntl5dQDU9ExFk15Uairi3ZGze4DnhvRPw78L/Ak3Uvr/ldq/bLabQFWvR9WeDlmfkzwCuBt0XEpQO+rlX75nSM4hq+XfBy4FURcTXwDOCsiPibzPz1iHgOvaGPa/u9OHtXLiMz74mIA8BL6I2Xj0NdWz4D/GRmfqUq83HgszWvPcLi3tpGxjuksuK2ZObjwOPV/YMRcTfwAmB+JFt+sr6fMeASgIj4hWoblzoCbF7weCNwYFW39tSatKVt35el23QsIm6m951/MCI2ZObRiNgAHKt5adv2zeDGfdChbT/0duSnFzx+C7D3FOXXAWdW988GDgEXjbsdC9tC7x/5h/nBwbjrgU/VlH8Riw/43kMLDviusC3nnNh24MeA+4H1427H0s8YcG51eyYwC1xWU3498J3qs7auut/VtrTu+wL8EPCsBff/jd64/XtYfMD33V3aN8v92PNf3nUsGVOuxjXfkplvBl4IfCAinqI3jLYrM+8c/Wb2l5lPRsRvAp+qtvNR4DcAIuJVwHRmvjMz74iITwB30vsv+9sy8/jYNrzGoG0BLgVujIgngeP09tcj49ruU3hHNYzyNODPM/MLsPgzlpmPRMRNwO3Va27salto5/flPODmiIBe5+KjmfnZiLgd+EREXA98D3gNdHbfnMQzfCWpQB7wlaQCGf6SVCDDX5IKZPhLUoEMf0kqkOEvSQUy/CWpQIa/JBXo/wEHJ9sO4FQk+QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss:  0.005655461922287941\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD8CAYAAACW/ATfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAEvNJREFUeJzt3W+MXFd5x/HvkzUGQSmg2Aia2CS0jkSavqBdAtsUtNSATFTZbUUhQVVLG2GBFPoPWqWiSlH8IvwRQryIaF2IKFQlpFVbWchVoIaFiq6p10AS7JBiXEM2rogJCAlFZHHy9MXMhmEz670zOzt3zp7vR1rtzsy9O+fsmfntuWfOPTcyE0nS5nZR2wWQJG08w16SKmDYS1IFDHtJqoBhL0kVMOwlqQKGvSRVwLCXpAoY9pJUgS1tPfG2bdvysssua+vpJalIx48f/25mbh90v9bC/rLLLmNhYaGtp5ekIkXEt4bZz2EcSaqAYS9JFTDsJakChr0kVcCwl6QKGPaSVAHDXlWZn4dbb+18b2N/qS2tzbOXxm1+HnbvhqUl2LoVjhyBmZnx7S+1yZ69qjE31wnqxx7rfJ+bG+/+UpsMe1VjdrbTI5+a6nyfnR3v/lKbHMZRUebnOz3q2dnBh1BmZjpDL23tL7UpMrOVJ56enk7XxtEgNsuY+Xr+YUkRcTwzpwfdz569itFvzLy0sNws/7BUHsfsVYzNMGbuh7xqiz17FWMzjJkv/8Na7tmX+A9LZTLsVZSZmTJDftlm+IelMq0Z9hFxO/AbwEOZeVWfxwP4IHAt8Ajwpsz88qgLKm0Wpf/DUpmajNl/FNhzgcdfC+zqfu0HPrT+YkmSRmnNsM/MLwDfu8Am+4CPZcdR4NkR8fxRFVCbS/FryxRfAdVqFGP2lwAP9Nxe7N73fyP43dpEip92WHwFVLNRTL2MPvf1PVMrIvZHxEJELJw7d24ET62STMS0w/X0zEdVgXUeHXhwoWGMome/COzouX0pcLbfhpl5EDgInTNoR/DcKkjr0w7X2zMfRQXWWQYPLjSsUfTsDwG/Fx0vA36QmQ7h6EmWpx0eONBSSK23Zz6KCqyzDBNxdKQiNZl6+QlgFtgWEYvAXwNPAcjMvwEO05l2eYrO1Ms/2KjCqnytTjscRc98vRVYZxlaPzpSsVwITXWZhFXI1lmGSaiC2jPsQmiGvSQVZNiwdyE0SaqAYa+yOO/Qv4GG4kJoKofzDv0baGj27DWQVjuVzjv0b6Ch2bNXY613Kp136N9AQzPs1VjrlwV0MXj/BhqaYa/GJqJT6WLw/g00FMNejdmp3Bw8KatOhr0GYqeybK1/7qLWOBtHqoiTeepl2EsVWf7cZWrKyTy1cRhHqoifu9TLsNd4+elg62aYZ4Y5OiuX2wa1MOw1Pn462D7boFqO2Wt8/HSwfbZBtQx7jY+fDrbPNqiWwzgaHz8dbJ9tUC2vVCVJBfFKVWrE615IdXIYpyJOxJDqZc++Ik7EkOpl2FfEiRhSvRzGqYgTMaR6GfaVcYliqU4O40hSBQx7SaqAYS9pIJ6rUSbH7CU15rka5WrUs4+IPRFxf0Scioib+jy+MyI+FxFfiYh7IuLa0RdVE8FuXdXm5mDp0eycq/Foeq5GQdbs2UfEFHAb8GpgETgWEYcy82TPZn8F3JmZH4qIK4HDwGUbUF61yW5d9WYvvpetj/88SzyFrY//mNmLvwn8UtvFUgNNevZXA6cy83RmLgF3APtWbJPAz3Z/fhZwdnRF1MTwFNzqzTz8KY5c9BoOcDNHLnoNMw9/qu0iqaEmY/aXAA/03F4EXrpim3cBn46ItwHPAF7V7xdFxH5gP8DOnTsHLavatnwK7nLP3lNw6zM7y8xTDzCzdLT7Gnhf2yVSQ03CPvrct3Jd5OuBj2bm+yNiBvh4RFyVmY//1E6ZB4GD0FnieJgCq0WegitfA8VqEvaLwI6e25fy5GGaG4A9AJk5HxFPA7YBD42ikJognoIrXwNFajJmfwzYFRGXR8RW4Drg0Iptvg3sBoiIFwFPA86NsqCSpOGtGfaZeR64EbgLuI/OrJsTEXFLROztbvZ24M0RcTfwCeBN2dYlsCRJT9LopKrMPExnOmXvfTf3/HwSuGa0RVM/8/MOl0oanGfQFsRp7pKG5do4BXGau6RhGfYF8UpTkoblME5BnOIsaViGfWGc4ixpGA7jSFIFDHtJqoBhL0kVMOwlqQKGfW280pTa5muwFc7GqYmn4KptvgZbY8++Jp6Cq7b5GmyNPfuaeKUptW12lvmpX2Pu8WuYnfoiM74Gx8awr4mn4Kpl88ywO46wRLA1kiNM4atwPAz72ngKrlo0NwdL56d4LGHpfOe2L8fxcMxe0ti4mF977NlLGhtHEttj2EsaK0cS2+EwjiRVwLCXpAoY9mPmmeKS2uCY/Rh5prikttizHyPPFJfUFsN+jJxjLKktDuOMkXOMJbXFsB8z5xhLaoPDOJJUAcNekirQKOwjYk9E3B8RpyLiplW2eX1EnIyIExHxj6MtpiRpPdYcs4+IKeA24NXAInAsIg5l5smebXYBfwlck5nfj4jnblSBJUmDa9Kzvxo4lZmnM3MJuAPYt2KbNwO3Zeb3ATLzodEWU5K0Hk3C/hLggZ7bi937el0BXBERX4yIoxGxp98vioj9EbEQEQvnzp0brsSSpIE1Cfvoc1+uuL0F2AXMAtcDH46IZz9pp8yDmTmdmdPbt28ftKwCF9eRfA8Mpck8+0VgR8/tS4GzfbY5mpk/Bv43Iu6nE/7HRlJKdbi4jmrne2BoTXr2x4BdEXF5RGwFrgMOrdjm34BXAkTENjrDOqdHWVDh4jqS74GhrRn2mXkeuBG4C7gPuDMzT0TELRGxt7vZXcDDEXES+Bzw55n58EYVulourqPa+R4YWmSuHH4fj+np6VxYWGjluYs2P+/iOqpb5e+BiDiemdMD72fYS1I5hg17l0uQpAoY9pJUAcNeUlGcZj8c17OXVAyn2Q/Pnr2kYjjNfniG/YA8hJTa4zT74TmMMwAPIaV2eR3n4Rn2A+h3COmLTRovr+M8HIdxBuAhpKRS2bMfgIeQkkpl2A/IQ0hJJXIYR5IqYNhLUgUMe0mqgGEvSRUw7CWpAoa9JFXAsJekChj2klQBw37cXDZTalel70HPoB0nl82U2lXxe9Ce/Th55QWpXRW/Bw37cXLZTKldFb8HHcYZJ5fNlNpV8XswMrOVJ56ens6FhYVWnluSShURxzNzetD9HMaRpAoY9pJUgUZhHxF7IuL+iDgVETddYLvXRURGxMCHGJKkjbNm2EfEFHAb8FrgSuD6iLiyz3bPBP4I+NKoCylJWp8mPfurgVOZeTozl4A7gH19tjsAvBf40QjLJ0kagSZhfwnwQM/txe59T4iIFwM7MvNTIyybJGlEmoR99LnvifmaEXER8AHg7Wv+ooj9EbEQEQvnzp1rXsoRqnRZDEmVa3JS1SKwo+f2pcDZntvPBK4C5iIC4HnAoYjYm5k/NZE+Mw8CB6Ezz34d5R5KxctiSOqan6/ynKpGYX8M2BURlwMPAtcBb1x+MDN/AGxbvh0Rc8A7Vgb9JOi3LEZNjS3VruYO35rDOJl5HrgRuAu4D7gzM09ExC0RsXejCzhKFS+LIYmq10FrtjZOZh4GDq+47+ZVtp1df7E2RsXLYkjiJx2+5Z59TR2+6hZCm5kx5KVa1dzhqy7sJdWt1g6fa+NIUgUMe0mqgGEvSRUw7CWpAoa9JFXAsJekChj2klQBw35QLpsp1a3QDPCkqkHUvIqSpKIzwJ79IGpeRUlS0Rlg2A/CZTOluhWcAQ7jDKLmVZQkFZ0BkTn2C0YBnStVLSxM3PVNJGmiRcTxzJwedD+HcSSpAoa9JFXAsJekChj2klQBw16SKmDYS1IFDHtJqoBhL0kVMOwlqQLFhX2hq4tKUquKWhun4NVFJalVRfXsC15dVJJaVVTYF7y6qCS1qlHYR8SeiLg/Ik5FxE19Hv+ziDgZEfdExJGIeMHoi/qT1UUPHHAIR5IGseaYfURMAbcBrwYWgWMRcSgzT/Zs9hVgOjMfiYi3Au8F3rARBZ6ZMeQlaVBNevZXA6cy83RmLgF3APt6N8jMz2XmI92bR4FLR1tMSZoMpc4IbDIb5xLggZ7bi8BLL7D9DcC/r6dQkjSJSp4R2KRnH33u63t5q4j4XWAaeN8qj++PiIWIWDh37lzzUkrSBCh5RmCTsF8EdvTcvhQ4u3KjiHgV8E5gb2Y+2u8XZebBzJzOzOnt27cPU15Jak3JMwKbDOMcA3ZFxOXAg8B1wBt7N4iIFwN/C+zJzIdGXkpJmgAFX2987bDPzPMRcSNwFzAF3J6ZJyLiFmAhMw/RGbb5GeCfIgLg25m5dwPLLUmtKHVGYKPlEjLzMHB4xX039/z8qhGXS5I0QkWdQStJGk59YV/qJFlJm0NLGVTUqpfrVvIkWUnlazGD6urZlzxJVlL5WsygusK+5EmyksrXYgbVNYxT8iRZSeVrMYMis+/KBxtueno6FxYWWnluSSpVRBzPzOlB96trGEeSKmXYS1IFDHtJqoBhL0kVMOwlqQKGvSRVwLCXpAoY9pJUAcNekipg2EtSBQx7SaqAYS9JFTDsJakChr0kVcCwl6QKGPaSVAHDXpIqYNhLUgUMe0mqgGEvSRUw7CVpjObn4dZbO9/HaUuTjSJiD/BBYAr4cGa+e8XjTwU+BvwK8DDwhsw8M9qiSlLZ5udh925YWoKtW+HIEZiZGc9zr9mzj4gp4DbgtcCVwPURceWKzW4Avp+ZvwB8AHjPqAsqSaWbm+sE/WOPdb7PzY3vuZsM41wNnMrM05m5BNwB7FuxzT7g77s//zOwOyJidMXs0dYxkCSt0+xsp0c/NdX5Pjs7vuduMoxzCfBAz+1F4KWrbZOZ5yPiB8DFwHdHUcgntHkMJEnrNDPTia25uU7QjzO+moR9vx56DrENEbEf2A+wc+fOBk+9Qr9jIMNeUkFmZtqJrSbDOIvAjp7blwJnV9smIrYAzwK+t/IXZebBzJzOzOnt27cPXto2j4EkqWBNevbHgF0RcTnwIHAd8MYV2xwCfh+YB14HfDYzn9SzX7c2j4EkqWBrhn13DP5G4C46Uy9vz8wTEXELsJCZh4CPAB+PiFN0evTXbViJ2zoGkqSCNZpnn5mHgcMr7ru55+cfAb8z2qJJkkbFM2glqQKGvSRVwLCXpAoY9pJUAcNekioQGzEdvtETR5wDvrXBT7ONUS/Z0B7rMpmsy2TazHV5QWYOfFZqa2E/DhGxkJnTbZdjFKzLZLIuk8m6PJnDOJJUAcNekiqw2cP+YNsFGCHrMpmsy2SyLits6jF7SVLHZu/ZS5IoNOwj4o8j4msRcSIi/mTFY++IiIyIbavs+1hEfLX7dWg8JV5dv7pExLsi4sGecl67yr57IuL+iDgVETeNt+R9y7OeupyJiHu72yyMt+R9y9P3NRYRb+v+zU9ExHtX2Xfi26V7f5O6THy7RMQne15fZyLiq6vsO/HtMkBdBm+XzCzqC7gK+BrwdDqrdv4HsKv72A46SzF/C9i2yv4/bLsOa9UFeBfwjjX2nQK+CbwQ2ArcDVxZYl26+59Zrc0mqC6v7P781O52zy24XdasSyntsmKb9wM3l9ouTeoybLuU2LN/EXA0Mx/JzPPA54Hf6j72AeAv6HNJxAl1obqspcmF4MdpPXWZNKvV5a3AuzPzUYDMfKjPvqW0S5O6TJoLvsYiIoDXA5/os28p7QKsWZehlBj2XwNeEREXR8TTgWuBHRGxF3gwM+9eY/+nRcRCRByNiN/c8NJeWN+6dB+7MSLuiYjbI+I5ffbtdyH4Sza2uBe0nrpA5x/0pyPiePdaxW1arS5XAC+PiC9FxOcj4iV99i2lXZrUBcpol2UvB76Tmd/os28p7bLsQnWBIdql0cVLJklm3hcR7wE+A/yQzuHYeeCdwGsa/IqdmXk2Il4IfDYi7s3Mb25ciVd3gbp8CDhAp0EP0Dmc+8MVuze6yPu4rLMuANd02+W5wGci4uuZ+YXxlP6nXaAuW4DnAC8DXgLcGREvzO5xdVcp7dKkLlBGuyy7ntV7wqW0y7IL1QWGaJcSe/Zk5kcy85cz8xV0LoN4BrgcuDsiztC5KPqXI+J5ffY92/1+GpgDXjymYvfVpy7fyMzvZOZjmfk48Hd0DkFXanIh+LFaR1162+Uh4F9X225c+tWFzt/8X7Ljv4HH6axb0quIdqFZXUppFyJiC/DbwCdX2bWUdmlSl+HaZVQfOIzzi+6HScBO4OvAc1Y8foY+H17Q6cksfyC1rfvHbe1DmtXqAjy/5/E/Be7os98W4DSdf3LLHzj9YqF1eQbwzJ6f/wvYM4F1eQtwS/f+K+gMC0Sh7dKkLkW0S/f2HuDzF9iviHZpWJeh2qW1iq7zj/SfwMlug+3u8/gTYQ9MAx/u/vyrwL3d/e4FbpjEugAf75bvHuDQcmACPwcc7tn3WuB/6MwyeGepdaEzQ+Lu7teJCa7LVuAf6Iy3fhn49YLbZc26lNIu3fs/CrxlxbbFtUuTugzbLp5BK0kVKHLMXpI0GMNekipg2EtSBQx7SaqAYS9JFTDsJakChr0kVcCwl6QK/D9o6TzCXL0iKgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss:  0.0006452086381614208\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFBtJREFUeJzt3X+MbGV9x/H3l4WlaQ3lN1yB66VKorZQqys6VmTqhVRt0wstWqqRa4q5aY1N/IMWCNU0JebSmkZqtWlvUQGbFhWLUMFaWJ1q4kBZWvSKihct6oUbLiK0klpWLt/+MWd1WWZ3Z/bMzs7M834lmzM/njnPeebMfObsc55zTmQmkqSyHLTRCyBJGj7DX5IKZPhLUoEMf0kqkOEvSQUy/CWpQIa/JBXI8JekAhn+klSggzd6AZZz9NFH55YtWzZ6MSRprNx1113fy8xjVis3suG/ZcsW5ubmNnoxJGmsRMS3eylnt48kFcjwl6QCGf6SVCDDX5IKZPhLUoEGEv4R8ZqIuDci7ouIS7o8f2hEfLR6/o6I2DKIeiVJa1M7/CNiCvgA8FrghcDvRMQLlxS7EHg0M58HvBf4s7r1rqjdhp07O1NJGiPDiq9BjPM/HbgvM78FEBHXAduAry4qsw34k+r29cD7IyJyPa4h2W7D1q0wPw/T0zA7C43GwKuRpEEbZnwNotvnBOC7i+7vrR7rWiYznwT+Gzhq6YwiYkdEzEXE3MMPP7y2pWm1Ou/cgQOdaau1tvlI0pANM74GEf7R5bGlW/S9lCEzd2XmTGbOHHPMqkcnd9dsdn4yp6Y602ZzbfORpCEbZnwNottnL3DSovsnAg8uU2ZvRBwM/Czw/QHU/UyNRud/pVar887Z5SNpTAwzvgYR/ncCp0TEycADwPnAG5eUuQnYDrSB84DPrkt//4JGw9CXNJaGFV+1wz8zn4yItwOfAaaAD2XmPRHxp8BcZt4EfBD4SETcR2eL//y69UqS1m4gZ/XMzFuAW5Y89q5Ft/8PeP0g6pIk1ecRvpJUIMNfkgpk+EtSgQx/SSqQ4S9JBTL8JalAhr8kFcjwl6QCGf6SVCDDX5IKZPhLUoEMf0kqkOEvSQUy/CWpQIa/JBXI8JekAhn+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUATGf7tNuzc2ZlKkp7p4I1egEFrt2HrVpifh+lpmJ2FRmMNM2m1oNlcw4slqYYh5c/EhX+r1Qn+Awc601arz/dvIL8ekrQGQ8yfiev2aTY779nUVGfabPY5g26/HpI0DEPMn4nb8m80Oj+Wa/6vaeHXY+GXt+9fD0laoyHmT2Tmus28jpmZmZybm9uYyu3zl7RGteOj5gwi4q7MnFm1nOEvSYMxCrsMew3/ievzl6SNMk67DGuFf0QcGRG3RsSeanpElzIvioh2RNwTEV+OiN+uU6ckjaraA06GqO6W/yXAbGaeAsxW95f6X+CCzPx54DXAlRFxeM16JWnkLAw4ufzy0R8lXne0zzagWd2+BmgBFy8ukJnfWHT7wYjYDxwDPFazbkkaOY3GaIf+grpb/sdl5j6AanrsSoUj4nRgGvhmzXolSTWsuuUfEbcBx3d56rJ+KoqITcBHgO2Z+dQyZXYAOwA2b97cz+wlSX1YNfwz86zlnouIhyJiU2buq8J9/zLlDgNuBv44M29foa5dwC7oDPVcbdkkSWtTt9vnJmB7dXs7cOPSAhExDdwAXJuZH69ZnyRpAOqG/xXA2RGxBzi7uk9EzETEVVWZNwCvAt4SEXdXfy+qWa8kqQaP8JWkCeIRvpKkZRn+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUCGvyQVyPCXpAIZ/pJUIMNfkgpk+EtSgQx/SSqQ4S9JBTL810O7DTt3dqaSyjIm3/9VL+NYonYbWi1oNqHRWMOLt26F+XmYnobZ2TXMRNJYGqPvv1v+Syysu3e+szPt+8e71eqs+AMHOtNWax2WUtJIGqPvv+G/RO1112x2fvGnpjrTZnPgyyhpRI3R999unyUW1t3Cf219r7tGo/Ov3pr7jSSNrTH6/nsN3y5q9flL0gbq9Rq+bvl30WgY+pImm33+krTImIzUrM0tf0mqjNFIzdrc8pekyhiN1KzN8JekyhiN1KzNbh9JqozRSM3aDH9JWqSU0X52+0hSgQx/SSqQ4S9JBTL8JalAtcI/Io6MiFsjYk81PWKFsodFxAMR8f46dUqS6qu75X8JMJuZpwCz1f3lXA78W836JEkDUDf8twHXVLevAc7pVigiXgIcB/xrzfokSQNQN/yPy8x9ANX02KUFIuIg4C+AP1xtZhGxIyLmImLu4YcfrrlokqTlrHqQV0TcBhzf5anLeqzjbcAtmfndiFixYGbuAnZB53z+Pc5fktSnVcM/M89a7rmIeCgiNmXmvojYBOzvUqwBnBERbwOeBUxHxOOZudL+AUnSOqp7eoebgO3AFdX0xqUFMvNNC7cj4i3AjMEvSRurbp//FcDZEbEHOLu6T0TMRMRVdRdOkrQ+vIavJE2QXq/h6xG+klQgw1+SCmT4S1KBDP910G7Dzp2d6cbMQNKaFfL980peA9Zuw9atnYs/T093LgnX11WBas9A0poV9P1zy3/AWq3O5+bAgc601Rr2DCStWUHfP8N/wJrNzgbD1FRn2mwOewaS1qyg75/dPgPWaHT+U2y1Op+bvv9jrD0DSWtW0PfPg7wkaYJ4kJckaVmGvyQVyPCXNFEKGaZfmzt8JU2Mgobp1+aWv6SJUdAw/doMf0kTo6Bh+rXZ7SNpYhQ0TL82w1/SRGk0DP1e2O0jSQUy/CWpQIa/JBXI8JekAhn+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUCGvyQVyPAfQV6MQtJ688RuI2YgF6Notz2tocrl578ntcI/Io4EPgpsAe4H3pCZj3Yptxm4CjgJSOB1mXl/nbonVbeLUfT1+fVSRiqZn/+e1e32uQSYzcxTgNnqfjfXAu/JzBcApwP7a9Y7sWpfjMJLGalkfv57VrfbZxvQrG5fA7SAixcXiIgXAgdn5q0Amfl4zTonWu2LUSz8eixs+XgpI5XEz3/PIjPX/uKIxzLz8EX3H83MI5aUOQd4KzAPnAzcBlySmQdWmvfMzEzOzc2tedmKZp+nSlb45z8i7srMmdXKrbrlHxG3Acd3eeqyHpflYOAM4JeA79DZR/AW4INd6toB7ADYvHlzj7PXM3gpI5XMz39PVg3/zDxrueci4qGI2JSZ+yJiE9378vcC/5mZ36pe80ng5XQJ/8zcBeyCzpZ/b02QJPWr7g7fm4Dt1e3twI1dytwJHBERx1T3Xw18tWa9kqQa6ob/FcDZEbEHOLu6T0TMRMRVAFXf/kXAbETsBgL4u5r1SpJqqDXaJzMfAbZ2eXyOzk7ehfu3AqfVqUuSNDie3kHSSPH0JsPh6R0kjQwP0B0et/wljQwP0B0ew1/SyKh9ehP1zG4fSSOj9ulN1DPDX9JI8QDd4bDbR5IKZPhLUoEMf0kqkOEvSQUy/CeQR0hKWo2jfSaMR0hK6oVb/hPGIyQl9cLwnzAeISmpF3b7TJiBHCFZ+DVQtcH8/A2F4T+Bah0h6U4DbSQ/f0Njt4+ezp0G2kh+/obG8NfTudNAG8nP39DY7aOn87SK2kh+/oYmMnOjl6GrmZmZnJub2+jFkKSxEhF3ZebMauXs9pGkAhn+klQgw1+SCmT4SxooTyw4HhztI2lgPEZrfLjlL2lgPEZrfBj+kgbGY7TGh90+kgbGY7TGh+EvaaBqnVhQQ2O3j57B0RrS5HPLX0/jaA2pDLW2/CPiyIi4NSL2VNMjlin35xFxT0R8LSLeFxFRp16tH0drSGWo2+1zCTCbmacAs9X9p4mIVwC/DJwG/ALwUuDMmvVqnThaQypD3W6fbUCzun0N0AIuXlImgZ8CpoEADgEeqlmv1omjNaQy1A3/4zJzH0Bm7ouIY5cWyMx2RHwO2Ecn/N+fmV/rNrOI2AHsANi8eXPNRdNa1R6t4TVYy+b6Hwurhn9E3AYc3+Wpy3qpICKeB7wAOLF66NaIeFVmfn5p2czcBeyCzvn8e5m/Rox7jMvm+h8bq4Z/Zp613HMR8VBEbKq2+jcB+7sUOxe4PTMfr17zaeDlwDPCXxOg2x5jv/zlcP2Pjbo7fG8Ctle3twM3dinzHeDMiDg4Ig6hs7O3a7ePJoB7jMvm+h8bdfv8rwA+FhEX0gn51wNExAzwe5n5VuB64NXAbjo7f/8lM/+5Zr0aVe4xLpvrf2x4DV9JmiBew1eStCzDX5IKZPhLehpP7FcGT+wm6cccpl8Ot/wl/Zgn9iuH4a+Bs9tgfDlMvxx2+2ig7DYYbw7TL4fhr4Hy6P7x52UYy2C3jwbKbgNpPLjlr4Gy20AaD4a/Bs7rAYw53/8iGP4aLe4x3li+/8Wwz1+jxYHmG8v3vxiGv0aLe4w3lu9/Mez20Whxj/HG8v0vhufzl6QJ4vn8pUJ5eg31wm4fjRxHGq6dg3XUK8NfI8XwqsfTa6hXdvtopDjSsB4H66hXbvlrpCyE18KWv+HVHwfrqFeGv0aK4VWfZ+VULwx/jRzPDVRT6e1XTwx/TZbS9xiX3n71zB2+miyl7zEuvf3qmeGvyTIBw11qHaQ1Ae3XcNjto8nSaNC+8g5an3iE5m8dRaNx6kYvUV9q99q4x1w9Mvw1Udpt2PqOUzvh+QWYPXW88m8gB2k53Ec9sNtHE2Xcu7zttdGwuOWviTKQg8Q2cKhkowGzV+4e224rjY9aW/4R8fqIuCcinoqIZU8hGhGviYh7I+K+iLikTp3SSha6vC+/fI2jHNtt2s1L2XnZ47Sbl65pr2utHbbtNo13vIxLZ8+i8Y6XeWpOrZu6W/5fAX4T+NvlCkTEFPAB4GxgL3BnRNyUmV+tWbfUVZ0u7/a1e9g6fwvzTDM9P8/stdfT6GNmtXfYemY2DUmtLf/M/Fpm3rtKsdOB+zLzW5k5D1wHbKtTr7ReWpzJPNMc4GDmOYQWZ/b3+hbMP5Gd7H4i+9/nYKe/hmQYO3xPAL676P7e6jFp5DQveA7ThwZTcYDpQw+iecFz+nv9UbuZfuqHTPEjpp/6Ic2jdve3ALX7raTerNrtExG3Acd3eeqyzLyxhzqiy2Ndrx0ZETuAHQCbN2/uYdbSYDUaMPu5qTXv72088ilmD7qZ1lNn0DzoCzQe+TWgz522DtXUEKwa/pl5Vs069gInLbp/IvDgMnXtAnZB5xq+NeuV1qRW9jabNA69nMb87VW3zXsGumzSoAxjqOedwCkRcTLwAHA+8MYh1CsNn0fYakzUCv+IOBf4K+AY4OaIuDszfzUing1clZmvy8wnI+LtwGeAKeBDmXlP7SWXRpXdNhoDtcI/M28Abujy+IPA6xbdvwW4pU5dkqTB8fQOklQgw1+SCmT4S1KBDH9JKpDhL0kFiszRPJYqIh4Gvt3ny44GvrcOi7MRbMtosi2jybb8xHMy85jVCo1s+K9FRMxl5rKnlh4ntmU02ZbRZFv6Z7ePJBXI8JekAk1a+O/a6AUYINsymmzLaLItfZqoPn9JUm8mbctfktSDkQ3/iPhQROyPiK90ee6iiMiIOLq6//yIaEfEExFx0QrzvDoi/isi7q7+XrSebVhUbz9teVNEfLn6+2JE/OIy8zw5Iu6IiD0R8dGImF7vdlT1rkdbxmG9bKvacXdEzEXEK5eZ50siYndE3BcR74uIbhczGrh1aksrIu5dtF6OXe92VPX23JZFj780Ig5ExHnLzHPk18uix1dry2DWS2aO5B/wKuDFwFeWPH4SndNDfxs4unrsWOClwLuBi1aY59XAeSPellcAR1S3Xwvcscw8PwacX93+G+D3x7gt47BensVPuklPA76+zDz/HWjQuYLdp4HXjnFbWsDMKK+X6vEp4LN0zhzc9XM0Duulj7YMZL2M7JZ/Zn4e+H6Xp94L/BGLLgWZmfsz807gR0NavL702ZYvZuaj1d3b6Vz57GmqrZZXA9dXD10DnDPIZV7OoNuykfpsy+NZffOAn6HLpUgjYhNwWGa2q7LXMprrZdW2bKR+2lL5A+ATwP5u8xuX9VJZsS2DNLLh301E/AbwQGZ+qcZs3l39y/veiDh0UMvWrx7bciGdrZSljgIey8wnq/t7gRMGvIg9q9mWBSO/XiLi3Ij4OnAz8LtdXn4CnXWxYGTXSw9tWfDhqmvhncPqKulmubZExAnAuXT++13OWKyXHtuyoPZ6GZvwj4ifBi4D3lVjNpcCz6fTRXQkcPEAFq1vvbQlIn6FTmB2W8ZuK3tDtt4G0BYYk/WSmTdk5vPpbDVe3m0W3V42uCXs3QDaAvCmzDwVOKP6e/N6LOtqVmnLlcDFmXlgpVl0eWwU10svbYEBrZexCX/gucDJwJci4n46XQj/ERHH9zqDzNyXHU8AHwZOX5clXd2KbYmI04CrgG2Z+UiX138PODwiFq7EdiLw4LovdXd12zI262VB9a/8c5fuqKOzRbm4a2tk18uCFdpCZj5QTX8A/AOjuV5mgOuqx88D/joilnbpjMt66aUtA1svw7iA+0Bk5m46O3YBqN6gmczs+QRIEbEpM/dV/yadAzxjD/wwrNSWiNgM/BPw5sz8xjKvz4j4HJ0PyHXAduDGdV/w7stSqy3Va8ZhvTwP+Gb13r8YmAYeWfL6fRHxg4h4OXAHcAGda1wPXd22VBsWh1flDwF+HbhtaA1YZJXv/smLHr8a+FRmfnLJ68divdBDWwa6XuruMV6vP+AfgX10duLuBS5c8vz9/GT0wvFVmf8BHqtuH1Y9dwvw7Or2Z4HddMLl74FnjWBbrgIeBe6u/uYWlVvclp+jM4LhPuDjwKFj3JZxWC8XA/dU7WgDr1xU7u5Ft2eqdnwTeD/VqJpxawudHcF3AV+uyv4lMDVqbVny+NUsGiEzbuull7YMcr14hK8kFWic+vwlSQNi+EtSgQx/SSqQ4S9JBTL8JalAhr8kFcjwl6QCGf6SVKD/B4hzZniRN4fHAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss:  0.0002867949951905757\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFMdJREFUeJzt3X2QJHddx/H3lw0bqiTydKcV7oGLegJXqKhrdAqUhVNJUlbOB7ASRRCjV1rGB1DLIIKppKzjQYvSSgo8NCBYGIFCPa3TSB0ZUWqCt1eEkERO10TJkhQ5FPGpyCbH1z+6lwyb2due7Z6ZnZ73q2prZnp6en6/6ZnP/ubbPd2RmUiS2uVxk26AJKl5hrsktZDhLkktZLhLUgsZ7pLUQoa7JLWQ4S5JLWS4S1ILGe6S1ELnTeqJd+zYkfv27ZvU00vSVDp16tRnM3PnZvNNLNz37dvH0tLSpJ5ekqZSRPxblfksy0hSCxnuktRChrsktZDhLkktZLhLUgttGu4RcVNEPBgRd25wf0TE70bEckTcERHf0nwzJUnDqDJyfydwyTnuvxTYX/4dBt5av1nSNtXrwZEjxaW0jW26n3tmfjgi9p1jlkPAu7I4X99tEfHkiLgwMx9oqI3S9tDrwcGDsLoK8/Nw4gR0OpNulTRQEzX3XcB9fbdXymmPERGHI2IpIpbOnDnTwFNLw6k18O52i2A/e7a47HYbbp3UnCZ+oRoDpg0863ZmHgWOAiwsLHhmbo1VrwcHX3iW1dVgfj45cevccAPvxUV6c8+n+8XnsTj3ETqLi1tvSLcLi4uO/DUyTYT7CrCn7/Zu4P4Glis1qvuuf2P1oV2cZY7Vhx6m+64VOp1nVH58jw4H4wSrBPORnGCOoaPZ0o7GpImyzDHg5eVeM98BfN56u0alTlllkb9lnlXmeJh5HmaRvx3q8d0urD4yx9l8HKuPzG2tKmNpR2Oy6cg9Iv4YWAR2RMQK8BvA4wEy823AceAyYBn4P+CVo2qsZlvdQW/n5fs5cdNldB9+HouP/widlx8Z6vkXF4vnXXv+LVVlGlmItLkqe8tcucn9CfxsYy2SNjBo0DtURaPTodM9QqfbhcUjQ5dDOp3iH0qtcnkjC5E2N7FD/krDamTQ2+nUCtSaDweK2n2XDoswfM1eqshw19Row6DX7akaF48to6nSocdrOEKH6fyFaCPbU/2VrCpw5K7p0YJhb+3SUgteA42HI3dNjxbsRrhWWrr++i3mcgteA42HI3eNVa0fZ7ZkN8JaG2Vb8hpo9Ax3jU3tikIbtqjW5Wugigx3jU3t/dShmX0Rp52vgSqw5q6xWasozM1ZUZBGzZG7xsaKgjQ+hrvGyopCfR4xWFUY7tIUcTd3VWXNXZoi7uauqgx3jZc/na+lkY3SroOZYFlG42NNobbaG6VdBzPDcNf4NLKju2ptlHYdzAzLMhofd3SfPNfBzHDkrqHU2g3PHd0nz3UwM6I4S974LSws5NLS0kSeW1tjuVaavIg4lZkLm81nWUaVuRueND0Md1VmuVaaHtbcVZnlWml6GO4aiseGkaaDZRlJaiHDXZJayHCXZoyHlpkN1tylGeJvFWaHI3dphvhbhdlhuEszxN8qzI5K4R4Rl0TE6YhYjohrBty/NyJujYiPRcQdEXFZ803VtmDBdqqt/Vbh+utrlGR8D0yFTWvuETEH3Ah8D7ACnIyIY5l5d99svw68NzPfGhEHgOPAvhG0V5NkwbYVav1WwffA1Kgycr8YWM7MezJzFbgZOLRungS+srz+JOD+5pqobcOCrXwPTI0q4b4LuK/v9ko5rd+1wMsiYoVi1P5zjbROjav1jdqCrXwPTI0qu0LGgGnrjxN8JfDOzPztiOgA746I52TmF79sQRGHgcMAe/fu3Up7VUPtb9QeXEa+B6ZGlXBfAfb03d7NY8suVwGXAGRmLyKeAOwAHuyfKTOPAkehOJ77FtusLWrkDGseXEa+B6ZClbLMSWB/RFwUEfPAFcCxdfN8CjgIEBHPBp4AnGmyoarPb9TS7Nh05J6Zj0TE1cAtwBxwU2beFRHXAUuZeQz4JeDtEfEqipLNj+ekTvGkDfmNWpodnmZPkqaIp9mTpBlmuEtSCxnuktRChrsktZDhLkktZLhLUgsZ7pKG4hF/p4On2ZNUmUf8nR6O3CVV5hF/p4fhLqkyj080PSzLSKrM4xNND8N9yvR6NT9YtRegWVf7iL++B8fCcJ8itTdmuTVMk+Z7cGysuU+R2huz3BqmSfM9ODaG+xSpvTHLrWGaNN+DY2NZZorU3pjl1jBNmu/BsfFkHZI0RTxZhyTNMMNdklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqoUrhHhGXRMTpiFiOiGs2mOeHI+LuiLgrIt7TbDMltUWvB0eOFJcanU2P5x4Rc8CNwPcAK8DJiDiWmXf3zbMfeA3wvMz8XER81agaLGl6eZa98akycr8YWM7MezJzFbgZOLRunp8CbszMzwFk5oPNNrM9HLVolnmWvfGpciamXcB9fbdXgG9fN8/XA0TER4A54NrM/OtGWtgijlo069bOsrf2GfAse6NTJdxjwLT1p286D9gPLAK7gb+LiOdk5n9+2YIiDgOHAfbu3Tt0Y6fdoFHL0OHe63mKMk2tRs6y52egkirhvgLs6bu9G7h/wDy3ZebDwL0RcZoi7E/2z5SZR4GjUJxmb6uNnla1Ry0O/dUCnU6Nt62fgcqq1NxPAvsj4qKImAeuAI6tm+fPgBcCRMQOijLNPU02tA3WRi3XX7/F96QFS806PwOVbTpyz8xHIuJq4BaKevpNmXlXRFwHLGXmsfK+742Iu4GzwK9k5r+PsuHTqtaoxYKlZp2fgcoiczLVkYWFhVxaWprIc081642adTP+GYiIU5m5sOl8hrskTY+q4e7hBySphQx3SWohw12SWshwl6QWMtwlqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtyH1OvBkSPFpSRtV5ueQ1WP8sTr0uTN+Fn2KjPchzDoxOu+uaTxcYBVnWWZIaydeH1uzhOvS5MwaIClwRy5D6HTKUYKtb4S+p1S2rK1AdbayH1LA6wZ+Qwa7kPqdGq8H/xOKdVSe4A1Q59Bw32cLNpLtdUaYM3QZ9Ca+zhZtJcma4Y+g47cx6mRor2kLZuhz2Bk5kSeeGFhIZeWliby3JI0rSLiVGYubDafZRlJaiHDXZJayHCXpBaqFO4RcUlEnI6I5Yi45hzzvSQiMiI2rQdJkkZn03CPiDngRuBS4ABwZUQcGDDfBcDPAx9tupGSpOFUGblfDCxn5j2ZuQrcDBwaMN/1wJuALzTYPknSFlQJ913AfX23V8ppXxIR3wzsycy/bLBtkqQtqhLuMWDal3aOj4jHAW8BfmnTBUUcjoiliFg6c+ZM9VZKkoZSJdxXgD19t3cD9/fdvgB4DtCNiH8FvgM4NmijamYezcyFzFzYuXPn1lstSTqnKuF+EtgfERdFxDxwBXBs7c7M/Hxm7sjMfZm5D7gNuDwz/fmpJE3IpuGemY8AVwO3AP8IvDcz74qI6yLi8lE3UJI0vEoHDsvM48DxddNev8G8i/WbJUmqw1+oSlILGe6S1EKGuyS1kOEuSS1kuEtSC81cuPd6cORIcSlJbTVT51Dt9eDgweKk5/PzxakUW3wKRUkzbKZG7t1uEexnzxaX3e6kWyRJozFT4b64WIzY5+aKy8XFSbdIkkZjpsoynU5Riul2i2C3JCPNnl5vNjJgpsIdipVZa4XOyjtDaqFGtrtNSQbMXLjX4hZZaaoN2u421Ed4ijJgpmrutblFVppqtbe7TVEGOHIfxto7Y+2/tltkpalSe7vbFGVAZObmc43AwsJCLi1N4fk8pqTeJmlEJpwBEXEqMx9zprvHzGe4S9L0qBru1twlqYUMd0lqIcNdklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWohw12SWshwl6QWMtwlqYUMd0lqoUrhHhGXRMTpiFiOiGsG3P/qiLg7Iu6IiBMR8YzmmypJqmrTcI+IOeBG4FLgAHBlRBxYN9vHgIXM/Ebg/cCbmm6oJKm6KiP3i4HlzLwnM1eBm4FD/TNk5q2Z+X/lzduA3c02U5I0jCrhvgu4r+/2SjltI1cBf1WnUZKkeqqcIDsGTBt4br6IeBmwALxgg/sPA4cB9u7dW7GJkqRhVRm5rwB7+m7vBu5fP1NEfDfwWuDyzHxo0IIy82hmLmTmws6dO7fSXklSBVXC/SSwPyIuioh54ArgWP8MEfHNwO9RBPuDzTfzUb0eHDlSXEqSBtu0LJOZj0TE1cAtwBxwU2beFRHXAUuZeQx4M/BE4H0RAfCpzLy86cb2enDwIKyuwvw8nDgBnU7TzyJJ069KzZ3MPA4cXzft9X3Xv7vhdg3U7RbBfvZscdntGu6SNMhU/UJ1cbEYsc/NFZeLi5NukSRtT5VG7ttFp1OUYrrdIti3NGrv9WouQJJqGFMGTVW4Q/FabPn1sGgvqaZa2TzGDJqqskxtg4r2klTRWja/7nXF5dB77Y0xg2Yr3C3aS6qhdjaPMYOmrixTSyNFe0mzai2b16oqQ2fzGDMoMgceSWDkFhYWcmlpaSLPLUlbNel9MiLiVGYubDbfbI3cJammWjt1jNFs1dwlaUYY7pLUQoa7JLWQ4S5JLWS4S1ILGe6S1EKGuyS1kOEuSS1kuEtSCxnuktRChrsktZDhLkktZLhLUgsZ7pLUQoa7JLWQ4S5JLWS4S1ILGe6S1EKGuyS1kOEuSS1kuEtSCxnuktRClcI9Ii6JiNMRsRwR1wy4//yI+JPy/o9GxL6mGypJqm7TcI+IOeBG4FLgAHBlRBxYN9tVwOcy8+uAtwBvbLqhX9LrwZEjxaUkaaDzKsxzMbCcmfcARMTNwCHg7r55DgHXltffD9wQEZGZ2WBbi0A/eBBWV2F+Hk6cgE6n0aeQpDaoUpbZBdzXd3ulnDZwnsx8BPg88LT1C4qIwxGxFBFLZ86cGb613W4R7GfPFpfd7vDLkKQZUCXcY8C09SPyKvOQmUczcyEzF3bu3FmlfV9ucbEYsc/NFZeLi8MvQ5JmQJWyzAqwp+/2buD+DeZZiYjzgCcB/9FIC/t1OkUpptstgt2SjKQp0+uNJ8KqhPtJYH9EXAR8GrgC+JF18xwDXgH0gJcAH2q83r6m0zHUJU2lcW423LQsU9bQrwZuAf4ReG9m3hUR10XE5eVsfwA8LSKWgVcDj9ldUpJm3Tg3G1YZuZOZx4Hj66a9vu/6F4CXNts0SWqXtc2GayP3UW42rBTukqT6xrnZ0HCXpDEa12ZDjy0jSS1kuEtSCxnuktRChrsktZDhLkktZLhLUgvFqI4SsOkTR/w3cHoiT968HcBnJ92IhtiX7cm+bE+T6MszMnPTIy9Ocj/305m5MMHnb0xELNmX7ce+bE/2ZTwsy0hSCxnuktRCkwz3oxN87qbZl+3JvmxP9mUMJrZBVZI0OpZlJKmFGgn3iLgpIh6MiDv7pj03Im6LiNvLk2JfXE5/UkT8RUR8PCLuiohXbrDMbkScLh9/e0R8VRNtbbgvT4mIP42IOyLiHyLiORss86KI+GhE/HNE/ElEzE9xX94ZEff2rZfnTrAv3xQRvYj4RPme+sq++14TEcvle+jFGyxzO62Xun0Z+3oZph8R8bSIuDUi/icibjjHMp8aER8s18kHI+Ipo+7HCPtybUR8um+dXDaOvnxJZtb+A74L+Bbgzr5pfwNcWl6/DOiW138NeGN5fSfFuVbnByyzCyw00b4R9uXNwG+U158FnNhgme8Friivvw34mSnuyzuBl2yT9XISeEF5/SeA68vrB4CPA+cDFwH/Asxt8/VSty9jXy9D9uMrgOcDPw3ccI5lvgm4prx+zVpWTGlfrgV+eZzrpP+vkZF7Zn6Yx54QO4G10ceTePSk2glcEBEBPLF83CNNtKMJQ/blAHCifNwngX0R8dX9Dyz7+SLg/eWkPwS+v/mWP1bTfZmkDfryTODD5fUPAj9UXj8E3JyZD2XmvcAycHH/A7fhetlyXyZlmH5k5v9m5t8DX9hksYco1gVs03UyRF8mapQ1918E3hwR9wG/BbymnH4D8GyKUPkE8AuZ+cUNlvGO8uvM68oP46Rs1JePAz8IUJY3ngHsXvfYpwH/mcW5aAFWgF0jb/HG6vRlzW+W5Zu3RMT5o27wOdwJrJ3H96XAnvL6LuC+vvkGvebbbb3U6cua7bBeNupHVV+dmQ8AlJdjKcduoG5fAK4u18lN4yoxrRlluP8M8KrM3AO8iuIk2gAvBm4Hng48F7ihv77Y50cz8xuA7yz/fmyEbd3MRn15A/CUiLgd+DngYzz2W8igf0qT3EWpTl+g+GfwLODbgKcCvzryFm/sJ4CfjYhTwAXAajm9ymu+3dZLnb7A9lkvG/VjGtXty1uBr6XIuQeA3262eec2ynB/BfCB8vr7ePSr5CuBD2RhGbiX4k35ZTLz0+XlfwPvYbJfRQf2JTP/KzNfmZnPBV5OsQ3h3nWP/Szw5IhYO9TDbh4thUxCnb6QmQ+U6+4h4B1McL1k5icz83sz81uBP6aoR0Mxuu0fZQ16zbfVeqnZl22zXs7Rj6o+ExEXApSXDzbdxqrq9iUzP5OZZ8vKxNsZ8zoZZbjfD7ygvP4i4J/L658CDgKUNd1nAvf0PzAizouIHeX1xwPfR/EVaVIG9iUinty3h8VPAh/OzP/qf2AWW1ZuBV5STnoF8Ocjb/HGttyXcr61D15Q1EMntl6i3IMqIh4H/DrFRlGAY8AVEXF+RFwE7Af+of+x22291OlL+bhtsV7O0Y+qjlGsC9i+66Tq4y/su/kDjHudNLFVluK/2gPAwxQjjasotiafoqjlfhT41nLep1PssfGJsrMv61vO7fno1uhTwB3AXcDvMGAPgVH8DdmXDkU4fpJiNPyUvuUcB55eXv8aig/kMsVo+fwp7suH+tbdHwFPnGBffgH4p/LvDZQ/yivnfy3FSOs05d5B23y91O3L2NfLFvrxrxQbLf+nnP9AOf33KfeMo9gWcqJ8L54AnrpN10mVvry7XCd3UPzTunAcfVn78xeqktRC/kJVklrIcJekFjLcJamFDHdJaiHDXZJayHCXpBYy3CWphQx3SWqh/wevf8fK9+HKxgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# train the rnn and monitor results\n",
    "n_steps = 75\n",
    "print_every = 15\n",
    "\n",
    "trained_rnn = train(rnn, n_steps, print_every)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Time-Series Prediction\n",
    "\n",
    "Time-series prediction can be applied to many tasks. Think about weather forecasting or predicting the ebb and flow of stock market prices. You can even try to generate predictions much further in the future than just one time step!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}

    sys.stdout.write("predicted string: ")
    for input, label in zip(inputs, labels):
        # print(input.size(), label.size())
        hidden, output = model(hidden, input)
        val, idx = output.max(1)
        sys.stdout.write(idx2char[idx.data[0]])
        loss += criterion(output, label)

    print(", epoch: %d, loss: %1.3f" % (epoch + 1, loss.data[0]))

    loss.backward()
    optimizer.step()

print("Learning finished!")

predicted string: l

RuntimeError: dimension specified as 0 but tensor has no dimensions