## Diabetes Progression

Predicts diabetes progression<br>
Last update: 1/1/24

In [176]:
import numpy as np
from sandbox import activations, costs, initializers, layers, model, optimizers, utils
from sklearn import datasets

In [178]:
# Load iris dataset
dataset = datasets.load_diabetes()
x = np.asarray(dataset.data)
y = np.asarray(dataset.target).reshape(-1, 1)

# Normalize, shuffle, and split data
mean_x = np.mean(x, axis=0)
std_x = np.std(x, axis=0)
x = (x - mean_x) / std_x

mean_y = np.mean(y, axis=0)
std_y = np.std(y, axis=0)
y = (y - mean_y) / std_y

x, y = utils.shuffle(x, y)
(train_x, train_y), (test_x, test_y) = utils.train_test_split(x, y, test_size=0.2)

In [200]:
# Create model
diabetes = model.Model()
diabetes.add(layers.Dense(units=12, activation=activations.ReLU()))
diabetes.add(layers.Dense(units=12, activation=activations.ReLU()))
diabetes.add(layers.Dense(units=1, activation=activations.Linear()))

diabetes.configure(
    input_size=train_x.shape[1],
    cost_type=costs.MSE(),
    optimizer=optimizers.Adam()
)

# Train model
diabetes.train(train_x, train_y, learning_rate=0.005, epochs=100, batch_size=64, verbose=True)

Cost on epoch 10: 0.44136
Cost on epoch 20: 0.38863
Cost on epoch 30: 0.40184
Cost on epoch 40: 0.27668
Cost on epoch 50: 0.37374
Cost on epoch 60: 0.31526
Cost on epoch 70: 0.35531
Cost on epoch 80: 0.26697
Cost on epoch 90: 0.20937
Cost on epoch 100: 0.25532


In [201]:
# Get model prediction and loss on test data
pred = diabetes.predict(test_x)
loss = costs.MSE().forward(pred, test_y)
print('Test MSE loss: ', loss)

# Note - because the labels were normalized, model outputs must be adjusted before being used, as so:
pred_actual = pred * std_y + mean_y

print('\nPredicted:\n', np.squeeze(np.round(pred_actual.T)))
print('Actual:\n', np.squeeze(test_y.T * std_y + mean_y))

Test MSE loss:  0.6618107755526453

Predicted:
 [181.  64. 180.  86. 110. 113. 104.  88. 184. 267. 172. 211.  95.  85.
 196. 149.  98. 154. 189. 222. 248. 212. 138. 118. 149.  97.  82.  56.
  82. 214.  77.  85. 151. 240.  78. 227. 151. 144. 100. 118.  80.  83.
 199. 198. 243.  79.  70.  91. 103. 230. 225. 146. 243. 108. 160. 187.
 181. 112. 116. 109. 169.  96. 266. 102. 310. 300.  68. 131.  89.  76.
 178. 134. 215.  76.  89. 278. 151. 172. 140. 226. 244.  91. 243. 259.
  98.  62. 145. 149.  88.]
Actual:
 [115. 181. 131.  53. 161.  83.  75.  72. 233. 252. 134. 173.  60.  93.
 109. 229. 170.  84. 180. 197. 221. 156. 241. 113. 172. 163.  77.  89.
  87. 296.  53.  71.  92. 233. 179. 110. 174. 109.  51. 144.  48. 201.
 164. 124. 268. 127. 135.  94. 160. 281. 265. 249. 303.  44. 131. 261.
  90.  80.  70. 181.  63.  63. 277. 187. 308. 243. 199.  92.  65.  60.
  91. 259. 258. 135.  90. 274.  85. 217.  68. 346. 163.  97.  84. 206.
  85. 116.  52.  55.  48.]
