Skip to content

Commit 6acaef7

Browse files
authored
Merge pull request #40 from devfym/develop
revise LinearRegression
2 parents 0178b9e + adb83af commit 6acaef7

File tree

2 files changed

+135
-63
lines changed

2 files changed

+135
-63
lines changed

src/Regression/LinearRegression.php

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,101 +2,137 @@
22

33
namespace devfym\IntelliPHP\Regression;
44

5+
use devfym\IntelliPHP\Data\DataFrame;
6+
57
class LinearRegression
68
{
79
/**
810
* @var array
911
*/
10-
11-
private $predictors;
12+
private $xColumn;
1213

1314
/**
1415
* @var array
1516
*/
16-
17-
private $outcomes;
17+
private $yColumn;
1818

1919
/**
2020
* @var int
2121
*/
22-
23-
private $sample_size;
22+
private $df;
2423

2524
/**
2625
* @var float
2726
*/
28-
2927
private $slope;
3028

3129
/**
3230
* @var float
3331
*/
34-
3532
private $intercept;
3633

3734
/**
38-
* LinearRegression constructor.
35+
* @var String
3936
*/
37+
private $metric;
4038

39+
/**
40+
* LinearRegression constructor.
41+
*/
4142
public function __construct()
4243
{
43-
$this->predictors = [];
44-
$this->outcomes = [];
45-
$this->sample_size = 0;
44+
$this->xColumn = '';
45+
$this->yColumn = '';
46+
$this->df = new DataFrame();
4647
$this->slope = 0;
4748
$this->intercept = 0;
4849
}
4950

50-
private function getMean($list = []) : float
51+
/**
52+
* @param DataFrame $df
53+
*/
54+
public function setTrain(DataFrame $df) : void
5155
{
52-
$mean = (array_sum($list) / count($list));
53-
54-
return round($mean, 4);
56+
$this->df = $df;
5557
}
5658

57-
public function setTrain($predictors = [], $outcomes = []) : void
59+
/**
60+
* @param string $xColumn
61+
* @param string $yColumn
62+
* @param string $metric
63+
*/
64+
public function model($xColumn = '', $yColumn = '', $metric = '') : void
5865
{
59-
$this->sample_size = count($predictors);
60-
$this->predictors = $predictors;
61-
$this->outcomes = $outcomes;
62-
}
66+
$this->xColumn = $xColumn;
67+
$this->yColumn = $yColumn;
6368

64-
public function model() : void
65-
{
66-
$mx = $this->getMean($this->predictors);
67-
$my = $this->getMean($this->outcomes);
69+
$mx = $this->df->{$xColumn}->mean();
70+
$my = $this->df->{$yColumn}->mean();
6871

6972
$this->slope = $my / $mx;
73+
7074
$this->intercept = $my - ($mx * $this->slope);
71-
}
7275

73-
public function getSlope() : float
74-
{
75-
return round($this->slope, 4);
76+
$this->metric = $metric;
7677
}
7778

78-
public function getIntercept() : float
79+
/**
80+
* @param array $x
81+
* @return array
82+
*/
83+
public function predict($x = []) : array
7984
{
80-
return round($this->intercept, 4);
81-
}
85+
$y = [];
8286

83-
public function predict($p = 0) : float
84-
{
85-
return round(($p * $this->slope) + $this->intercept, 2);
87+
foreach ($x as $p) {
88+
array_push($y, round(($p * $this->slope) + $this->intercept, 2));
89+
}
90+
91+
return $y;
8692
}
8793

88-
public function validate($validation_type, $y_train, $y_test) : float
94+
/**
95+
* @param $y_train
96+
* @param $y_test
97+
* @return float
98+
*/
99+
public function validate($y_train, $y_test) : float
89100
{
90101
$n = count($y_train);
91102
$total_diff = 0;
92103

93-
if ($validation_type == 'mean_squared_error') {
104+
if ($this->metric == 'mean_squared_error') {
94105
for ($i = 0; $i < $n; $i++) {
95106
$total_diff += $y_test[$i] - $y_train[$i];
96107
}
97108
$total_diff /= $n;
98109
}
99110

100-
return $total_diff;
111+
return round($total_diff, 4);
112+
}
113+
114+
public function saveModel() : string
115+
{
116+
$model = [
117+
'type' => 'LinearRegression',
118+
'xColumn' => $this->xColumn,
119+
'yColumn' => $this->yColumn,
120+
'slope' => $this->slope,
121+
'intercept' => $this->intercept,
122+
'metric' => $this->metric
123+
];
124+
125+
return json_encode($model);
126+
}
127+
128+
public function loadModel($model = '') : void
129+
{
130+
$model = json_decode($model);
131+
132+
$this->xColumn = $model->xColumn;
133+
$this->yColumn = $model->yColumn;
134+
$this->slope = $model->slope;
135+
$this->intercept = $model->intercept;
136+
$this->metric = $model->metric;
101137
}
102138
}

tests/Regression/LinearRegressionTest.php

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,84 @@
22

33
namespace devfym\Tests\Regression;
44

5+
use devfym\IntelliPHP\Data\DataFrame;
56
use devfym\IntelliPHP\Regression\LinearRegression;
67
use PHPUnit\Framework\TestCase;
8+
use SebastianBergmann\Diff\Line;
79

810
class LinearRegressionTest extends TestCase
911
{
12+
/**
13+
* Sample Data
14+
*/
15+
protected $data = [
16+
'student_id' => [1, 2, 3, 4, 5],
17+
'name' => ['aaron', 'bambi', 'celine', 'dennise', 'edwin'],
18+
'age' => [14.5, 12.2, 15, 14, 12.2],
19+
'height_cm' => [162, 158, 162, 170, 168],
20+
'weight_kg' => [68, 58, 56, 56, 52],
21+
'gpa' => [1.25, 4.0, 2.75, 4.0, 2.25]
22+
];
23+
24+
/**
25+
* @var DataFrame
26+
*/
27+
protected $df;
28+
29+
/**
30+
* @var string
31+
*/
32+
protected $model;
33+
34+
/**
35+
* LinearRegressionTest constructor.
36+
* @param null|string $name
37+
* @param array $data
38+
* @param string $dataName
39+
*/
40+
public function __construct(?string $name = null, array $data = [], string $dataName = '')
41+
{
42+
parent::__construct($name, $data, $dataName);
43+
44+
$this->df = new DataFrame();
45+
46+
$this->df->readArray($this->data);
47+
}
48+
1049
public function testExample() : void
1150
{
1251
$linear = new LinearRegression();
1352

14-
srand(2019);
53+
$linear->setTrain($this->df);
1554

16-
$x_train = [];
17-
$y_train = [];
55+
$linear->model('height_cm', 'weight_kg', 'mean_squared_error');
1856

19-
for ($i = 0; $i < 1000000; $i++) {
20-
$random_value = round(rand(0, 1000000) / rand(1, 10), 4);
21-
$x_train[$i] = $random_value;
22-
$y_train[$i] = round($random_value + (rand(1, 10) * rand(1, 5)), 4);
23-
}
57+
$y_test = $linear->predict($this->data['height_cm']);
2458

25-
$linear->setTrain($x_train, $y_train);
26-
$linear->model();
59+
$this->assertEquals(-0.002, $linear->validate($this->data['weight_kg'], $y_test));
2760

28-
// Test Slope
29-
$this->assertEquals(1.0001, $linear->getSlope());
61+
$linear->saveModel();
62+
/*
63+
* @return string
64+
* "type":"LinearRegression","xColumn":"height_cm","yColumn":"weight_kg","slope":0.35365853658536583,"intercept":0,"metric":"mean_squared_error"
65+
*
66+
*/
3067

31-
// Test Intercept
32-
$this->assertEquals(0, $linear->getIntercept());
68+
}
3369

34-
// Test Error Rate
35-
for ($n = 0; $n < 100; $n++) {
36-
$y_predict = $linear->predict($x_train[$n]);
37-
$this->assertGreaterThan($y_train[$n] * 0.95, $y_predict);
38-
$this->assertLessThan($y_train[$n] * 1.05, $y_predict);
39-
}
70+
public function testLoadModel() : void
71+
{
72+
$linear = new LinearRegression();
4073

41-
$y_predict = [];
74+
$linear->setTrain($this->df);
4275

43-
for ($n = 0; $n < count($y_train); $n++) {
44-
$y_predict[$n] = $linear->predict($x_train[$n]);
45-
}
76+
$model = '{"type":"LinearRegression","xColumn":"height_cm","yColumn":"weight_kg","slope":0.35365853658536583,"intercept":0,"metric":"mean_squared_error"}';
4677

47-
$this->assertNotEquals(0, $linear->validate('mean_squared_error', $y_train, $y_predict));
78+
$linear->loadModel($model);
79+
80+
$y_test = $linear->predict($this->data['height_cm']);
81+
82+
$this->assertEquals(-0.002, $linear->validate($this->data['weight_kg'], $y_test));
4883
}
84+
4985
}

0 commit comments

Comments
 (0)