In [28]:
import numpy as np
from numpy.random import rand, randint, uniform

# Logistic Regression

Logistic regession is a type of linear classification model. The idea is to compute the probability of a (usually) binary outcome by calculating a linear combination of different input values (= features). The linear combination is commonly called the logit and is passend to a sigmoid function. Note that while the decision boundary of logistic regression is linear in the feature space, the relationship between feature values and predicted probabilities is nonlinear.

$P(y = 1 | x) = \frac{1}{ 1 + e^{-(xw^\top + b)}}$

where $w$ and $b$ are model parameters.

## Create Dataset

In [19]:
X = rand(150, 4)
y = randint(0, 2, 150)

## Implement Model

In [66]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def prob(x, w, b):
    return sigmoid(np.dot(x, w) + b)

As a loss function, we aim to use the cross entropy which is defined as

$H(y, \hat{y}) = y \log(\frac{1}{\hat{y}}) + (1 - y) \log(\frac{1}{1 - \hat{y}})$

In [76]:
def cross_entropy(y_true, y_pred):
    return (y * np.log(1/y_pred) + (1 - y) * np.log(1 / (1 - y_pred)))

In [108]:
class LogisticRegression:

    def __init__(self, n_epochs: int):
        self._n_epochs = n_epochs

    def fit(self, X, y, lr=0.01):
        n_samples = X.shape[0]
        self.w = np.ones(X.shape[1])
        self.b = 1

        for i in range(self._n_epochs):
            y_pred = prob(X, self.w, self.b)
    
            dw = (1 / n_samples) * np.dot(X.T, (y_pred - y))
            db = (1 / n_samples) * np.sum(y_pred - y)
    
            self.w -= lr * dw
            self.b -= lr * db

        loss_after = cross_entropy(y, y_pred).mean()
        print(loss_after)            

        return self

    def predict_proba(self, X):
        return prob(X, self.w, self.b)

In [112]:
model = LogisticRegression(10000).fit(X, y)

0.689849512203447


In [113]:
model.predict_proba(X)

array([0.46174205, 0.52484484, 0.5015246 , 0.47113937, 0.45112005,
       0.4791733 , 0.45705089, 0.49661943, 0.47141215, 0.44044461,
       0.48584784, 0.51228096, 0.49923372, 0.46421701, 0.47198672,
       0.45020637, 0.46886376, 0.50017205, 0.46367617, 0.51317321,
       0.44804543, 0.45200138, 0.4587059 , 0.4602901 , 0.44794341,
       0.46247757, 0.4898743 , 0.48428779, 0.45316083, 0.46417689,
       0.52173621, 0.46813287, 0.48709019, 0.44860636, 0.4496718 ,
       0.50139973, 0.47359144, 0.48143366, 0.48213527, 0.48286214,
       0.50517267, 0.43954398, 0.46727843, 0.42892573, 0.46618513,
       0.45061474, 0.50812395, 0.50117489, 0.50079967, 0.48154921,
       0.49441695, 0.49724913, 0.46577479, 0.49027358, 0.49094945,
       0.46214863, 0.51078777, 0.48081273, 0.45857779, 0.45345474,
       0.47217172, 0.45803397, 0.51134581, 0.44615603, 0.45609825,
       0.45655351, 0.44792969, 0.46131672, 0.48099104, 0.43700109,
       0.4719436 , 0.45975328, 0.4662103 , 0.49611247, 0.48739