In [None]:
import numpy as np

def sigmoid(z):
    return 1 / (1 + np.exp(-z))
    
def binary_cross_entropy(y, y_pred):
    m = y.shape[0]
    epsilon = 1e-8 
    return -1/m * np.sum(
        y * np.log(y_pred + epsilon) + (1 - y) * np.log(1 - y_pred + epsilon)
    )

class LogisticRegression:
    """
    Supports both binary and multiclass classification.
    """
    def __init__(self, lr=0.001, tol=0.0001, max_iter=1000):
        self.lr = lr
        self.tol = tol
        self.max_iter = max_iter
        self.W = None
        self.b = None

    def fit_binary(self, X, y):
        #     # shape of X: (n, m)
        #     # n - observations
        #     # m - features

        #     # shape of y: (n,)
        #     # there are 2 distinct classes: 0 and 1
        X = np.array(X)
        y = np.array(y)
        m, n = X.shape  
        
        w = np.random.randn(n, 1)
        b = 0
        
        y = y.reshape(m, 1)
        for _ in range(self.max_iter):
            z = X @ w + b
            y_pred = sigmoid(z)
            dw = (1/m) * (X.T @ (y_pred - y))
            db = (1/m) * np.sum(y_pred - y)
            w -= self.lr * dw
            b -= self.lr * db
            if np.linalg.norm(dw) < self.tol and abs(db) < self.tol:
                break
        return w, b
    

    def fit(self, X, y):
        # shape of X: (n, m)
        # n - observations
        # m - features
        X = np.array(X)
        y = np.array(y)

        # shape of y: (n,)
        # there are c distinct classes

        self.W = []
        self.b = []

        classes = np.unique(y)
        self.n_classes = len(classes)

        for class_ in classes:
            y_binary = y == class_
            w, b = self.fit_binary(X, y_binary)
            self.W.append(w)
            self.b.append(b)


    def predict_binary(self, X, w, b):
        logits = X @ w + b
        probs = sigmoid(logits)
        return probs

    def predict(self, X):
        n = X.shape[0]
        all_probs = []
        for w,b in zip(self.W, self.b): # type: ignore
            prob = self.predict_binary(X, w, b)
            all_probs.append(prob)
        return np.stack(all_probs, axis=1).reshape(n, self.n_classes).argmax(axis=1)

log_reg = LogisticRegression()
log_reg.fit(X_train, y_train)