# Classification with linear models

In [40]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import sklearn as sk
from sklearn import datasets
from sklearn.model_selection import train_test_split

# Regression

In [41]:
diabetes = sk.datasets.load_diabetes()
#both numpy arrays
#feature matrix
data = diabetes.data
#one dimensional numpy array
target = diabetes.target

In [42]:
print(f"Number of data points: {data.shape[0]}\nNumber of features: {data.shape[1]}")

Number of data points: 442
Number of features: 10


In [43]:
diabetes.feature_names

['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']

In [44]:
#Lets make train test split (80% train, 20% test)
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=42)

In [45]:
def initialize_parameters(X):
    #create column vector of ones
    weights = np.ones(X.shape[1]).reshape(-1,1)
    bias = 0.0
    return weights, bias

In [46]:
def compute_predictions(X, weights, bias):
    #linear regression is as follows, y = Xw + b
    y = X @ weights + bias
    return y

In [47]:
w, b = initialize_parameters(X_train)
y_pred = compute_predictions(X_train, w, b)

print(f"Sanity checking y dimensions, should be ({X_train.shape[0]}, 1): {y_pred.shape}")
print(f"Y prediction vector: {y_pred}")


array([[ 2.94364608e-01],
       [-1.61538251e-04],
       [ 7.08786361e-02],
       [-4.40872738e-01],
       [-4.03884100e-01],
       [-4.45705710e-01],
       [ 3.15682654e-02],
       [ 3.47144087e-01],
       [-8.94502307e-02],
       [-2.68210507e-01],
       [-1.69220533e-01],
       [ 1.70663118e-01],
       [-3.89581204e-01],
       [ 2.14891613e-01],
       [-3.59112230e-01],
       [ 1.85784650e-01],
       [-5.18393379e-01],
       [-1.35447447e-02],
       [ 8.38341960e-02],
       [ 2.83971411e-01],
       [-1.73815267e-01],
       [ 2.22397777e-02],
       [ 1.57380392e-01],
       [ 2.31576856e-01],
       [ 2.98887835e-01],
       [ 1.79538141e-01],
       [-1.33177387e-01],
       [-2.53868917e-01],
       [-1.15754610e-01],
       [-4.14282630e-01],
       [-4.35780579e-03],
       [-1.37693586e-01],
       [-2.32303841e-01],
       [-1.72309017e-01],
       [ 1.60901659e-01],
       [-2.48718359e-02],
       [ 1.24185267e-01],
       [ 2.71684934e-02],
       [-2.9