In [1]:
import numpy as np

X = np.random.rand(10000,5)
m = np.random.randint(low = 1, high = 20,size = (5,1))  #parametri random tra low e high
q = np.random.rand(1)
y = (X @ m) + q 

noise = np.random.randn(y.shape[0], y.shape[1])
y = y + noise

X.shape, m.shape, q.shape, y.shape

((10000, 5), (5, 1), (1,), (10000, 1))

In [2]:
X = np.concatenate([X , np.ones((X.shape[0],1))], axis = 1)
m = np.concatenate([m,q.reshape(1,-1)],axis = 0)

In [3]:
def partial_derivative(X_batch, y_batch, m_stat):

  y_pred = X_batch @ m_stat
  n = len(X_batch)

  df_dm =  (-2/n) * (X_batch.T @ (y_batch - y_pred))
  df_dm = df_dm.reshape(len(df_dm),-1)

  return df_dm

In [4]:
def mean_squared_error(X,y,m_stat):
  y_pred = X @ m_stat
  mse = np.sum(((y_pred - y)**2),axis = 0) / len(X)
  
  return mse

In [5]:
def training(X, y, batch_size, lr, epochs):
  
  for epoch in range(epochs):

    #random initial statistics
    if epoch == 0:
      m_stat = np.random.rand(X.shape[1],1)

    #shuffle X and y using same permutation
    indices = np.arange(X.shape[0])
    np.random.shuffle(indices)

    X = X[indices]
    y = y[indices]

    #store comulative derivative
    cumulative_derivative = np.zeros((X.shape[1],1))

    for batch in range(len(X)//batch_size):
      start = batch*batch_size
      stop = (batch*batch_size) + batch_size

      X_batch = X[start:stop]
      y_batch = y[start:stop]
      
      #derivative
      cumulative_derivative = cumulative_derivative + partial_derivative(X_batch, y_batch, m_stat)

      #updating rule
      m_stat = m_stat - (lr*cumulative_derivative)
    
    print(f"epoch: {epoch} ----> MSE: {mean_squared_error(X,y,m_stat)}")
      
  return m_stat

In [6]:
batch_size = 1024
lr = 0.01
epochs = 500

m_stat = training(X,y, batch_size,lr,epochs)

epoch: 0 ----> MSE: [122.86222286]
epoch: 1 ----> MSE: [38.18675133]
epoch: 2 ----> MSE: [19.10298882]
epoch: 3 ----> MSE: [13.89025953]
epoch: 4 ----> MSE: [11.86348889]
epoch: 5 ----> MSE: [10.53663454]
epoch: 6 ----> MSE: [9.46984912]
epoch: 7 ----> MSE: [8.57656662]
epoch: 8 ----> MSE: [7.79019504]
epoch: 9 ----> MSE: [7.10216248]
epoch: 10 ----> MSE: [6.48409789]
epoch: 11 ----> MSE: [5.94352933]
epoch: 12 ----> MSE: [5.46686768]
epoch: 13 ----> MSE: [5.04491259]
epoch: 14 ----> MSE: [4.67214305]
epoch: 15 ----> MSE: [4.34021642]
epoch: 16 ----> MSE: [4.03634788]
epoch: 17 ----> MSE: [3.76978775]
epoch: 18 ----> MSE: [3.5289234]
epoch: 19 ----> MSE: [3.31236866]
epoch: 20 ----> MSE: [3.11615417]
epoch: 21 ----> MSE: [2.93958502]
epoch: 22 ----> MSE: [2.78049688]
epoch: 23 ----> MSE: [2.63941138]
epoch: 24 ----> MSE: [2.5108445]
epoch: 25 ----> MSE: [2.39377481]
epoch: 26 ----> MSE: [2.28668959]
epoch: 27 ----> MSE: [2.18982116]
epoch: 28 ----> MSE: [2.10093945]
epoch: 29 ----> MSE

epoch: 296 ----> MSE: [1.00651167]
epoch: 297 ----> MSE: [1.00651101]
epoch: 298 ----> MSE: [1.00651932]
epoch: 299 ----> MSE: [1.00651247]
epoch: 300 ----> MSE: [1.00657837]
epoch: 301 ----> MSE: [1.00651819]
epoch: 302 ----> MSE: [1.00670537]
epoch: 303 ----> MSE: [1.00659083]
epoch: 304 ----> MSE: [1.00650936]
epoch: 305 ----> MSE: [1.00656312]
epoch: 306 ----> MSE: [1.0066982]
epoch: 307 ----> MSE: [1.0065013]
epoch: 308 ----> MSE: [1.00663249]
epoch: 309 ----> MSE: [1.00658412]
epoch: 310 ----> MSE: [1.00653608]
epoch: 311 ----> MSE: [1.00662293]
epoch: 312 ----> MSE: [1.00656408]
epoch: 313 ----> MSE: [1.00655478]
epoch: 314 ----> MSE: [1.00650316]
epoch: 315 ----> MSE: [1.00651828]
epoch: 316 ----> MSE: [1.0065302]
epoch: 317 ----> MSE: [1.00668333]
epoch: 318 ----> MSE: [1.00653633]
epoch: 319 ----> MSE: [1.00662479]
epoch: 320 ----> MSE: [1.0066602]
epoch: 321 ----> MSE: [1.00651463]
epoch: 322 ----> MSE: [1.00659126]
epoch: 323 ----> MSE: [1.00652216]
epoch: 324 ----> MSE: [1

In [7]:
print(m_stat,"\n")
print(m)

X_test = np.random.rand(500,5)
X_test = np.concatenate([X_test,np.ones(shape = (500,1))] , axis = 1)
y_test = X_test @ m

y_preds = X_test @ m_stat
mse = mean_squared_error(X_test, y_test, m_stat)

print("mse" , mse)

print(y_test[:5])
print(y_preds[:5])

[[ 7.97757429]
 [ 2.96965385]
 [17.02938106]
 [13.95256803]
 [ 4.01255622]
 [ 0.06539434]] 

[[ 8.       ]
 [ 3.       ]
 [17.       ]
 [14.       ]
 [ 4.       ]
 [ 0.0260738]]
mse [0.00052078]
[[14.17020469]
 [12.15764071]
 [20.86863596]
 [31.55506619]
 [22.64320689]]
[[14.2041179 ]
 [12.16685685]
 [20.87678543]
 [31.54678972]
 [22.62908573]]
