In [6]:
import pandas as pd
import numpy as np
import os
os.chdir(r'C:\Users\Daniel\Documents\Python Scripts\Kaggle')

In [1]:
def predict_sales(radio, weight, bias):
    return weight*radio + bias


In [2]:
def cost_function(radio, sales, weight, bias):
    companies = len(radio)
    total_error = 0.0
    for i in range(companies):
        total_error += (sales[i] - (weight*radio[i] + bias))**2
    return total_error / companies

In [3]:
def update_weights(radio, sales, weight, bias, learning_rate):
    weight_deriv = 0
    bias_deriv = 0
    companies = len(radio)

    for i in range(companies):
        # Calculate partial derivatives
        # -2x(y - (mx + b))
        weight_deriv += -2*radio[i] * (sales[i] - (weight*radio[i] + bias))

        # -2(y - (mx + b))
        bias_deriv += -2*(sales[i] - (weight*radio[i] + bias))

    # We subtract because the derivatives point in direction of steepest ascent
    weight -= (weight_deriv / companies) * learning_rate
    bias -= (bias_deriv / companies) * learning_rate

    return weight, bias

In [5]:
def train(radio, sales, weight, bias, learning_rate, iters):
    cost_history = []

    for i in range(iters):
        weight,bias = update_weights(radio, sales, weight, bias, learning_rate)

        #Calculate cost for auditing purposes
        cost = cost_function(radio, sales, weight, bias)
        cost_history.append(cost)

        # Log Progress
        if i % 10 == 0:
            print("iter={:d}    weight={:.2f}    bias={:.4f}    cost={:.2}".format(i, weight, bias, cost))

    return weight, bias, cost_history

In [7]:
adv = pd.read_csv('advertising.csv')
sales = list(adv['Sales'])
radio = list(adv['Radio'])

In [8]:
train(radio, sales, 1, 1, .01, 50)

iter=0    weight=-7.09    bias=0.8173    cost=4.4e+04
iter=10    weight=-2563802288829.82    bias=-78453407752.1893    cost=5e+27
iter=20    weight=-869790250988801773010944.00    bias=-26615940518679924441088.0000    cost=5.8e+50
iter=30    weight=-295083237896761287432512507547222016.00    bias=-9029668818420438703048493625245696.0000    cost=6.6e+73
iter=40    weight=-100109327724297252049901816458456750254995800064.00    bias=-3063386729209543434406942450028056668431974400.0000    cost=7.6e+96


(2.3875449080957245e+57,
 7.305985918709849e+55,
 [43575.929012523724,
  8805468.661484092,
  1781785945.5607898,
  360546692125.3616,
  72957092200527.44,
  1.4762962521846224e+16,
  2.9873046724857687e+18,
  6.044849868750675e+20,
  1.2231832351177458e+23,
  2.4751271895233017e+25,
  5.008452068694185e+27,
  1.0134667919525475e+30,
  2.0507632384278386e+32,
  4.1497460927994086e+34,
  8.397065205784292e+36,
  1.699157068731091e+39,
  3.4382664341227682e+41,
  6.957376860306123e+43,
  1.4078342590303945e+46,
  2.8487709386673263e+48,
  5.7645250560849765e+50,
  1.1664591445803134e+53,
  2.360345254356682e+55,
  4.776189329604525e+57,
  9.664681245306039e+59,
  1.9556608234603732e+62,
  3.957305118857631e+64,
  8.007658391411322e+66,
  1.620360093235639e+69,
  3.278819729081701e+71,
  6.6347343782997e+73,
  1.3425471330477955e+76,
  2.7166615898747854e+78,
  5.497200070098584e+80,
  1.1123655858838384e+83,
  2.2508862345927043e+85,
  4.554697579081706e+87,
  9.216489806579082e+89,
  1.