Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
363 lines (307 sloc) 9.27 KB
#~
Translsated from Tariq Rashid book
Make Your Own Neural Network
~#
use System.IO.File;
use System.Matrix;
use System.Time;
use Collection.Generic;
class NeuralNetwork {
@input_nodes : Float;
@hidden_nodes : Float;
@output_nodes : Float;
@learning_rate : Float;
@weight_inputs_hidden : Float[,];
@weight_outputs_hidden : Float[,];
function : Main(args : String[]) ~ Nil {
input_nodes := 784;
hidden_nodes := 200;
output_nodes := 10;
learning_rate := 0.2;
epoch := 7;
value := IntHolder->New();
if(args->Size() = 2) {
timer := Timer->New();
timer->Start();
inputs := LoadInput(args[0]);
size := inputs->Size();
"[Training with {$size} sample(s)]"->PrintLine();
count := 0;
max := epoch * size;
network := NeuralNetwork->New(input_nodes, hidden_nodes, output_nodes, learning_rate);
for(i := 0; i < 7; i += 1;) {
each(i : inputs) {
train_input := ParseInput(inputs->Get(i), value);
number := value->Get();
train_target := GetTrainingTarget(number);
network->Train(train_input, train_target, learning_rate);
# show progress
network->ShowProgress(count, max);
count += 1;
};
};
'\n'->Print();
network->Store();
timer->End();
elapsed := timer->GetElapsedTime();
"[Time: {$elapsed} second(s)]"->PrintLine();
"---"->PrintLine();
timer->Start();
inputs := LoadInput(args[1]);
size := inputs->Size();
"[Testing {$size} input(s)]"->PrintLine();
correct_count := 0;
each(i : inputs) {
test_input := ParseInput(inputs->Get(i), value);
test_output := network->Query(test_input);
if(TestQuery(test_output, value->Get(), false)) {
correct_count += 1;
};
};
timer->End();
elapsed := timer->GetElapsedTime();
raw_correct := 100.0 * correct_count->As(Float) / size->As(Float);
"\n[Time: {$elapsed} second(s), correct: {$raw_correct}%]"->PrintLine();
}
else if(args->Size() = 1) {
network := NeuralNetwork->New(learning_rate);
inputs := LoadInput(args[0]);
size := inputs->Size();
"---"->PrintLine();
"[Testing {$size} input(s)]"->PrintLine();
timer := Timer->New();
timer->Start();
correct_count := 0;
each(i : inputs) {
test_input := ParseInput(inputs->Get(i), value);
test_output := network->Query(test_input);
if(TestQuery(test_output, value->Get())) {
correct_count += 1;
};
};
timer->End();
elapsed := timer->GetElapsedTime();
raw_correct := 100.0 * correct_count->As(Float) / size->As(Float);
correct := raw_correct->As(Int);
"\n[Time: {$elapsed} second(s), correct: {$correct}%]"->PrintLine();
};
}
New(learning_rate : Float) {
@learning_rate := learning_rate;
Load();
}
New(input_nodes : Float, hidden_nodes : Float, output_nodes : Float, learning_rate : Float) {
@input_nodes := input_nodes;
@hidden_nodes := hidden_nodes;
@output_nodes := output_nodes;
@learning_rate := learning_rate;
@weight_inputs_hidden := Matrix2D->RandomNormal(0.01, Float->Power(@input_nodes, -0.5),
@hidden_nodes, @input_nodes);
@weight_outputs_hidden := Matrix2D->RandomNormal(0.01, Float->Power(@input_nodes, -0.5),
@output_nodes, @hidden_nodes);
}
method : Query(inputs : Float[,]) ~ Float[,] {
# calculate signals into hidden layer
hidden_outputs := Matrix2D->DotSigmoid(@weight_inputs_hidden, inputs);
# calculate the signals emerging from final output layer
return Matrix2D->DotSigmoid(@weight_outputs_hidden, hidden_outputs);
}
method : Train(inputs : Float[,], targets : Float[,], rate : Float) ~ Nil {
# calculate signals into hidden layer
hidden_outputs := Matrix2D->DotSigmoid(@weight_inputs_hidden, inputs);
# calculate signals into final output layer
final_outputs := Matrix2D->DotSigmoid(@weight_outputs_hidden, hidden_outputs);
# output layer error is the (target - actual)
output_errors := Matrix2D->Subtract(targets, final_outputs);
# hidden layer error is the output_errors, split by weights, recombined at hidden nodes
hidden_errors := Matrix2D->Dot(Matrix2D->Transpose(@weight_outputs_hidden), output_errors);
# update the weights for the links between the input and hidden layers
@weight_inputs_hidden := Matrix2D->Add(@weight_inputs_hidden, Adjust(rate, hidden_errors, hidden_outputs, inputs));
# update the weights for the links between the hidden and output layers
@weight_outputs_hidden := Matrix2D->Add(@weight_outputs_hidden, Adjust(rate, output_errors, final_outputs, hidden_outputs));
}
method : Adjust(rate : Float, errors : Float[,], outputs : Float[,], inputs : Float[,]) ~ Float[,] {
return Matrix2D->Multiple(rate, Matrix2D->Dot(Matrix2D->Multiple(errors, Matrix2D->Multiple(outputs, Matrix2D->Subtract(0.99, outputs))), Matrix2D->Transpose(inputs)));
}
method : Store() ~ Nil {
writer := FileWriter->New("train.dat");
leaving {
writer->Close();
};
dims := @weight_inputs_hidden->Size();
rows := dims[0];
cols := dims[1];
writer->WriteString("{$rows},{$cols}\n");
for(i := 0; i < rows; i += 1;) {
for(j := 0; j < cols; j += 1;) {
value := @weight_inputs_hidden[i,j]->ToString();
writer->WriteString("{$value}");
if(j + 1 < cols) {
writer->WriteString(",");
};
};
writer->WriteString("\n");
};
"[Stored input weights: {$rows}x{$cols}]"->PrintLine();
# ---
dims := @weight_outputs_hidden->Size();
rows := dims[0];
cols := dims[1];
writer->WriteString("{$rows},{$cols}\n");
for(i := 0; i < rows; i += 1;) {
for(j := 0; j < cols; j += 1;) {
value := @weight_outputs_hidden[i,j]->ToString();
writer->WriteString("{$value}");
if(j + 1 < cols) {
writer->WriteString(",");
};
};
writer->WriteString("\n");
};
"[Stored output weights: {$rows}x{$cols}]"->PrintLine();
}
method : Load() ~ Nil {
lines := LoadInput("train.dat");
line := lines->Get(0);
dims := line->Split(",");
rows := dims[0]->ToInt();
cols := dims[1]->ToInt();
"Loading input weights: {$rows}, {$cols}"->PrintLine();
i : Int;
@weight_inputs_hidden := Float->New[rows, cols];
for(i := 1; i <= rows; i += 1;) {
line := lines->Get(i);
col := line->Split(",");
if(col->Size() <> cols) {
"Invalid Row!"->ErrorLine();
Runtime->Exit(1);
};
for(j := 0; j < cols; j += 1;) {
@weight_inputs_hidden[i - 1, j] := col[j]->ToFloat();
};
};
line := lines->Get(i);
dims := line->Split(",");
old_row := rows;
rows := dims[0]->ToInt();
cols := dims[1]->ToInt();
"Loading output weights: {$rows}, {$cols}"->PrintLine();
@weight_outputs_hidden := Float->New[rows, cols];
for(i := i + 1; i <= rows + old_row + 1; i += 1;) {
line := lines->Get(i);
col := line->Split(",");
if(col->Size() <> cols) {
error := col->Size();
"Invalid row size: {$error}!"->ErrorLine();
Runtime->Exit(1);
};
for(j := 0; j < cols; j += 1;) {
@weight_outputs_hidden[i - 2 - old_row, j] := col[j]->ToFloat();
};
};
}
# ---
function : TestQuery(test : Float[,], expect : Int, verbose : Bool := true) ~ Bool {
result := LargestIndex(test);
if(verbose) {
"---"->PrintLine();
expect->PrintLine();
"---"->PrintLine();
rows := test->Size();
for(i := 0; i < rows[0]; i += 1;) {
probability := test[i,0];
if(i = expect) {
"{$i}:\t*[{$probability}]"->PrintLine();
}
else if(i = result) {
"{$i}:\t-[{$probability}]"->PrintLine();
}
else {
"{$i}:\t [{$probability}]"->PrintLine();
};
};
}
else {
if(result = expect) {
"{$expect},"->Print();
}
else {
"{$expect}*,"->Print();
};
};
return result = expect;
}
function : native : GetTrainingTarget(number : Int) ~ Float[,] {
if(number < 0 | number > 9) {
return Nil;
};
targets := Float->New[10, 1];
for(j := 0; j < 10; j += 1;) {
targets[j, 0] := 0.01;
};
targets[number, 0] := 0.99;
return targets;
}
function : native : LargestIndex(b : Float[,]) ~ Int {
test := 0.0;
index := 0;
b_dims := b->Size();
for(i := 0; i < b_dims[0]; i += 1;) {
value := b[i,0];
if(value > test) {
test := value;
index := i;
};
};
return index;
}
function : ParseInput(line : String, value : IntHolder) ~ Float[,] {
matrix := Float->New[784, 1];
row := -1; col := 0;
values := line->Split(",");
index := values[0];
value->Set(index->ToInt());
for(i := 1; i < values->Size(); i += 1;) {
matrix[i - 1, 0] := values[i]->ToFloat() / 255.0 * 0.99 + 0.01;
};
return matrix;
}
function : LoadInput(file : String) ~ Vector<String> {
inputs := Vector->New()<String>;
reader := FileReader->New(file);
leaving {
reader->Close();
};
buffer_size := 4096 * 4;
buffer := Char->New[buffer_size + 1];
line := "";
while(<>reader->IsEOF()) {
read := reader->ReadBuffer(0, buffer_size, buffer);
for(i := 0; i < read; i += 1;) {
if(buffer[i] = '\n') {
inputs->AddBack(line);
line := "";
}
else if(buffer[i] <> '\r') {
line += buffer[i];
};
};
};
return inputs;
}
function : ShowProgress(count : Int, max : Int) ~ Nil {
done := ((count->As(Float) / max->As(Float)) * 100.0)->As(Int);
if(done < 100 & done > 9) {
"\b{$done}%\b\b"->Print();
System.IO.Console->Flush();
}
else if(done >= 1) {
"\b{$done}%\b"->Print();
System.IO.Console->Flush();
}
else {
"\b0%\b"->Print();
System.IO.Console->Flush();
};
}
}
You can’t perform that action at this time.