# RFM Segmentation with python (The Data analytics approach to gain customer insights.)

## Import required libraries 

In [121]:
# importing required packages
import pandas as pd
import datetime as dt
from utils.config import config

## Download or import transaction data and customer data

In [122]:

transact_data_file = config.get('INPUT','transaction_data',
                                 fallback='./data/input/customer_transaction_data.csv')

customer_data_file = config.get('INPUT','customer_data',
                                 fallback='./data/input/customer_data.csv')

df_customer = pd.read_csv(customer_data_file, on_bad_lines='skip')
df_transactions = pd.read_csv(transact_data_file, parse_dates=['transactionDate'], on_bad_lines='skip')

In [123]:
df_customer.info()
df_customer.describe()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1500 entries, 0 to 1499
Data columns (total 8 columns):
 #   Column                         Non-Null Count  Dtype 
---  ------                         --------------  ----- 
 0   customerID                     1500 non-null   int64 
 1   customerName                   1500 non-null   object
 2   customerAge                    1500 non-null   int64 
 3   customerGender                 1500 non-null   object
 4   customerLocation               1500 non-null   object
 5   customerEducation              1500 non-null   object
 6   customerIndustry               1500 non-null   object
 7   customerAuthorizedSignatories  1500 non-null   object
dtypes: int64(2), object(6)
memory usage: 93.9+ KB


Unnamed: 0,customerID,customerAge
count,1500.0,1500.0
mean,56682250.0,52.958
std,25943490.0,15.978061
min,10046130.0,20.0
25%,34485030.0,42.0
50%,57242780.0,54.5
75%,79241620.0,66.0
max,99914130.0,80.0


In [124]:
df_transactions.info()
df_transactions.describe()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 107895 entries, 0 to 107894
Data columns (total 11 columns):
 #   Column                Non-Null Count   Dtype         
---  ------                --------------   -----         
 0   transactionReference  107895 non-null  object        
 1   transactionDate       107895 non-null  datetime64[ns]
 2   payeeAccountNumber    107895 non-null  int64         
 3   payeeName             107895 non-null  object        
 4   payeeIndustry         0 non-null       float64       
 5   transactionCode       107895 non-null  object        
 6   amount                107895 non-null  float64       
 7   indicator             107895 non-null  object        
 8   transferNotes         87093 non-null   object        
 9   transactionCategory   107895 non-null  object        
 10  customerID            107895 non-null  int64         
dtypes: datetime64[ns](1), float64(2), int64(2), object(6)
memory usage: 9.1+ MB


Unnamed: 0,payeeAccountNumber,payeeIndustry,amount,customerID
count,107895.0,0.0,107895.0,107895.0
mean,4987604000.0,,3978.089271,56655440.0
std,2883222000.0,,3085.140476,26015200.0
min,64249.0,,100.1,10046130.0
25%,2492739000.0,,878.55,34227040.0
50%,4982691000.0,,3500.0,56898980.0
75%,7478428000.0,,6629.245,79864270.0
max,9999903000.0,,9998.99,99914130.0


## Data Preparation

In [125]:
df_customer[df_customer.duplicated()].any()

customerID                       False
customerName                     False
customerAge                      False
customerGender                   False
customerLocation                 False
customerEducation                False
customerIndustry                 False
customerAuthorizedSignatories    False
dtype: bool

In [126]:
df_transactions[df_transactions.duplicated()].any()

transactionReference    False
transactionDate         False
payeeAccountNumber      False
payeeName               False
payeeIndustry           False
transactionCode         False
amount                  False
indicator               False
transferNotes           False
transactionCategory     False
customerID              False
dtype: bool

## Data Cleaning

In [127]:
#Data Cleaning for customers 

#Dropping Duplicates
df_customer.drop_duplicates(inplace=True)
generations = lambda age: 'Gen Z' if age <= 24 else 'Millennials' if age <= 40 else 'Gen X' if age <= 55 else 'Baby Boomers' if age <= 75 else 'Silent Generation'
df_customer['customerGeneration'] = df_customer['customerAge'].apply(generations)
state = lambda x: x.split(",")[-1].split()[0] if len(x.split(",")[-1].split()) == 2 else None
df_customer['customerState'] = df_customer['customerLocation'].apply(state)

#Data cleaning for transactions 
df_transactions.drop_duplicates(inplace=True)
df_transactions = df_transactions[df_transactions['customerID'].isin(df_customer['customerID'])]
df_transactions['amount'] = df_transactions['amount'].apply(abs)


In [128]:
#Checking if duplicates have been dropped
df_transactions[df_transactions.duplicated()].any()

transactionReference    False
transactionDate         False
payeeAccountNumber      False
payeeName               False
payeeIndustry           False
transactionCode         False
amount                  False
indicator               False
transferNotes           False
transactionCategory     False
customerID              False
dtype: bool

In [129]:
#Checking for the number of unique Customers
df_transactions['customerID'].nunique()

1500

In [130]:
# Checking for the total number of transaction records
df_transactions.shape

(107895, 11)

In [131]:
# Checking for the max and min InvoiceData inorder to calculate number of months of data available
print('Min:{}; max:{}'.format(min(df_transactions.transactionDate),max(df_transactions.transactionDate)))

Min:2021-05-02 00:00:00; max:2023-05-01 00:00:00


# Cohort Analysis

Descriptive analytics tool used to group customers into mutually exclusive cohorts measured over time. Helps understand high level trends better by providing insight on metrics across products ans Customer life cycle.

## Assign Acquisition Month Cohort to each Customer
Assumption: Considert first transactionDate as acquisition date

In [132]:
# Define a function that will parse the date, it truncates given date obect to the first day of the month
def get_month(x): return dt.datetime(x.year, x.month, 1) 

def get_quarter_start(x):
    quarter_start_month = ((x.month - 1) // 3) * 3 + 1
    return dt.datetime(x.year, quarter_start_month, 1)

def get_cohort_start_date(x):
    cohort = config.get('GROUPING','cohort',
                                 fallback='MONTHLY')
    x = get_month(x) if cohort == 'MONTHLY' else get_quarter_start(x) if cohort == 'QUARTERLY' else x 
    return x 

# Apply get_month method to transactionDate and create acquisitionDate Column
df_transactions['acquisitionDate'] = df_transactions['transactionDate'].apply(get_cohort_start_date) 

# Create groupby Obj with customerID & use acquisitionDate column for further Manipulation
grouping = df_transactions.groupby('customerID')['acquisitionDate'] 

# Finally Transform with min function to assign the smallest acquisitionDate Value to each Customer in the DataSet
df_transactions['cohort'] = grouping.transform('min')

In [133]:
#Extract integer values from the data
def get_date_int(data,column):
    year=data[column].dt.year
    month=data[column].dt.month
    day=data[column].dt.day
    return year, month, day

In [134]:
# Assign Time Offset Value
invoice_year, invoice_month, _=get_date_int(df_transactions,'acquisitionDate')
cohort_year, cohort_month,_=get_date_int(df_transactions,'cohort')
year_diff= invoice_year-cohort_year
month_diff=invoice_month-cohort_month
#+1 for first month to be marked as one instead of 0 for better interpretetation
df_transactions['CohortIndex']= (year_diff*12) + (month_diff+1)
#check if the new column has been added. CohortIndex
df_transactions.head()

Unnamed: 0,transactionReference,transactionDate,payeeAccountNumber,payeeName,payeeIndustry,transactionCode,amount,indicator,transferNotes,transactionCategory,customerID,acquisitionDate,cohort,CohortIndex
0,d80f337c5f81428eb5cc67e93864d7af,2023-02-25,7165930885,Wijayanti Ltd Shipment LLC,,GIRO,8045.5,DB,transport,Transportation costs,68370032,2023-02-01,2021-05-01,22
1,48ddb2dd7c704bf0825005384356aa41,2022-11-27,6068871754,Megantara Ltd LLC,,TT,7532.19,DB,pay,Other Outgoing,68370032,2022-11-01,2021-05-01,19
2,a223f421cf2b4e69a780e3114b74020a,2021-05-13,4930832466,Puspasari-Sinaga Beverages and Sons,,MB,9010.51,DB,coffee,Meals and entertainment,68370032,2021-05-01,2021-05-01,1
3,091f3b20bcc74ab8bd15cfd048fabdfe,2021-05-07,7542270260,Prabowo-Utami Energy Ltd,,IB,3642.16,DB,heating,Utilities,68370032,2021-05-01,2021-05-01,1
4,aa9128c120de4d48b7bd84180ad14363,2022-10-28,137117784,"Mansur, Hutapea and Yulianti Development LLC",,GIRO,3607.85,DB,s/w,IT expenses,68370032,2022-10-01,2021-05-01,18


In [135]:
grouping = df_transactions.groupby(['cohort', 'CohortIndex'])

cohort_data = grouping['customerID'].apply(pd.Series.nunique)

cohort_data = cohort_data.reset_index()

cohort_counts = cohort_data.pivot(index='cohort',
                                 columns='CohortIndex',
                                 values='customerID')
print(cohort_counts)

CohortIndex     1      2      3      4      5      6      7      8      9   \
cohort                                                                       
2021-05-01   485.0  437.0  430.0  433.0  435.0  431.0  224.0  224.0  224.0   
2021-06-01    53.0   27.0   25.0   28.0   28.0    1.0    1.0    1.0    1.0   
2021-07-01   509.0  477.0  478.0  476.0  477.0  478.0  480.0  470.0  477.0   
2021-08-01    68.0   44.0   42.0   42.0   42.0   44.0   44.0   44.0   44.0   
2021-09-01    24.0    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN   
2021-10-01    14.0    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN   
2021-11-01   233.0  203.0  207.0  202.0  206.0  200.0  203.0  201.0  206.0   
2021-12-01    31.0    6.0    8.0    8.0    8.0    7.0    8.0    7.0    8.0   
2022-01-01    30.0    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN   
2022-02-01    14.0    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN   
2022-03-01    18.0    NaN    NaN    NaN    NaN    NaN    NaN    

In [136]:
cohort_sizes = cohort_counts.iloc[:,0]
retention = cohort_counts.divide(cohort_sizes, axis=0)
retention.index=retention.index.date
retention.round(3)*100

CohortIndex,1,2,3,4,5,6,7,8,9,10,...,16,17,18,19,20,21,22,23,24,25
2021-05-01,100.0,90.1,88.7,89.3,89.7,88.9,46.2,46.2,46.2,46.0,...,46.2,46.2,46.0,46.2,46.2,46.2,46.2,46.2,46.2,11.3
2021-06-01,100.0,50.9,47.2,52.8,52.8,1.9,1.9,1.9,1.9,1.9,...,1.9,1.9,1.9,1.9,1.9,1.9,1.9,1.9,,
2021-07-01,100.0,93.7,93.9,93.5,93.7,93.9,94.3,92.3,93.7,93.3,...,93.3,93.1,94.1,93.3,92.7,93.9,92.9,13.8,,
2021-08-01,100.0,64.7,61.8,61.8,61.8,64.7,64.7,64.7,64.7,64.7,...,64.7,64.7,64.7,64.7,64.7,64.7,10.3,,,
2021-09-01,100.0,,,,,,,,,,...,,,,,,,,,,
2021-10-01,100.0,,,,,,,,,,...,,,,,,,,,,
2021-11-01,100.0,87.1,88.8,86.7,88.4,85.8,87.1,86.3,88.4,89.7,...,83.7,89.3,89.3,4.7,,,,,,
2021-12-01,100.0,19.4,25.8,25.8,25.8,22.6,25.8,22.6,25.8,25.8,...,22.6,25.8,3.2,,,,,,,
2022-01-01,100.0,,,,,,,,,,...,,,,,,,,,,
2022-02-01,100.0,,,,,,,,,,...,,,,,,,,,,


In [137]:
import plotly.express as px

values = retention.fillna(0).applymap(lambda x: round(x*100)).values

fig = px.imshow(values,
                title='Retention Rates',
                x=list(retention.columns),
                y=list(retention.index),
                color_continuous_scale='ylgn',
                text_auto=True,
                height=400,
                width=800
                )

fig.update_layout(xaxis_title='Number of months', yaxis_title='Cohort',
                  coloraxis_colorbar=dict(title='% Activity'),
                  height=500, width=800)

fig.show()



# RFM Segmentation 

Recency (R) : Days Since Last Customer Transaction
Frequency (F): Number of transacations in the last 12 months
Monetary Value (M) : Total Spend in the last 12 months

# RFM Data Preperation

 Pandas built-in function #qcut will be used to calculate percentiles
 
To implement RFM Segmentation, we need to further process the data set in by the following steps:

Recency : For each customer ID, calculate the days since the last transaction. Create a hypothetical date maximum Date +1 to make it seem as though we are working on the most recent data substract the max Date of transaction(Most recent date of transation) of the customer. However, usually the data used is Real time data and using the present date would be ideal. 
Frequency: Count the number of invoices per customer to derive the frequency and 
Monetary Data: Sum the amount of money a customer transacted and divide it by Frequency, to get the amount per transaction on average, that is the Monetary data.

In [138]:
# create hypothetical snapshot_day as if anlysisng the most recent data
snapshot=max(df_transactions.transactionDate)+dt.timedelta(days=1)

In [139]:
datamart=df_transactions.groupby(['customerID']).agg({
    'transactionDate':lambda x:(snapshot-x.max()).days,
    'transactionReference':'count',
    'amount':'sum'
})

In [140]:
# Rename columns for easy interpretation
datamart.rename(columns={'transactionDate':'Recency',
                         'transactionReference': 'Frequency',
                         'amount': 'MonetaryValue'},inplace=True

)

In [141]:
#View of The RFM table
datamart.tail()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
99554488,4,96,393492.18
99623425,550,12,36272.03
99673872,630,1,9426.94
99881323,14,96,402700.56
99914131,4,48,180644.98


In [142]:
# Create Lables for Each RFM Metric:Create generator of values for labels with range function
r_labels=list(range(5,0,-1))
m_labels=range(1,6)
f_labels=range(1,5)

In [143]:
#Create quartile Values using qcut function
r_quartiles=pd.qcut(datamart['Recency'],5,labels=r_labels)
m_quartiles=pd.qcut(datamart['MonetaryValue'],5,labels=m_labels)
f_quartiles=pd.qcut(datamart['Frequency'],4,labels=f_labels)

In [144]:
#Assign R,F,M quartile values to customers
datamart=datamart.assign(R=r_quartiles.values)
datamart=datamart.assign(F=f_quartiles.values,M=m_quartiles.values)

In [145]:
#snealpeak of the added column-R
datamart.head()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue,R,F,M
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
10046128,565,12,39126.1,1,1,2
10086469,1,192,737444.48,5,4,5
10116026,551,12,52698.61,2,1,2
10191698,561,12,25465.01,1,1,1
10208524,5,96,418488.78,4,3,5


In [146]:
# Sneakpeak of the new datamart
datamart.head()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue,R,F,M
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
10046128,565,12,39126.1,1,1,2
10086469,1,192,737444.48,5,4,5
10116026,551,12,52698.61,2,1,2
10191698,561,12,25465.01,1,1,1
10208524,5,96,418488.78,4,3,5


In [147]:
# deriving RFM-Segment column
def join_rfm(x) : return str(x['R'])+str(x['F'])+str(x['M'])
datamart['RFM_Segment']=datamart.apply(join_rfm,axis=1)
# Deriving RFM Score column
datamart['RFM_Score']=datamart[['R','F','M']].sum(axis=1)

In [148]:
# snakpeak of  datamart 
datamart.head()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue,R,F,M,RFM_Segment,RFM_Score
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
10046128,565,12,39126.1,1,1,2,1.01.02.0,4
10086469,1,192,737444.48,5,4,5,5.04.05.0,14
10116026,551,12,52698.61,2,1,2,2.01.02.0,5
10191698,561,12,25465.01,1,1,1,1.01.01.0,3
10208524,5,96,418488.78,4,3,5,4.03.05.0,12


In [149]:
datamart.groupby('RFM_Segment').size().sort_values(ascending=False)[:10]

RFM_Segment
1.01.01.0    160
2.01.01.0    140
1.01.02.0    139
5.04.05.0    130
5.03.04.0    121
3.03.04.0    101
2.01.02.0     86
4.03.04.0     62
3.02.03.0     58
5.03.03.0     55
dtype: int64

In [150]:
# Summary metrics per RFM Score
datamart.groupby('RFM_Score').agg({
  'Recency':'mean',
  'MonetaryValue' :'mean',
  'Frequency':['mean','count']
}).round(1)

Unnamed: 0_level_0,Recency,MonetaryValue,Frequency,Frequency
Unnamed: 0_level_1,mean,mean,mean,count
RFM_Score,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
3,631.9,9093.7,3.1,160
4,518.6,29979.4,7.1,279
5,550.8,51269.6,12.0,86
6,29.8,169497.7,48.0,13
7,17.9,186145.2,48.0,49
8,11.9,214753.5,53.8,83
9,10.2,281576.9,74.9,114
10,8.7,341278.5,85.7,186
11,4.8,378541.5,96.0,140
12,4.4,492211.2,121.2,187


## Grouping Customers into Named Segments
Now that we have competed the RFM segmentation, users can be groups into named categories for marketing and profiling purpsoses.
1. MVC (Most Valuable customer): RFM_Score >=12
2. Loyal Customers: RFM_Score between 9 and 11
3. Potentially Loyal: RFM_Score between 7 and 9
4. Need Attention : RFM Score between 5 and 6
5. Churned Folk : RFM_Score < 5


In [151]:
def segment_me(datamart):
    if datamart['RFM_Score']>=11 :
        return 'MVC'
    if(datamart['RFM_Score']>=9) and datamart['RFM_Score']<11:
        return 'Loyal '
    if(datamart['RFM_Score']>=7) and datamart['RFM_Score']<9:
        return 'Potentially Loyal'
    elif(datamart['RFM_Score']>=4) and datamart['RFM_Score']<6:
        return 'Need Attention' 
    else:
        return 'Churned Folk'

In [152]:
datamart['General_Segment']=datamart.apply(segment_me,axis=1)    
datamart.groupby('General_Segment').agg({
  'Recency':'mean',
  'MonetaryValue' :'mean',
  'Frequency':['mean','count']
    
    }).round(1)

Unnamed: 0_level_0,Recency,MonetaryValue,Frequency,Frequency
Unnamed: 0_level_1,mean,mean,mean,count
General_Segment,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
Churned Folk,586.7,21147.2,6.4,173
Loyal,9.3,318591.9,81.6,300
MVC,3.8,547662.0,136.8,530
Need Attention,526.2,34995.7,8.2,365
Potentially Loyal,14.1,204133.8,51.6,132


In [153]:
datamart.head()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue,R,F,M,RFM_Segment,RFM_Score,General_Segment
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
10046128,565,12,39126.1,1,1,2,1.01.02.0,4,Need Attention
10086469,1,192,737444.48,5,4,5,5.04.05.0,14,MVC
10116026,551,12,52698.61,2,1,2,2.01.02.0,5,Need Attention
10191698,561,12,25465.01,1,1,1,1.01.01.0,3,Churned Folk
10208524,5,96,418488.78,4,3,5,4.03.05.0,12,MVC


In [154]:
datamart.describe()

Unnamed: 0,Recency,Frequency,MonetaryValue,RFM_Score
count,1500.0,1500.0,1500.0,1500.0
mean,200.143333,71.93,286143.96126,8.366667
std,266.721218,62.893578,250650.571119,3.720769
min,1.0,1.0,113.2,3.0
25%,3.0,12.0,47739.3375,4.0
50%,11.0,72.0,271693.505,9.0
75%,549.0,96.0,399498.0725,12.0
max,730.0,192.0,856285.64,14.0


## Display as chart for high-level segment

In [155]:
import plotly.express as px

# count the number of occurrences of each General_Segment value
segment_counts = datamart['General_Segment'].value_counts()

fig = px.treemap(
    title= "RFM Customer Segmentation",
    names=[f"{x}<br>{y}" for x, y in zip(segment_counts.index.to_list(), segment_counts.to_list())],
    parents=["Customer Segments"]*segment_counts.size,
    values=segment_counts.to_list(),
    labels=segment_counts.to_list()
)
fig.show()

## Further Analysis using customer demographic information

In [156]:
df_customer_datamart = pd.merge(df_customer, datamart, on='customerID', how='inner')
df_customer_datamart.head()

Unnamed: 0,customerID,customerName,customerAge,customerGender,customerLocation,customerEducation,customerIndustry,customerAuthorizedSignatories,customerGeneration,customerState,Recency,Frequency,MonetaryValue,R,F,M,RFM_Segment,RFM_Score,General_Segment
0,68370032,"Irwan Hartati, M.Ak",25,Female,"Gg. H.J Maemunah No. 8\nBalikpapan, PA 91500",Bachelor,Retail,"['Irwan Hartati, M.Ak']",Millennials,PA,8,192,780273.31,3,4,5,3.04.05.0,12,MVC
1,69040616,"Tgk. Amelia Yulianti, S.I.Kom",47,Female,"Jalan Jend. A. Yani No. 4\nSawahlunto, KS 52885",PhD,Automotive,"['Tgk. Amelia Yulianti, S.I.Kom']",Gen X,KS,6,192,758020.7,4,4,5,4.04.05.0,13,MVC
2,79864269,Endah Pradipta,41,Female,Gang Jend. Sudirman No. 3\nKota Administrasi J...,High School,Furniture,"['Endah Pradipta', 'R.A. Jamalia Uwais']",Gen X,,7,192,705223.52,3,4,5,3.04.05.0,12,MVC
3,93226633,Cahyadi Dongoran,21,Female,"Jalan Lembong No. 88\nDumai, Jawa Timur 33139",Bachelor,Healthcare,['Cahyadi Dongoran'],Gen Z,,8,192,691230.14,3,4,5,3.04.05.0,12,MVC
4,99087446,Prayoga Fujiati,31,Other,"Gang Sentot Alibasa No. 101\nPangkalpinang, RI...",PhD,Hospitality,['Prayoga Fujiati'],Millennials,RI,2,192,799200.87,5,4,5,5.04.05.0,14,MVC


## Show drill-down data or sub-segments

Feel free to add more entries to list variable `sub_segment_columns=['customerEducation', 'customerIndustry', 'customerGeneration']`

In [157]:
import plotly.express as px
import pandas as pd

segment_groups = df_customer_datamart.groupby('General_Segment')
for segment, segment_data in segment_groups:
    sub_segment_columns=['customerEducation', 'customerIndustry', 'customerGeneration']
    group_data = segment_data.groupby(sub_segment_columns).size().reset_index(name='count')
    group_data["all"] = segment # in order to have a single root node
    fig = px.treemap(group_data,path=sub_segment_columns, values='count', title=segment)
    fig.update_traces(root_color="lightgrey")
    fig.update_layout(margin = dict(t=50, l=25, r=25, b=25))
    fig.show()
    
for segment, segment_data in segment_groups:
    sub_segment_columns=['customerState']
    group_data = segment_data.groupby(sub_segment_columns).size().reset_index(name='count')
    group_data["all"] = segment # in order to have a single root node
    fig = px.treemap(group_data,path=sub_segment_columns, values='count', title=segment)
    fig.update_traces(root_color="lightgrey")
    fig.update_layout(margin = dict(t=50, l=25, r=25, b=25))
    fig.show()

## Save data for further evaluation and action.

In [158]:
import os


transaction_data_file = config.get('OUTPUT','transaction_data',
                                 fallback='./data/output/customer_transaction_data.csv')

customer_data_file = config.get('OUTPUT','customer_data',
                                 fallback='./data/output/customer_data.csv')

os.makedirs(os.path.dirname(transaction_data_file),exist_ok=True)
os.makedirs(os.path.dirname(customer_data_file),exist_ok=True)

df_transactions.to_csv(transaction_data_file, index=False)
df_customer_datamart.to_csv(customer_data_file, index=False)

print(f"Data is saved to files\n{customer_data_file}\n{transaction_data_file}\n")

Data is saved to files
./data/output/customer_data.csv
./data/output/customer_transaction_data.csv

