##### Imports

In [1]:
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
from sdv.single_table import CTGANSynthesizer
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality, get_column_plot, get_column_pair_plot
from sdv.metadata import SingleTableMetadata

##### Load dataset

In [2]:
data = pd.read_csv("data/merged_df_file_rutgers.csv")

In [3]:
data.head()

Unnamed: 0,serverTimestamp,day_part_x,user_id,numberRating,highestRating,lowestRating,medianRating,sdRating,numberLowRating,numberMediumRating,numberHighRating,numberMessageReceived,numberMessageRead,readAllMessage,reward,timestamp,day_part_y,action,message,day_part_numeric
0,10/6/2020,0,1BIGBILHOT,0,0,0,0.0,0.0,0,0,0,0,0,0,0.0,01:03.0,morning,0.0,No message was sent!,0.0
1,10/6/2020,1,1BIGBILHOT,0,0,0,0.0,0.0,0,0,0,1,1,1,0.5,01:04.3,afternoon,2.0,Did you forget what pleasant activities to do?...,1.0
2,10/6/2020,2,1BIGBILHOT,0,0,0,0.0,0.0,0,0,0,2,2,1,0.5,01:03.7,evening,3.0,"Many people sometimes feel sad, this is nothin...",2.0
3,10/7/2020,0,1BIGBILHOT,0,0,0,0.0,0.0,0,0,0,1,1,1,0.5,01:03.6,morning,3.0,Even if you don’t rate your mood at some point...,0.0
4,10/7/2020,1,1BIGBILHOT,0,0,0,0.0,0.0,0,0,0,2,2,1,0.5,01:03.7,afternoon,3.0,"Many people sometimes feel sad, this is nothin...",1.0


In [4]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9564 entries, 0 to 9563
Data columns (total 20 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   serverTimestamp        9564 non-null   object 
 1   day_part_x             9564 non-null   int64  
 2   user_id                9564 non-null   object 
 3   numberRating           9564 non-null   int64  
 4   highestRating          9564 non-null   int64  
 5   lowestRating           9564 non-null   int64  
 6   medianRating           9564 non-null   float64
 7   sdRating               9564 non-null   float64
 8   numberLowRating        9564 non-null   int64  
 9   numberMediumRating     9564 non-null   int64  
 10  numberHighRating       9564 non-null   int64  
 11  numberMessageReceived  9564 non-null   int64  
 12  numberMessageRead      9564 non-null   int64  
 13  readAllMessage         9564 non-null   int64  
 14  reward                 9564 non-null   float64
 15  time

##### Data Cleaning (will drop any rows with NAs)

In [5]:
def process_csv(df):

    df = df.dropna()
    '''for col in df.columns:
        if pd.api.types.is_numeric_dtype(df[col]): # Handle numerical missing values (mean imputation)
            if df[col].isnull().any():
                df[col].fillna(df[col].mean(), inplace=True)
        else: # Handle missing values in categorical or other object columns
            if pd.api.types.is_object_dtype(df[col]):
                if df[col].isnull().any():
                    df[col].fillna(df[col].mode()[0], inplace=True)  # Impute with mode for categorical data'''
    return df

clean_data = process_csv(data)

In [6]:
clean_data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 8764 entries, 0 to 9559
Data columns (total 20 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   serverTimestamp        8764 non-null   object 
 1   day_part_x             8764 non-null   int64  
 2   user_id                8764 non-null   object 
 3   numberRating           8764 non-null   int64  
 4   highestRating          8764 non-null   int64  
 5   lowestRating           8764 non-null   int64  
 6   medianRating           8764 non-null   float64
 7   sdRating               8764 non-null   float64
 8   numberLowRating        8764 non-null   int64  
 9   numberMediumRating     8764 non-null   int64  
 10  numberHighRating       8764 non-null   int64  
 11  numberMessageReceived  8764 non-null   int64  
 12  numberMessageRead      8764 non-null   int64  
 13  readAllMessage         8764 non-null   int64  
 14  reward                 8764 non-null   float64
 15  timestamp

##### Check if constraints applied on real data

In [7]:
clean_data = clean_data[(clean_data['numberRating'] >= 0) & (clean_data['numberRating'] <= 3)]
clean_data = clean_data[(clean_data['numberMessageReceived'] >= 0) & (clean_data['numberMessageReceived'] <= 3)]
clean_data = clean_data[(clean_data['numberMessageRead'] >= 0) & (clean_data['numberMessageRead'] <= 3)]
clean_data = clean_data[(clean_data['reward'] <= 2)]

In [8]:
# When generating at day_part_x  = 0, columns numberRating, numberMessageReceived and numberMessageRead have a max of 1, 
# for day_part_x  = 1 the max is 2 and for day_part_x  = 2 the max is 3.

def day_part_no_check(df):
    conditions = {
        0: {'numberRating': 1, 'numberMessageReceived': 1, 'numberMessageRead': 1},
        1: {'numberRating': 2, 'numberMessageReceived': 2, 'numberMessageRead': 2},
        2: {'numberRating': 3, 'numberMessageReceived': 3, 'numberMessageRead': 3}
    }
    
    for day_part, max_values in conditions.items():
        for col, max_value in max_values.items():
            df = df[~((df['day_part_x'] == day_part) & (df[col] > max_value))]
    
    return df

# Filter the DataFrame
clean_data = day_part_no_check(clean_data)

In [9]:
# Action can only take 0, 1, 2, same applies to day_part_x.

def check_action_daypart(df):
    # Define conditions
    condition_action = df['action'].isin([0, 1, 2])
    condition_day_part_x = df['day_part_x'].isin([0, 1, 2])
    
    # Filter DataFrame based on conditions
    df_filtered = df[condition_action & condition_day_part_x]
    
    return df_filtered

# Call the function
clean_data = check_action_daypart(clean_data)

In [10]:
clean_data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 8172 entries, 0 to 9559
Data columns (total 20 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   serverTimestamp        8172 non-null   object 
 1   day_part_x             8172 non-null   int64  
 2   user_id                8172 non-null   object 
 3   numberRating           8172 non-null   int64  
 4   highestRating          8172 non-null   int64  
 5   lowestRating           8172 non-null   int64  
 6   medianRating           8172 non-null   float64
 7   sdRating               8172 non-null   float64
 8   numberLowRating        8172 non-null   int64  
 9   numberMediumRating     8172 non-null   int64  
 10  numberHighRating       8172 non-null   int64  
 11  numberMessageReceived  8172 non-null   int64  
 12  numberMessageRead      8172 non-null   int64  
 13  readAllMessage         8172 non-null   int64  
 14  reward                 8172 non-null   float64
 15  timestamp

##### Metadata Extraction

In [11]:
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(clean_data)

In [12]:
metadata.update_columns(
    column_names=['numberRating', 'numberMessageReceived', 'numberMessageRead', 'highestRating', 'lowestRating', 
                  'numberLowRating', 'numberMediumRating', 'numberHighRating', 'readAllMessage', 'day_part_x', 'day_part_numeric', 'action'],
    sdtype='numerical',
    computer_representation='Int64'
)

metadata.update_columns(
    column_names=['reward', 'medianRating', 'sdRating'],
    sdtype='numerical',
    computer_representation='Float'
)

metadata.update_column(
    column_name='serverTimestamp',
    sdtype='datetime',
    datetime_format='%m/%d/%Y'
)

metadata.update_column(
    column_name='timestamp',
    sdtype='datetime',
    datetime_format='%M:%S.%f'
)

metadata.update_columns(
    column_names=['day_part_y', 'message'],
    sdtype='categorical'
)

metadata.update_column(
    column_name='user_id',
    sdtype='id',
    regex_format="[A-Z0-9]{8,10}"
)

##### Model Initialization

In [13]:
# Initialize the model
synthesizer = CTGANSynthesizer(
    metadata=metadata,
    enforce_rounding=True,
    epochs=1000,
    verbose=True)




##### Define Constraints

In [14]:
# A user can’t input more than 3 ratings per day and can’t read/receive more than 3 messages per day. 
# Thus, rows where columns numberRating, numberMessageReceived and numberMessageRead have a value > 3 can be capped to 3.

nR_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'numberRating',
        'low_value': 0,
        'high_value': 3,
        'strict_boundaries': False
    }
}

nMRc_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'numberMessageReceived',
        'low_value': 0,
        'high_value': 3,
        'strict_boundaries': False
    }
}

nMRd_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'numberMessageRead',
        'low_value': 0,
        'high_value': 3,
        'strict_boundaries': False
    }
}

# Reward can’t be higher than 2.

rwd_constraint = {
    'constraint_class': 'ScalarInequality',
    'constraint_parameters': {
        'column_name': 'reward',
        'relation': '<=',
        'value':  2.0
    }
}

act_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'action',
        'low_value': 0,
        'high_value': 2,
        'strict_boundaries': False
    }
}

daypart_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'day_part_x',
        'low_value': 0,
        'high_value': 2,
        'strict_boundaries': False
    }
}

In [15]:
# load the constraint from the file
synthesizer.load_custom_constraint_classes(
    filepath='data\day_part_logic.py',
    class_names=['day_partConstraintClass']
)

In [16]:
day_partConstraint= {
    'constraint_class': 'day_partConstraintClass',
    'constraint_parameters': {
        'column_names': ['day_part_x', 'numberRating','numberMessageReceived','numberMessageRead']
    }
}

In [17]:
synthesizer.add_constraints(
    constraints=[nR_constraint, nMRc_constraint, nMRd_constraint, day_partConstraint, day_partConstraint, act_constraint]
)

In [18]:
# Fit the synthesizer to real data
synthesizer.fit(clean_data)

  from .autonotebook import tqdm as notebook_tqdm
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Gen. (-0.77) | Discrim. (-0.62): 100%|██████████| 100/100 [00:38<00:00,  2.60it/s]


In [19]:
# Generate synthetic data
synthetic_data = synthesizer.sample(100)

Sampling rows: 100%|██████████| 100/100 [00:00<00:00, 485.35it/s]


In [20]:
# Save synthetic data to CSV file
os.makedirs(os.path.dirname("data/ctgan_data.csv"), exist_ok=True)
synthetic_data.to_csv("data/ctgan_data.csv", index=False)

In [21]:
# Run diagnostic and evaluate quality
diagnostic = run_diagnostic(real_data=clean_data, synthetic_data=synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Data Validity: |██████████| 20/20 [00:00<00:00, 1665.07it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 499.38it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%



In [22]:
quality_report = evaluate_quality(real_data=clean_data, synthetic_data=synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 20/20 [00:00<00:00, 624.96it/s]|
Column Shapes Score: 81.01%

(2/2) Evaluating Column Pair Trends: |██████████| 190/190 [00:00<00:00, 233.05it/s]|
Column Pair Trends Score: 69.73%

Overall Score (Average): 75.37%



In [23]:
'''column_plot = get_column_plot(real_data=clean_data, synthetic_data=synthetic_data, column_name="user_id", metadata=metadata)
column_pair_plot = get_column_pair_plot(real_data=clean_data, synthetic_data=synthetic_data, column_names=["numberMessageReceived", "numberMessageRead"], metadata=metadata)

column_plot.show()
column_pair_plot.show()'''

'column_plot = get_column_plot(real_data=clean_data, synthetic_data=synthetic_data, column_name="user_id", metadata=metadata)\ncolumn_pair_plot = get_column_pair_plot(real_data=clean_data, synthetic_data=synthetic_data, column_names=["numberMessageReceived", "numberMessageRead"], metadata=metadata)\n\ncolumn_plot.show()\ncolumn_pair_plot.show()'

In [24]:
synthetic_data

Unnamed: 0,serverTimestamp,day_part_x,user_id,numberRating,highestRating,lowestRating,medianRating,sdRating,numberLowRating,numberMediumRating,numberHighRating,numberMessageReceived,numberMessageRead,readAllMessage,reward,timestamp,day_part_y,action,message,day_part_numeric
0,2020-10-26,2,AAAAAACI,0,0,0,0.000000,0.000845,0,0,0,3,0,0,0.001023,1900-01-01 00:01:26.453324,evening,2.000000,Reminders may help you to not forget about the...,1.982616
1,2020-10-19,1,AAAAAAAU,0,0,0,0.006996,0.000056,0,0,0,2,0,0,0.005020,1900-01-01 00:01:29.008252,afternoon,0.999869,It is good that you take part in Moodbuster Li...,1.033070
2,2020-12-22,2,AAAAAABI,0,0,0,0.007277,0.000000,0,0,0,1,0,0,0.005019,1900-01-01 00:01:41.196897,morning,0.980500,No message was sent!,0.013339
3,2021-01-11,0,AAAAAAA9,0,0,0,0.004137,0.000000,0,0,0,1,0,0,0.000791,1900-01-01 00:01:09.622512,morning,0.994356,Good that you are still rating your mood on a ...,0.028673
4,2020-12-13,2,AAAAAAA0,0,0,0,0.000000,0.000000,0,0,0,3,2,0,0.003690,1900-01-01 00:01:13.876671,evening,1.998506,Reminders may help you to not forget about the...,2.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,2020-10-06,2,AAAAAAB6,2,0,5,5.626724,0.000000,0,2,0,3,3,1,0.002956,1900-01-01 00:01:12.459539,evening,1.999494,Don’t forget to set a reminder in order to not...,2.000000
96,2020-10-22,2,AAAAAACN,0,0,0,0.000000,0.000000,0,0,0,3,0,0,0.000647,1900-01-01 00:01:28.822169,evening,1.998167,Do you know you can always review the material...,2.000000
97,2020-12-17,2,AAAAAABD,0,0,0,0.014496,0.000000,0,0,0,3,0,0,0.000098,1900-01-01 00:01:36.730289,evening,1.997871,Did you forget what pleasant activities to do?...,1.981640
98,2020-10-20,0,AAAAAACP,0,0,0,0.000000,0.000000,0,0,0,1,0,0,0.001256,1900-01-01 00:01:26.805814,morning,0.992164,It’s great that you are making time for yourse...,0.032509
