##### Imports

In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from sdv.single_table import CTGANSynthesizer
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality, get_column_plot, get_column_pair_plot
from sdv.metadata import SingleTableMetadata

##### Load dataset

In [2]:
data = pd.read_csv("data/merged_df_file_rutgers.csv")

# Sort data by user_id, timestamp, and day_part_x
data = data.sort_values(by=['user_id', 'timestamp', 'day_part_x']) # there is a mismatch between timestamp and day_part_x

In [3]:
data.head()

Unnamed: 0,serverTimestamp,day_part_x,user_id,numberRating,highestRating,lowestRating,medianRating,sdRating,numberLowRating,numberMediumRating,numberHighRating,numberMessageReceived,numberMessageRead,readAllMessage,reward,timestamp,day_part_y,action,message,day_part_numeric
16,10/11/2020,1,1BIGBILHOT,0,0,0,0.0,0.0,0,0,0,1,0,0,0.0,01:02.9,afternoon,0.0,No message was sent!,1.0
0,10/6/2020,0,1BIGBILHOT,0,0,0,0.0,0.0,0,0,0,0,0,0,0.0,01:03.0,morning,0.0,No message was sent!,0.0
10,10/9/2020,1,1BIGBILHOT,1,4,4,4.0,0.0,0,1,0,1,0,0,0.5,01:03.2,afternoon,0.0,No message was sent!,1.0
3,10/7/2020,0,1BIGBILHOT,0,0,0,0.0,0.0,0,0,0,1,1,1,0.5,01:03.6,morning,3.0,Even if you don’t rate your mood at some point...,0.0
9,10/9/2020,0,1BIGBILHOT,1,4,4,4.0,0.0,0,1,0,1,0,0,0.5,01:03.6,morning,2.0,Don’t forget to set a reminder in order to not...,0.0


In [4]:
data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 9564 entries, 16 to 9563
Data columns (total 20 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   serverTimestamp        9564 non-null   object 
 1   day_part_x             9564 non-null   int64  
 2   user_id                9564 non-null   object 
 3   numberRating           9564 non-null   int64  
 4   highestRating          9564 non-null   int64  
 5   lowestRating           9564 non-null   int64  
 6   medianRating           9564 non-null   float64
 7   sdRating               9564 non-null   float64
 8   numberLowRating        9564 non-null   int64  
 9   numberMediumRating     9564 non-null   int64  
 10  numberHighRating       9564 non-null   int64  
 11  numberMessageReceived  9564 non-null   int64  
 12  numberMessageRead      9564 non-null   int64  
 13  readAllMessage         9564 non-null   int64  
 14  reward                 9564 non-null   float64
 15  timestam

##### Data Cleaning (will drop any rows with NAs)

In [5]:
def process_csv(df):

    df = df.dropna()
    '''for col in df.columns:
        if pd.api.types.is_numeric_dtype(df[col]): # Handle numerical missing values (mean imputation)
            if df[col].isnull().any():
                df[col].fillna(df[col].mean(), inplace=True)
        else: # Handle missing values in categorical or other object columns
            if pd.api.types.is_object_dtype(df[col]):
                if df[col].isnull().any():
                    df[col].fillna(df[col].mode()[0], inplace=True)  # Impute with mode for categorical data'''
    return df

clean_data = process_csv(data)

In [6]:
clean_data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 8764 entries, 16 to 9285
Data columns (total 20 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   serverTimestamp        8764 non-null   object 
 1   day_part_x             8764 non-null   int64  
 2   user_id                8764 non-null   object 
 3   numberRating           8764 non-null   int64  
 4   highestRating          8764 non-null   int64  
 5   lowestRating           8764 non-null   int64  
 6   medianRating           8764 non-null   float64
 7   sdRating               8764 non-null   float64
 8   numberLowRating        8764 non-null   int64  
 9   numberMediumRating     8764 non-null   int64  
 10  numberHighRating       8764 non-null   int64  
 11  numberMessageReceived  8764 non-null   int64  
 12  numberMessageRead      8764 non-null   int64  
 13  readAllMessage         8764 non-null   int64  
 14  reward                 8764 non-null   float64
 15  timestam

##### Check if constraints applied on real data

In [7]:
clean_data = clean_data[(clean_data['numberRating'] >= 0) & (clean_data['numberRating'] <= 3)]
clean_data = clean_data[(clean_data['numberMessageReceived'] >= 0) & (clean_data['numberMessageReceived'] <= 3)]
clean_data = clean_data[(clean_data['numberMessageRead'] >= 0) & (clean_data['numberMessageRead'] <= 3)]
clean_data = clean_data[(clean_data['reward'] <= 2)]

In [8]:
# When generating at day_part_x  = 0, columns numberRating, numberMessageReceived and numberMessageRead have a max of 1, 
# for day_part_x  = 1 the max is 2 and for day_part_x  = 2 the max is 3.

def day_part_no_check(df):
    conditions = {
        0: {'numberRating': 1, 'numberMessageReceived': 1, 'numberMessageRead': 1},
        1: {'numberRating': 2, 'numberMessageReceived': 2, 'numberMessageRead': 2},
        2: {'numberRating': 3, 'numberMessageReceived': 3, 'numberMessageRead': 3}
    }
    
    for day_part, max_values in conditions.items():
        for col, max_value in max_values.items():
            df = df[~((df['day_part_x'] == day_part) & (df[col] > max_value))]
    
    return df

# Filter the DataFrame
clean_data = day_part_no_check(clean_data)

In [9]:
# Action can only take 0, 1, 2, same applies to day_part_x.

def check_action_daypart(df):
    # Define conditions
    condition_action = df['action'].isin([0, 1, 2])
    condition_day_part_x = df['day_part_x'].isin([0, 1, 2])
    
    # Filter DataFrame based on conditions
    df_filtered = df[condition_action & condition_day_part_x]
    
    return df_filtered

# Call the function
clean_data = check_action_daypart(clean_data)

In [10]:
clean_data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 8172 entries, 16 to 9552
Data columns (total 20 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   serverTimestamp        8172 non-null   object 
 1   day_part_x             8172 non-null   int64  
 2   user_id                8172 non-null   object 
 3   numberRating           8172 non-null   int64  
 4   highestRating          8172 non-null   int64  
 5   lowestRating           8172 non-null   int64  
 6   medianRating           8172 non-null   float64
 7   sdRating               8172 non-null   float64
 8   numberLowRating        8172 non-null   int64  
 9   numberMediumRating     8172 non-null   int64  
 10  numberHighRating       8172 non-null   int64  
 11  numberMessageReceived  8172 non-null   int64  
 12  numberMessageRead      8172 non-null   int64  
 13  readAllMessage         8172 non-null   int64  
 14  reward                 8172 non-null   float64
 15  timestam

##### Metadata Extraction

In [11]:
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(clean_data)

In [12]:
metadata.update_columns(
    column_names=['numberRating', 'numberMessageReceived', 'numberMessageRead', 'highestRating', 'lowestRating', 
                  'numberLowRating', 'numberMediumRating', 'numberHighRating', 'readAllMessage', 'day_part_x', 'day_part_numeric', 'action'],
    sdtype='numerical',
    computer_representation='Int64'
)

metadata.update_columns(
    column_names=['reward', 'medianRating', 'sdRating'],
    sdtype='numerical',
    computer_representation='Float'
)

metadata.update_column(
    column_name='serverTimestamp',
    sdtype='datetime',
    datetime_format='%m/%d/%Y'
)

metadata.update_column(
    column_name='timestamp',
    sdtype='datetime',
    datetime_format='%M:%S.%f'
)

metadata.update_columns(
    column_names=['day_part_y', 'message'],
    sdtype='categorical'
)

metadata.update_column(
    column_name='user_id',
    sdtype='id',
    regex_format="[A-Z0-9]{8,10}"
)

##### Model Initialization

In [13]:
# Initialize the model

ctgan_synthesizer = CTGANSynthesizer(
    metadata=metadata,
    enforce_rounding=True,
    epochs=1000,
    verbose=True)




In [14]:
from sdv.single_table import CopulaGANSynthesizer

cop_synthesizer = CopulaGANSynthesizer(
    metadata=metadata,
    enforce_rounding=True,
    epochs=1000,
    verbose=True)

In [15]:
from sdv.single_table import TVAESynthesizer

tvae_synthesizer = TVAESynthesizer(
    metadata=metadata,
    enforce_rounding=True,
    epochs=1000)

In [16]:
from sdv.single_table import GaussianCopulaSynthesizer

gauss_synthesizer = GaussianCopulaSynthesizer(
    metadata=metadata,
    enforce_rounding=True,
    default_distribution='norm' # no idea if this is the case
)

##### Define Constraints

In [17]:
# A user can’t input more than 3 ratings per day and can’t read/receive more than 3 messages per day. 
# Thus, rows where columns numberRating, numberMessageReceived and numberMessageRead have a value > 3 can be capped to 3.

nR_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'numberRating',
        'low_value': 0,
        'high_value': 3,
        'strict_boundaries': False
    }
}

nMRc_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'numberMessageReceived',
        'low_value': 0,
        'high_value': 3,
        'strict_boundaries': False
    }
}

nMRd_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'numberMessageRead',
        'low_value': 0,
        'high_value': 3,
        'strict_boundaries': False
    }
}

# Reward can’t be higher than 2.

rwd_constraint = {
    'constraint_class': 'ScalarInequality',
    'constraint_parameters': {
        'column_name': 'reward',
        'relation': '<=',
        'value':  2.0
    }
}

act_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'action',
        'low_value': 0,
        'high_value': 2,
        'strict_boundaries': False
    }
}

daypart_constraint = {
    'constraint_class': 'ScalarRange',
    'constraint_parameters': {
        'column_name': 'day_part_x',
        'low_value': 0,
        'high_value': 2,
        'strict_boundaries': False
    }
}

In [18]:
# load the constraint from the file

ctgan_synthesizer.load_custom_constraint_classes(
    filepath='data\day_part_logic.py',
    class_names=['day_partConstraintClass']
)

cop_synthesizer.load_custom_constraint_classes(
    filepath='data\day_part_logic.py',
    class_names=['day_partConstraintClass']
)

tvae_synthesizer.load_custom_constraint_classes(
    filepath='data\day_part_logic.py',
    class_names=['day_partConstraintClass']
)

gauss_synthesizer.load_custom_constraint_classes(
    filepath='data\day_part_logic.py',
    class_names=['day_partConstraintClass']
)

In [19]:
day_partConstraint= {
    'constraint_class': 'day_partConstraintClass',
    'constraint_parameters': {
        'column_names': ['day_part_x', 'numberRating','numberMessageReceived','numberMessageRead']
    }
}

In [20]:
ctgan_synthesizer.add_constraints(
    constraints=[nR_constraint, nMRc_constraint, nMRd_constraint, day_partConstraint, day_partConstraint, act_constraint]
)

In [21]:
tvae_synthesizer.add_constraints(
    constraints=[nR_constraint, nMRc_constraint, nMRd_constraint, day_partConstraint, day_partConstraint, act_constraint]
)

In [22]:
gauss_synthesizer.add_constraints(
    constraints=[nR_constraint, nMRc_constraint, nMRd_constraint, day_partConstraint, day_partConstraint, act_constraint]
)

In [23]:
cop_synthesizer.add_constraints(
    constraints=[nR_constraint, nMRc_constraint, nMRd_constraint, day_partConstraint, day_partConstraint, act_constraint]
)

##### Model Training

In [24]:
# Fit the synthesizer to real data
ctgan_synthesizer.fit(clean_data)

  from .autonotebook import tqdm as notebook_tqdm
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Gen. (-0.14) | Discrim. (0.06): 100%|██████████| 1000/1000 [21:17<00:00,  1.28s/it]


In [25]:
# Generate synthetic data
ctgan_synthetic_data = ctgan_synthesizer.sample(100)

Sampling rows: 100%|██████████| 100/100 [00:00<00:00, 406.41it/s]


In [26]:
# Save synthetic data to CSV file
os.makedirs(os.path.dirname("data/ctgan_data.csv"), exist_ok=True)
ctgan_synthetic_data.to_csv("data/ctgan_data.csv", index=False)

In [27]:
# Fit the synthesizer to real data
gauss_synthesizer.fit(clean_data)

In [28]:
# Generate synthetic data
gauss_synthetic_data = gauss_synthesizer.sample(100)

Sampling rows: 100%|██████████| 100/100 [00:00<00:00, 760.32it/s]


In [29]:
# Save synthetic data to CSV file
os.makedirs(os.path.dirname("data/gauss_data.csv"), exist_ok=True)
gauss_synthetic_data.to_csv("data/gauss_data.csv", index=False)

In [30]:
# Fit the synthesizer to real data
tvae_synthesizer.fit(clean_data)

In [31]:
# Generate synthetic data
tvae_synthetic_data = tvae_synthesizer.sample(100)

  0.26844807  0.10007719  0.1207542   0.19553155  0.05437207  0.20476521
 -0.13554234  0.07110971  0.00190919 -0.10666125  0.44464896  0.05561612
  0.08888733 -0.23364693 -0.09261802 -0.26220827  0.07131798  0.13479111
 -0.14908821  0.00701614  0.07594368 -0.31570651 -0.02036282 -0.10273044
 -0.04083713 -0.08224043  0.22014695 -0.1253006   0.06436819  0.28799836
  0.2138711   0.06561034 -0.05363421  0.11051807 -0.25784844  0.09967976
  0.36392871  0.08618446 -0.06906532  0.02146169  0.12922461  0.16130234
  0.17501859  0.2507657  -0.29074641  0.16667663 -0.24107238  0.18836035
 -0.33968588 -0.0023842   0.08267678  0.10351596 -0.03195605  0.20180599
 -0.21605361  0.45413815  0.16556976 -0.35382471  0.31515741 -0.14171086
  0.47839106  0.4037034   0.19566851  0.56138071  0.36962822 -0.30888235
  0.43291852 -0.03166009  0.10960221 -0.20016758 -0.43994384 -0.05713774
 -0.26383254  0.46924447  0.12642546  0.1232959   0.46496566 -0.13151061
 -0.0950113  -0.0799969  -0.28610891  0.06204304 -0

In [32]:
# Save synthetic data to CSV file
os.makedirs(os.path.dirname("data/tvae_data.csv"), exist_ok=True)
tvae_synthetic_data.to_csv("data/tvae_data.csv", index=False)

In [33]:
# Fit the synthesizer to real data
cop_synthesizer.fit(clean_data)

Gen. (-0.02) | Discrim. (-0.34): 100%|██████████| 1000/1000 [14:41<00:00,  1.13it/s]


In [34]:
cop_synthetic_data = cop_synthesizer.sample(100)

Sampling rows: 100%|██████████| 100/100 [00:00<00:00, 323.26it/s]


In [35]:
# Save synthetic data to CSV file
os.makedirs(os.path.dirname("data/cop_synthetic_data.csv"), exist_ok=True)
cop_synthetic_data.to_csv("data/cop_synthetic_data.csv", index=False)

##### Evaluation

In [36]:
# Run diagnostic and evaluate quality
diagnostic = run_diagnostic(real_data=clean_data, synthetic_data=ctgan_synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Data Validity: |██████████| 20/20 [00:00<00:00, 1052.58it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 500.10it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%



In [37]:
quality_report = evaluate_quality(real_data=clean_data, synthetic_data=ctgan_synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 20/20 [00:00<00:00, 579.25it/s]|
Column Shapes Score: 91.32%

(2/2) Evaluating Column Pair Trends: |██████████| 190/190 [00:01<00:00, 149.60it/s]|
Column Pair Trends Score: 73.15%

Overall Score (Average): 82.23%



In [38]:
# Run diagnostic and evaluate quality
diagnostic = run_diagnostic(real_data=clean_data, synthetic_data=tvae_synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Data Validity: |██████████| 20/20 [00:00<00:00, 1250.07it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 1000.07it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%



In [39]:
quality_report = evaluate_quality(real_data=clean_data, synthetic_data=tvae_synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 20/20 [00:00<00:00, 487.82it/s]|
Column Shapes Score: 91.51%

(2/2) Evaluating Column Pair Trends: |██████████| 190/190 [00:01<00:00, 152.43it/s]|
Column Pair Trends Score: 69.25%

Overall Score (Average): 80.38%



In [40]:
# Run diagnostic and evaluate quality
diagnostic = run_diagnostic(real_data=clean_data, synthetic_data=gauss_synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Data Validity: |██████████| 20/20 [00:00<00:00, 1176.41it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 996.98it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%



In [41]:
quality_report = evaluate_quality(real_data=clean_data, synthetic_data=gauss_synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 20/20 [00:00<00:00, 481.56it/s]|
Column Shapes Score: 82.21%

(2/2) Evaluating Column Pair Trends: |██████████| 190/190 [00:01<00:00, 143.19it/s]|
Column Pair Trends Score: 69.98%

Overall Score (Average): 76.09%



In [42]:
# Run diagnostic and evaluate quality
diagnostic = run_diagnostic(real_data=clean_data, synthetic_data=cop_synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Data Validity: |██████████| 20/20 [00:00<00:00, 952.23it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 333.15it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%



In [43]:
quality_report = evaluate_quality(real_data=clean_data, synthetic_data=cop_synthetic_data, metadata=metadata)

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 20/20 [00:00<00:00, 512.85it/s]|
Column Shapes Score: 83.74%

(2/2) Evaluating Column Pair Trends: |██████████| 190/190 [00:01<00:00, 126.25it/s]|
Column Pair Trends Score: 75.73%

Overall Score (Average): 79.73%

