In [None]:
import numpy as np
import pandas as pd
from functools import reduce
from glob import glob
import slreg as slr


def fetch_country(COMPANY_OID):
    """Return country of location of buyer's company."""
    pass


def fetch_suppliers(code):
    """Return DataFrame of suppliers from TT selling product specified by code, indexed by `BAN_REAL`.

    Parameters
    ----------
    code : str
        TAITRA code of len 4.
        
    Returns
    -------
    suppliers : DataFrame
        Columns: `n_items`, `item_name`, `item_desc`, 'upload_date', and 'update_date'.
    """
    pass


def fetch_export(supp_ban, ctry):
    """Return DataFrame of export records for all suppliers contained in `supp_ban`,
    indexed by `BAN_REAL`.
    
    Parameters
    ----------
    supp_ban : list-like
        `BAN_REAL` of target suppliers.
        
    Returns
    -------
    export : DataFrame
        Columns: `n_comm` (number of unique commodities exported by the supplier), `hs_desc`
        (HS descriptions of commodities exported), and `isexporter` (boolean, whether the supplier
        has shipped to buyer's country in the last five years).
    """
    pass


def compute_features(inq):
    """Return DataFrame of calculated features.
    
    Parameters
    ----------
    inq : Series
        Buyer inquiry.
        
    Returns
    -------
    features : DataFrame
        At present there are fourteen features:
        
        1) n_items : number of items in the supplier's catalogue.
        
        2) name_comp : number of words shared by product name of inquiry and catalogue items,
               normalized by `n_items`. measure of 'total compatibility'.
              
        3) name_max_comp : max number of shared words found in any single catalogue item. measure of
               'maximal compatibility'.
        
        4) name_min_comp : min number of shared words found in any single catalogue item. measure of
               'minimal compatibility'.
        
        5) desc_comp : analogous to `name_comp`, except taken between product specification of
               inquiry and catalogue items. normalized by `n_items`.
               
        6) desc_max_comp : analogous to `name_max_comp`.
        
        7) desc_min_comp : analogous to `name_min_comp`.
        
        8) upload_rec : upload recency, equals to
               max(0, `CREATION_DATE` from inquiry - 'upload_date' from suppliers).
        
        9) update_rec : update recency, equals to
               max(0, `CREATION_DATE` from inquiry - 'update_date' from suppliers).
        
        10) n_comm : number of unique commodities the supplier exports.
              
        11) hs_comp : analogous to `name_comp`, except taken between product name of inquiry and HS
                commodity description. normalized by `n_comm`.
            
        12) hs_max_comp : analogous to `name_max_comp`.
        
        13) hs_min_comp : analogous to `name_min_comp`.
        
        14) isexporter : whether the supplier has shipped to buyer's country in the last five years.
    """
    
    code, product, spec = inq['PRODUCT_CATEGORY_OID'], inq['PRODUCT_NAME'], inq['SPECIFICATION']
    ctry = fetch_country(inq['COMPANY_OID'])
    supp = fetch_suppliers(code)  # return DataFrame of suppliers selling product specified by code
    supp_ban = supp.index
    supp_export = fetch_export(supp_ban, ctry)  # return DataFrame indexed by `BAN_REAL`
    supp = pd.concat([supp, supp_export], axis=1)
    
    # ========================================================
    # Transform `supp` to features using `product`, and `spec`
    # ========================================================
    
    return features


def estimate_feature_dist(inqs):
    """Estimate and save population mean and standard deviation for each feature.
    
    Parameters
    ----------
    inqs : DataFrame
        Each row represents an inquiry.
    """
    
    feature_map = map(compute_features, [row for lab, row in inqs.iterrows()])
    stacked = reduce(lambda x, y: x.append(y), feature_map)
    mean, std = stacked.mean(), stacked.std()
    dist = pd.concat([mean, std], axis=1)
    dist.columns = ['mean', 'std']
    dist.to_csv('feature_distribution.csv')
    return


def normalize_features(X):
    """Return normalized features."""
    return (X - mean) / std


# Some useful values
n = 14
alpha = 0.01
dist = pd.read_csv('feature_distribution.csv', index_col=0)
mean, std = dist['mean'], dist['std']

# For each incoming inquiry Series `inq`, run:
# ================================================================
# Load current theta or initialize if not exists
if len(glob('theta.txt')):
    with open('theta.txt', 'r') as f:
        theta = np.array([float(x) for x in f.read().split('\n')])
else:
    theta = np.zeros(n + 1)

# Get data ready
X = compute_features(inq)
X = normalize_features(X)
X['intercept'] = 1

# Predict probabilities
X['prob'] = slr.predict_prob(X, theta)

# Get top 10 suppliers
top10 = X.sort_values('prob', ascending=False).drop('prob', axis=1).head(10)

# Fetch user response Series `y`

# Save y together with `inq` (broadcasted) and 10 `BAN_REAL`s

# Update theta using 10 steps of gradient descent
for i in range(10):
    x = top10.iloc[i].values.reshape((1, n))
    theta, J = slr.gradient_descent(x, y[[i]], theta, alpha)

with open('theta.txt', 'w') as f:
    f.write('\n'.join([str(x) for x in theta]))
# ================================================================