Skip to content
This repository has been archived by the owner on Apr 2, 2024. It is now read-only.

Commit

Permalink
Add the training time to the Dashboard API
Browse files Browse the repository at this point in the history
  • Loading branch information
hugolgst committed Mar 21, 2020
1 parent 9e15211 commit 087043c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
28 changes: 21 additions & 7 deletions dashboard/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ import (
var neuralNetwork network.Network

type Dashboard struct {
Layers Layers `json:"layers"`
LearningRate float64 `json:"learning_rate"`
ErrorLoss float64 `json:"error_loss"`
Layers Layers `json:"layers"`
Training Training `json:"training"`
}

type Layers struct {
Expand All @@ -23,6 +22,12 @@ type Layers struct {
OutputNodes int `json:"output"`
}

type Training struct {
Rate float64 `json:"rate"`
Error float64 `json:"error"`
Time float64 `json:"time"`
}

// Serve serves the dashboard REST API on the port 8081 by default.
func Serve(_neuralNetwork network.Network) {
// Set the current global network as a global variable
Expand All @@ -41,9 +46,8 @@ func GetDashboardData(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")

dashboard := Dashboard{
Layers: GetLayers(),
LearningRate: neuralNetwork.Rate,
ErrorLoss: neuralNetwork.Error,
Layers: GetLayers(),
Training: GetTraining(),
}

err := json.NewEncoder(w).Encode(dashboard)
Expand All @@ -52,7 +56,7 @@ func GetDashboardData(w http.ResponseWriter, _ *http.Request) {
}
}

// GetLayers returns the number of input, hidden and output layers
// GetLayers returns the number of input, hidden and output layers of the network
func GetLayers() Layers {
return Layers{
// Get the number of rows of the first layer to get the count of input nodes
Expand All @@ -63,3 +67,13 @@ func GetLayers() Layers {
OutputNodes: network.Columns(neuralNetwork.Output),
}
}

// GetTraining returns the learning rate, training time and error loss for the network
func GetTraining() Training {
// Retrieve the information from the neural network
return Training{
Rate: neuralNetwork.Rate,
Error: neuralNetwork.Error,
Time: neuralNetwork.Time,
}
}
11 changes: 11 additions & 0 deletions network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package network
import (
"encoding/json"
"fmt"
"math"
"os"
"strconv"
"time"

"github.com/gookit/color"
"gopkg.in/cheggaaa/pb.v1"
Expand All @@ -19,6 +21,7 @@ type Network struct {
Output Matrix
Rate float64
Error float64
Time float64
}

// LoadNetwork returns a Network from a specified file
Expand Down Expand Up @@ -152,6 +155,9 @@ func (network *Network) ComputeError() float64 {
// Train trains the neural network with a given number of iterations by executing
// forward and back propagation
func (network *Network) Train(iterations int) {
// Initialize the start time
start := time.Now()

// Create the progress bar
bar := pb.New(iterations).Postfix(fmt.Sprintf(
" - %s",
Expand Down Expand Up @@ -179,5 +185,10 @@ func (network *Network) Train(iterations int) {
errorLoss, _ := strconv.ParseFloat(arrangedError, 5)
network.Error = errorLoss

// Calculate elapsed time
elapsed := time.Since(start)
// Round the elapsed time at two decimals
network.Time = math.Floor(elapsed.Seconds()*100) / 100

fmt.Printf("The error rate is %s.\n", color.FgGreen.Render(arrangedError))
}
2 changes: 1 addition & 1 deletion res/training.json

Large diffs are not rendered by default.

0 comments on commit 087043c

Please sign in to comment.