In [26]:
import torch
from torch import nn, optim
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
def make_net(hidden_depth, width):
    assert hidden_depth >= 1
    yield nn.Linear(1, width)
    yield nn.ReLU()
    for i in range(hidden_depth-1):
        yield nn.Linear(width, width)
        yield nn.ReLU()
    yield nn.Linear(width, 1)


In [18]:
loss = np.ones(10)

In [19]:
np.any(np.isnan(loss))

False

In [20]:
xr = (-2, 5)
xsize = (xr[1] - xr[0]) * 10 + 1

# xsize = (xr[1] - xr[0]) * 10 + 1
xt = torch.linspace(*xr, xsize).unsqueeze(-1)
yt = torch.exp(xt)
xd = xt.detach()
yd = yt.detach()
STR_NAME = f"exp:{xr[0]}:{xr[1]}"
def get_pred(hidden, width, nepoch=200, lr=0.002, momentum=0.9, debug=False):
    net = nn.Sequential(*make_net(hidden, width))
    lossfunc = nn.MSELoss()
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    data_out = torch.zeros(nepoch, xt.shape[0])
    loss_t = torch.zeros(nepoch)

    og_loss = lossfunc(net(xt), yt)
    for epoch in range(nepoch):
        optimizer.zero_grad()
        ypred = net(xt)
        
        
        loss = lossfunc(ypred, yt)

        
        if debug: print(epoch, loss)
        
        loss_t[epoch] = loss.item()
        data_out[epoch, :] = ypred.squeeze()
        
        loss.backward()
        optimizer.step()
    if debug: print(f"First loss {og_loss} v final {loss}")
    return data_out.detach(), loss_t.detach()

In [21]:
round(xt.max().item(), 2)

5.0

In [12]:
h, w = 2,5
nepoch = 200
data_np = [get_pred(h,w, nepoch)[0].numpy() for i in range(50)]

In [13]:
data_np[0]

array([[ 0.25278017,  0.24895099,  0.24512184, ..., -0.3529262 ,
        -0.36501575, -0.3771053 ],
       [ 0.344922  ,  0.34108284,  0.33724368, ...,  0.14829408,
         0.14497976,  0.14166556],
       [ 0.5292239 ,  0.5251156 ,  0.5210073 , ...,  0.71337706,
         0.7215554 ,  0.7297338 ],
       ...,
       [21.946367  , 21.946367  , 21.946367  , ..., 21.946367  ,
        21.946367  , 21.946367  ],
       [21.946442  , 21.946442  , 21.946442  , ..., 21.946442  ,
        21.946442  , 21.946442  ],
       [21.946514  , 21.946514  , 21.946514  , ..., 21.946514  ,
        21.946514  , 21.946514  ]], dtype=float32)

In [23]:
df = pd.DataFrame(data_np[30])

In [24]:
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,61,62,63,64,65,66,67,68,69,70
0,-0.150518,-0.14683,-0.143143,-0.139456,-0.135769,-0.132082,-0.128395,-0.124707,-0.12102,-0.117333,...,-0.080606,-0.082697,-0.084788,-0.086879,-0.08897,-0.091061,-0.093152,-0.095243,-0.097333,-0.099424
1,-0.028314,-0.024463,-0.020611,-0.01676,-0.012908,-0.009057,-0.005205,-0.001353,0.002498,0.00635,...,0.221273,0.224105,0.226937,0.229769,0.232601,0.235433,0.238265,0.241096,0.243928,0.24676
2,0.182819,0.186935,0.191051,0.195166,0.199282,0.203398,0.207513,0.211629,0.215744,0.21986,...,0.732505,0.74349,0.754475,0.765461,0.776446,0.787431,0.798417,0.809402,0.820387,0.831373
3,0.45371,0.455469,0.459852,0.464235,0.468618,0.473001,0.477384,0.481766,0.486149,0.490532,...,1.565957,1.591109,1.61626,1.641412,1.666563,1.691715,1.716866,1.742018,1.767169,1.792321
4,0.851718,0.846865,0.842012,0.837159,0.832307,0.827454,0.822601,0.817748,0.812896,0.808043,...,3.832177,3.904046,3.975913,4.047781,4.119648,4.191516,4.263384,4.335252,4.40712,4.478987


In [None]:
import numpy as np
import seaborn as sns
sns.set()

plt.rcParams["figure.figsize"] = (14.0, 7.0)
from matplotlib import pyplot as plt
from matplotlib import animation

# First set up the figure, the axis, and the plot element we want to animate
fig = plt.figure()

ax = plt.axes()
plt.title(f"Approxmiated exp with {h} Hidden of {w} Width")
ax.plot(xd, yd, ".")
line_ref = []
for i in range(len(data)):
    line, = ax.plot([], [], lw=2)
    line_ref.append(line)

    

# initialization function: plot the background of each frame
def init():
    for line in line_ref:
        line.set_data([], [])
    return line_ref


# animation function.  This is called sequentially
def animate(i):
#     print(i)
    for dnum, line in enumerate(line_ref):
        line.set_data(xd, data_np[dnum][i])
    return line_ref

# call the animator.  blit=True means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=nepoch, interval=20, blit=True);

anim.save(f'training{h}h{w}w e:{nepoch} {STR_NAME}.mp4',  fps=30, extra_args=['-vcodec', 'libx264'])
print("And Done")