# RFM Segmentation with python (The Data analytics approach to gain customer insights.)

## Import required libraries 

In [1]:
# importing required packages
import pandas as pd
import datetime as dt
from utils.config import config

## Download or import transaction data and customer data

In [2]:

transact_data_file = config.get('INPUT','transaction_data',
                                 fallback='./data/input/customer_transaction_data.csv')

customer_data_file = config.get('INPUT','customer_data',
                                 fallback='./data/input/customer_data.csv')

df_customer = pd.read_csv(customer_data_file, on_bad_lines='skip')
df_transactions = pd.read_csv(transact_data_file, parse_dates=['transactionDate'], on_bad_lines='skip')

In [3]:
df_customer.info()
df_customer.describe()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2000 entries, 0 to 1999
Data columns (total 8 columns):
 #   Column                         Non-Null Count  Dtype 
---  ------                         --------------  ----- 
 0   customerID                     2000 non-null   int64 
 1   customerName                   2000 non-null   object
 2   customerAge                    2000 non-null   int64 
 3   customerGender                 2000 non-null   object
 4   customerLocation               2000 non-null   object
 5   customerEducation              2000 non-null   object
 6   customerIndustry               2000 non-null   object
 7   customerAuthorizedSignatories  2000 non-null   object
dtypes: int64(2), object(6)
memory usage: 125.1+ KB


Unnamed: 0,customerID,customerAge
count,2000.0,2000.0
mean,54395640.0,53.6385
std,26184540.0,15.99155
min,10054370.0,20.0
25%,31532970.0,42.75
50%,53438500.0,55.0
75%,78019020.0,66.0
max,99943150.0,80.0


In [4]:
df_transactions.info()
df_transactions.describe()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 143860 entries, 0 to 143859
Data columns (total 11 columns):
 #   Column                Non-Null Count   Dtype         
---  ------                --------------   -----         
 0   transactionReference  143860 non-null  object        
 1   transactionDate       143860 non-null  datetime64[ns]
 2   payeeAccountNumber    143860 non-null  int64         
 3   payeeName             143860 non-null  object        
 4   payeeIndustry         0 non-null       float64       
 5   transactionCode       143860 non-null  object        
 6   amount                143860 non-null  float64       
 7   indicator             143860 non-null  object        
 8   transferNotes         116478 non-null  object        
 9   transactionCategory   143860 non-null  object        
 10  customerID            143860 non-null  int64         
dtypes: datetime64[ns](1), float64(2), int64(2), object(6)
memory usage: 12.1+ MB


Unnamed: 0,payeeAccountNumber,payeeIndustry,amount,customerID
count,143860.0,0.0,143860.0,143860.0
mean,5011971000.0,,3966.444975,53565560.0
std,2883764000.0,,3095.450172,26386080.0
min,38382.0,,100.0,10054370.0
25%,2507307000.0,,863.8075,30153780.0
50%,5027602000.0,,3500.0,52952830.0
75%,7503200000.0,,6633.845,77829310.0
max,9999921000.0,,9998.92,99943150.0


## Data Preparation

In [5]:
df_customer[df_customer.duplicated()].any()

customerID                       False
customerName                     False
customerAge                      False
customerGender                   False
customerLocation                 False
customerEducation                False
customerIndustry                 False
customerAuthorizedSignatories    False
dtype: bool

In [6]:
df_transactions[df_transactions.duplicated()].any()

transactionReference    False
transactionDate         False
payeeAccountNumber      False
payeeName               False
payeeIndustry           False
transactionCode         False
amount                  False
indicator               False
transferNotes           False
transactionCategory     False
customerID              False
dtype: bool

## Data Cleaning

In [7]:
#Data Cleaning for customers 

#Dropping Duplicates
df_customer.drop_duplicates(inplace=True)
generations = lambda age: 'Gen Z' if age <= 24 else 'Millennials' if age <= 40 else 'Gen X' if age <= 55 else 'Baby Boomers' if age <= 75 else 'Silent Generation'
df_customer['customerGeneration'] = df_customer['customerAge'].apply(generations)
state = lambda x: x.split(",")[-1].split()[0] if len(x.split(",")[-1].split()) == 2 else None
df_customer['customerState'] = df_customer['customerLocation'].apply(state)

#Data cleaning for transactions 
df_transactions.drop_duplicates(inplace=True)
df_transactions = df_transactions[df_transactions['customerID'].isin(df_customer['customerID'])]
df_transactions['amount'] = df_transactions['amount'].apply(abs)


In [8]:
#Checking if duplicates have been dropped
df_transactions[df_transactions.duplicated()].any()

transactionReference    False
transactionDate         False
payeeAccountNumber      False
payeeName               False
payeeIndustry           False
transactionCode         False
amount                  False
indicator               False
transferNotes           False
transactionCategory     False
customerID              False
dtype: bool

In [9]:
#Checking for the number of unique Customers
df_transactions['customerID'].nunique()

2000

In [10]:
# Checking for the total number of transaction records
df_transactions.shape

(143860, 11)

In [11]:
# Checking for the max and min InvoiceData inorder to calculate number of months of data available
print('Min:{}; max:{}'.format(min(df_transactions.transactionDate),max(df_transactions.transactionDate)))

Min:2021-04-24 00:00:00; max:2023-04-23 00:00:00


# Cohort Analysis

Descriptive analytics tool used to group customers into mutually exclusive cohorts measured over time. Helps understand high level trends better by providing insight on metrics across products ans Customer life cycle.

## Assign Acquisition Month Cohort to each Customer
Assumption: Considert first transactionDate as acquisition date

In [12]:
# Define a function that will parse the date, it truncates given date obect to the first day of the month
def get_month(x): return dt.datetime(x.year, x.month, 1) 

def get_quarter_start(x):
    quarter_start_month = ((x.month - 1) // 3) * 3 + 1
    return dt.datetime(x.year, quarter_start_month, 1)

def get_cohort_start_date(x):
    cohort = config.get('GROUPING','cohort',
                                 fallback='MONTHLY')
    x = get_month(x) if cohort == 'MONTHLY' else get_quarter_start(x) if cohort == 'QUARTERLY' else x 
    return x 

# Apply get_month method to transactionDate and create acquisitionDate Column
df_transactions['acquisitionDate'] = df_transactions['transactionDate'].apply(get_cohort_start_date) 

# Create groupby Obj with customerID & use acquisitionDate column for further Manipulation
grouping = df_transactions.groupby('customerID')['acquisitionDate'] 

# Finally Transform with min function to assign the smallest acquisitionDate Value to each Customer in the DataSet
df_transactions['cohort'] = grouping.transform('min')

In [13]:
#Extract integer values from the data
def get_date_int(data,column):
    year=data[column].dt.year
    month=data[column].dt.month
    day=data[column].dt.day
    return year, month, day

In [14]:
# Assign Time Offset Value
invoice_year, invoice_month, _=get_date_int(df_transactions,'acquisitionDate')
cohort_year, cohort_month,_=get_date_int(df_transactions,'cohort')
year_diff= invoice_year-cohort_year
month_diff=invoice_month-cohort_month
#+1 for first month to be marked as one instead of 0 for better interpretetation
df_transactions['CohortIndex']= (year_diff*12) + (month_diff+1)
#check if the new column has been added. CohortIndex
df_transactions.head()

Unnamed: 0,transactionReference,transactionDate,payeeAccountNumber,payeeName,payeeIndustry,transactionCode,amount,indicator,transferNotes,transactionCategory,customerID,acquisitionDate,cohort,CohortIndex
0,643f40764e524392a01399ee37ba88c1,2023-03-01,83407944,Vasquez Inc License Solutions Ltd,,AXS,4195.13,DB,license fee,Licenses and insurance,57691954,2023-03-01,2021-05-01,23
1,0abf25481d924be0af38c6aeab8fd58b,2022-12-13,8994077921,Henderson Inc Dining Group,,MB,6052.76,DB,,Meals and entertainment,57691954,2022-12-01,2021-05-01,20
2,5a081469bcfa441abe1288a4c523cc98,2022-09-22,6125994629,"Clements, Garcia and Hughes License Providers Ltd",,IB,1832.56,DB,accounting,Licenses and insurance,57691954,2022-09-01,2021-05-01,17
3,7fde23bda94543809f21255a402c7fda,2022-09-01,2444672539,"Grant, Reed and Jones Technologies and Sons",,TT,6691.37,DB,,IT expenses,57691954,2022-09-01,2021-05-01,17
4,c180344412e94c94b037a595588c2c68,2022-06-23,1993595922,Mr. Paul Stewart,,FX,708.46,CR,miscellaneous,Other Income,57691954,2022-06-01,2021-05-01,14


In [15]:
grouping = df_transactions.groupby(['cohort', 'CohortIndex'])

cohort_data = grouping['customerID'].apply(pd.Series.nunique)

cohort_data = cohort_data.reset_index()

cohort_counts = cohort_data.pivot(index='cohort',
                                 columns='CohortIndex',
                                 values='customerID')
print(cohort_counts)

CohortIndex     1      2      3      4      5      6      7      8      9   \
cohort                                                                       
2021-04-01   408.0  385.0  380.0  383.0  383.0  377.0  373.0  256.0  256.0   
2021-05-01   258.0  218.0  221.0  222.0  218.0  183.0   44.0   44.0   44.0   
2021-06-01    38.0   15.0   17.0   16.0   12.0    NaN    NaN    NaN    NaN   
2021-07-01   713.0  681.0  682.0  684.0  672.0  677.0  676.0  673.0  679.0   
2021-08-01    41.0   17.0   17.0   16.0   17.0   17.0   17.0   17.0   17.0   
2021-09-01    39.0    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN   
2021-10-01   180.0  139.0  138.0  133.0  133.0  136.0  133.0  136.0  139.0   
2021-11-01   172.0  143.0  135.0  132.0  140.0  139.0  138.0  135.0  142.0   
2021-12-01    35.0    7.0    5.0    6.0    7.0    5.0    7.0    7.0    7.0   
2022-01-01    41.0    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN   
2022-02-01    28.0    NaN    NaN    NaN    NaN    NaN    NaN    

In [16]:
cohort_sizes = cohort_counts.iloc[:,0]
retention = cohort_counts.divide(cohort_sizes, axis=0)
retention.index=retention.index.date
retention.round(3)*100

CohortIndex,1,2,3,4,5,6,7,8,9,10,...,16,17,18,19,20,21,22,23,24,25
2021-04-01,100.0,94.4,93.1,93.9,93.9,92.4,91.4,62.7,62.7,62.7,...,62.7,62.7,62.7,62.7,62.7,62.7,62.7,62.7,62.7,62.5
2021-05-01,100.0,84.5,85.7,86.0,84.5,70.9,17.1,17.1,17.1,17.1,...,17.1,17.1,17.1,17.1,17.1,17.1,17.1,17.1,17.1,
2021-06-01,100.0,39.5,44.7,42.1,31.6,,,,,,...,,,,,,,,,,
2021-07-01,100.0,95.5,95.7,95.9,94.2,95.0,94.8,94.4,95.2,95.1,...,94.7,94.7,95.4,94.8,94.4,95.0,93.7,,,
2021-08-01,100.0,41.5,41.5,39.0,41.5,41.5,41.5,41.5,41.5,41.5,...,41.5,41.5,41.5,41.5,39.0,41.5,,,,
2021-09-01,100.0,,,,,,,,,,...,,,,,,,,,,
2021-10-01,100.0,77.2,76.7,73.9,73.9,75.6,73.9,75.6,77.2,76.7,...,76.1,76.1,75.0,70.6,,,,,,
2021-11-01,100.0,83.1,78.5,76.7,81.4,80.8,80.2,78.5,82.6,81.4,...,82.6,75.6,76.2,,,,,,,
2021-12-01,100.0,20.0,14.3,17.1,20.0,14.3,20.0,20.0,20.0,20.0,...,20.0,17.1,,,,,,,,
2022-01-01,100.0,,,,,,,,,,...,,,,,,,,,,


In [17]:
import plotly.express as px

values = retention.fillna(0).applymap(lambda x: round(x*100)).values

fig = px.imshow(values,
                title='Retention Rates',
                x=list(retention.columns),
                y=list(retention.index),
                color_continuous_scale='ylgn',
                text_auto=True,
                height=400,
                width=800
                )

fig.update_layout(xaxis_title='Number of months', yaxis_title='Cohort',
                  coloraxis_colorbar=dict(title='% Activity'),
                  height=500, width=800)

fig.show()



# RFM Segmentation 

Recency (R) : Days Since Last Customer Transaction
Frequency (F): Number of transacations in the last 12 months
Monetary Value (M) : Total Spend in the last 12 months

# RFM Data Preperation

 Pandas built-in function #qcut will be used to calculate percentiles
 
To implement RFM Segmentation, we need to further process the data set in by the following steps:

Recency : For each customer ID, calculate the days since the last transaction. Create a hypothetical date maximum Date +1 to make it seem as though we are working on the most recent data substract the max Date of transaction(Most recent date of transation) of the customer. However, usually the data used is Real time data and using the present date would be ideal. 
Frequency: Count the number of invoices per customer to derive the frequency and 
Monetary Data: Sum the amount of money a customer transacted and divide it by Frequency, to get the amount per transaction on average, that is the Monetary data.

In [18]:
# create hypothetical snapshot_day as if anlysisng the most recent data
snapshot=max(df_transactions.transactionDate)+dt.timedelta(days=1)

In [19]:
datamart=df_transactions.groupby(['customerID']).agg({
    'transactionDate':lambda x:(snapshot-x.max()).days,
    'transactionReference':'count',
    'amount':'sum'
})

In [20]:
# Rename columns for easy interpretation
datamart.rename(columns={'transactionDate':'Recency',
                         'transactionReference': 'Frequency',
                         'amount': 'MonetaryValue'},inplace=True

)

In [21]:
#View of The RFM table
datamart.tail()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
99880787,552,12,56171.64
99893556,2,48,156178.37
99900651,2,192,756029.01
99923008,7,48,192463.08
99943153,461,1,1081.64


In [22]:
# Create Lables for Each RFM Metric:Create generator of values for labels with range function
r_labels=list(range(5,0,-1))
m_labels=range(1,6)
f_labels=range(1,5)

In [23]:
#Create quartile Values using qcut function
r_quartiles=pd.qcut(datamart['Recency'],5,labels=r_labels)
m_quartiles=pd.qcut(datamart['MonetaryValue'],5,labels=m_labels)
f_quartiles=pd.qcut(datamart['Frequency'],4,labels=f_labels)

In [24]:
#Assign R,F,M quartile values to customers
datamart=datamart.assign(R=r_quartiles.values)
datamart=datamart.assign(F=f_quartiles.values,M=m_quartiles.values)

In [25]:
#snealpeak of the added column-R
datamart.head()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue,R,F,M
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
10054368,474,1,6359.9,2,1,1
10061286,37,48,233045.15,2,2,3
10387014,2,192,751096.67,5,4,5
10414616,705,1,7050.9,1,1,1
10428720,2,96,348497.24,5,3,3


In [26]:
# Sneakpeak of the new datamart
datamart.head()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue,R,F,M
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
10054368,474,1,6359.9,2,1,1
10061286,37,48,233045.15,2,2,3
10387014,2,192,751096.67,5,4,5
10414616,705,1,7050.9,1,1,1
10428720,2,96,348497.24,5,3,3


In [27]:
# deriving RFM-Segment column
def join_rfm(x) : return str(x['R'])+str(x['F'])+str(x['M'])
datamart['RFM_Segment']=datamart.apply(join_rfm,axis=1)
# Deriving RFM Score column
datamart['RFM_Score']=datamart[['R','F','M']].sum(axis=1)

In [28]:
# snakpeak of  datamart 
datamart.head()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue,R,F,M,RFM_Segment,RFM_Score
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
10054368,474,1,6359.9,2,1,1,2.01.01.0,4
10061286,37,48,233045.15,2,2,3,2.02.03.0,7
10387014,2,192,751096.67,5,4,5,5.04.05.0,14
10414616,705,1,7050.9,1,1,1,1.01.01.0,3
10428720,2,96,348497.24,5,3,3,5.03.03.0,11


In [29]:
datamart.groupby('RFM_Segment').size().sort_values(ascending=False)[:10]

RFM_Segment
2.01.01.0    203
1.01.01.0    197
1.01.02.0    195
5.04.05.0    160
5.03.04.0    149
4.03.04.0    117
3.03.04.0    117
2.01.02.0    105
4.04.05.0     97
5.03.03.0     79
dtype: int64

In [30]:
# Summary metrics per RFM Score
datamart.groupby('RFM_Score').agg({
  'Recency':'mean',
  'MonetaryValue' :'mean',
  'Frequency':['mean','count']
}).round(1)

Unnamed: 0_level_0,Recency,MonetaryValue,Frequency,Frequency
Unnamed: 0_level_1,mean,mean,mean,count
RFM_Score,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
3,619.5,10197.9,3.3,197
4,518.0,28297.9,6.9,398
5,550.5,50308.9,12.0,105
6,29.7,171649.6,48.0,21
7,21.3,189580.9,48.0,71
8,11.8,206125.2,51.8,100
9,9.4,275309.9,72.7,142
10,8.4,338174.6,86.1,228
11,5.0,378547.5,96.4,225
12,4.1,465422.1,114.2,221


## Grouping Customers into Named Segments
Now that we have competed the RFM segmentation, users can be groups into named categories for marketing and profiling purpsoses.
1. MVC (Most Valuable customer): RFM_Score >=12
2. Loyal Customers: RFM_Score between 9 and 11
3. Potentially Loyal: RFM_Score between 7 and 9
4. Need Attention : RFM Score between 5 and 6
5. Churned Folk : RFM_Score < 5


In [31]:
def segment_me(datamart):
    if datamart['RFM_Score']>=12 :
        return 'MVC'
    if(datamart['RFM_Score']>=9) and datamart['RFM_Score']<11:
        return 'Loyal '
    if(datamart['RFM_Score']>=7) and datamart['RFM_Score']<9:
        return 'Potentially Loyal'
    elif(datamart['RFM_Score']>=4) and datamart['RFM_Score']<6:
        return 'Churned Folk'
    else:
        return  'Need Attention'

In [32]:
datamart['General_Segment']=datamart.apply(segment_me,axis=1)    
datamart.groupby('General_Segment').agg({
  'Recency':'mean',
  'MonetaryValue' :'mean',
  'Frequency':['mean','count']
    
    }).round(1)

Unnamed: 0_level_0,Recency,MonetaryValue,Frequency,Frequency
Unnamed: 0_level_1,mean,mean,mean,count
General_Segment,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
Churned Folk,524.8,32892.6,8.0,503
Loyal,8.7,314048.2,81.0,370
MVC,3.5,610156.7,152.0,513
Need Attention,279.4,204936.3,52.7,443
Potentially Loyal,15.7,199255.9,50.2,171


In [33]:
datamart.head()

Unnamed: 0_level_0,Recency,Frequency,MonetaryValue,R,F,M,RFM_Segment,RFM_Score,General_Segment
customerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
10054368,474,1,6359.9,2,1,1,2.01.01.0,4,Churned Folk
10061286,37,48,233045.15,2,2,3,2.02.03.0,7,Potentially Loyal
10387014,2,192,751096.67,5,4,5,5.04.05.0,14,MVC
10414616,705,1,7050.9,1,1,1,1.01.01.0,3,Need Attention
10428720,2,96,348497.24,5,3,3,5.03.03.0,11,Need Attention


In [34]:
datamart.describe()

Unnamed: 0,Recency,Frequency,MonetaryValue,RFM_Score
count,2000.0,2000.0,2000.0,2000.0
mean,197.7475,71.93,285306.38707,8.386
std,263.093255,62.888334,250341.248365,3.727191
min,1.0,1.0,106.65,3.0
25%,4.0,12.0,45929.915,4.0
50%,10.0,72.0,268674.715,9.0
75%,548.0,96.0,399094.6575,12.0
max,727.0,192.0,881015.56,14.0


## Display as chart for high-level segment

In [35]:
import plotly.express as px

# count the number of occurrences of each General_Segment value
segment_counts = datamart['General_Segment'].value_counts()

fig = px.treemap(
    title= "RFM Customer Segmentation",
    names=[f"{x}<br>{y}" for x, y in zip(segment_counts.index.to_list(), segment_counts.to_list())],
    parents=["Customer Segments"]*segment_counts.size,
    values=segment_counts.to_list(),
    labels=segment_counts.to_list()
)
fig.show()

## Further Analysis using customer demographic information

In [36]:
df_customer_datamart = pd.merge(df_customer, datamart, on='customerID', how='inner')
df_customer_datamart.head()

Unnamed: 0,customerID,customerName,customerAge,customerGender,customerLocation,customerEducation,customerIndustry,customerAuthorizedSignatories,customerGeneration,customerState,Recency,Frequency,MonetaryValue,R,F,M,RFM_Segment,RFM_Score,General_Segment
0,57691954,Miss Patricia Baird,45,Male,"65307 Chavez Bypass Apt. 116\nPort Sharonview,...",PhD,Finance,['Miss Patricia Baird'],Gen X,NJ,6,192,812626.87,4,4,5,4.04.05.0,13,MVC
1,10915242,Abigail Quinn,38,Male,"89836 Darryl Vista\nStaceystad, MT 17850",Bachelor,Healthcare,"['Abigail Quinn', 'Robert Walker']",Millennials,MT,2,192,716325.47,5,4,5,5.04.05.0,14,MVC
2,25972837,Robertson PLC,31,Female,USNS Vega\nFPO AA 24277,PhD,Tourism,['Robertson PLC'],Millennials,,4,192,766320.2,4,4,5,4.04.05.0,13,MVC
3,18390037,Robles Inc,56,Male,"237 Sullivan Drive\nSmithbury, RI 14765",PhD,Education,"['Robles Inc', 'Mr. Nicholas Rogers']",Baby Boomers,RI,6,192,791639.06,4,4,5,4.04.05.0,13,MVC
4,39736631,Amy Smith,74,Male,"166 Huber Street\nEast Angela, AZ 54158",High School,Furniture,"['Amy Smith', 'John Brooks']",Baby Boomers,AZ,12,192,785246.74,3,4,5,3.04.05.0,12,MVC


## Show drill-down data or sub-segments

Feel free to add more entries to list variable `sub_segment_columns=['customerEducation', 'customerIndustry', 'customerGeneration']`

In [37]:
import plotly.express as px
import pandas as pd

segment_groups = df_customer_datamart.groupby('General_Segment')
for segment, segment_data in segment_groups:
    sub_segment_columns=['customerEducation', 'customerIndustry', 'customerGeneration']
    group_data = segment_data.groupby(sub_segment_columns).size().reset_index(name='count')
    group_data["all"] = segment # in order to have a single root node
    fig = px.treemap(group_data,path=sub_segment_columns, values='count', title=segment)
    fig.update_traces(root_color="lightgrey")
    fig.update_layout(margin = dict(t=50, l=25, r=25, b=25))
    fig.show()
    
for segment, segment_data in segment_groups:
    sub_segment_columns=['customerState']
    group_data = segment_data.groupby(sub_segment_columns).size().reset_index(name='count')
    group_data["all"] = segment # in order to have a single root node
    fig = px.treemap(group_data,path=sub_segment_columns, values='count', title=segment)
    fig.update_traces(root_color="lightgrey")
    fig.update_layout(margin = dict(t=50, l=25, r=25, b=25))
    fig.show()

## Save data for further evaluation and action.

In [38]:
import os


transaction_data_file = config.get('OUTPUT','transaction_data',
                                 fallback='./data/output/customer_transaction_data.csv')

customer_data_file = config.get('OUTPUT','customer_data',
                                 fallback='./data/output/customer_data.csv')

os.makedirs(os.path.dirname(transaction_data_file),exist_ok=True)
os.makedirs(os.path.dirname(customer_data_file),exist_ok=True)

df_transactions.to_csv(transaction_data_file, index=False)
df_customer_datamart.to_csv(customer_data_file, index=False)

print(f"Data is saved to files\n{customer_data_file}\n{transaction_data_file}\n")

Data is saved to files
./data/output/customer_data.csv
./data/output/customer_transaction_data.csv

