![part1](part1.jpg)

![part2](part2.jpg)

In [2]:
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 29 14:48:17 2023

@author: Shree
"""

import numpy as np
from scipy.stats import gamma

def generate_data(n_features, n_points):
    alphas = np.random.uniform(0, 10, n_points)
    betas = np.random.uniform(0, 10, n_points)
    
    X = np.zeros((n_points, n_features))
    for i, (alpha, beta) in enumerate(zip(alphas, betas)):
        X[i] = gamma.rvs(alpha, scale=beta, size=n_features)
    
    return X, alphas, betas

def descent(a, b, c, d, X, alphas, betas, lr=1e-4):
    diff1 = ((a@X.T + b@np.square(X.T)) - alphas)
    diff2 = ((c@X.T + d@np.square(X.T)) - betas)
    
    new_a = a - lr * (diff1*X.T).mean(1)
    new_b = b - lr * (diff1*np.square(X.T)).mean(1)
    new_c = c - lr * (diff2*X.T).mean(1)
    new_d = d - lr * (diff2*np.square(X.T)).mean(1)
    
    # print(new_a, new_b, new_c, new_d)
    
    return new_a, new_b, new_c, new_d

def predict(a, b, c, d, X):
    return a@X.T + b@np.square(X).T, c@X.T + d@np.square(X).T

def error(a, b, c, d, X, true_alpha, true_beta):
    alpha, beta = predict(a, b, c, d, X)
    return abs(alpha-true_alpha).mean(), abs(beta-true_beta).mean()

In [6]:
n_features = 10
n_points = 1000

a = np.random.normal(0, 1, (1, n_features))
b = np.random.normal(0, 1, (1, n_features))
c = np.random.normal(0, 1, (1, n_features))
d = np.random.normal(0, 1, (1, n_features))

X, alphas, betas = generate_data(n_features, n_points)

for i in range(5000):
    # print(a, b, c, d)
    a, b, c, d = descent(a, b, c, d, X, alphas, betas, 2e-8)
    print(error(a, b, c, d, X, alphas, betas))

(1815.9414744762944, 2580.756860491041)
(1712.4989768236583, 2436.6133271287567)
(1614.9464494009846, 2302.58949260006)
(1523.1630993839944, 2176.2876400373775)
(1436.9434079688094, 2057.078894538144)
(1355.8788711307943, 1944.50388914041)
(1279.591444591117, 1838.2007808933988)
(1207.7819184240295, 1737.775554336844)
(1140.1672334730615, 1642.9031709835776)
(1076.7838812254133, 1553.4320400234396)
(1017.0977634249195, 1468.9752202962527)
(960.8669953008789, 1389.1894919605074)
(907.8754324058499, 1313.8244397297083)
(858.0098275342909, 1242.6239408877764)
(811.2668313079544, 1175.4531896139974)
(767.3289675209749, 1112.0096748519409)
(725.9325146996036, 1052.1140666505694)
(686.9035896759668, 995.5520767309869)
(650.0719049446984, 942.1111310544277)
(615.3520035687185, 891.6302051336737)
(582.604146926101, 843.9240087749962)
(551.7050191361351, 798.8208071891582)
(522.688360178828, 756.191750973878)
(495.33101887397027, 715.9016167175712)
(469.56052445023926, 677.8041968977093)
(445.2