In [1]:
from torchic import Sigmoid, TanH, ReLU, LeakyReLU, Linear, Softmax
from torchic import CrossEntropyLoss, MeanSquareError
from torchic import SGD, Adam
from torchic import Activation, Cost, Optimizer
from torchic import Layer
from torchic import Dataloader

from utils.datasets import Datasets
from utils.visualization import (
    plot_train_val_curve, 
    plot_activations_histogram,
    plot_gradients_histogram,
    plot_weights_histogram
)

import numpy as np

In [2]:
class MLP:
    def __init__(self):
        self.layers = [
            Layer(input_size=13, output_size=64, activation=Sigmoid(), initialization='normal'),
            Layer(input_size=64, output_size=64, activation=Sigmoid(), initialization='normal'),
            Layer(input_size=64, output_size=64, activation=Sigmoid(), initialization='normal'),
            Layer(input_size=64, output_size=64, activation=Sigmoid(), initialization='normal'),
            Layer(input_size=64, output_size=1, activation=Linear(), initialization='normal'),
        ]
        self.criterion = MeanSquareError()
        self.optimizer = SGD(parameters=self.layers, lr=1e-3)
        
        
    def forward(self, X):
        z = X
        for layer in self.layers:
            z = layer(z)
        return z
    
    def backward(self):
        dJ_dA = self.criterion.backward()
        for layer in reversed(self.layers):
            dJ_dA = layer.backward(dJ_dA)
            
    def train(self, dataloader: Dataloader):
        epoch_losses = []
        for batch, (X, y_true) in dataloader:
            self.optimizer.zero_grad()
            y_pred = self.forward(X)
            print(y_pred)
            loss = self.criterion(y_pred, y_true)
            epoch_losses.append(loss)
            self.backward()
            self.optimizer.step()
            if batch % 100 == 0:
                batch_size = X.shape[0]
                total_samples = dataloader.num_samples
                current = (batch + 1) * batch_size
                print(f'Loss: {loss:>7f}, [{current:>5d}/{total_samples}]')
        avg_loss = np.mean(epoch_losses) 
        return avg_loss
    
    def val(self, dataloader: Dataloader):
        epoch_losses = []
        predictions = []
        targets = []
        
        for _, (X, y_true) in dataloader:
            y_pred = self.forward(X)
            epoch_losses.append(self.criterion(y_pred, y_true))
            predictions.extend(y_pred.flatten())
            targets.extend(y_true.flatten())
        
        predictions = np.array(predictions)
        targets = np.array(targets)
        
        avg_loss = np.mean(epoch_losses)
        mse = np.mean((predictions - targets)**2)
        rmse = np.sqrt(mse)
        mae = np.mean(np.abs(predictions - targets))
        
        print(f'Validation Error: \n MSE: {mse:>0.4f}, RMSE: {rmse:>0.4f}, MAE: {mae:>0.4f}, Avg loss: {avg_loss:>0.8f}\n')
        
        return avg_loss
    
    def test(self, dataloader: Dataloader):
        epoch_losses = []
        predictions = []
        targets = []
        
        for _, (X, y_true) in dataloader:
            y_pred = self.forward(X)
            epoch_losses.append(self.criterion(y_pred, y_true))
            predictions.extend(y_pred.flatten())
            targets.extend(y_true.flatten())
        
        predictions = np.array(predictions)
        targets = np.array(targets)
        
        avg_loss = np.mean(epoch_losses)
        mse = np.mean((predictions - targets)**2)
        rmse = np.sqrt(mse)
        mae = np.mean(np.abs(predictions - targets))
        
        print(f'Test Error: \n MSE: {mse:>0.4f}, RMSE: {rmse:>0.4f}, MAE: {mae:>0.4f}, Avg loss: {avg_loss:>0.8f}\n')
        
    def fit(self, train_dataloader: Dataloader, val_dataloader: Dataloader, test_dataloader: Dataloader, epochs: int):
        train_loss_per_epoch = []
        val_loss_per_epoch = []
        
        print('Training MLP...\n')
        for epoch in range(epochs):
            print(f'Epoch: {epoch+1}')
            train_loss_per_epoch.append(self.train(train_dataloader))
            val_loss_per_epoch.append(self.val(val_dataloader))
            plot_activations_histogram(self.layers)
            plot_gradients_histogram(self.layers)
        plot_train_val_curve(train_loss_per_epoch, val_loss_per_epoch)
        self.test(test_dataloader)
        
    def get_topology(self) -> str:
        architecture = ''
        for i, layer in enumerate(self.layers):
            architecture += (
                f'Layer: {i+1}\n'
                f'    Input: {layer.input_size} | Output: {layer.output_size}\n'
                f'    Theta.shape: {layer.theta.shape} | Bias.shape: {layer.bias.shape}\n'
                f'    Activation: {layer.activation}\n\n'
            )
        architecture += f'Loss function: {self.criterion} | Optimizer: {self.optimizer}'
        return architecture   

    def save_model(self, file_path: str):
        parameters_to_save = {}
        for i, layer in enumerate(self.layers):
            parameters_to_save[f'layer_{i}_theta'] = layer.theta
            parameters_to_save[f'layer_{i}_bias'] = layer.bias
        np.savez(file_path, **parameters_to_save)
        print('Model saved.')
    
    def load_model(self, file_path):
        parameters = np.load(file_path)
        for i, layer in enumerate(self.layers):
            layer.theta = parameters[f'layer_{i}_theta']
            layer.bias = parameters[f'layer_{i}_bias']
        print(f"Model weights loaded from {file_path}")
    
    def __str__(self) -> str:
        return self.get_topology()

In [3]:
datasets = Datasets()
batch_size = 32

X_train, X_val, X_test, y_train, y_val, y_test = datasets('boston_housing')

train_dataloader = Dataloader(X_train, y_train, batch_size=batch_size)
val_dataloader = Dataloader(X_val, y_val, batch_size=batch_size)
test_dataloader = Dataloader(X_test, y_test, batch_size=batch_size)

In [4]:
mlp = MLP()
print(mlp)

Layer: 1
    Input: 13 | Output: 64
    Theta.shape: (64, 13) | Bias.shape: (1, 64)
    Activation: Sigmoid

Layer: 2
    Input: 64 | Output: 64
    Theta.shape: (64, 64) | Bias.shape: (1, 64)
    Activation: Sigmoid

Layer: 3
    Input: 64 | Output: 64
    Theta.shape: (64, 64) | Bias.shape: (1, 64)
    Activation: Sigmoid

Layer: 4
    Input: 64 | Output: 64
    Theta.shape: (64, 64) | Bias.shape: (1, 64)
    Activation: Sigmoid

Layer: 5
    Input: 64 | Output: 1
    Theta.shape: (1, 64) | Bias.shape: (1, 1)
    Activation: Linear

Loss function: MSE | Optimizer: SGD(learning_rate=0.001)


In [5]:
plot_weights_histogram(mlp.layers)

In [6]:
mlp.fit(
    train_dataloader=train_dataloader, 
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    epochs=15
)

Training MLP...

Epoch: 1
[[-0.27294087]
 [-0.27115302]
 [-0.27097388]
 [-0.27767454]
 [-0.27809174]
 [-0.27014616]
 [-0.27475823]
 [-0.27133393]
 [-0.27296912]
 [-0.28300603]
 [-0.27064434]
 [-0.27322588]
 [-0.27050386]
 [-0.27152933]
 [-0.27483706]
 [-0.27858738]
 [-0.2796222 ]
 [-0.27364776]
 [-0.27569417]
 [-0.27415174]
 [-0.27937695]
 [-0.2777113 ]
 [-0.27273278]
 [-0.27123407]
 [-0.2748626 ]
 [-0.27688066]
 [-0.27142951]
 [-0.27037263]
 [-0.28045416]
 [-0.27075096]
 [-0.27426237]
 [-0.27291935]]
Loss: 845.982564, [   32/303]
[[0.25390312]
 [0.25736017]
 [0.2488198 ]
 [0.25536447]
 [0.25825912]
 [0.25349733]
 [0.25317383]
 [0.25630891]
 [0.25094741]
 [0.25656008]
 [0.25639816]
 [0.25662412]
 [0.24897117]
 [0.25582122]
 [0.25751964]
 [0.25781037]
 [0.25696799]
 [0.25221483]
 [0.24986485]
 [0.2587654 ]
 [0.25771606]
 [0.25798275]
 [0.24981708]
 [0.25014445]
 [0.25078365]
 [0.2539763 ]
 [0.24675828]
 [0.2585811 ]
 [0.25453211]
 [0.2537019 ]
 [0.25705731]
 [0.25341794]]
[[0.65361244]


Epoch: 2
[[3.9640208 ]
 [3.96654861]
 [3.96622316]
 [3.95979825]
 [3.9589353 ]
 [3.96673548]
 [3.96235497]
 [3.96628114]
 [3.96478788]
 [3.95113574]
 [3.96650163]
 [3.96352576]
 [3.96677271]
 [3.96551331]
 [3.96015648]
 [3.95529764]
 [3.9535471 ]
 [3.96329539]
 [3.9610135 ]
 [3.96302321]
 [3.953466  ]
 [3.9561547 ]
 [3.96465452]
 [3.96618536]
 [3.96176331]
 [3.95850661]
 [3.96584993]
 [3.96706601]
 [3.95227372]
 [3.96650644]
 [3.96251766]
 [3.96474622]]
Loss: 635.103218, [   32/303]
[[4.43910883]
 [4.44353137]
 [4.43160388]
 [4.44218622]
 [4.44487566]
 [4.44011782]
 [4.43847428]
 [4.44280741]
 [4.4335388 ]
 [4.44308068]
 [4.44269495]
 [4.44213493]
 [4.43081473]
 [4.44160784]
 [4.44407717]
 [4.4440736 ]
 [4.4430042 ]
 [4.43846644]
 [4.43243351]
 [4.44475456]
 [4.44470751]
 [4.44412126]
 [4.43198066]
 [4.43555254]
 [4.43328763]
 [4.4382405 ]
 [4.42949458]
 [4.44461757]
 [4.44064386]
 [4.43804641]
 [4.44352541]
 [4.437716  ]]
[[4.78516443]
 [4.78512291]
 [4.78306827]
 [4.78796013]
 [4.782

Epoch: 3
[[7.79997868]
 [7.80309151]
 [7.80230034]
 [7.79580913]
 [7.79499287]
 [7.80234993]
 [7.7984653 ]
 [7.80240184]
 [7.80130008]
 [7.78401573]
 [7.80247378]
 [7.79910749]
 [7.80278597]
 [7.80134183]
 [7.7936662 ]
 [7.7884833 ]
 [7.78556487]
 [7.79911471]
 [7.7968712 ]
 [7.79918341]
 [7.78522971]
 [7.78878881]
 [7.8010068 ]
 [7.80239459]
 [7.79689193]
 [7.79248078]
 [7.80206574]
 [7.80326935]
 [7.78389142]
 [7.80267176]
 [7.79814187]
 [7.801418  ]]
Loss: 475.147294, [   32/303]
[[8.25760164]
 [8.2627212 ]
 [8.24782803]
 [8.26185702]
 [8.26453337]
 [8.25934213]
 [8.25659844]
 [8.26243697]
 [8.249049  ]
 [8.26285942]
 [8.26193276]
 [8.26079556]
 [8.24565053]
 [8.2603856 ]
 [8.26361753]
 [8.26333489]
 [8.26203198]
 [8.25733938]
 [8.24788699]
 [8.26362737]
 [8.26440933]
 [8.26316326]
 [8.2470788 ]
 [8.25379556]
 [8.24871157]
 [8.25545156]
 [8.24506406]
 [8.26353242]
 [8.25934161]
 [8.25480319]
 [8.263141  ]
 [8.25494821]]
[[8.56560679]
 [8.5658609 ]
 [8.56385711]
 [8.56866834]
 [8.562

Epoch: 4
[[11.30573701]
 [11.30923725]
 [11.30801623]
 [11.3011611 ]
 [11.30107545]
 [11.30750813]
 [11.30438525]
 [11.30795546]
 [11.3073884 ]
 [11.28660181]
 [11.30814281]
 [11.30434104]
 [11.30829964]
 [11.30676588]
 [11.29671784]
 [11.29200192]
 [11.28724433]
 [11.30471726]
 [11.30261519]
 [11.30513866]
 [11.2867025 ]
 [11.29118535]
 [11.30708322]
 [11.30813   ]
 [11.30144967]
 [11.29599146]
 [11.30792686]
 [11.30898373]
 [11.28526495]
 [11.30848694]
 [11.30346454]
 [11.30787394]]
Loss: 354.692348, [   32/303]
[[11.73896238]
 [11.74449818]
 [11.72748916]
 [11.74408649]
 [11.74680553]
 [11.74087124]
 [11.73724754]
 [11.74478211]
 [11.72732181]
 [11.74539181]
 [11.7437723 ]
 [11.74214552]
 [11.72319395]
 [11.74184439]
 [11.74575101]
 [11.7452471 ]
 [11.74367   ]
 [11.73852077]
 [11.72608292]
 [11.74502655]
 [11.7464825 ]
 [11.7447086 ]
 [11.7250008 ]
 [11.7345163 ]
 [11.7269449 ]
 [11.73527059]
 [11.72331088]
 [11.74495878]
 [11.74035201]
 [11.73385495]
 [11.74548197]
 [11.73478414]]

Epoch: 5
[[14.34458165]
 [14.34826467]
 [14.34666559]
 [14.33929551]
 [14.34046874]
 [14.34560039]
 [14.34339279]
 [14.34640253]
 [14.34638116]
 [14.32246038]
 [14.34684977]
 [14.34260482]
 [14.34668656]
 [14.34515026]
 [14.33296345]
 [14.3292135 ]
 [14.32220837]
 [14.34345075]
 [14.34149198]
 [14.34416563]
 [14.32150394]
 [14.32700258]
 [14.34613576]
 [14.34672361]
 [14.33901917]
 [14.33260699]
 [14.34669248]
 [14.34753978]
 [14.32006516]
 [14.34722393]
 [14.3418715 ]
 [14.34734165]]
Loss: 270.159913, [   32/303]
[[14.73771753]
 [14.74341829]
 [14.72521597]
 [14.7434824 ]
 [14.74620124]
 [14.73943477]
 [14.73515352]
 [14.74431828]
 [14.72324756]
 [14.74509345]
 [14.7428084 ]
 [14.7407476 ]
 [14.71831595]
 [14.74059678]
 [14.74502176]
 [14.74436254]
 [14.7425071 ]
 [14.73673701]
 [14.72193151]
 [14.74356951]
 [14.74557888]
 [14.74337168]
 [14.72068469]
 [14.7323871 ]
 [14.72289829]
 [14.73246201]
 [14.71910444]
 [14.74352016]
 [14.73849465]
 [14.73021884]
 [14.74502673]
 [14.73199502]]

Epoch: 6
[[16.80230953]
 [16.80601819]
 [16.80410941]
 [16.79620454]
 [16.79885708]
 [16.80257308]
 [16.80126872]
 [16.80368785]
 [16.80412909]
 [16.77759234]
 [16.80444727]
 [16.79976939]
 [16.80389304]
 [16.80239609]
 [16.78844641]
 [16.78581237]
 [16.77653589]
 [16.80112804]
 [16.79927084]
 [16.80204935]
 [16.77570378]
 [16.78223902]
 [16.8039549 ]
 [16.80410273]
 [16.79559308]
 [16.78833445]
 [16.80421035]
 [16.80485637]
 [16.77435836]
 [16.80472752]
 [16.79921574]
 [16.8055774 ]]
Loss: 215.294266, [   32/303]
[[17.14679925]
 [17.15249677]
 [17.13379502]
 [17.15302712]
 [17.15570985]
 [17.1481626 ]
 [17.14340837]
 [17.15396962]
 [17.12995374]
 [17.15485543]
 [17.15203576]
 [17.14960332]
 [17.1242061 ]
 [17.14961799]
 [17.15441937]
 [17.15365128]
 [17.15154354]
 [17.14513753]
 [17.12857543]
 [17.15233793]
 [17.1547837 ]
 [17.15223977]
 [17.12726114]
 [17.14050999]
 [17.12968632]
 [17.14012801]
 [17.12560882]
 [17.15229355]
 [17.14691586]
 [17.13715541]
 [17.1546941 ]
 [17.13968702]]

Epoch: 7
[[18.66819154]
 [18.67183162]
 [18.6696851 ]
 [18.66128526]
 [18.66536958]
 [18.66781464]
 [18.66726387]
 [18.66915983]
 [18.66994394]
 [18.64131269]
 [18.67024575]
 [18.66512588]
 [18.66931585]
 [18.66784943]
 [18.65248861]
 [18.65086956]
 [18.63957278]
 [18.66700642]
 [18.66520142]
 [18.66804767]
 [18.6386321 ]
 [18.64613776]
 [18.66982087]
 [18.66966522]
 [18.66047686]
 [18.65250132]
 [18.66981861]
 [18.6703217 ]
 [18.63744212]
 [18.67033105]
 [18.66476314]
 [18.67184029]]
Loss: 181.702904, [   32/303]
[[18.96543072]
 [18.97104373]
 [18.95221384]
 [18.97197491]
 [18.97464867]
 [18.96640033]
 [18.96130324]
 [18.97299396]
 [18.94665724]
 [18.97392603]
 [18.97072403]
 [18.96799052]
 [18.94015193]
 [18.96814002]
 [18.97323936]
 [18.97238424]
 [18.97005291]
 [18.96309074]
 [18.94523701]
 [18.97069448]
 [18.97344236]
 [18.97067904]
 [18.94391931]
 [18.95822293]
 [18.94650036]
 [18.95753757]
 [18.94210339]
 [18.97062941]
 [18.96491998]
 [18.95397531]
 [18.9737326 ]
 [18.95713022]]

Epoch: 8
[[20.01638669]
 [20.0199125 ]
 [20.01759196]
 [20.00874576]
 [20.01407689]
 [20.01553371]
 [20.01552654]
 [20.01699941]
 [20.01801154]
 [19.98771906]
 [20.01841904]
 [20.01282487]
 [20.01717494]
 [20.01569935]
 [19.99918612]
 [19.9983721 ]
 [19.98540847]
 [20.01522463]
 [20.01343094]
 [20.01631428]
 [19.98436505]
 [19.99272898]
 [20.01791068]
 [20.01763979]
 [20.00778844]
 [19.99922872]
 [20.017724  ]
 [20.01816143]
 [19.98336073]
 [20.01823777]
 [20.01264681]
 [20.02030114]]
Loss: 161.761157, [   32/303]
[[20.27390295]
 [20.27940851]
 [20.26057935]
 [20.28064417]
 [20.28338452]
 [20.27449266]
 [20.26913509]
 [20.28173018]
 [20.25354985]
 [20.2826503 ]
 [20.27919533]
 [20.27622881]
 [20.24638082]
 [20.27645489]
 [20.2818309 ]
 [20.28089296]
 [20.27835763]
 [20.27095076]
 [20.25210651]
 [20.27900554]
 [20.28191707]
 [20.27905846]
 [20.25081906]
 [20.26586324]
 [20.2535119 ]
 [20.26495596]
 [20.2488229 ]
 [20.27888709]
 [20.27280611]
 [20.2609242 ]
 [20.28247485]
 [20.26458905]]

Epoch: 9
[[20.95650979]
 [20.9599039 ]
 [20.9574596 ]
 [20.948206  ]
 [20.95457042]
 [20.95534134]
 [20.95567269]
 [20.95681739]
 [20.95796526]
 [20.92633452]
 [20.95857863]
 [20.95246887]
 [20.95709229]
 [20.95555937]
 [20.93806355]
 [20.93783557]
 [20.92354267]
 [20.95538261]
 [20.95357118]
 [20.95646756]
 [20.92239527]
 [20.9315009 ]
 [20.95786027]
 [20.9576568 ]
 [20.94709609]
 [20.9380685 ]
 [20.9575651 ]
 [20.9580099 ]
 [20.92158671]
 [20.95808398]
 [20.9524621 ]
 [20.96059785]]
Loss: 150.004324, [   32/303]
[[21.18357413]
 [21.18897755]
 [21.17015145]
 [21.19042063]
 [21.19331671]
 [21.18381305]
 [21.17824442]
 [21.19157815]
 [21.1618841 ]
 [21.19244064]
 [21.18882825]
 [21.18568601]
 [21.1541473 ]
 [21.18592598]
 [21.19158805]
 [21.19056298]
 [21.18783453]
 [21.1800876 ]
 [21.16043239]
 [21.18864489]
 [21.19159749]
 [21.18875456]
 [21.15918965]
 [21.17479212]
 [21.16196285]
 [21.1736977 ]
 [21.15703452]
 [21.18844138]
 [21.18192072]
 [21.16928311]
 [21.19231898]
 [21.17337642]]

Epoch: 10
[[21.59630311]
 [21.59956153]
 [21.59703033]
 [21.58739658]
 [21.59460707]
 [21.59495121]
 [21.59545219]
 [21.59634468]
 [21.59756072]
 [21.56481861]
 [21.59845541]
 [21.59179194]
 [21.59679055]
 [21.59515696]
 [21.57678527]
 [21.57695989]
 [21.56161095]
 [21.59521445]
 [21.5933661 ]
 [21.5962578 ]
 [21.56035597]
 [21.570104  ]
 [21.59743128]
 [21.59744477]
 [21.58611109]
 [21.57670706]
 [21.59709195]
 [21.59760291]
 [21.55974358]
 [21.59761823]
 [21.59194246]
 [21.60049675]]
Loss: 143.012378, [   32/303]
[[21.80134758]
 [21.80666355]
 [21.78779603]
 [21.80823264]
 [21.81136425]
 [21.80126411]
 [21.79551367]
 [21.8094715 ]
 [21.778482  ]
 [21.81024341]
 [21.80653951]
 [21.80326551]
 [21.77025841]
 [21.80346539]
 [21.80943297]
 [21.808314  ]
 [21.80539804]
 [21.79739446]
 [21.77703252]
 [21.80649951]
 [21.80939638]
 [21.80665716]
 [21.77583923]
 [21.79189819]
 [21.77867148]
 [21.79062794]
 [21.77355731]
 [21.80618496]
 [21.79915603]
 [21.78588872]
 [21.81019962]
 [21.79035591]

Epoch: 11
[[22.02469977]
 [22.02782382]
 [22.02523152]
 [22.01523582]
 [22.0231475 ]
 [22.02326317]
 [22.02380725]
 [22.02450338]
 [22.02574059]
 [21.99204153]
 [22.02697014]
 [22.01972298]
 [22.0251763 ]
 [22.02340816]
 [22.00422722]
 [22.00466631]
 [21.98846401]
 [22.02364969]
 [22.02175213]
 [22.0246267 ]
 [21.98709781]
 [21.99741201]
 [22.02557227]
 [22.02591395]
 [22.01374953]
 [22.00403607]
 [22.02523775]
 [22.02585811]
 [21.98667841]
 [22.02577282]
 [22.02001898]
 [22.02895118]]
Loss: 138.786858, [   32/303]
[[22.21439409]
 [22.21963777]
 [22.20067703]
 [22.22127119]
 [22.22470058]
 [22.21401105]
 [22.20809628]
 [22.22260132]
 [22.19045859]
 [22.22326005]
 [22.2195086 ]
 [22.2161346 ]
 [22.18180999]
 [22.21625304]
 [22.22254527]
 [22.2213261 ]
 [22.21822553]
 [22.2100264 ]
 [22.18901857]
 [22.21971426]
 [22.22248447]
 [22.21991354]
 [22.18787534]
 [22.20433393]
 [22.19075255]
 [22.20288947]
 [22.18549674]
 [22.2192693 ]
 [22.21167572]
 [22.19786371]
 [22.22330982]
 [22.20266909]

Epoch: 12
[[22.30852224]
 [22.31151473]
 [22.3088789 ]
 [22.29853319]
 [22.30704126]
 [22.30707089]
 [22.30756837]
 [22.30810731]
 [22.30933396]
 [22.27478111]
 [22.31093489]
 [22.30308247]
 [22.30904798]
 [22.30712074]
 [22.28717247]
 [22.28777786]
 [22.27086527]
 [22.30750958]
 [22.30555503]
 [22.30840401]
 [22.26938496]
 [22.28021021]
 [22.30911692]
 [22.30986502]
 [22.296825  ]
 [22.28684916]
 [22.30882224]
 [22.30958222]
 [22.2691547 ]
 [22.30936696]
 [22.30351493]
 [22.31279884]]
Loss: 136.188255, [   32/303]
[[22.48775273]
 [22.49293662]
 [22.47383827]
 [22.49459033]
 [22.49836251]
 [22.48708691]
 [22.48101849]
 [22.49601982]
 [22.46281613]
 [22.49655035]
 [22.49278017]
 [22.48932826]
 [22.45378701]
 [22.48933558]
 [22.49596782]
 [22.49464352]
 [22.49135993]
 [22.48300786]
 [22.4613904 ]
 [22.49330371]
 [22.49589769]
 [22.49354016]
 [22.46029639]
 [22.47712227]
 [22.46320936]
 [22.475502  ]
 [22.45784496]
 [22.49271476]
 [22.48451463]
 [22.47021439]
 [22.49670374]
 [22.47533481]

Epoch: 13
[[22.49529115]
 [22.49815487]
 [22.49548725]
 [22.48479924]
 [22.49383088]
 [22.49387297]
 [22.49426234]
 [22.49467082]
 [22.49586581]
 [22.46052724]
 [22.4978625 ]
 [22.48938999]
 [22.49590749]
 [22.49380412]
 [22.47311508]
 [22.47381936]
 [22.45629439]
 [22.49431458]
 [22.49229811]
 [22.49511586]
 [22.45469787]
 [22.46599582]
 [22.49559351]
 [22.4968013 ]
 [22.4828534 ]
 [22.47264724]
 [22.49536283]
 [22.49628294]
 [22.45465339]
 [22.49591773]
 [22.48995241]
 [22.49957115]]
Loss: 134.564991, [   32/303]
[[22.66747904]
 [22.67261251]
 [22.65334279]
 [22.67425617]
 [22.67840235]
 [22.66654225]
 [22.66032675]
 [22.67579101]
 [22.64158606]
 [22.67618349]
 [22.67241333]
 [22.66889849]
 [22.63220724]
 [22.66877421]
 [22.67575772]
 [22.67432473]
 [22.67085902]
 [22.66238287]
 [22.64017775]
 [22.67330409]
 [22.6756882 ]
 [22.67357468]
 [22.63913187]
 [22.65630615]
 [22.64207461]
 [22.6545079 ]
 [22.63662499]
 [22.67256195]
 [22.66372575]
 [22.64897439]
 [22.67644692]
 [22.65439486]

Epoch: 14
[[22.61768501]
 [22.62042221]
 [22.61773055]
 [22.60670508]
 [22.61621102]
 [22.61633222]
 [22.61657206]
 [22.616868  ]
 [22.61801746]
 [22.58193748]
 [22.62042604]
 [22.61132335]
 [22.61841987]
 [22.61612882]
 [22.59471572]
 [22.59547353]
 [22.57740184]
 [22.61674335]
 [22.6146618 ]
 [22.61744465]
 [22.57568763]
 [22.5874323 ]
 [22.61768558]
 [22.61938873]
 [22.60451044]
 [22.59409536]
 [22.61753524]
 [22.61862917]
 [22.57582638]
 [22.61810067]
 [22.61201124]
 [22.621954  ]]
Loss: 133.537974, [   32/303]
[[22.78516008]
 [22.7902498 ]
 [22.77078445]
 [22.79186293]
 [22.7964043 ]
 [22.78396032]
 [22.77760183]
 [22.79350737]
 [22.75833947]
 [22.79375537]
 [22.78999738]
 [22.78642961]
 [22.74863179]
 [22.78615985]
 [22.79350262]
 [22.79195843]
 [22.78831108]
 [22.77973001]
 [22.75695063]
 [22.79128851]
 [22.79344025]
 [22.79159119]
 [22.75595191]
 [22.77346341]
 [22.75892056]
 [22.77148522]
 [22.75340166]
 [22.79038702]
 [22.78089443]
 [22.76571607]
 [22.79413299]
 [22.77142698]

Epoch: 15
[[22.69770972]
 [22.70032206]
 [22.69761143]
 [22.68625138]
 [22.6961989 ]
 [22.69644375]
 [22.69650647]
 [22.69670174]
 [22.69779655]
 [22.66100351]
 [22.70062778]
 [22.69088791]
 [22.69858175]
 [22.69609523]
 [22.67396837]
 [22.67474953]
 [22.65617446]
 [22.69680207]
 [22.69465339]
 [22.69739893]
 [22.65434151]
 [22.66651572]
 [22.69740224]
 [22.69962454]
 [22.6838002 ]
 [22.67319024]
 [22.69734318]
 [22.69862006]
 [22.6546617 ]
 [22.69791938]
 [22.69169828]
 [22.70195824]]
Loss: 132.881615, [   32/303]
[[22.8620311 ]
 [22.86708171]
 [22.8474037 ]
 [22.86865065]
 [22.87360151]
 [22.86057363]
 [22.85407461]
 [22.87040763]
 [22.83430095]
 [22.87050693]
 [22.86676897]
 [22.86315485]
 [22.82427816]
 [22.86273029]
 [22.87043796]
 [22.86878081]
 [22.86495201]
 [22.85627853]
 [22.8329329 ]
 [22.86848262]
 [22.87038702]
 [22.86881605]
 [22.83198062]
 [22.84982272]
 [22.83497269]
 [22.84766305]
 [22.82939516]
 [22.86741768]
 [22.85725466]
 [22.84166473]
 [22.87100143]
 [22.84766004]

Test Error: 
 MSE: 87.5529, RMSE: 9.3570, MAE: 6.8759, Avg loss: 78.47124765



In [7]:
file_path = './models/boston_model.npz'
mlp.save_model(file_path)

Model saved.


In [8]:
mlp_2 = MLP()
mlp_2.test(test_dataloader)

Test Error: 
 MSE: 551.8305, RMSE: 23.4911, MAE: 21.5857, Avg loss: 498.21535283



In [9]:
mlp_2.load_model(file_path)
mlp_2.test(test_dataloader)

Model weights loaded from ./models/boston_model.npz
Test Error: 
 MSE: 87.5529, RMSE: 9.3570, MAE: 6.8759, Avg loss: 78.47124765

