Here is a tool for nonlinear approximation of the retention rate using a model containing 3 elements: the main trend function, offsets for patches and weekly seasonality.

Run the first notebook's block, then run the second block and use the UI to fill in the settings.

Finally, run the second block to start the learning process and display the results.

Dmitry Baltin, 2023,

https://github.com/dmitrybaltin/retention-rate-approximator

In [4]:
#@title Import libraries, define classes and functions
import numpy as np
import torch
import torch.nn as nn

import os
from google.colab import files
import pandas as pd
import io
import matplotlib.pyplot as plt

class ConstantFunction(torch.nn.Module):
    def __init__(self, initial_weights=None):

        super(ConstantFunction, self).__init__()

        if initial_weights is not None:
            self.w = nn.Parameter(torch.Tensor([initial_weights]))
        else:
            self.w = nn.Parameter(torch.Tensor([0.25]))

    def forward(self, x):
        ret_value = torch.ones_like(x)*self.w[0]
        return ret_value

    def reset_weights(self, new_w=None):
        if new_w is None:
            new_w = [0]
        self.w = nn.Parameter(torch.Tensor(new_w))

    # Initialize weights using train dataset
    def init_weights_from_train_data(self, x_train, y_train):
        new_w = torch.mean(y_train).unsqueeze(0)
        self.w = nn.Parameter(new_w)

class LinearFunction(torch.nn.Module):
    def __init__(self, initial_weights=None):

        super(LinearFunction, self).__init__()

        if initial_weights is not None:
            self.w = nn.Parameter(torch.Tensor(initial_weights))
        else:
            self.w = nn.Parameter(torch.Tensor([0.25, 0.25]))

    def forward(self, x):
        return x * self.w[1] + self.w[0]

    def reset_weights(self, new_w=None):
        if new_w is None:
            new_w = [0, 0]
        self.w = nn.Parameter(torch.Tensor(new_w))

    # Initialize weights using train dataset
    def init_weights_from_train_data(self, x_train, y_train):
        firstValue = y_train[torch.argmin(x_train)]
        lastValue =  y_train[torch.argmax(x_train)]

        new_w = [lastValue, firstValue - lastValue]
        self.w = nn.Parameter(torch.Tensor(new_w))

# Approximation of a decreasing process by a linear-fractional function
class InverseFunction(torch.nn.Module):
    def __init__(self, initial_weights=None):

        super(InverseFunction, self).__init__()

        if initial_weights is not None:
            self.w = nn.Parameter(torch.Tensor(initial_weights))
        else:
            self.w = nn.Parameter(torch.Tensor([0.25, 0.25, 1]))

    def forward(self, x):
        return self.w[0] + torch.ones_like(x) * self.w[1] / (x + self.w[2])

    def reset_weights(self, new_w=None):
        if new_w is None:
            new_w = [0, 0, 10]  # The value 10 was obtained by experience
        self.w = nn.Parameter(torch.Tensor(new_w))

    # Initialize weights using train dataset
    def init_weights_from_train_data(self, x_train, y_train):  # The value 1 was obtained by experience

        firstValue = y_train[torch.argmin(x_train)]
        lastValue =  y_train[torch.argmax(x_train)]

        new_w = [lastValue, firstValue - lastValue, 1.0]
        self.w = nn.Parameter(torch.Tensor(new_w))

class InverseFunction_4w(torch.nn.Module):
    def __init__(self, initial_weights=None):

        super(InverseFunction_4w, self).__init__()

        if initial_weights is not None:
            self.w = nn.Parameter(torch.Tensor(initial_weights))
        else:
            self.w = nn.Parameter(torch.Tensor([0.25, 0.25, 1, 1]))

        # todo: !!!! add constrains. all the weights must be > 0

    def forward(self, x):

        return self.w[0] + torch.ones_like(x) * self.w[1] / (torch.pow(x, self.w[3]) + self.w[2])

    def reset_weights(self, new_w=None):
        if new_w is None:
            new_w = [0, 0, 10, 1]  # The value 10 was obtained by experience
        self.w = nn.Parameter(torch.Tensor(new_w))

    # Initialize weights using train dataset
    def init_weights_from_train_data(self, x_train, y_train):  # The value 1 was obtained by experience

        firstValue = y_train[torch.argmin(x_train)]
        lastValue =  y_train[torch.argmax(x_train)]

        new_w = [lastValue, firstValue - lastValue, 1.0, 1.0]
        self.w = nn.Parameter(torch.Tensor(new_w))

class LinearFractionalFunction_new(torch.nn.Module):
    def __init__(self, initial_weights=None):

        super(LinearFractionalFunction_new, self).__init__()

        #print('LinearFractionalDecreaser3w_v2 weights = {0}'.format(initial_weights))

        if initial_weights is not None:
            self.w0 = nn.Parameter(torch.Tensor([initial_weights[0]]))
            self.w1 = nn.Parameter(torch.Tensor([initial_weights[1]]))
            self.w2 = nn.Parameter(torch.Tensor([initial_weights[2]]))
        else:
            self.w0 = nn.Parameter(torch.Tensor([0.5]))
            self.w1 = nn.Parameter(torch.Tensor([0.25]))
            self.w2 = nn.Parameter(torch.Tensor([0.1]))

    def forward(self, x):

        return (x / (x + 1/self.w2[0])) * (self.w1[0] - self.w0[0]) + self.w0[0]
        #return self.w[0] - (x / (x + torch.pow(self.w[2],2))) * self.w[1]

    def reset_weights(self, new_w=None):
        if new_w is None:
            self.w0 = nn.Parameter(torch.Tensor([0.5]))
            self.w1 = nn.Parameter(torch.Tensor([0.25]))
            self.w2 = nn.Parameter(torch.Tensor([0.1]))
        else:
            self.w0 = nn.Parameter(torch.Tensor([new_w[0]]))
            self.w1 = nn.Parameter(torch.Tensor([new_w[1]]))
            self.w2 = nn.Parameter(torch.Tensor([new_w[2]]))

    # Initialize weights using train dataset
    def init_weights_from_train_data(self, x_train, y_train):  # The value 10 was obtained by experience

        firstValue = torch.max(y_train)
        lastValue =  torch.min(y_train)
        w0 = firstValue
        w1 = firstValue - lastValue

        if w0 < 0:
            w0 = 0.1
        if w1 < 0:
            w1 = 0
        if w1 > w0:
            w1 = w0

        x_mean = torch.mean(x_train)
        y_mean = torch.mean(y_train)

        if x_mean != 0:
            w2 = ( w1 / (w0 - y_mean) - 1 ) / x_mean
        else:
            w2 = 0.05

        self.reset_weights([w0, w1, w2])

class LinearFractionalFunction(torch.nn.Module):
    def __init__(self, initial_weights=None):

        #super(LinearFractionalFunction, self).__init__()
        super().__init__()

        #print('LinearFractionalDecreaser3w_v2 weights = {0}'.format(initial_weights))

        if initial_weights is not None:
            self.w = nn.Parameter(torch.Tensor(initial_weights))
        else:
            self.w = nn.Parameter(torch.Tensor([0.5, 0.25, 20]))

    def forward(self, x):

        return (x / (x + self.w[2])) * (-self.w[1]) + self.w[0]

    def reset_weights(self, new_w=None):
        if new_w is None:
            new_w = [0, 0, 10]  # The value 10 was obtained by experience
        self.w = nn.Parameter(torch.Tensor(new_w))

    # Initialize weights using train dataset
    def init_weights_from_train_data(self, x_train, y_train):  # The value 10 was obtained by experience

        firstValue = torch.max(y_train)
        lastValue =  torch.min(y_train)
        w0 = firstValue
        w1 = firstValue - lastValue

        if w0 < 0:
            w0 = 0.1
        if w1 < 0:
            w1 = 0
        if w1 > w0:
            w1 = w0

        x_mean = torch.mean(x_train)
        y_mean = torch.mean(y_train)

        if y_mean - w0 != 0:
            w2 = x_mean * ( w1 / (w0 - y_mean) - 1 )
        else:
            w2 = 20

        self.reset_weights([w0, w1, w2])

class SigmaFunction(torch.nn.Module):
    def __init__(self, initial_weights=None):

        #super(LinearFractionalFunction, self).__init__()
        super().__init__()

        #print('LinearFractionalDecreaser3w_v2 weights = {0}'.format(initial_weights))

        if initial_weights is not None:
            self.w = nn.Parameter(torch.Tensor(initial_weights))
        else:
            self.w = nn.Parameter(torch.Tensor([0.5, 0.25, 20]))

    def forward(self, x):

        return torch.nn.Sigmoid()(-x*self.w[2]*2) * self.w[1] + self.w[0]

    def reset_weights(self, new_w=None):
        if new_w is None:
            new_w = [0, 0, 10]  # The value 10 was obtained by experience
        self.w = nn.Parameter(torch.Tensor(new_w))

    # Initialize weights using train dataset
    def init_weights_from_train_data(self, x_train, y_train):  # The value 10 was obtained by experience

        firstValue = torch.max(y_train)
        lastValue =  torch.min(y_train)
        w0 = firstValue
        w1 = firstValue - lastValue

        x_mean = torch.mean(x_train)
        y_mean = torch.mean(y_train)

        if x_mean != 0:
            w2 = ( w1 / (w0 - y_mean) - 1 ) / x_mean
        else:
            w2 = 0.05

        self.reset_weights([w0, w1, w2])

# Approximation of weekly fluctuations. Every day of week has its own constant weight
class WeekFunction(torch.nn.Module):
    def __init__(self, first_day_of_week:int, regularizer_base:int, initial_weights=None):
        super(WeekFunction, self).__init__()

        self.first_day_of_week = first_day_of_week

        if initial_weights is not None:
            self.w = nn.Parameter(torch.Tensor(initial_weights))
        else:
            self.w = nn.Parameter(torch.ones(7))

        self.regularizer_base = regularizer_base

    def forward(self, day_numbers):
        day_of_week = (day_numbers.type(torch.long) + 7 - self.first_day_of_week) % 7
        weight_values = self.w[day_of_week]
        return weight_values

    def regularize(self):
        return torch.square(torch.mean(self.w) - self.regularizer_base)


def multiply_connector(input1, input2):
    return torch.mul(input1, input2)

def additive_connector(input1, input2):
    return torch.add(input1, input2)


class ApproximatorsFactory():

    main_functions = [
        ['w0', ConstantFunction, [0.25]],
        ['w0+w1*x', LinearFunction, [0.25, 0]],
        ['w0+w1/(w2+x)', InverseFunction, [0.25, 0.25, 1.0]],
        ['w0-w1*x/(w2+x)', LinearFractionalFunction, [0.5, 0.25, 10.0]],
        ['w0-(w0-w1)*x/(1/w2+x)', LinearFractionalFunction_new, [0.5, 0.25, 0.1]],
        ['w0+w1/(w2+pow(x,w3))', InverseFunction_4w, [0.25, 0.25, 1.0, 1]],
        ['w0+w1*Sigmoid(x*w3)', SigmaFunction, [0.5, 0.25, 0.05]]]

    chain_functions = [
        ['w0', ConstantFunction, [0.25]],
        ['w0+w1*x', LinearFunction, [0.25, 0]]]

    connectors = [
        ['mul', multiply_connector,   1],
        ['add', additive_connector,   0]]

    @staticmethod
    def create_main_function(function_type, initial_weights):

        if function_type is None:
            return ApproximatorsFactory.main_functions[0][1](initial_weights)

        for index, row in enumerate(ApproximatorsFactory.main_functions):
            if function_type == index or function_type == str(index) or function_type == row[0]:
                return row[1](initial_weights)

    @staticmethod
    def create_chain_function(function_type, initial_weights):

        if function_type is None:
            return ApproximatorsFactory.chain_functions[0][1](initial_weights)

        for index, row in enumerate(ApproximatorsFactory.chain_functions):
            if function_type == index or function_type == str(index) or function_type == row[0]:
                return row[1](initial_weights)

        return ApproximatorsFactory.chain_functions[0][1](initial_weights)

    @staticmethod
    def create_connector(connector_type):

        if connector_type is None:
            row = ApproximatorsFactory.connectors[0]
            return  row[1], row[2]

        for index, row in enumerate(ApproximatorsFactory.connectors):
            if connector_type == index or connector_type == str(index) or connector_type == row[0]:
                return row[1], row[2]

        row = ApproximatorsFactory.connectors[0]
        return  row[1], row[2]

    @staticmethod
    def get_main_function_weights_number(function_type):
        if function_type is None:
            return None

        for index, row in enumerate(ApproximatorsFactory.main_functions):
            if function_type == index or function_type == str(index) or function_type == row[0]:
                return row[2]

        return None

class ComplexApproximator(nn.Module):

    def __init_main_function(self,
                            main_function_type='0',
                            main_function_weights=None):

        self.main_function = ApproximatorsFactory.create_main_function(main_function_type, main_function_weights)

    def __init_chains_functions(self,
                                patches_dates=None,
                                chain_functions_type='0',
                                chain_functions_weights=None):

        # Create sorted list of all the patches
        if patches_dates is not None:
            # Create array of patches removing incorrect numbers
            self.patches_dates = [patch_date for patch_date in patches_dates if patch_date > 0]
            self.patches_dates = list(set(self.patches_dates))
        else:
            self.patches_dates = []
        self.patches_dates.sort()

        chain_functions = []
        for i in range(0, len(self.patches_dates)):
            if chain_functions_weights is not None and i < len(chain_functions_weights):
                initial_weights = chain_functions_weights[i]
            else:
                initial_weights = None

            new_function = ApproximatorsFactory.create_chain_function(chain_functions_type, initial_weights)

            chain_functions.append(new_function)

        self.chain_functions = nn.ModuleList(chain_functions)

    def __init_week_function(self,
                             first_day_of_week,
                             regularizer_base,
                             week_function_initial_weights=None):

        self.week_function = WeekFunction(first_day_of_week, regularizer_base, week_function_initial_weights)

    def __init_connector(self, connector_type):
        self.connector, regularizer_base = ApproximatorsFactory.create_connector(connector_type)
        return  regularizer_base

    def __init__(self,
                 first_day_of_week: int,
                 patches_dates=None,
                 main_function_type=0,
                 chain_functions_type=0,
                 connector_type='mul',
                 main_function_initial_weights=None,
                 chain_functions_initial_weights=None,
                 week_function_initial_weights=None):

        super(ComplexApproximator, self).__init__()

        self.__init_main_function(main_function_type,
                                  main_function_initial_weights)
        self.__init_chains_functions(patches_dates,
                                     chain_functions_type,
                                     chain_functions_initial_weights)
        regularizer_base = self.__init_connector(connector_type)
        self.__init_week_function(first_day_of_week, regularizer_base, week_function_initial_weights)

    def init_weights_from_train_data(self, x_initial, y_initial):

        self.main_function.init_weights_from_train_data(x_initial, y_initial)

    def forward(self, x):
        return self.connector(self.forward_trend_function(x), self.forward_week_function(x))

    def forward_trend_function(self, x):

        result = self.main_function.forward(x)

        if len(self.patches_dates)>0:
            chain_mask = (x >= self.patches_dates[-1])
            val = self.chain_functions[-1](x[chain_mask])
            result[chain_mask] += val

            for i in range(len(self.chain_functions) - 2, -1, -1):
                chain_mask = (x >= self.patches_dates[i]) * (x < self.patches_dates[i+1])
                val = self.chain_functions[i](x[chain_mask])
                result[chain_mask] += val

        return result

    def forward_week_function(self, x):
        return self.week_function(x)

    def regularize(self):
        return self.week_function.regularize()

    def print_summary(self, print_function_types=False):
        if print_function_types:
            print('Main function = ', self.main_function)
            print('Chain functions = ', self.chain_functions)
            print('Week function = ', self.week_function)

        for name, param in self.named_parameters():
            if param.requires_grad:
                print(name, param.data)


def separate_dataset_by_indices(indices_list_2d, data):
    data_by_patches = []
    for i in indices_list_2d:
        data_by_patches.append(data[i])

    return data_by_patches

# Separate x_train ans y_train by patches
# return two lists of lists
def indices_of_x_by_patches(x_data, patches_dates):

    indices = []
    # Append all the patches excluding the last
    for patch_index in range(1, len(patches_dates)):
        indices.append([])

        start_date = patches_dates[patch_index - 1]
        end_date = patches_dates[patch_index]

        for index in range(len(x_data)):
            x = x_data[index].tolist()
            if x >= start_date and x < end_date:
                indices[-1].append(index)

    # Append last patch
    indices.append([])
    for index in range(len(x_data)):
        x = x_data[index].tolist()
        if x >= patches_dates[-1]:
            indices[-1].append(index)

    return indices, separate_dataset_by_indices(indices, x_data)

# Custom mse loss function with use the size of sample in every point to estimate relative variation
# of normal distribution in this point
# The sample size is obtained from the 0 element of y_true
# It is not required to change the model to use this function. But you need to add sample size to x_true argument
# as zero element when you call model.fit function
def custom_mse_loss(y_pred,
                    y_true,
                    sample_size_true,
                    regularizer,
                    regualizer_lambda):

    #temp = torch.square(y_pred - y_true) * sample_size_true / (-y_true * (y_true - 1)) / torch.sum(sample_size_true)
    temp = torch.square(y_pred - y_true) * sample_size_true / (-y_pred * (y_pred - 1)) / torch.sum(sample_size_true)

    return torch.mean(temp) + regularizer() * regualizer_lambda

def save_retention_data_to_local_csv(days,
                        installs,
                        retention,
                        retention_mean = None,
                        date_of_release = None):

    print('date_of_release = ', date_of_release)

    if retention_mean is None:
        retention_mean = torch.zeros_like(installs)

    if date_of_release==None:
        table = torch.stack((days, installs, retention, retention_mean), 1).detach().numpy()
        df = pd.DataFrame(table)
        df.columns = ['date', 'installs', 'retention', 'retention_mean']
    else:
        date_list = [date_of_release + datetime.timedelta(days=day) for day in days.tolist()]
        date_series = pd.Series(date_list, name='date')

        table = torch.stack((installs, retention, retention_mean), 1).detach().numpy()
        df = pd.DataFrame(table)
        df.columns = ['installs', 'retention', 'retention_mean']

        df = pd.concat([date_series, df], axis=1)

    filename = 'exoprt.csv'
    df.to_csv(filename, index=False)
    files.download(filename)

    return filename

def load_retention_from_local_csv(file):

    try:
        df = pd.read_csv(file)

        df['date'] = pd.to_datetime(df['date'])
        dates_list = df['date'].tolist()
        first_date = min(dates_list)
        days = [(date_ - first_date).days for date_ in dates_list]

        date = torch.clamp(torch.tensor(days), 0, 100000000000000).float()
        installs = torch.clamp(torch.tensor(df['installs'].values), 1, 10000000000)
        retention = torch.clamp(torch.tensor(df['retention'].values), 0,1)
        retention_mean = torch.clamp(torch.tensor(df['retention_mean'].values), 0,1)

        return date, installs, retention, retention_mean, first_date, df

    except ValueError:
        print('Error! Wrong file format')
        return None, None, None, None, None


#Tools functions for evaluating confidence interval of approximation
def separate_intervals(x_tensor, y_tensor, x_intervals_by_start):
    x_sets = []
    y_sets = []
    x_intervals_by_start = [0.0] + x_intervals_by_start + [float('inf')]  # add start and end intervals
    for i in range(len(x_intervals_by_start)-1):
        start = x_intervals_by_start[i]
        end = x_intervals_by_start[i+1]
        mask = (x_tensor >= start) & (x_tensor < end)  # boolean mask for selecting indices
        x_sets.append(x_tensor[mask])
        y_sets.append(y_tensor[mask])
    return x_sets, y_sets

def extract_avg_values_of_patches(x_tensor, x_intervals_by_start, y_tensor_by_intervals):

    y_tensor_estimated = []
    x_intervals_by_start = [0.0] + x_intervals_by_start + [float('inf')]  # add start and end intervals
    for i in range(len(x_intervals_by_start)-1):
        start = x_intervals_by_start[i]
        end = x_intervals_by_start[i+1]
        elements = torch.sum((x_tensor >= start) & (x_tensor < end)).item()  # boolean mask for selecting indices

        y_tensor_estimated.append(torch.ones(int(elements)) * y_tensor_by_intervals[i])

    return torch.cat(y_tensor_estimated)

def evaluate_ideal_confidence_interval(_dates,
                                       _retention,
                                       _installs,
                                       _patches_dates,
                                       sigmas_number,
                                       _dates_for_plotting,
                                       _retention_for_plotting):

    _dates = _dates.to(device='cpu')
    _retention = _retention.to(device='cpu')

    dates_per_patches, retention_per_patches = separate_intervals(_dates, _retention, _patches_dates)
    _, installs_per_patch = separate_intervals(_dates, _installs, _patches_dates)
    overal_inst_per_patches = [torch.sum(inst) for inst in installs_per_patch]
    avg_retention_per_patches = torch.stack([(torch.sum(ret_per_patch*ins)/torch.sum(ins)) for (ret_per_patch, ins) in list(zip(retention_per_patches, dates_per_patches))])
    sigma_per_patch = torch.stack([torch.sqrt(-p*(p-1)/N) for p, N in list(zip(avg_retention_per_patches, overal_inst_per_patches))]).to(device='cpu')

    estimated_sigma = extract_avg_values_of_patches(_dates_for_plotting, _patches_dates, sigma_per_patch)
    return _retention_for_plotting - estimated_sigma * sigmas_number, _retention_for_plotting + estimated_sigma * sigmas_number,

#Generate an example of dataset
def generate_retention_dataset_2(total_days,
                                 first_day_of_week,
                                 patches_dates,
                                 main_function_type,
                                 chains_functions_type,
                                 main_function_weights,
                                 chains_functions_weights,
                                 week_function_weights,
                                 daily_installs_mean,
                                 daily_installs_sigma):

    x = torch.FloatTensor(list(range(total_days)))

    #todo: Add random to all the parameters

    model = ComplexApproximator(first_day_of_week,
                                patches_dates=patches_dates,
                                main_function_type=main_function_type,
                                chain_functions_type=chains_functions_type,
                                main_function_initial_weights=main_function_weights,
                                chain_functions_initial_weights=chains_functions_weights)

    y_modeled_trend = model.forward(x)

    model = ComplexApproximator(first_day_of_week,
                                patches_dates=patches_dates,
                                main_function_type=main_function_type,
                                chain_functions_type=chains_functions_type,
                                main_function_initial_weights=main_function_weights,
                                chain_functions_initial_weights=chains_functions_weights,
                                week_function_initial_weights=week_function_weights)
    y_modeled_with_oscillations = model.forward(x)
    y_modeled_with_oscillations = torch.clamp(y_modeled_with_oscillations, 0, 1)

    y_modeled_chains = []
    for i, chain in enumerate(model.chain_functions):
        chain_result = (x>=model.patches_dates[i]) * chain.forward(x)
        y_modeled_chains.append(chain_result)

    # Add noise to data using binomial model of returned users
    sample_size = torch.rand_like(y_modeled_with_oscillations) * daily_installs_sigma + daily_installs_mean
    sigma = torch.sqrt(-y_modeled_with_oscillations * (y_modeled_with_oscillations - 1) / sample_size)

    y_modeled_final = torch.normal(mean=y_modeled_with_oscillations, std=sigma)

    # Generate random indices for data wich don't have anomaly
    probability_of_anomaly = 0.3
    choosen = np.random.choice(a=[True, False], size=(total_days),
                               p=[probability_of_anomaly, 1 - probability_of_anomaly])
    bad_days = [i for i, e in enumerate(choosen) if e != 0]
    #todo: modify retention at abnormal days

    return x, sample_size, y_modeled_final, y_modeled_trend, y_modeled_with_oscillations, patches_dates, bad_days, week_function_weights, y_modeled_chains, model

def plot_generated_retention_dataset(x_modeled,
                                     y_modeled_final,
                                     y_modeled_trend = None,
                                     y_modeled_with_oscillations = None,
                                     y_modeled_chains=None):

    plt.figure(figsize=(20, 5))

    if y_modeled_chains is not None:
        for i, chain in enumerate(y_modeled_chains):
            plt.plot(x_modeled, chain.detach().numpy(), color='Grey', label="chain {0}".format(i), marker='.', linestyle='')

    if y_modeled_with_oscillations is not None:
        plt.plot(x_modeled, y_modeled_with_oscillations.detach().numpy(), color='Green', label="y_modeled_with_oscillations")

    plt.plot(x_modeled, y_modeled_final.detach().numpy(), color='Blue', label="Train data")

    if y_modeled_trend is not None:
        plt.plot(x_modeled, y_modeled_trend.detach().numpy(), color='Red', label="y_modeled_decreasing", linewidth=3)

    plt.legend(bbox_to_anchor=(1.05, 0.95), loc=2, borderaxespad=0., fontsize=12)
    plt.title("Train data", fontsize=12)

    plt.show()


#test
dates, installs, retention, retention_mean, retention_oscillated,\
    patches_dates, bad_days, week_weights, y_modeled_chains, generator_model = generate_retention_dataset_2(
                        160,    #total_days
                        2,      #first_day_of_week
                        [30,60,90,120,150], #patches_dates
                        '4',    #main_function_type
                        '0',    #chains_functions_type
                        [0.5,0.4,0.05],   #main_function_weights
                        [0.01, 0.02, 0.02, 0.03, 0.04], #chains_functions_weights
                        [1,1,1,1,1.05,1.05,0.9],    #week_function_weights
                        1000,   #daily_installs_mean
                        200)    #daily_installs_sigma
#generator_model.print_summary(True)
#plot_generated_retention_dataset(dates, retention, retention_mean, retention_oscillated)
#plot_generated_retention_dataset(dates, retention, None, None)

In [None]:
#@title User initerface

#%pip install -q ipywidgets

date_of_release = None
first_day_of_week = 0
dates = None
installs = None
retention = None
retention_mean = None
retention_oscillated = None
patches_dates = None
bad_days = None
week_weights = None

from ipywidgets import interact, interactive, fixed, interact_manual, Layout
import ipywidgets as widgets
import datetime

output = widgets.Output()

#Define butons
data_upload_button = widgets.FileUpload(
    multiple=False,
    button_style='info')

data_save_button = widgets.Button(
    description='Save current train data',
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Save dataset',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)

generator_start_button = widgets.Button(
    description='Generate dataset',
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Generate dataset',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)

#Define generators widgets
generator_start_date_picker = widgets.DatePicker(value=datetime.date.today())
generator_x_points_text_input = widgets.IntText(value='160')
generator_patches_dates_text_input = widgets.Textarea(value='0, 60, 90, 150')
generator_main_function_dropdown = widgets.Dropdown(
    options=[name for name, _, _ in ApproximatorsFactory.main_functions],
    value=ApproximatorsFactory.main_functions[3][0])
generator_main_function_weights_text_input = widgets.Textarea(value='0.4, 0.08, 20')
generator_chains_function_dropdown = widgets.Dropdown(
    options=[name for name, _, _ in ApproximatorsFactory.chain_functions],
    value=ApproximatorsFactory.chain_functions[0][0])
generator_chains_function_weights_text_input = widgets.Textarea(value='0.01, 0.02, 0.03')
generator_week_weights_text_input = widgets.Textarea(value='1,1,1,1,1.05,1.05,0.9')
generator_daily_installs_mean_text_input = widgets.IntText(value='1000')
generator_daily_installs_sigma_text_input = widgets.IntText(value='200')

##Define approximators widgets
approximator_main_function_dropdown = widgets.Dropdown(
    options=[name for name, _, _ in ApproximatorsFactory.main_functions],
    value=ApproximatorsFactory.main_functions[4][0])

approximator_chain_function_dropdown = widgets.Dropdown(
    options=[name for name, _, _ in ApproximatorsFactory.chain_functions],
    value=ApproximatorsFactory.chain_functions[0][0])

approximator_connector_dropdown = widgets.Dropdown(
    options=[name for name, _, _ in ApproximatorsFactory.connectors])

approximator_patches_dates_text_input = widgets.Textarea(
    value='', layout=Layout(width='50%'))

approximator_bad_dates_text_input = widgets.Textarea(
    value='', layout=Layout(width='50%'))

approximator_week_weights_text_input = widgets.Textarea(
    value='1, 1, 1, 1, 1, 1, 1', layout=Layout(width='50%'))

approximator_exclude_bad_dates_checkbox = widgets.Checkbox(
    value=True,
    description = 'Exlude bad dates from training set',
    layout=Layout(width='50%'))

approximator_exclude_patch_dates_checkbox = widgets.Checkbox(
    value=True,
    description = 'Exlude patch dates from training set',
    layout=Layout(width='50%'))


#Event handlers

def check_decreaser_weights(x):

    patch_dates_defined, generator_patches_dates, _ = check_dates(generator_patches_dates_text_input.value)

    if not patch_dates_defined:
        return False, [], 'Error! The dates of patches don''t defined'

    if generator_patches_dates==[] and x.strip() == '':
        return True, [], 'OK'

    string_array = x.split(",")
    try:
        float_array = [float(item_value) for item_value in string_array]
        if len(float_array) != len(generator_patches_dates)-1:
            return False, float_array, 'Error! You should type exactly {0} float numbers'.format(len(generator_patches_dates)-1)
        return True, float_array, 'OK'
    except ValueError:        return False, [], 'Error! Can''t convert data to float array'

def check_main_decreaser_weights(x):

    len_main_decreaser_weights = len(ApproximatorsFactory.get_main_function_weights_number(generator_main_function_dropdown.value))

    #val = generator_main_decreaser_type_dropdown.value

    string_array = x.split(",")
    try:
        float_array = [float(item_value) for item_value in string_array]
        if len(float_array) != len_main_decreaser_weights:
            return False, float_array, 'Error! You should type exactly {0} float numbers'.format(len_main_decreaser_weights)
        return True, float_array, 'OK'
    except ValueError:
        return False, [], 'Error! Can''t convert data to float array'

def check_oscillator_weights(x):
    string_array = x.split(",")
    try:
        float_array = [float(item_value) for item_value in string_array]
        if len(float_array) != 7:
            return False, float_array, 'Error! You should type exactly 7 float numbers'
        return True, float_array, 'OK'
    except ValueError:
        return False, [], 'Error! Can''t convert data to float array'

def check_dates(x):

    if x.strip() =='':
        return True, [], 'OK'

    string_array = x.split(",")

    try:
        int_array = [int(item_value) for item_value in string_array]
        int_array.sort()
        if int_array[0]!=0:
            int_array = [0]+int_array
        for item_value in int_array:
            if item_value < 0:
                return False, [], 'Error! All the values must be >=0'
            #todo: add 0 patch if it does not exist
        return True, int_array, 'OK'
    except ValueError:
        return False, [], 'Error! Can''t convert data to int array'

def check_decreaser_weights_and_print_result(x):
    b, lst, err = check_decreaser_weights(x)
    print(err)
    print('{0} values found'.format(len(lst)))
    return b

def check_main_decreaser_weights_and_print_result(x):
    b, lst, err = check_main_decreaser_weights(x)
    print(err)
    print('{0} values found'.format(len(lst)))
    return b

def check_oscillator_weights_and_print_result(x):
    b, lst, err = check_oscillator_weights(x)
    print(err)
    print('{0} values found'.format(len(lst)))
    return b

def check_dates_and_print_result(x):
    b, lst, err = check_dates(x)
    print(err)
    print('{0} values found'.format(len(lst)))
    return b

def on_generate_button_click(x):
    global dates
    global installs
    global retention
    global retention_mean
    global retention_oscillated
    global patches_dates
    global bad_days
    global week_weights
    global first_day_of_week
    global date_of_release

    generator_main_decreaser_type = generator_main_function_dropdown.value
    generator_decreasers_type = generator_chains_function_dropdown.value
    generator_patches_dates_OK, generator_patches_dates, _ = check_dates(generator_patches_dates_text_input.value)
    generator_main_decreaser_weights_ok, generator_main_decreaser_weights, _ = check_main_decreaser_weights(generator_main_function_weights_text_input.value)
    generator_decreasers_weights_ok, generator_decreasers_weights, _ = check_decreaser_weights(generator_chains_function_weights_text_input.value)
    generator_oscillator_weights_ok, generator_oscillator_weights, _ = check_oscillator_weights(generator_week_weights_text_input.value)

    selected_date = generator_start_date_picker.value
    try:
        local_first_day_of_week = datetime.datetime.strptime(str(selected_date), '%Y-%m-%d').weekday()
    except ValueError:
        print("Erro! Select correct date")
        return

    if not (generator_patches_dates_OK and
            generator_main_decreaser_weights_ok and
            generator_decreasers_weights_ok and
            generator_oscillator_weights_ok):
        print("Error! Input parameters are incorrect : patches dates {0}, main weights {1}, chain weights {2}, weeks weights {3}".
              format(generator_patches_dates_OK,
                     generator_main_decreaser_weights_ok,
                     generator_decreasers_weights_ok,
                     generator_oscillator_weights_ok))
        return

    dates, installs, retention, retention_mean, retention_oscillated,\
        patches_dates, bad_days, week_weights, y_modeled_chains, generator_model = generate_retention_dataset_2(
                            generator_x_points_text_input.value,
                            local_first_day_of_week,
                            generator_patches_dates,
                            generator_main_decreaser_type,
                            generator_decreasers_type,
                            generator_main_decreaser_weights,
                            generator_decreasers_weights,
                            generator_oscillator_weights,
                            generator_daily_installs_mean_text_input.value,
                            generator_daily_installs_sigma_text_input.value)
    first_day_of_week = local_first_day_of_week
    date_of_release = selected_date

    generator_model.print_summary()

    approximator_patches_dates_text_input.value = ",".join(str(x) for x in patches_dates)
    approximator_bad_dates_text_input.value = ",".join(str(x) for x in bad_days)
    approximator_week_weights_text_input.value = ",".join(str(x) for x in week_weights)

    output.clear_output()

    with output:
        plot_generated_retention_dataset(dates, retention, retention_mean, retention_oscillated, None)

def on_save_button_click(x):
    global dates
    global install
    global retention
    global retention_mean

    if dates is not None and installs is not None and retention is not None:
        save_retention_data_to_local_csv(dates.cpu(), installs.cpu(), retention.cpu(), retention_mean, date_of_release = date_of_release)

def on_upload_change(change):

    filename = list(change['new'].keys())[0]
    content = change['new'][filename]['content']

    global dates
    global installs
    global retention
    global retention_mean
    global first_day_of_week
    global date_of_release

    dates, installs, retention, retention_mean, date_of_release, df = load_retention_from_local_csv(io.StringIO(content.decode()))

    if df is not None:
        first_day_of_week = date_of_release.weekday()
        output.clear_output()
        retention_oscillated = None
        with output:
            print('Uploading is usccessful')
            plot_generated_retention_dataset(dates, retention, retention_mean, retention_oscillated)
    else:
        with output:
            display('Uploading error')

generator_start_button.on_click(on_generate_button_click)
data_save_button.on_click(on_save_button_click)
data_upload_button.observe(on_upload_change, names='value')

#Display all the widgets

#Compose generator widgets
generator_settings = widgets.VBox([
    widgets.HTML("<b>Generator</b>"),
    widgets.HTML("Fill all the parameters bellow and then press the button 'Generate dataset'"),
    widgets.HBox([generator_start_date_picker,
                  widgets.Label('Select the date of release')]),
    widgets.HBox([generator_x_points_text_input,
                  widgets.Label('Total quantity of days, int')]),
    widgets.HBox([generator_patches_dates_text_input,
                  widgets.Label('List of patches, integers separated by commas, excluding release patch'),
                  widgets.interactive_output(check_dates_and_print_result, {'x': generator_patches_dates_text_input})]),
    widgets.HBox([generator_main_function_dropdown,
                  widgets.Label('Type of the main trend')]),
    widgets.HBox([generator_chains_function_dropdown,
                  widgets.Label('Type of approximators of the patches')]),
    widgets.HBox([generator_main_function_weights_text_input,
                  widgets.Label('Weights of the main trend'),
                  widgets.interactive_output(check_main_decreaser_weights_and_print_result, {'x': generator_main_function_weights_text_input})]),
    widgets.HBox([generator_chains_function_weights_text_input,
                  widgets.Label('Weights of approximators of the patches'),
                  widgets.interactive_output(check_decreaser_weights_and_print_result, {'x': generator_chains_function_weights_text_input})]),
    widgets.HBox([generator_daily_installs_mean_text_input,
                  widgets.Label('Mean value of installs every day, int')]),
    widgets.HBox([generator_daily_installs_sigma_text_input,
                  widgets.Label('Standard deviation of installs every day, int')]),
    widgets.HBox([generator_week_weights_text_input,
                  widgets.Label('Weights of days of week, 7 floats from Mon to Sun'),
                  widgets.interactive_output(check_oscillator_weights_and_print_result, {'x': generator_week_weights_text_input})]),
    #widgets.HBox([oscillator_weights, widgets.Label('Enter the weight of every day of week. 7 float numbers from Monday to Sunday')])#,
    #widgets.HBox([dataGenerateButton, widgets.Label('Autogenerate dataset and all the parameters bellow')])
    widgets.HBox([generator_start_button,
                      widgets.Label('Generate dataset, display it and and fill the parameters of approximator on the right')])
    ],
    layout=widgets.Layout(border='solid 2px gray', padding='10px', width = '50%'))


#Compose approximator widgets
approximator_settings = widgets.VBox([
    widgets.HTML("<b>Approximator</b>"),
    widgets.HTML("Upload or generate train data, then fill all the parameters <br>and then launch the block bellow 'Create the regression model and train it'"),
    widgets.HBox([data_upload_button,
                      widgets.Label('Upload source retention data from local csv-file with 3 columns: date:int>=0, installs:int>0, retention:float∈[0,1]')]),
    widgets.HBox([approximator_patches_dates_text_input,
                  widgets.Label('List of patches, integers separated by commas, excluding release patch'),
                  widgets.interactive_output(check_dates_and_print_result, {'x': approximator_patches_dates_text_input})]),
    widgets.HBox([approximator_main_function_dropdown,
                  widgets.Label('Type of the main trend')]),
    widgets.HBox([approximator_chain_function_dropdown,
                  widgets.Label('Type of approximators of the patches')]),
    widgets.HBox([approximator_connector_dropdown,
                  widgets.Label('Type of connection function')]),
    widgets.HBox([approximator_week_weights_text_input,
                  widgets.Label('Weights of days of week, 7 floats from Mon to Sun'),
                  widgets.interactive_output(check_oscillator_weights_and_print_result,
                                {'x': approximator_week_weights_text_input})]),
    widgets.HBox([approximator_bad_dates_text_input,
                  widgets.Label('List of bad days, integers separeted by commas'),
                  widgets.interactive_output(check_dates_and_print_result, {'x': approximator_bad_dates_text_input})]),
    approximator_exclude_bad_dates_checkbox,
    approximator_exclude_patch_dates_checkbox
    ],
    layout=widgets.Layout(border='solid 2px gray', padding='10px', width = '50%'))

all_settings = widgets.HBox([
    generator_settings,
    approximator_settings])

display(all_settings)
display(output)
display(widgets.HBox([data_save_button,
                      widgets.Label('Save dataset to local file')]))

In [None]:
# @title Create the regression model and train it

#import Retention_approximators_1d from Retention_approximators_1d.py
#from Retention_approximators_1d.py import ComplexApproximatorNew
from torch.utils.data import DataLoader, TensorDataset
import sys
from tqdm import tqdm

# Configuration of training strategy
# Every row is a one iteration of training, containing parameters:
# epochs, trend_function_trainable, week_function_trainable, learning_rate
# If we have a few of data it's recommended to train only trend but not a weeks function
training_strategy = [
    ('Adam', 500, True, False, 0.01),
    ('Adam', 500, True, False, 0.01),
    ('LBFGS',  5, True, False, 0.01),
    ('LBFGS',  5, True, False, 0.01),
    ('Adam', 500, True, False, 0.01),
    ('Adam', 500, True, False, 0.01),
    ('Adam', 500, True, False, 0.01),
    ('Adam', 500, True, False, 0.01),
    ('Adam', 500, True, False, 0.01),
    ('Adam', 500, True, False, 0.01),
    ('Adam', 500, True, False, 0.001),
    ('Adam', 500, True, True, 0.001),
    ('Adam', 500, True, True, 0.001)]

loss_function = 'custom'    #possible values: 'mse'- standard mse, 'custom' - weighted least squares with weeks regularizer
regualizer_lambda = 100     #weight of regularized function
number_of_sigmas_for_plotting = 3   #confidence interval width

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Load the settings from the form above
main_function_type=approximator_main_function_dropdown.value
chain_function_type=approximator_chain_function_dropdown.value
connector_type=approximator_connector_dropdown.value
_, patches_dates, _ = check_dates(approximator_patches_dates_text_input.value)
_, bad_dates, _ = check_dates(approximator_bad_dates_text_input.value)
_, oscillator_initial_weights, _ = check_oscillator_weights(approximator_week_weights_text_input.value)

patches_dates = [*set(patches_dates)]
patches_dates.sort()

if approximator_exclude_bad_dates_checkbox.value:
    bad_dates = [*set(bad_dates)]
else:
    bad_dates=[]

if approximator_exclude_patch_dates_checkbox.value:
    bad_dates = [*set(bad_dates).union(set(patches_dates))]

if dates!=None:
    good_days_indices = [index for index in range(len(dates)) if dates[index] not in bad_dates]
else:
    sys.exit("Training dataset is not found")

dates = dates.to(device=device)
retention = retention.to(device=device)
installs = installs.to(device=device)

dates_excluding_bad = dates[good_days_indices].to(device=device).detach()
retention_excluding_bad = retention[good_days_indices].to(device=device).detach()
installs_excluding_bad = installs[good_days_indices].to(device=device).detach()

if retention_oscillated != None:
    retention_oscillated_excluding_bad = retention_oscillated[good_days_indices].to(device=device).detach()

model = ComplexApproximator(first_day_of_week,
                            patches_dates=patches_dates,
                            main_function_type=main_function_type,
                            chain_functions_type=chain_function_type,
                            main_function_initial_weights=[0.4, 0.08, 20],
                            week_function_initial_weights=week_weights)

model.init_weights_from_train_data(dates_excluding_bad, retention_excluding_bad)

print("Initial weights")
model.print_summary()

# sys.exit()

if torch.cuda.is_available():
    model.cuda()

#Training

model.train()

for i in range(len(training_strategy)):

    print("\nIteration ", i+1)

    # Apply configuration of training
    optimizer_name, epochs, trend_function_trainable, week_function_trainable, learning_rate = training_strategy[i]

    if optimizer_name == 'LBFGS':
        optimizer = torch.optim.LBFGS(model.parameters(), lr=learning_rate)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Disabling weights if required
    model.week_function.requires_grad_(week_function_trainable)
    for dec_chain in model.chain_functions:
        dec_chain.requires_grad_(trend_function_trainable)

    for epoch in tqdm(range(epochs)):

        def closure():
            optimizer.zero_grad()
            retention_predicted = model(dates_excluding_bad)
            if loss_function == "custom":
                loss = custom_mse_loss(retention_predicted,
                                    retention_excluding_bad,
                                    installs_excluding_bad,
                                    model.regularize,
                                    regualizer_lambda)
            else:
                loss = torch.nn.functional.mse_loss(retention_predicted, retention_excluding_bad)
            loss.backward()
            return loss
#
        optimizer.step(closure)
        loss = closure().item()

    # Visualize the results of training
    retention_predicted_for_plot = model.forward(dates)
    retention_predicted_avg = model.forward_trend_function(dates)
    retention_predicted_osc = model.forward_week_function(dates)

    sigma_retention = torch.sqrt(-retention * (retention-1)/installs)
    retention_min = retention / retention_predicted_osc - number_of_sigmas_for_plotting * sigma_retention
    retention_max = retention / retention_predicted_osc + number_of_sigmas_for_plotting * sigma_retention

    print('Results of iteration')
    print('Loss = ', loss)
    if retention_oscillated != None:
        loss_original = custom_mse_loss(retention_oscillated_excluding_bad,
                                    retention_excluding_bad,
                                    installs_excluding_bad,
                                    model.regularize,
                                    regualizer_lambda)
        print('Ideal loss = ', loss_original.item())
    model.print_summary()

    plt.figure(figsize=(30, 12))

    plt.fill_between(dates.cpu().detach().numpy(),
                     retention_min.cpu().detach().numpy(),
                     retention_max.cpu().detach().numpy(),
                     color='#000044',   label="Source confidence interval", linestyle='dotted')

    retention_predicted_trend = model.forward_trend_function(dates_excluding_bad)
    retention_estimated_min, retention_estimated_max = \
        evaluate_ideal_confidence_interval(dates_excluding_bad.cpu().detach(),
                                           retention_predicted_trend.cpu().detach(),
                                           installs_excluding_bad.cpu().detach(),
                                           patches_dates,
                                           number_of_sigmas_for_plotting,
                                           dates,
                                           retention_predicted_avg.cpu().detach())

    plt.fill_between(dates.cpu().detach().numpy(),
                     retention_estimated_min.cpu().detach().numpy(),
                     retention_estimated_max.cpu().detach().numpy(),
                     color='#BBBBBB', label="Best confidence interval", alpha=0.8)

    if retention_oscillated != None and retention_oscillated.shape == dates.shape:
        plt.plot(dates.cpu().detach().numpy(), retention_oscillated.cpu().detach().numpy(),  color='Green',  label="Source data without noise")

    plt.plot(dates.cpu().detach().numpy(), retention.cpu().detach().numpy(),        color='Blue',   label="Source data")
    plt.plot(dates_excluding_bad.cpu().detach().numpy(), retention_excluding_bad.cpu().detach().numpy(), color='Blue', label="Train data", marker='o', linestyle='')

    plt.plot(dates.cpu().detach().numpy(), retention_predicted_for_plot.cpu().detach().numpy(), color='Magenta', label="Predicted data", linewidth=3)

    if retention_mean != None:
        plt.plot(dates.cpu().detach().numpy(), retention_mean.cpu().detach().numpy(),   color='Red',    label="Original trend")

    plt.plot(dates.cpu().detach().numpy(), retention_predicted_avg.cpu().detach().numpy(), color='Red', label="Predicted trend", linewidth=5)

    plt.legend(bbox_to_anchor=(1.05, 0.95), loc=2, borderaxespad=0., fontsize=17)
    plt.title("Iteration {0} of {1}".format(i+1, len(training_strategy)), fontsize=17)

    plt.show()