In [1]:
from sklearn.datasets import load_boston

In [2]:
data = load_boston()

In [3]:
X, Y = data['data'], data['target']

In [12]:
def price(rm, k, b):
    """f(x) = k * x + b"""
    return k * rm + b 

## loss
$$ loss = \frac{1}{n} \sum{(y_i - \hat{y_i})}^2$$
$$ loss = \frac{1}{n} \sum{(y_i - (kx_i + b_i))}^2 $$
$$ \frac{\partial{loss}}{\partial{k}} = -\frac{2}{n}\sum(y_i - (kx_i + b_i))x_i$$
$$ \frac{\partial{loss}}{\partial{k}} = -\frac{2}{n}\sum(y_i - \hat{y_i})x_i$$
$$ \frac{\partial{loss}}{\partial{b}} = -\frac{2}{n}\sum(y_i - \hat{y_i})$$

In [13]:
def loss(y, y_hat): 
    return sum((y_i - y_hat_i)**2 for y_i, y_hat_i in zip(list(y), list(y_hat))) / len(list(y))

In [14]:
def partial_k(x, y, y_hat):
    n = len(y)

    gradient = 0
    
    for x_i, y_i, y_hat_i in zip(list(x), list(y), list(y_hat)):
        gradient += (y_i - y_hat_i) * x_i
    
    return -2 / n * gradient


def partial_b(x, y, y_hat):
    n = len(y)

    gradient = 0
    
    for y_i, y_hat_i in zip(list(y), list(y_hat)):
        gradient += (y_i - y_hat_i)
    
    return -2 / n * gradient

In [19]:
import random
trying_times = 20000

X, y = data['data'], data['target']
X_rm = X[:, 5]

min_loss = float('inf') 

current_k = random.random() * 200 - 100
current_b = random.random() * 200 - 100

learning_rate = 1e-03


update_time = 0

for i in range(trying_times):
    
    price_by_k_and_b = [price(r, current_k, current_b) for r in X_rm]
    
    current_loss = loss(y, price_by_k_and_b)

    if current_loss < min_loss: # performance became better
        min_loss = current_loss
        
        if i % 50 == 0: 
            print('When time is : {}, get best_k: {} best_b: {}, and the loss is: {}'.format(i, current_k, current_b, min_loss))

    k_gradient = partial_k(X_rm, y, price_by_k_and_b)
    
    b_gradient = partial_b(X_rm, y, price_by_k_and_b)
    
    current_k = current_k + (-1 * k_gradient) * learning_rate

    current_b = current_b + (-1 * b_gradient) * learning_rate

When time is : 0, get best_k: 43.28735114904168 best_b: -0.8214234323919527, and the loss is: 62466.565876456596
When time is : 50, get best_k: 5.275022511902216 best_b: -6.831371404869759, and the loss is: 65.16133260642869
When time is : 100, get best_k: 4.751586247666949 best_b: -6.947834322162105, and the loss is: 53.071239280356686
When time is : 150, get best_k: 4.749471115366129 best_b: -6.982301846779568, and the loss is: 53.04616400107392
When time is : 200, get best_k: 4.754599364209697 best_b: -7.015589656708968, and the loss is: 53.023476489542674
When time is : 250, get best_k: 4.759822143105302 best_b: -7.048821636874418, and the loss is: 53.00084392039752
When time is : 300, get best_k: 4.765040046279076 best_b: -7.082013463598246, and the loss is: 52.97826571165747
When time is : 350, get best_k: 4.770251698814842 best_b: -7.115165402195812, and the loss is: 52.95574173266913
When time is : 400, get best_k: 4.775457089022385 best_b: -7.148277503617354, and the loss is: 

When time is : 3850, get best_k: 5.119926307733239 best_b: -9.339486854539643, and the loss is: 51.50636750072987
When time is : 3900, get best_k: 5.12471151032812 best_b: -9.369926092945898, and the loss is: 51.487378818785814
When time is : 3950, get best_k: 5.129490962768936 best_b: -9.400328753943592, and the loss is: 51.46843574504398
When time is : 4000, get best_k: 5.1342646719653855 best_b: -9.43069488148609, and the loss is: 51.449538169959645
When time is : 4050, get best_k: 5.139032644818849 best_b: -9.461024519473932, and the loss is: 51.4306859842513
When time is : 4100, get best_k: 5.143794888222427 best_b: -9.491317711754906, and the loss is: 51.41187907889995
When time is : 4150, get best_k: 5.148551409060927 best_b: -9.52157450212411, and the loss is: 51.393117345148475
When time is : 4200, get best_k: 5.153302214210891 best_b: -9.551794934324018, and the loss is: 51.3744006745006
When time is : 4250, get best_k: 5.158047310540592 best_b: -9.58197905204455, and the los

When time is : 7500, get best_k: 5.454555557986983 best_b: -11.46810300489438, and the loss is: 50.23351597478609
When time is : 7550, get best_k: 5.458938652268924 best_b: -11.495984384869354, and the loss is: 50.21758450712771
When time is : 7600, get best_k: 5.463316479591547 best_b: -11.523832261094562, and the loss is: 50.20169130466331
When time is : 7650, get best_k: 5.467689046283912 best_b: -11.551646673829888, and the loss is: 50.18583627548514
When time is : 7700, get best_k: 5.472056358667472 best_b: -11.579427663286856, and the loss is: 50.170019327906175
When time is : 7750, get best_k: 5.476418423056086 best_b: -11.607175269628645, and the loss is: 50.154240370459696
When time is : 7800, get best_k: 5.480775245756021 best_b: -11.634889532970186, and the loss is: 50.138499311898634
When time is : 7850, get best_k: 5.485126833065972 best_b: -11.662570493378208, and the loss is: 50.12279606119508
When time is : 7900, get best_k: 5.48947319127706 best_b: -11.690218190871278,

When time is : 11200, get best_k: 5.765080146941997 best_b: -13.443386560144527, and the loss is: 49.152229519951405
When time is : 11250, get best_k: 5.769090098354639 best_b: -13.468894334293424, and the loss is: 49.13889515362747
When time is : 11300, get best_k: 5.773095231196291 best_b: -13.49437145694362, and the loss is: 49.125592814617384
When time is : 11350, get best_k: 5.777095551257204 best_b: -13.519817964927578, and the loss is: 49.112322425995785
When time is : 11400, get best_k: 5.781091064320672 best_b: -13.545233895033519, and the loss is: 49.09908391102237
When time is : 11450, get best_k: 5.785081776163042 best_b: -13.570619284005447, and the loss is: 49.08587719314098
When time is : 11500, get best_k: 5.789067692553712 best_b: -13.595974168543222, and the loss is: 49.07270219597934
When time is : 11550, get best_k: 5.793048819255154 best_b: -13.621298585302577, and the loss is: 49.05955884334874
When time is : 11600, get best_k: 5.797025162022912 best_b: -13.646592

When time is : 15050, get best_k: 6.0601615708065175 best_b: -15.320434319811072, and the loss is: 48.21380966936583
When time is : 15100, get best_k: 6.063816936678773 best_b: -15.34368653349305, and the loss is: 48.20272925844417
When time is : 15150, get best_k: 6.067467910068892 best_b: -15.366910806077804, and the loss is: 48.191675461145415
When time is : 15200, get best_k: 6.071114496255118 best_b: -15.390107171140844, and the loss is: 48.18064821354739
When time is : 15250, get best_k: 6.074756700509342 best_b: -15.413275662217325, and the loss is: 48.16964745188124
When time is : 15300, get best_k: 6.07839452809713 best_b: -15.436416312802116, and the loss is: 48.15867311253149
When time is : 15350, get best_k: 6.082027984277717 best_b: -15.45952915634984, and the loss is: 48.147725132035376
When time is : 15400, get best_k: 6.085657074304022 best_b: -15.482614226274903, and the loss is: 48.13680344708251
When time is : 15450, get best_k: 6.089281803422643 best_b: -15.50567155

When time is : 18750, get best_k: 6.319129530870593 best_b: -16.967760074504707, and the loss is: 47.46176983847708
When time is : 18800, get best_k: 6.322473707062189 best_b: -16.9890327739046, and the loss is: 47.452495723786754
When time is : 18850, get best_k: 6.325813864713724 best_b: -17.01027991089696, and the loss is: 47.443243884244254
When time is : 18900, get best_k: 6.3291500086540955 best_b: -17.031501516198926, and the loss is: 47.434014266347745
When time is : 18950, get best_k: 6.332482143706385 best_b: -17.052697620490747, and the loss is: 47.42480681672394
When time is : 19000, get best_k: 6.335810274687896 best_b: -17.073868254415796, and the loss is: 47.41562148212762
When time is : 19050, get best_k: 6.339134406410125 best_b: -17.095013448580616, and the loss is: 47.406458209441574
When time is : 19100, get best_k: 6.342454543678799 best_b: -17.11613323355499, and the loss is: 47.39731694567608
When time is : 19150, get best_k: 6.345770691293863 best_b: -17.1372276

## Dynamic Programming

In [20]:
original_price = [1, 5, 8, 9, 10, 17, 17, 20, 24, 30, 35]
from collections import defaultdict
price = defaultdict(int)
for i, p in enumerate(original_price): 
    price[i + 1] = p

In [21]:
price[10]

30

In [22]:
price[11]

35

In [25]:
def r(n):
    return max(
        [price[n]] + [r(i) + r(n-i) for i in range(1, n)]
    )

In [26]:
# 普通装饰器，额外保存调用次数
called_time = defaultdict(int)

def get_call_times(f):
    result = f()
    print('function: {} called once! '.format(f.__name__))
    called_time[f.__name__] += 1
    
    return result

In [35]:
r(12) # 抛出问题：计算时间太长

36

In [36]:
from functools import wraps
def get_call_time(f):
    """@param f is a function"""
    @wraps(f)
    def wrap(n):
        """Haha I am warp"""
       # print('I can count')
        result = f(n)
        get_call_time.already_computed[(f.__name__, n)] += 1
        return result
    return wrap

In [41]:
solution = {}
get_call_time.already_computed = defaultdict(int)
@get_call_time
def r(n):
    """
    Args: n is the iron length
    Return: the max revenue 
    """
    max_price, max_split = max(
        [(price[n], 0)] + [(r(i) + r(n-i), i) for i in range(1, n)], key=lambda x: x[0]
    )

    solution[n] = (n - max_split, max_split)
    
    return max_price

In [42]:
r(10)

30

In [65]:
get_call_time.already_computed = defaultdict(int)
r(200)
get_call_time.already_computed # 还是没有解决计算遍历太长的问题

defaultdict(int, {})

In [48]:
solution

{1: (1, 0),
 2: (2, 0),
 3: (3, 0),
 4: (2, 2),
 5: (3, 2),
 6: (6, 0),
 7: (6, 1),
 8: (6, 2),
 9: (6, 3),
 10: (10, 0),
 11: (11, 0),
 12: (11, 1),
 13: (11, 2),
 14: (11, 3),
 15: (13, 2),
 16: (14, 2)}

In [49]:
def parse_solution(n):
    left_split, right_split = solution[n]
    
    if right_split == 0: return [left_split]
    
    return parse_solution(left_split) + parse_solution(right_split)

In [52]:
parse_solution(16)

[11, 3, 2]

In [59]:
def memo(f): 
    memo.already_computed = defaultdict(int)
    @wraps(f)
    def _wrap(arg):
        result = None
        
        if arg in memo.already_computed: 
            result = memo.already_computed[arg]
        else:
            result = f(arg)
            memo.already_computed[arg] = result
        
        return result
    
    return _wrap

In [60]:
memo.already_computed = {}


In [61]:
solution = {}
memo.already_computed = defaultdict(int)
@memo
def r(n):
    """
    Args: n is the iron length
    Return: the max revenue 
    """
    max_price, max_split = max(
        [(price[n], 0)] + [(r(i) + r(n-i), i) for i in range(1, n)], key=lambda x: x[0]
    )

    solution[n] = (n - max_split, max_split)
    
    return max_price

In [72]:
r(300) # 最大递归次数

953

In [73]:
memo.already_computed

defaultdict(int,
            {1: 1,
             2: 5,
             3: 8,
             4: 10,
             5: 13,
             6: 17,
             7: 18,
             8: 22,
             9: 25,
             10: 30,
             11: 35,
             12: 36,
             13: 40,
             14: 43,
             15: 45,
             16: 48,
             17: 52,
             18: 53,
             19: 57,
             20: 60,
             21: 65,
             22: 70,
             23: 71,
             24: 75,
             25: 78,
             26: 80,
             27: 83,
             28: 87,
             29: 88,
             30: 92,
             31: 95,
             32: 100,
             33: 105,
             34: 106,
             35: 110,
             36: 113,
             37: 115,
             38: 118,
             39: 122,
             40: 123,
             41: 127,
             42: 130,
             43: 135,
             44: 140,
             45: 141,
             46: 145,
             4

In [74]:
solution

{1: (1, 0),
 2: (2, 0),
 3: (3, 0),
 4: (2, 2),
 5: (3, 2),
 6: (6, 0),
 7: (6, 1),
 8: (6, 2),
 9: (6, 3),
 10: (10, 0),
 11: (11, 0),
 12: (11, 1),
 13: (11, 2),
 14: (11, 3),
 15: (13, 2),
 16: (14, 2),
 17: (11, 6),
 18: (17, 1),
 19: (17, 2),
 20: (17, 3),
 21: (11, 10),
 22: (11, 11),
 23: (22, 1),
 24: (22, 2),
 25: (22, 3),
 26: (24, 2),
 27: (25, 2),
 28: (22, 6),
 29: (28, 1),
 30: (28, 2),
 31: (28, 3),
 32: (22, 10),
 33: (22, 11),
 34: (33, 1),
 35: (33, 2),
 36: (33, 3),
 37: (35, 2),
 38: (36, 2),
 39: (33, 6),
 40: (39, 1),
 41: (39, 2),
 42: (39, 3),
 43: (33, 10),
 44: (33, 11),
 45: (44, 1),
 46: (44, 2),
 47: (44, 3),
 48: (46, 2),
 49: (47, 2),
 50: (44, 6),
 51: (50, 1),
 52: (50, 2),
 53: (50, 3),
 54: (44, 10),
 55: (44, 11),
 56: (55, 1),
 57: (55, 2),
 58: (55, 3),
 59: (57, 2),
 60: (58, 2),
 61: (55, 6),
 62: (61, 1),
 63: (61, 2),
 64: (61, 3),
 65: (55, 10),
 66: (55, 11),
 67: (66, 1),
 68: (66, 2),
 69: (66, 3),
 70: (68, 2),
 71: (69, 2),
 72: (66, 6),


In [76]:
parse_solution(88)

[11, 11, 11, 11, 11, 11, 11, 11]

In [77]:
parse_solution(55)

[11, 11, 11, 11, 11]