In [1]:
import ipywidgets as widgets

In [2]:
import matplotlib.pyplot as plt
import random


def getConsumption(assets: float, expectedincomepath: list, periodevaluated: int, interestrate: float) -> float:
    restoflife = expectedincomepath[periodevaluated-1: ]
    incomesum = sum(restoflife)
    lifetimeincome = assets + incomesum
    c1 = lifetimeincome / len(restoflife)
    interestsum = 0
    for y in restoflife:
        interest = interestrate * (y - c1)
        interestsum += interest
    c1 += (interestsum / len(restoflife))
    return round(c1, 4)

class InteractiveSmoothing:
    def __init__(
        self, 
        l: list,
        INCOME_PER_PERIOD: int = 100, 
        PERIODS: int = 20, 
        ASSETS: int = 0, 
        R: float = 0) -> None:

        '''TODO -> change hardcoded ranges for random generation'''

        self.PERIODS = PERIODS
        self.R = R
        self.ASSETS = ASSETS

        # random.seed(1234)

        randis = random.sample(range(75, 125, 5), k=4) #HARDCODED
        randpsexpected = random.sample(range(2, 15), k=2) #HARDCODED
        randpsunexpected = random.sample(range(6, 21), k=2) #HARDCODED

        self.expected = {randpsexpected[0]: randis[0], randpsexpected[1]: randis[1]} #HARDCODED
        self.unexpected = {randpsunexpected[0]: randis[2], randpsunexpected[1]: randis[3]} #HARDCODED
        # self.unexpected = {6: 120}
        self.unexpectedperiods = list(self.unexpected.keys())
        
        self.initialincomepath = []
        for i in range(PERIODS):
            self.initialincomepath.append(INCOME_PER_PERIOD)
        for i in self.expected:
            self.initialincomepath[i-1] = self.expected[i]

        self.consumptionpath = l
        self.optimalpath = self.getOptimalPath()

    def getOptimalPath(self):
        self.expectedincomepath = self.initialincomepath.copy()
        period = 1
        assetpath = [self.ASSETS]
        assets = self.ASSETS
        consumptionpath = []
        consumption = getConsumption(assets, self.expectedincomepath, period, self.R)
        while period <= self.PERIODS:
            assets = assets * (1+self.R)
            if period in self.unexpectedperiods:
                self.expectedincomepath[period-1] = self.unexpected[period]
            consumption = getConsumption(assets, self.expectedincomepath, period, self.R)
            assetpath.append(self.expectedincomepath[period-1] - consumption)
            consumptionpath.append(consumption)
            assets = sum(assetpath)
            period += 1
        
        return consumptionpath
        
    def PlotAll(self, plot_optimal=False, plot_unexpected=False):
        index = 0
        while index < len(self.consumptionpath):
            plt.plot([index, index+1], [self.initialincomepath[index], self.initialincomepath[index]], c='g')
            plt.plot([index, index+1], [self.consumptionpath[index], self.consumptionpath[index]], c='r')
            if index != len(self.consumptionpath) - 1:
                plt.plot([index+1, index+1], [self.initialincomepath[index], self.initialincomepath[index+1]], c='gray')
            if plot_optimal:
                plt.plot([index, index+1], [self.optimalpath[index], self.optimalpath[index]], c='pink')
                if index != len(self.consumptionpath) - 1:
                    plt.plot([index+1, index+1], [self.optimalpath[index], self.optimalpath[index+1]], c='gray')
            if plot_unexpected:
                if index+1 in self.unexpectedperiods:
                    plt.plot([index+1, index+2], [self.unexpected[index+1], self.unexpected[index+1]], c='b')
                    plt.plot([index+1, index+1], [self.initialincomepath[index], self.unexpected[index+1]], c='gray')
                    plt.plot([index+2, index+2], [self.initialincomepath[index+2], self.unexpected[index+1]], c='gray')

                    # if index != len(self.consumptionpath) - 2:
                    #     plt.plot([index+2, index+2], [self.initialincomepath[index+2], self.unexpected[index+1]], c='gray')
                    #     plt.plot([index+2, index+2], [self.initialincomepath[index+3], self.unexpected[index+1]], c='gray')

            if index != (len(self.consumptionpath) - 1):
               # plt.plot([index+1, index+1], [self.expectedincomepath[index], self.expectedincomepath[index+1]], c='gray')
                plt.plot([index+1, index+1], [self.initialincomepath[index], self.initialincomepath[index+1]], c='gray')
                plt.plot([index+1, index+1], [self.consumptionpath[index], self.consumptionpath[index+1]], c='gray')
                plt.plot([index+2, index+2], [self.initialincomepath[index+2], self.initialincomepath[index+1]], c='gray')
            index+=1
        while index < (self.PERIODS):
            #plt.plot([index, index+1], [self.expectedincomepath[index], self.expectedincomepath[index]], c='b')
            plt.plot([index, index+1], [self.initialincomepath[index], self.initialincomepath[index]], c='g')
            if plot_optimal:
                plt.plot([index, index+1], [self.optimalpath[index], self.optimalpath[index]], c='pink')
            
            if index != (self.PERIODS - 1):
               # plt.plot([index+1, index+1], [self.expectedincomepath[index], self.expectedincomepath[index+1]], c='gray')
                plt.plot([index+1, index+1], [self.initialincomepath[index], self.initialincomepath[index+1]], c='gray')
                if plot_optimal:
                    plt.plot([index+1, index+1], [self.optimalpath[index], self.optimalpath[index+1]], c='gray')
            index+=1

        plt.show()

    def PlotInitial(self):
        index = 0
        while index < (self.PERIODS):
            plt.plot([index, index+1], [self.initialincomepath[index], self.initialincomepath[index]], c='g')
            
            if index != (self.PERIODS - 1):
                plt.plot([index+1, index+1], [self.initialincomepath[index], self.initialincomepath[index+1]], c='gray')
            index+=1

        plt.show()


In [3]:
out1 = widgets.Output(layout={'border': '2px solid gray', 'padding': '5px 5px 5px 5px', 'margin': '0 0 0 0'})
out1

Output(layout=Layout(border='2px solid gray', margin='0 0 0 0', padding='5px 5px 5px 5px'))

In [4]:
l = []
a = widgets.FloatText(value=100)
sub = widgets.Button(description='Submit')
start = widgets.Button(description='Start')
seed = widgets.IntText(value=random.randint(0, 10000), description='random seed')

with out1:
    out1.clear_output(wait=True)
    display(seed, start)

def on_start(b):
    random.seed(seed.value)
    sim = InteractiveSmoothing([])
    with out1:
        out1.clear_output(wait=True)
        display(a, sub)
        sim.PlotInitial()

def on_sub(b):
    random.seed(seed.value)
    l.append(a.value)
    sim = InteractiveSmoothing(l)
    with out1:
        if len(l) == 20:
            print('DONE - pink = optimal path, red = your path')
            out1.clear_output(wait=True)
            sim.PlotAll(True, True)
        else:
            out1.clear_output(wait=True)
            display(a, sub)
            sim.PlotAll(plot_unexpected=True)

start.on_click(on_start)
sub.on_click(on_sub)

