In [14]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_curve, precision_score, recall_score, f1_score, auc
import matplotlib
import matplotlib.pyplot as plt
from ludwig.api import LudwigModel
import requests
import yaml
import json
from pathlib import Path
from typing import Dict, Any, Tuple
import plotly.graph_objects as go
import os
import shutil
import glob
from sdv.metadata import SingleTableMetadata


In [3]:
def read_data(data_location) -> pd.DataFrame:
    creditcard = pd.read_csv(data_location)
    return creditcard


In [53]:
# read data and isolate fraud transactions
df = read_data("G:/My Drive/Data-Centric Solutions/07. Blog Posts/kedro/data/creditcard.csv")

In [67]:
df['Class'].value_counts(0)

Class
0    284315
1       492
Name: count, dtype: int64

In [54]:
# add transaction ID
df['transaction_id'] = df.index
# df["data"] = "real data"

In [55]:
df_fraud = df.loc[df["Class"] == 1]
df_fraud.head(5)

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V22,V23,V24,V25,V26,V27,V28,Amount,Class,transaction_id
541,406.0,-2.312227,1.951992,-1.609851,3.997906,-0.522188,-1.426545,-2.537387,1.391657,-2.770089,...,-0.035049,-0.465211,0.320198,0.044519,0.17784,0.261145,-0.143276,0.0,1,541
623,472.0,-3.043541,-3.157307,1.088463,2.288644,1.359805,-1.064823,0.325574,-0.067794,-0.270953,...,0.435477,1.375966,-0.293803,0.279798,-0.145362,-0.252773,0.035764,529.0,1,623
4920,4462.0,-2.30335,1.759247,-0.359745,2.330243,-0.821628,-0.075788,0.56232,-0.399147,-0.238253,...,-0.932391,0.172726,-0.08733,-0.156114,-0.542628,0.039566,-0.153029,239.93,1,4920
6108,6986.0,-4.397974,1.358367,-2.592844,2.679787,-1.128131,-1.706536,-3.496197,-0.248778,-0.247768,...,0.176968,-0.436207,-0.053502,0.252405,-0.657488,-0.827136,0.849573,59.0,1,6108
6329,7519.0,1.234235,3.01974,-4.304597,4.732795,3.624201,-1.357746,1.713445,-0.496358,-1.282858,...,-0.704181,-0.656805,-1.632653,1.488901,0.566797,-0.010016,0.146793,1.0,1,6329


In [56]:
# detect metadata
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data=df_fraud)

metadata.update_column(
    column_name='transaction_id',
    sdtype='id')

# set primary key 
metadata.set_primary_key(column_name='transaction_id')
metadata.validate()
# code to inspect metadata
# python_dict = metadata.to_dict()
# python_dict

Detected metadata:
{
    "columns": {
        "Time": {
            "sdtype": "numerical"
        },
        "V1": {
            "sdtype": "numerical"
        },
        "V2": {
            "sdtype": "numerical"
        },
        "V3": {
            "sdtype": "numerical"
        },
        "V4": {
            "sdtype": "numerical"
        },
        "V5": {
            "sdtype": "numerical"
        },
        "V6": {
            "sdtype": "numerical"
        },
        "V7": {
            "sdtype": "numerical"
        },
        "V8": {
            "sdtype": "numerical"
        },
        "V9": {
            "sdtype": "numerical"
        },
        "V10": {
            "sdtype": "numerical"
        },
        "V11": {
            "sdtype": "numerical"
        },
        "V12": {
            "sdtype": "numerical"
        },
        "V13": {
            "sdtype": "numerical"
        },
        "V14": {
            "sdtype": "numerical"
        },
        "V15": {
            "sdtype": "

In [68]:
from sdv.single_table import TVAESynthesizer
import warnings

warnings.filterwarnings('ignore', category=FutureWarning)

synthesizer = TVAESynthesizer(
    metadata, # required
    enforce_min_max_values=True,
    enforce_rounding=False,
    epochs=1000,
    # missing_value_generation=False
)
synthesizer.fit(df_fraud)

synthetic_data = synthesizer.sample(num_rows=284315)
synthetic_data

Fitting table None metadata
Fitting formatters for table None
Fitting constraints for table None
Setting the configuration for the ``HyperTransformer`` for table None
Fitting HyperTransformer for table None
Guidance: There are no missing values in column Time. Extra column not created.
Guidance: There are no missing values in column V1. Extra column not created.
Guidance: There are no missing values in column V2. Extra column not created.
Guidance: There are no missing values in column V3. Extra column not created.
Guidance: There are no missing values in column V4. Extra column not created.
Guidance: There are no missing values in column V5. Extra column not created.
Guidance: There are no missing values in column V6. Extra column not created.
Guidance: There are no missing values in column V7. Extra column not created.
Guidance: There are no missing values in column V8. Extra column not created.
Guidance: There are no missing values in column V9. Extra column not created.
Guidance: T

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V22,V23,V24,V25,V26,V27,V28,Amount,Class,transaction_id
0,136488.787096,1.618500,2.173212,-1.565718,2.317882,0.404058,-1.866094,0.443151,0.770524,-1.569363,...,0.394648,0.311787,0.353490,0.163342,0.234732,0.329153,0.312123,0.000000,1,0
1,160603.720851,-2.102464,6.229663,-4.716104,3.457660,-1.106527,-2.527197,0.365203,-0.060808,-0.414211,...,0.531727,1.226812,-0.440880,-0.209805,0.375805,0.014289,0.141168,518.621036,1,1
2,140495.032974,-1.857006,1.387231,-5.701704,6.245074,1.312756,-3.974903,-2.590195,0.599971,-2.649374,...,0.462317,-0.162221,-0.250340,0.496197,0.569073,0.169167,0.341447,1.005218,1,2
3,90858.024568,2.132386,2.140755,-1.172046,5.477995,0.976377,-1.988365,1.051325,0.745753,-2.805199,...,0.639101,0.271760,-0.267855,-0.291630,0.833107,0.362174,-0.032101,0.000000,1,3
4,82350.702313,-0.539535,2.403030,-4.835757,3.746537,-2.500718,-3.032350,-2.421035,1.725791,-1.031977,...,-0.241521,0.190159,-0.586993,-0.589626,-0.361861,0.290281,-0.847919,0.000000,1,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
284310,43884.294500,0.178875,0.510346,-3.876855,5.424047,-4.161251,-2.061733,-2.356398,0.733864,-3.534016,...,-0.432307,-0.487183,-0.935083,0.038641,0.663172,0.768475,0.248638,19.756449,1,284310
284311,35449.643925,1.394540,3.516081,-5.914680,3.443014,-0.152192,-0.156487,-2.409597,1.305848,-3.190067,...,-0.567638,-0.052506,0.040408,0.489111,0.948033,0.586421,0.096384,0.000000,1,284311
284312,31579.127284,0.688387,2.890277,-3.662365,1.993986,-3.035947,-1.331333,-0.699600,0.889672,-3.213919,...,0.273415,-0.052719,0.584996,-0.004712,-0.375809,0.256899,0.043171,14.793215,1,284312
284313,66103.855638,-3.468717,2.922210,-3.049703,9.453860,-3.817322,-3.250291,-17.589726,0.469920,-6.223264,...,0.768849,-0.412001,-0.672110,0.083143,0.709613,2.030518,0.290361,0.000000,1,284313


In [61]:
synthetic_data.head(10)

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V22,V23,V24,V25,V26,V27,V28,Amount,Class,transaction_id
0,34860.104038,-1.47155,1.784722,0.26756,2.029218,-0.973223,-0.367268,-1.034723,-0.150202,-1.942132,...,0.545172,-0.308779,0.187879,0.054956,0.297517,0.525928,0.366039,0.0,1,0
1,49945.197603,-2.178765,6.040311,-15.489178,5.527469,-0.424573,-1.642288,-16.482189,9.121632,-3.982,...,-0.715013,-0.411088,-0.852509,1.48653,0.202372,0.928911,0.657786,11.505783,1,1
2,28959.202832,-1.033353,3.904716,-4.655551,2.804259,0.481845,-0.525394,0.587499,-0.33389,-1.062518,...,-1.169871,-0.438001,-0.313941,0.087025,0.336122,0.316062,0.25726,0.0,1,2
3,48100.121272,0.322934,3.151862,-2.82564,9.060824,1.491815,-1.557936,-1.100007,-0.135239,-2.794572,...,-0.567855,-0.075404,-0.806891,0.405359,0.528041,0.169199,0.473117,0.0,1,3
4,94940.779268,-4.820729,-4.869568,-4.774974,3.114344,10.26136,-2.444714,-2.58941,0.498777,-1.626008,...,-1.888999,-4.269953,-0.594741,-0.586769,0.039305,0.302669,-0.37809,0.0,1,4
5,149645.939555,-0.428399,2.835349,-2.981499,5.753949,1.243314,-3.262469,-4.573647,0.697825,-0.529181,...,0.064998,-0.805682,-0.535222,0.540199,0.469909,0.574244,0.28493,0.0,1,5
6,58158.275679,-5.554175,8.870599,-6.645761,3.814752,-11.605797,-3.087501,-4.890628,6.485213,-1.360038,...,-0.140615,-0.620214,0.009074,0.577353,-0.252924,1.47534,0.32947,100.102809,1,6
7,54144.908255,-0.490094,3.251503,-5.099711,5.318124,1.803051,-2.655521,-0.218959,-0.334152,-2.297105,...,-1.272912,-0.354569,-1.7175,0.503031,0.560964,0.376259,0.110344,0.0,1,7
8,67883.888495,-1.801738,1.385782,-3.131311,3.395037,-0.149663,-0.866081,-5.789967,0.186474,-1.791981,...,0.472092,1.070815,-0.296639,-0.087829,-0.545111,0.41548,0.083114,622.179995,1,8
9,52223.804827,-6.398552,0.839058,-6.096985,5.971092,0.185958,-1.557263,-6.567804,0.258361,-4.712058,...,-0.021412,-0.704549,-0.086496,0.63906,-0.376015,1.759029,0.613175,0.0,1,9


In [None]:
# Pipeline

def read_data(data_location) -> pd.DataFrame:
    creditcard = pd.read_csv(data_location)
    return creditcard


def split_data(creditcard: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:

    train_df, holdout_df = train_test_split(creditcard, test_size=0.2, random_state=42)
    
    return train_df, holdout_df


def run_experiment(train_df: pd.DataFrame, model_yaml, output_dir) -> pd.DataFrame:

    # Send a GET request to the URL
    response = requests.get(model_yaml)

    # Raise an exception if the request was unsuccessful
    response.raise_for_status()

    # Load the YAML data from the response text
    config = yaml.safe_load(response.text)

    # Set up your experiment
    model = LudwigModel(config=config)
    experiment_results = model.experiment(
      dataset=train_df,
      output_directory=output_dir
    )
    
    df = pd.DataFrame()
    
    delete_file()

    
    return exp_run


def run_predictions(holdout_df: pd.DataFrame, exp_run: pd.DataFrame, output_dir) -> pd.DataFrame:
    
    # dummpy input varibale
    df = exp_run
    
    latest_experiment_dir = get_latest_experiment_dir(output_dir)
    model_path = Path(latest_experiment_dir) / 'model'

    # Load the model
    model = LudwigModel.load(model_path)
    
    # run predictions on holdout
    predictions, _ = model.predict(dataset=holdout_df)
    
    full_predictions = predictions.merge(right=holdout_df,   left_index=True, right_index=True)
    full_predictions['Class_predictions'] = full_predictions['Class_predictions'].map({True: 1, False: 0})
    
    return full_predictions


def model_training_diagnostics(full_predictions: pd.DataFrame, output_dir) -> Tuple[matplotlib.figure.Figure, go.Figure]:
    
    # plot roc curve 
    
    # plot roc curve 
    fpr, tpr, thresholds = roc_curve(full_predictions['Class'], full_predictions['Class_predictions'])
    roc_auc = auc(fpr, tpr)

    # Create the base figure
    fig = go.Figure()

    # Add the ROC curve
    fig.add_trace(go.Scatter(x=fpr, y=tpr, mode='lines', name=f'ROC curve (area = {roc_auc:.2f})'))

    # Add the random guess line
    fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Random Guess', line=dict(dash='dash')))

    # Update the layout
    fig.update_layout(
        xaxis_title='False Positive Rate',
        yaxis_title='True Positive Rate',
        yaxis=dict(scaleanchor="x", scaleratio=1),
        xaxis=dict(constrain='domain'),
        width=700, height=700,
        title='Receiver Operating Characteristic'
    )

    roc_curve_plot = fig
        
    # plot loss curve
    latest_experiment_dir = get_latest_experiment_dir(output_dir)

    json_path = latest_experiment_dir + "/training_statistics.json"

    # Load the JSON file
    with open(json_path, 'r') as f:
        train_stats = json.load(f)

    train_loss = train_stats['training']['Class']['loss']
    validation_loss = train_stats['validation']['Class']['loss']
    test_loss = train_stats['test']['Class']['loss']

    # Create list of epochs
    epochs = list(range(1, len(train_loss) + 1))

    # Create the plot
    fig = go.Figure()

    # Add traces
    fig.add_trace(go.Scatter(x=epochs, y=train_loss, mode='lines', name='Training loss'))
    fig.add_trace(go.Scatter(x=epochs, y=validation_loss, mode='lines', name='Validation loss'))
    fig.add_trace(go.Scatter(x=epochs, y=test_loss, mode='lines', name='Test loss'))

    # Add details
    fig.update_layout(title='Training, Validation and Test Loss', xaxis_title='Epochs', yaxis_title='Loss')
    
    loss_plot = fig
    
    return loss_plot, roc_curve_plot