In [14]:
import pandas as pd
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
from tqdm import tqdm

from minigrad.loss import HuberLoss
from minigrad.nn import MLP
from minigrad.optim import ADAGrad

In [2]:
diabetes = load_diabetes()
print(diabetes.data.shape)

X = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
y = pd.DataFrame(diabetes.target, columns=['TARGET'])

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.10, random_state=8, shuffle=True)

# normalize all features
scaler = StandardScaler() # scaling is important here because math.exp() overflows for arguments>700
X_train[diabetes.feature_names] = scaler.fit_transform(X_train[diabetes.feature_names])
X_test[diabetes.feature_names] = scaler.transform(X_test[diabetes.feature_names])

(442, 10)


In [16]:
classifier = MLP(
    nin=len(X_train.columns),
    nouts=[16, 1],
    activation='relu'
)

# setup hyperparameters
epochs = 200

# setup loss function and optimizer
huber_loss = HuberLoss()
optimizer = ADAGrad(
    params=classifier.parameters()
)

actuals = [float(v) for v in list(y_train.TARGET.values)]

for e in tqdm(range(epochs), total=epochs):
# for e in range(epochs):
    # forward pass
    ypreds = [classifier(x.values) for _, x in X_train.iterrows()]
    loss = huber_loss(actuals, ypreds)

    # zero grad
    classifier.zero_grad()

    # backward prop
    loss.backward()

    # recalculate the new values for all parameters - optimizer.step()
    optimizer.step()

    print(f'Epoch: {e}, Loss: {loss}')

  0%|          | 1/200 [00:01<06:03,  1.83s/it]

Epoch: 0, Loss: Value(data=152.02716279446776, grad=1, label=)


  1%|          | 2/200 [00:03<05:53,  1.78s/it]

Epoch: 1, Loss: Value(data=150.84479506414536, grad=1, label=)


  2%|▏         | 3/200 [00:05<06:23,  1.94s/it]

Epoch: 2, Loss: Value(data=149.54536642717042, grad=1, label=)


  2%|▏         | 4/200 [00:07<06:03,  1.86s/it]

Epoch: 3, Loss: Value(data=147.98029170329687, grad=1, label=)


  2%|▎         | 5/200 [00:09<05:50,  1.80s/it]

Epoch: 4, Loss: Value(data=146.17633485798348, grad=1, label=)


  3%|▎         | 6/200 [00:11<06:05,  1.88s/it]

Epoch: 5, Loss: Value(data=144.17813097989094, grad=1, label=)


  4%|▎         | 7/200 [00:12<05:55,  1.84s/it]

Epoch: 6, Loss: Value(data=142.0220606228075, grad=1, label=)


  4%|▍         | 8/200 [00:14<05:44,  1.79s/it]

Epoch: 7, Loss: Value(data=139.7247184751802, grad=1, label=)


  4%|▍         | 9/200 [00:16<06:13,  1.96s/it]

Epoch: 8, Loss: Value(data=137.30749912144609, grad=1, label=)


  5%|▌         | 10/200 [00:18<05:56,  1.88s/it]

Epoch: 9, Loss: Value(data=134.8331395201589, grad=1, label=)


  6%|▌         | 11/200 [00:20<05:42,  1.81s/it]

Epoch: 10, Loss: Value(data=132.28823335109968, grad=1, label=)


  6%|▌         | 12/200 [00:22<05:54,  1.89s/it]

Epoch: 11, Loss: Value(data=129.78822037905672, grad=1, label=)


  6%|▋         | 13/200 [00:24<05:43,  1.84s/it]

Epoch: 12, Loss: Value(data=127.33837147783035, grad=1, label=)


  7%|▋         | 14/200 [00:25<05:34,  1.80s/it]

Epoch: 13, Loss: Value(data=124.88354348412535, grad=1, label=)


  8%|▊         | 15/200 [00:27<05:25,  1.76s/it]

Epoch: 14, Loss: Value(data=122.53933146598523, grad=1, label=)


  8%|▊         | 16/200 [00:29<05:48,  1.89s/it]

Epoch: 15, Loss: Value(data=120.24844897459438, grad=1, label=)


  8%|▊         | 17/200 [00:31<05:38,  1.85s/it]

Epoch: 16, Loss: Value(data=118.1162265636025, grad=1, label=)


  9%|▉         | 18/200 [00:33<05:54,  1.95s/it]

Epoch: 17, Loss: Value(data=116.11113390232491, grad=1, label=)


 10%|▉         | 19/200 [00:35<06:03,  2.01s/it]

Epoch: 18, Loss: Value(data=114.11906785071909, grad=1, label=)


 10%|█         | 20/200 [00:37<05:48,  1.94s/it]

Epoch: 19, Loss: Value(data=112.17766424347249, grad=1, label=)


 10%|█         | 21/200 [00:39<05:35,  1.88s/it]

Epoch: 20, Loss: Value(data=110.26588032602209, grad=1, label=)


 11%|█         | 22/200 [00:40<05:23,  1.82s/it]

Epoch: 21, Loss: Value(data=108.37132289182165, grad=1, label=)


 12%|█▏        | 23/200 [00:42<05:35,  1.89s/it]

Epoch: 22, Loss: Value(data=106.46473896524253, grad=1, label=)


 12%|█▏        | 24/200 [00:44<05:25,  1.85s/it]

Epoch: 23, Loss: Value(data=104.58163120761321, grad=1, label=)


 12%|█▎        | 25/200 [00:46<05:14,  1.80s/it]

Epoch: 24, Loss: Value(data=102.70954956618785, grad=1, label=)


 13%|█▎        | 26/200 [00:48<05:27,  1.88s/it]

Epoch: 25, Loss: Value(data=100.8703453314015, grad=1, label=)


 14%|█▎        | 27/200 [00:50<05:20,  1.85s/it]

Epoch: 26, Loss: Value(data=99.02906260237529, grad=1, label=)


 14%|█▍        | 28/200 [00:51<05:08,  1.79s/it]

Epoch: 27, Loss: Value(data=97.18435268249797, grad=1, label=)


 14%|█▍        | 29/200 [00:54<05:20,  1.88s/it]

Epoch: 28, Loss: Value(data=95.33392394228322, grad=1, label=)


 15%|█▌        | 30/200 [00:55<05:12,  1.84s/it]

Epoch: 29, Loss: Value(data=93.51478072361068, grad=1, label=)


 16%|█▌        | 31/200 [00:57<05:05,  1.81s/it]

Epoch: 30, Loss: Value(data=91.75300750697167, grad=1, label=)


 16%|█▌        | 32/200 [00:59<04:57,  1.77s/it]

Epoch: 31, Loss: Value(data=90.1122845625405, grad=1, label=)


 16%|█▋        | 33/200 [01:01<05:10,  1.86s/it]

Epoch: 32, Loss: Value(data=88.5598366887024, grad=1, label=)


 17%|█▋        | 34/200 [01:03<05:04,  1.83s/it]

Epoch: 33, Loss: Value(data=87.03821577080443, grad=1, label=)


 18%|█▊        | 35/200 [01:05<05:33,  2.02s/it]

Epoch: 34, Loss: Value(data=85.54148425563271, grad=1, label=)


 18%|█▊        | 36/200 [01:07<05:35,  2.05s/it]

Epoch: 35, Loss: Value(data=84.1327552566745, grad=1, label=)


 18%|█▊        | 37/200 [01:09<05:23,  1.99s/it]

Epoch: 36, Loss: Value(data=82.74204219063176, grad=1, label=)


 19%|█▉        | 38/200 [01:11<05:15,  1.95s/it]

Epoch: 37, Loss: Value(data=81.37198992970842, grad=1, label=)


 20%|█▉        | 39/200 [01:13<05:21,  2.00s/it]

Epoch: 38, Loss: Value(data=80.00660073919342, grad=1, label=)


 20%|██        | 40/200 [01:15<05:12,  1.95s/it]

Epoch: 39, Loss: Value(data=78.68467775780381, grad=1, label=)


 20%|██        | 41/200 [01:16<05:00,  1.89s/it]

Epoch: 40, Loss: Value(data=77.44841281283736, grad=1, label=)


 21%|██        | 42/200 [01:18<04:48,  1.82s/it]

Epoch: 41, Loss: Value(data=76.30666996176419, grad=1, label=)


 22%|██▏       | 43/200 [01:20<05:00,  1.91s/it]

Epoch: 42, Loss: Value(data=75.19856642324477, grad=1, label=)


 22%|██▏       | 44/200 [01:22<04:50,  1.86s/it]

Epoch: 43, Loss: Value(data=74.15913904141735, grad=1, label=)


 22%|██▎       | 45/200 [01:24<04:41,  1.82s/it]

Epoch: 44, Loss: Value(data=73.17140081351783, grad=1, label=)


 23%|██▎       | 46/200 [01:26<04:54,  1.91s/it]

Epoch: 45, Loss: Value(data=72.2658120529018, grad=1, label=)


 24%|██▎       | 47/200 [01:28<04:45,  1.86s/it]

Epoch: 46, Loss: Value(data=71.44479667786197, grad=1, label=)


 24%|██▍       | 48/200 [01:29<04:37,  1.83s/it]

Epoch: 47, Loss: Value(data=70.70312028661098, grad=1, label=)


 24%|██▍       | 49/200 [01:31<04:30,  1.79s/it]

Epoch: 48, Loss: Value(data=70.09304775930127, grad=1, label=)


 25%|██▌       | 50/200 [01:33<04:42,  1.88s/it]

Epoch: 49, Loss: Value(data=69.5202016488571, grad=1, label=)


 26%|██▌       | 51/200 [01:35<04:34,  1.84s/it]

Epoch: 50, Loss: Value(data=68.97819238934493, grad=1, label=)


 26%|██▌       | 52/200 [01:37<04:25,  1.80s/it]

Epoch: 51, Loss: Value(data=68.44743555315463, grad=1, label=)


 26%|██▋       | 53/200 [01:39<04:38,  1.89s/it]

Epoch: 52, Loss: Value(data=67.93002938350499, grad=1, label=)


 27%|██▋       | 54/200 [01:41<04:36,  1.90s/it]

Epoch: 53, Loss: Value(data=67.42049685057518, grad=1, label=)


 28%|██▊       | 55/200 [01:42<04:25,  1.83s/it]

Epoch: 54, Loss: Value(data=66.91604122975235, grad=1, label=)


 28%|██▊       | 56/200 [01:44<04:37,  1.92s/it]

Epoch: 55, Loss: Value(data=66.46608206602454, grad=1, label=)


 28%|██▊       | 57/200 [01:46<04:28,  1.88s/it]

Epoch: 56, Loss: Value(data=66.02740408572356, grad=1, label=)


 29%|██▉       | 58/200 [01:48<04:21,  1.84s/it]

Epoch: 57, Loss: Value(data=65.5851130794008, grad=1, label=)


 30%|██▉       | 59/200 [01:50<04:13,  1.80s/it]

Epoch: 58, Loss: Value(data=65.14344600039361, grad=1, label=)


 30%|███       | 60/200 [01:52<04:28,  1.92s/it]

Epoch: 59, Loss: Value(data=64.715565908952, grad=1, label=)


 30%|███       | 61/200 [01:54<04:20,  1.87s/it]

Epoch: 60, Loss: Value(data=64.29470870141779, grad=1, label=)


 31%|███       | 62/200 [01:55<04:10,  1.82s/it]

Epoch: 61, Loss: Value(data=63.896288782296324, grad=1, label=)


 32%|███▏      | 63/200 [01:57<04:22,  1.92s/it]

Epoch: 62, Loss: Value(data=63.515141015330805, grad=1, label=)


 32%|███▏      | 64/200 [01:59<04:14,  1.87s/it]

Epoch: 63, Loss: Value(data=63.14311362123716, grad=1, label=)


 32%|███▎      | 65/200 [02:01<04:07,  1.83s/it]

Epoch: 64, Loss: Value(data=62.767693847184404, grad=1, label=)


 33%|███▎      | 66/200 [02:03<03:59,  1.79s/it]

Epoch: 65, Loss: Value(data=62.390587713053975, grad=1, label=)


 34%|███▎      | 67/200 [02:05<04:11,  1.89s/it]

Epoch: 66, Loss: Value(data=62.03317365195959, grad=1, label=)


 34%|███▍      | 68/200 [02:07<04:03,  1.85s/it]

Epoch: 67, Loss: Value(data=61.69320273746708, grad=1, label=)


 34%|███▍      | 69/200 [02:08<03:59,  1.83s/it]

Epoch: 68, Loss: Value(data=61.371106638977906, grad=1, label=)


 35%|███▌      | 70/200 [02:11<04:21,  2.01s/it]

Epoch: 69, Loss: Value(data=61.06534016233906, grad=1, label=)


 36%|███▌      | 71/200 [02:13<04:11,  1.95s/it]

Epoch: 70, Loss: Value(data=60.76904336511151, grad=1, label=)


 36%|███▌      | 72/200 [02:14<04:03,  1.90s/it]

Epoch: 71, Loss: Value(data=60.477494335409325, grad=1, label=)


 36%|███▋      | 73/200 [02:16<04:10,  1.97s/it]

Epoch: 72, Loss: Value(data=60.19531721658946, grad=1, label=)


 37%|███▋      | 74/200 [02:18<04:00,  1.91s/it]

Epoch: 73, Loss: Value(data=59.914003063331414, grad=1, label=)


 38%|███▊      | 75/200 [02:20<03:53,  1.87s/it]

Epoch: 74, Loss: Value(data=59.63191088731361, grad=1, label=)


 38%|███▊      | 76/200 [02:22<03:45,  1.82s/it]

Epoch: 75, Loss: Value(data=59.35096051523412, grad=1, label=)


 38%|███▊      | 77/200 [02:24<03:55,  1.91s/it]

Epoch: 76, Loss: Value(data=59.07423472766391, grad=1, label=)


 39%|███▉      | 78/200 [02:26<03:47,  1.86s/it]

Epoch: 77, Loss: Value(data=58.80181015612288, grad=1, label=)


 40%|███▉      | 79/200 [02:27<03:39,  1.81s/it]

Epoch: 78, Loss: Value(data=58.53134053643937, grad=1, label=)


 40%|████      | 80/200 [02:29<03:50,  1.92s/it]

Epoch: 79, Loss: Value(data=58.25961939359927, grad=1, label=)


 40%|████      | 81/200 [02:31<03:42,  1.87s/it]

Epoch: 80, Loss: Value(data=57.98599330207217, grad=1, label=)


 41%|████      | 82/200 [02:33<03:34,  1.82s/it]

Epoch: 81, Loss: Value(data=57.71226577306095, grad=1, label=)


 42%|████▏     | 83/200 [02:36<04:05,  2.10s/it]

Epoch: 82, Loss: Value(data=57.447868200479924, grad=1, label=)


 42%|████▏     | 84/200 [02:38<03:57,  2.05s/it]

Epoch: 83, Loss: Value(data=57.19743027186864, grad=1, label=)


 42%|████▎     | 85/200 [02:39<03:45,  1.96s/it]

Epoch: 84, Loss: Value(data=56.9525676442248, grad=1, label=)


 43%|████▎     | 86/200 [02:41<03:40,  1.94s/it]

Epoch: 85, Loss: Value(data=56.708153102853345, grad=1, label=)


 44%|████▎     | 87/200 [02:43<03:44,  1.99s/it]

Epoch: 86, Loss: Value(data=56.46194642521501, grad=1, label=)


 44%|████▍     | 88/200 [02:45<03:35,  1.92s/it]

Epoch: 87, Loss: Value(data=56.21163308706199, grad=1, label=)


 44%|████▍     | 89/200 [02:47<03:25,  1.85s/it]

Epoch: 88, Loss: Value(data=55.96299092142196, grad=1, label=)


 45%|████▌     | 90/200 [02:49<03:32,  1.93s/it]

Epoch: 89, Loss: Value(data=55.71339560533817, grad=1, label=)


 46%|████▌     | 91/200 [02:51<03:26,  1.90s/it]

Epoch: 90, Loss: Value(data=55.463876294431664, grad=1, label=)


 46%|████▌     | 92/200 [02:52<03:19,  1.84s/it]

Epoch: 91, Loss: Value(data=55.21841776061101, grad=1, label=)


 46%|████▋     | 93/200 [02:54<03:13,  1.81s/it]

Epoch: 92, Loss: Value(data=54.9795563788024, grad=1, label=)


 47%|████▋     | 94/200 [02:56<03:21,  1.90s/it]

Epoch: 93, Loss: Value(data=54.749128915083794, grad=1, label=)


 48%|████▊     | 95/200 [02:58<03:15,  1.86s/it]

Epoch: 94, Loss: Value(data=54.52233134921292, grad=1, label=)


 48%|████▊     | 96/200 [03:00<03:08,  1.81s/it]

Epoch: 95, Loss: Value(data=54.29846467764214, grad=1, label=)


 48%|████▊     | 97/200 [03:02<03:15,  1.90s/it]

Epoch: 96, Loss: Value(data=54.07994138956783, grad=1, label=)


 49%|████▉     | 98/200 [03:04<03:08,  1.85s/it]

Epoch: 97, Loss: Value(data=53.86813104914562, grad=1, label=)


 50%|████▉     | 99/200 [03:05<03:02,  1.81s/it]

Epoch: 98, Loss: Value(data=53.66624868043991, grad=1, label=)


 50%|█████     | 100/200 [03:07<03:10,  1.90s/it]

Epoch: 99, Loss: Value(data=53.47465123935922, grad=1, label=)


 50%|█████     | 101/200 [03:09<03:04,  1.87s/it]

Epoch: 100, Loss: Value(data=53.29287901109269, grad=1, label=)


 51%|█████     | 102/200 [03:11<02:59,  1.83s/it]

Epoch: 101, Loss: Value(data=53.12179830275754, grad=1, label=)


 52%|█████▏    | 103/200 [03:13<02:54,  1.80s/it]

Epoch: 102, Loss: Value(data=52.962100324262096, grad=1, label=)


 52%|█████▏    | 104/200 [03:15<03:04,  1.92s/it]

Epoch: 103, Loss: Value(data=52.80289707274279, grad=1, label=)


 52%|█████▎    | 105/200 [03:17<02:57,  1.87s/it]

Epoch: 104, Loss: Value(data=52.64485895780213, grad=1, label=)


 53%|█████▎    | 106/200 [03:19<02:58,  1.90s/it]

Epoch: 105, Loss: Value(data=52.49195748228968, grad=1, label=)


 54%|█████▎    | 107/200 [03:21<03:19,  2.14s/it]

Epoch: 106, Loss: Value(data=52.34439217456491, grad=1, label=)


 54%|█████▍    | 108/200 [03:23<03:07,  2.03s/it]

Epoch: 107, Loss: Value(data=52.20124194655523, grad=1, label=)


 55%|█████▍    | 109/200 [03:25<02:56,  1.94s/it]

Epoch: 108, Loss: Value(data=52.06086573423388, grad=1, label=)


 55%|█████▌    | 110/200 [03:27<03:00,  2.00s/it]

Epoch: 109, Loss: Value(data=51.919054286078065, grad=1, label=)


 56%|█████▌    | 111/200 [03:29<02:52,  1.94s/it]

Epoch: 110, Loss: Value(data=51.776079146829936, grad=1, label=)


 56%|█████▌    | 112/200 [03:31<02:45,  1.88s/it]

Epoch: 111, Loss: Value(data=51.63281471400873, grad=1, label=)


 56%|█████▋    | 113/200 [03:32<02:39,  1.84s/it]

Epoch: 112, Loss: Value(data=51.493012187309716, grad=1, label=)


 57%|█████▋    | 114/200 [03:34<02:45,  1.92s/it]

Epoch: 113, Loss: Value(data=51.36092296078905, grad=1, label=)


 57%|█████▊    | 115/200 [03:36<02:39,  1.88s/it]

Epoch: 114, Loss: Value(data=51.22973610440617, grad=1, label=)


 58%|█████▊    | 116/200 [03:38<02:34,  1.83s/it]

Epoch: 115, Loss: Value(data=51.09976839067123, grad=1, label=)


 58%|█████▊    | 117/200 [03:40<02:41,  1.95s/it]

Epoch: 116, Loss: Value(data=50.9729104651017, grad=1, label=)


 59%|█████▉    | 118/200 [03:42<02:35,  1.89s/it]

Epoch: 117, Loss: Value(data=50.84870210164448, grad=1, label=)


 60%|█████▉    | 119/200 [03:44<02:29,  1.85s/it]

Epoch: 118, Loss: Value(data=50.724003051535796, grad=1, label=)


 60%|██████    | 120/200 [03:45<02:23,  1.80s/it]

Epoch: 119, Loss: Value(data=50.59967021854729, grad=1, label=)


 60%|██████    | 121/200 [03:47<02:29,  1.89s/it]

Epoch: 120, Loss: Value(data=50.477655512099936, grad=1, label=)


 61%|██████    | 122/200 [03:49<02:24,  1.85s/it]

Epoch: 121, Loss: Value(data=50.355842640522724, grad=1, label=)


 62%|██████▏   | 123/200 [03:51<02:18,  1.80s/it]

Epoch: 122, Loss: Value(data=50.23566388725218, grad=1, label=)


 62%|██████▏   | 124/200 [03:53<02:23,  1.89s/it]

Epoch: 123, Loss: Value(data=50.12010426466716, grad=1, label=)


 62%|██████▎   | 125/200 [03:55<02:18,  1.85s/it]

Epoch: 124, Loss: Value(data=50.006868878683804, grad=1, label=)


 63%|██████▎   | 126/200 [03:56<02:13,  1.80s/it]

Epoch: 125, Loss: Value(data=49.894735407530604, grad=1, label=)


 64%|██████▎   | 127/200 [03:58<02:18,  1.90s/it]

Epoch: 126, Loss: Value(data=49.78342416301779, grad=1, label=)


 64%|██████▍   | 128/200 [04:00<02:13,  1.86s/it]

Epoch: 127, Loss: Value(data=49.67340503891684, grad=1, label=)


 64%|██████▍   | 129/200 [04:02<02:09,  1.82s/it]

Epoch: 128, Loss: Value(data=49.56354503653388, grad=1, label=)


 65%|██████▌   | 130/200 [04:04<02:04,  1.78s/it]

Epoch: 129, Loss: Value(data=49.454614535280456, grad=1, label=)


 66%|██████▌   | 131/200 [04:06<02:09,  1.87s/it]

Epoch: 130, Loss: Value(data=49.347157116985024, grad=1, label=)


 66%|██████▌   | 132/200 [04:08<02:06,  1.86s/it]

Epoch: 131, Loss: Value(data=49.24017737075563, grad=1, label=)


 66%|██████▋   | 133/200 [04:09<02:01,  1.82s/it]

Epoch: 132, Loss: Value(data=49.133345918126324, grad=1, label=)


 67%|██████▋   | 134/200 [04:11<02:06,  1.92s/it]

Epoch: 133, Loss: Value(data=49.02695822419869, grad=1, label=)


 68%|██████▊   | 135/200 [04:13<02:01,  1.87s/it]

Epoch: 134, Loss: Value(data=48.92127566970562, grad=1, label=)


 68%|██████▊   | 136/200 [04:15<01:57,  1.83s/it]

Epoch: 135, Loss: Value(data=48.81686148078814, grad=1, label=)


 68%|██████▊   | 137/200 [04:17<01:52,  1.78s/it]

Epoch: 136, Loss: Value(data=48.713780876281064, grad=1, label=)


 69%|██████▉   | 138/200 [04:19<01:56,  1.87s/it]

Epoch: 137, Loss: Value(data=48.61416921038205, grad=1, label=)


 70%|██████▉   | 139/200 [04:20<01:52,  1.84s/it]

Epoch: 138, Loss: Value(data=48.518885944200925, grad=1, label=)


 70%|███████   | 140/200 [04:22<01:48,  1.80s/it]

Epoch: 139, Loss: Value(data=48.42630143928306, grad=1, label=)


 70%|███████   | 141/200 [04:25<02:06,  2.14s/it]

Epoch: 140, Loss: Value(data=48.33420373153363, grad=1, label=)


 71%|███████   | 142/200 [04:27<01:57,  2.03s/it]

Epoch: 141, Loss: Value(data=48.24467209642801, grad=1, label=)


 72%|███████▏  | 143/200 [04:29<01:50,  1.94s/it]

Epoch: 142, Loss: Value(data=48.15799031244315, grad=1, label=)


 72%|███████▏  | 144/200 [04:31<01:53,  2.02s/it]

Epoch: 143, Loss: Value(data=48.072662564628466, grad=1, label=)


 72%|███████▎  | 145/200 [04:33<01:48,  1.97s/it]

Epoch: 144, Loss: Value(data=47.98866332754976, grad=1, label=)


 73%|███████▎  | 146/200 [04:34<01:43,  1.92s/it]

Epoch: 145, Loss: Value(data=47.9052703066688, grad=1, label=)


 74%|███████▎  | 147/200 [04:36<01:39,  1.88s/it]

Epoch: 146, Loss: Value(data=47.822315481194764, grad=1, label=)


 74%|███████▍  | 148/200 [04:38<01:42,  1.97s/it]

Epoch: 147, Loss: Value(data=47.7411599783311, grad=1, label=)


 74%|███████▍  | 149/200 [04:40<01:39,  1.94s/it]

Epoch: 148, Loss: Value(data=47.65984313392209, grad=1, label=)


 75%|███████▌  | 150/200 [04:42<01:37,  1.94s/it]

Epoch: 149, Loss: Value(data=47.580365992184, grad=1, label=)


 76%|███████▌  | 151/200 [04:45<01:42,  2.09s/it]

Epoch: 150, Loss: Value(data=47.50169487750648, grad=1, label=)


 76%|███████▌  | 152/200 [04:46<01:35,  1.99s/it]

Epoch: 151, Loss: Value(data=47.42298924328871, grad=1, label=)


 76%|███████▋  | 153/200 [04:48<01:30,  1.92s/it]

Epoch: 152, Loss: Value(data=47.34434774441961, grad=1, label=)


 77%|███████▋  | 154/200 [04:50<01:31,  1.99s/it]

Epoch: 153, Loss: Value(data=47.26554545953186, grad=1, label=)


 78%|███████▊  | 155/200 [04:52<01:26,  1.93s/it]

Epoch: 154, Loss: Value(data=47.1873978335342, grad=1, label=)


 78%|███████▊  | 156/200 [04:54<01:24,  1.91s/it]

Epoch: 155, Loss: Value(data=47.10982199026216, grad=1, label=)


 78%|███████▊  | 157/200 [04:56<01:22,  1.92s/it]

Epoch: 156, Loss: Value(data=47.032185255248336, grad=1, label=)


 79%|███████▉  | 158/200 [04:58<01:24,  2.00s/it]

Epoch: 157, Loss: Value(data=46.95409224058627, grad=1, label=)


 80%|███████▉  | 159/200 [05:00<01:19,  1.94s/it]

Epoch: 158, Loss: Value(data=46.87722111910157, grad=1, label=)


 80%|████████  | 160/200 [05:02<01:15,  1.88s/it]

Epoch: 159, Loss: Value(data=46.801972584167856, grad=1, label=)


 80%|████████  | 161/200 [05:04<01:17,  1.99s/it]

Epoch: 160, Loss: Value(data=46.72835081492042, grad=1, label=)


 81%|████████  | 162/200 [05:06<01:16,  2.02s/it]

Epoch: 161, Loss: Value(data=46.65715118097404, grad=1, label=)


 82%|████████▏ | 163/200 [05:08<01:11,  1.93s/it]

Epoch: 162, Loss: Value(data=46.588305734262875, grad=1, label=)


 82%|████████▏ | 164/200 [05:09<01:06,  1.86s/it]

Epoch: 163, Loss: Value(data=46.520974253867124, grad=1, label=)


 82%|████████▎ | 165/200 [05:12<01:07,  1.93s/it]

Epoch: 164, Loss: Value(data=46.45515470280491, grad=1, label=)


 83%|████████▎ | 166/200 [05:13<01:04,  1.91s/it]

Epoch: 165, Loss: Value(data=46.390436964867426, grad=1, label=)


 84%|████████▎ | 167/200 [05:15<01:01,  1.86s/it]

Epoch: 166, Loss: Value(data=46.32739875008842, grad=1, label=)


 84%|████████▍ | 168/200 [05:17<01:02,  1.94s/it]

Epoch: 167, Loss: Value(data=46.26548685148579, grad=1, label=)


 84%|████████▍ | 169/200 [05:19<00:58,  1.89s/it]

Epoch: 168, Loss: Value(data=46.20489385877579, grad=1, label=)


 85%|████████▌ | 170/200 [05:21<00:55,  1.84s/it]

Epoch: 169, Loss: Value(data=46.14503426758261, grad=1, label=)


 86%|████████▌ | 171/200 [05:23<00:56,  1.94s/it]

Epoch: 170, Loss: Value(data=46.08610794442938, grad=1, label=)


 86%|████████▌ | 172/200 [05:25<00:53,  1.89s/it]

Epoch: 171, Loss: Value(data=46.028148178754385, grad=1, label=)


 86%|████████▋ | 173/200 [05:27<00:50,  1.86s/it]

Epoch: 172, Loss: Value(data=45.97089280344248, grad=1, label=)


 87%|████████▋ | 174/200 [05:29<00:52,  2.02s/it]

Epoch: 173, Loss: Value(data=45.91510418528889, grad=1, label=)


 88%|████████▊ | 175/200 [05:31<00:51,  2.06s/it]

Epoch: 174, Loss: Value(data=45.85990657647471, grad=1, label=)


 88%|████████▊ | 176/200 [05:33<00:47,  1.96s/it]

Epoch: 175, Loss: Value(data=45.805601312914234, grad=1, label=)


 88%|████████▊ | 177/200 [05:35<00:44,  1.95s/it]

Epoch: 176, Loss: Value(data=45.7516731874944, grad=1, label=)


 89%|████████▉ | 178/200 [05:37<00:44,  2.01s/it]

Epoch: 177, Loss: Value(data=45.69806120484533, grad=1, label=)


 90%|████████▉ | 179/200 [05:39<00:40,  1.95s/it]

Epoch: 178, Loss: Value(data=45.64413292645139, grad=1, label=)


 90%|█████████ | 180/200 [05:41<00:38,  1.93s/it]

Epoch: 179, Loss: Value(data=45.589880317600546, grad=1, label=)


 90%|█████████ | 181/200 [05:43<00:38,  2.01s/it]

Epoch: 180, Loss: Value(data=45.53633391685518, grad=1, label=)


 91%|█████████ | 182/200 [05:45<00:35,  1.96s/it]

Epoch: 181, Loss: Value(data=45.4833083962015, grad=1, label=)


 92%|█████████▏| 183/200 [05:46<00:32,  1.90s/it]

Epoch: 182, Loss: Value(data=45.43171875676016, grad=1, label=)


 92%|█████████▏| 184/200 [05:48<00:29,  1.85s/it]

Epoch: 183, Loss: Value(data=45.38108169223573, grad=1, label=)


 92%|█████████▎| 185/200 [05:51<00:32,  2.14s/it]

Epoch: 184, Loss: Value(data=45.33077182264746, grad=1, label=)


 93%|█████████▎| 186/200 [05:53<00:28,  2.05s/it]

Epoch: 185, Loss: Value(data=45.281150883669945, grad=1, label=)


 94%|█████████▎| 187/200 [05:55<00:25,  1.96s/it]

Epoch: 186, Loss: Value(data=45.232320503422166, grad=1, label=)


 94%|█████████▍| 188/200 [05:57<00:24,  2.05s/it]

Epoch: 187, Loss: Value(data=45.18481842793399, grad=1, label=)


 94%|█████████▍| 189/200 [05:59<00:21,  1.97s/it]

Epoch: 188, Loss: Value(data=45.13828792515729, grad=1, label=)


 95%|█████████▌| 190/200 [06:00<00:19,  1.91s/it]

Epoch: 189, Loss: Value(data=45.09130387231922, grad=1, label=)


 96%|█████████▌| 191/200 [06:02<00:16,  1.85s/it]

Epoch: 190, Loss: Value(data=45.0440731730156, grad=1, label=)


 96%|█████████▌| 192/200 [06:04<00:15,  1.93s/it]

Epoch: 191, Loss: Value(data=44.99698606332082, grad=1, label=)


 96%|█████████▋| 193/200 [06:06<00:13,  1.88s/it]

Epoch: 192, Loss: Value(data=44.95056760856839, grad=1, label=)


 97%|█████████▋| 194/200 [06:08<00:10,  1.82s/it]

Epoch: 193, Loss: Value(data=44.90519722147429, grad=1, label=)


 98%|█████████▊| 195/200 [06:10<00:09,  1.91s/it]

Epoch: 194, Loss: Value(data=44.86233108601493, grad=1, label=)


 98%|█████████▊| 196/200 [06:11<00:07,  1.87s/it]

Epoch: 195, Loss: Value(data=44.82071922324997, grad=1, label=)


 98%|█████████▊| 197/200 [06:13<00:05,  1.82s/it]

Epoch: 196, Loss: Value(data=44.7799622390086, grad=1, label=)


 99%|█████████▉| 198/200 [06:15<00:03,  1.91s/it]

Epoch: 197, Loss: Value(data=44.74061167827466, grad=1, label=)


100%|█████████▉| 199/200 [06:17<00:01,  1.87s/it]

Epoch: 198, Loss: Value(data=44.70260036064233, grad=1, label=)


100%|██████████| 200/200 [06:19<00:00,  1.90s/it]

Epoch: 199, Loss: Value(data=44.66549774757104, grad=1, label=)





In [17]:
all_res = list()
X_test.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)
for counter, x in X_test.iterrows():
    pred = classifier(x.values).data
    actual = y_test.iloc[counter, :].TARGET

    res = {
        'PREDICTED': pred,
        'ACTUAL': actual
    }

    all_res.append(res)

In [18]:
res_df = pd.DataFrame.from_records(all_res)
mse = mean_squared_error(res_df.ACTUAL.values, res_df.PREDICTED.values)
mae = mean_absolute_error(res_df.ACTUAL.values, res_df.PREDICTED.values)
print(f'MSE: {mse}, MAE: {mae}')

MSE: 3624.4397875038653, MAE: 50.725837505346945
