# Task F

In [3]:
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import train_test_split

from models import *
from utils import * 

# Define parameters
degree = 5
epochs = 100 
lr = 0.001
batch_size = 10

# Create data
x = np.arange(0, 1, 0.05)
y = np.arange(0, 1, 0.05)
x, y = np.meshgrid(x, y)

# Compute z
z = FrankeFunction(x, y)

# Create features as pairs of (x, y)
features = np.stack([x.ravel(), y.ravel()], axis=1)
outputs = z.ravel()

# Split dataset into train and test set
x_train, x_test, y_train, y_test = train_test_split(features, outputs, test_size=0.2, random_state=20)

# Get polynomial features
poly = PolynomialFeatures(degree, include_bias=True)
x_train = poly.fit_transform(x_train)
x_test = poly.transform(x_test)

model = LinearRegression(dimension=x_train.shape[1], random_init=True)

# Iterate over epochs
for _ in range(epochs):

    # Generate batches
    x_batches, y_batches = generate_batches(x_train, y_train, batch_size)

    # Iterate through batches
    for x, y in zip(x_batches, y_batches):
        
        # Determine the gradient for this batch
        gradient = model.gradient(x, y)
        
        # Update parameters
        model.beta = model.beta - lr * gradient

    # Print current error
    y_hat = model.predict(x_test)
    mse = MSE(y_test, y_hat)
    


0.1267984960478647
0.028124426724559025
0.02192347939935015
0.020382948220870615
0.017955140324206968
0.01635995424946327
0.015861150908451107
0.015007863538615792
0.01437643717438633
0.013509710692137234
0.013352681152489863
0.01261390870858578
0.012743480470113177
0.01234340976774524
0.013344432089651973
0.012459560455439968
0.012128214078153546
0.011959658189318762
0.011560926007599873
0.011672501670189083
0.011756439851340167
0.012022844833803264
0.01139906929287431
0.011420975459521787
0.011537762436803474
0.011562053292307856
0.011564670097975753
0.011666616217593411
0.011518917329146472
0.011564548731657664
0.011252363288363574
0.011395799355575926
0.010977469225851248
0.011117871347978429
0.010939253710080357
0.01139498076168977
0.011355894692787055
0.010892851492393638
0.010744540946631114
0.011157331210170674
0.01058223356468691
0.011484159488614585
0.011097276787390512
0.010564268757583355
0.01047643303681801
0.01053626394986833
0.010699667318727218
0.010536303481512643
0.01