In [1]:
%%capture
!pip install xgboost smart_open pandas sklearn
!pip install gretel-synthetics --upgrade
!pip install tensorflow==2.9.0

## Using Synthetics to balance data with extreme class imbalance

Traditionally, techniques to work with imbalanced datasets focus on resampling.
In this notebook, we will examine a different approach, which instead of subsampling
to balance data (resulting in data loss), we will train a synthetic model, including
positive (fraudulent) examples and their nearest neighbors from the negative (non-fraudulent), 
and then use it to generate artificial records in the under-represented class.

This notebook accompanies our Medium post at: https://medium.com/gretel-ai/improve-fraud-detection-under-an-extreme-class-imbalance-with-synthetic-data-7dd3d856bbdf 


In [2]:
from smart_open import open
import pandas as pd
import numpy as np

# Load our sample dataset
# 10k records sampled from https://www.kaggle.com/mlg-ulb/creditcardfraud

training_set = 's3://gretel-public-website/datasets/creditcard_train.csv'
df = pd.read_csv(training_set, nrows=999999).round(6)
df

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15,V16,V17,V18,V19,V20,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
0,2235,-1.787204,0.013612,2.249977,2.763243,0.746002,3.336379,-0.915483,-1.154089,0.612705,1.158881,1.238839,1.320342,-0.313427,-1.169979,-1.499442,-1.867284,1.411472,-2.018529,-0.333217,-0.958321,1.299637,0.148985,-0.030648,-0.973078,-1.345001,-0.030742,-0.080531,-0.189799,98.59,0
1,4833,1.021000,-0.083708,1.055318,1.605248,-0.289435,1.108695,-0.685134,0.406702,1.973717,-0.488662,1.932661,-1.351542,0.775229,1.378726,-1.827262,-0.925200,1.237457,-0.509727,-0.533960,-0.285997,-0.157535,0.124962,-0.040400,-0.339393,0.464752,-0.282935,0.054502,0.002592,15.00,0
2,2108,-0.386097,1.115191,1.274662,0.008660,0.254300,-0.500077,0.672194,0.037168,-0.473125,-0.571201,0.093082,0.264732,0.585162,-0.497726,1.092386,0.067395,0.259014,-0.646916,-0.453188,0.116085,-0.229752,-0.534956,0.011962,0.035986,-0.216450,0.099894,0.264855,0.096023,4.56,0
3,3138,-1.045938,0.610737,0.358115,-0.203802,-0.645612,-0.605557,1.136754,0.375660,-0.226405,-0.871793,-0.658783,-0.179832,-1.246345,0.628470,-0.376542,-0.366958,0.301025,-0.854270,-1.017236,0.118375,0.218629,0.400002,0.435944,0.394657,-0.428569,0.302804,0.177020,0.181093,183.29,0
4,10044,-0.996416,0.041710,2.450331,-0.613888,-0.709753,-0.570799,-0.167753,0.025144,1.820153,-1.086424,0.181111,-3.174768,0.827891,1.207836,0.404479,1.066242,-0.094182,0.771189,-1.130399,-0.187603,0.145444,0.559481,-0.081808,0.389500,-0.557130,0.883625,-0.054558,0.127527,83.47,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7995,2440,-1.312804,1.669903,0.912969,0.889439,-0.286737,-0.763205,0.531421,-0.930255,-0.481063,0.326102,-0.048958,0.750987,1.187114,0.116962,0.881338,-0.451869,-0.038362,-0.156263,0.263371,-0.113889,0.923671,0.386139,-0.073646,0.753070,-0.230089,-0.400181,-0.201132,-0.052964,38.19,0
7996,10879,-0.418463,0.619981,2.263142,-0.105900,-0.185224,-0.360789,0.595857,-0.298291,0.968862,-0.787718,1.292230,-2.351918,2.053873,1.282558,0.908528,-0.484597,0.917204,-0.497691,0.192288,0.215590,-0.168063,-0.099271,0.032625,0.401172,-0.492768,0.923389,-0.159249,-0.126262,46.04,0
7997,681,1.120872,0.142425,0.782354,1.517495,-0.451665,-0.179163,-0.099553,0.011572,0.422085,-0.177543,-0.268722,0.916821,0.287166,-0.185884,-0.281333,-0.817230,0.373960,-1.007960,-0.523609,-0.167307,-0.004126,0.292506,-0.059899,0.434357,0.636963,-0.263692,0.057531,0.023309,9.99,0
7998,4956,1.300774,0.165218,-1.158493,-0.273887,2.206840,3.231120,-0.588503,0.722496,1.210062,-0.591265,1.079232,-2.534780,1.652524,1.323319,0.563430,0.784480,0.237039,0.441292,-0.030286,0.003743,-0.488931,-1.338442,0.100869,0.885454,0.344020,0.071279,-0.049894,0.018446,8.99,0


In [3]:
import numpy as np
from sklearn.neighbors import NearestNeighbors

# Separate out positive and negative records
positive = df[df['Class'] == 1]
negative = df[df['Class'] == 0]
print(f"Fraudulent records shape (rows, columns): {positive.shape}")
print(f"Non-fraudulent records shape (rows, columns): {negative.shape}")

# Train a nearest neighbor model on the negative dataset
neighbors = NearestNeighbors(n_neighbors=5, algorithm='ball_tree')
neighbors.fit(negative)

Fraudulent records shape (rows, columns): (31, 31)
Non-fraudulent records shape (rows, columns): (7969, 31)


NearestNeighbors(algorithm='ball_tree', leaf_size=30, metric='minkowski',
                 metric_params=None, n_jobs=None, n_neighbors=5, p=2,
                 radius=1.0)

In [4]:
# Locate the nearest neighbors to fraudulent records
# from the non-fraudulent set

MAX_NEIGHBORS = 5

nn = neighbors.kneighbors(positive, MAX_NEIGHBORS, return_distance=False)
nn_idx = list(set([item for sublist in nn for item in sublist]))
nearest_neighbors = negative.iloc[nn_idx, :]

print("Computed dataframe of nearest neighbors")
nearest_neighbors


Computed dataframe of nearest neighbors


Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15,V16,V17,V18,V19,V20,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
6678,11086,-1.085383,0.507854,0.904237,-1.156567,0.876260,0.263153,0.130753,0.170491,2.246222,-0.766036,-1.205004,-3.203415,0.609478,1.081984,-1.293146,-0.063452,0.415835,-0.259458,-0.311460,-0.023759,-0.304590,-0.304902,-0.297577,-1.333377,-0.066861,0.974165,0.224810,0.285037,1.60,0
4627,11119,1.168168,-0.186397,0.739630,0.283902,-0.617703,-0.059249,-0.516005,0.097901,1.892134,-0.586254,2.232486,-1.547079,0.762419,1.576431,-1.429804,-0.520132,1.063914,-0.381379,0.115907,-0.208079,-0.182320,-0.059071,0.005182,0.245466,0.267776,1.075531,-0.084632,-0.019960,3.57,0
4118,12613,0.797503,3.840937,-8.129344,5.180961,4.170041,1.643695,-0.287454,1.642456,-1.167485,-1.344935,2.585227,-4.220693,1.325096,-5.101458,0.227495,2.563968,7.413594,4.049234,-0.686911,0.100035,-0.157389,-0.360011,0.359804,0.259060,-0.819655,0.432642,-0.211101,-0.166596,1.00,0
4636,8619,1.129109,0.342468,1.663494,2.919192,-0.761252,0.183303,-0.676978,0.106030,1.457268,0.135100,-0.079504,-2.556549,1.274313,1.100595,-0.996644,0.611266,0.339163,0.123260,-1.041203,-0.233489,-0.241923,-0.313467,0.061480,0.296673,0.261252,-0.079955,0.013650,0.028282,0.00,0
1043,8805,-1.079056,0.953523,2.130665,0.033482,0.268279,0.465453,0.957703,-1.445607,3.560716,2.353309,0.143993,-3.376028,1.573371,-0.649317,0.496419,-1.167004,0.342842,0.342391,2.037970,0.959375,-0.428685,0.430396,-0.413251,-0.411012,-0.449524,1.161113,-1.188817,-0.848404,17.02,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2028,12078,-2.296963,2.129603,1.429219,2.233739,-0.866471,0.183034,-0.489894,0.569271,0.913899,0.229134,-0.408016,-2.267845,1.972766,1.276751,-1.436354,0.421778,0.651924,0.289051,-0.868988,-0.527016,0.253111,0.444265,-0.152910,0.363314,-0.227920,-0.076940,-1.359426,0.042595,22.81,0
3056,7508,-0.680392,0.771942,1.288085,2.027619,0.971212,-0.133492,0.180973,0.250592,0.686421,-0.603909,1.072059,-3.100767,-0.398053,0.540224,-2.727196,0.925790,0.906979,1.180030,-1.115507,-0.223080,-0.094699,-0.108767,-0.046942,-0.202751,-0.528246,-0.225556,0.129102,0.198958,7.58,0
7689,8816,1.268242,-0.045050,0.606489,-0.042602,-0.592296,-0.490880,-0.485831,-0.016197,1.606747,-0.381650,1.942281,-2.277953,0.661623,1.956090,-0.043769,0.719422,0.060161,0.589146,0.246064,-0.158693,-0.267566,-0.595112,0.070135,-0.024782,0.072075,0.872417,-0.108545,-0.013005,0.92,0
7176,12599,1.338234,0.211595,0.018598,0.663102,-0.051332,-0.593435,-0.074188,-0.182849,1.621555,-0.341634,-0.426270,-3.537989,0.097667,2.063575,0.449892,0.427537,0.172813,0.428520,-0.075278,-0.240762,-0.228316,-0.447310,-0.133073,-0.524905,0.558941,0.417162,-0.080218,-0.009743,1.50,0


In [5]:
# Build our training dataset

# The neural network needs a lot of samples to learn the columnar structure,
# so we repeat the dataset.

oversample = pd.concat([positive] * 5)
training_set = pd.concat([oversample, nearest_neighbors] * 25).sample(frac=1)
print("Created synthetic dataset from positive (fraudulent) and nearest neighbors (non-fraudulent)")
training_set

Created synthetic dataset from positive (fraudulent) and nearest neighbors (non-fraudulent)


Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15,V16,V17,V18,V19,V20,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
4145,8154,1.139394,0.460056,0.887066,1.788764,-0.103472,0.052514,-0.217782,0.017468,0.347879,0.294000,2.703087,-1.425727,2.218396,1.733381,-0.690611,0.546343,0.280375,-0.330058,-1.062582,-0.140148,-0.146679,-0.180720,0.104411,0.221911,0.102892,0.836159,-0.087110,-0.007068,0.01,0
6822,7610,0.725646,2.300894,-5.329976,4.007683,-1.730411,-1.732193,-3.968593,1.063728,-0.486097,-4.624985,5.588724,-7.148243,1.680451,-6.210258,0.495282,-3.599540,-4.830324,-0.649090,2.250123,0.504646,0.589669,0.109541,0.601045,-0.364700,-1.843078,0.351909,0.594550,0.099372,1.00,1
4746,406,-2.312227,1.951992,-1.609851,3.997906,-0.522188,-1.426545,-2.537387,1.391657,-2.770089,-2.772272,3.202033,-2.899907,-0.595222,-4.289254,0.389724,-1.140747,-2.830056,-0.016822,0.416956,0.126911,0.517232,-0.035049,-0.465211,0.320198,0.044519,0.177840,0.261145,-0.143276,0.00,1
1685,13323,-5.454362,8.287421,-12.752811,8.594342,-3.106002,-3.179949,-9.252794,4.245062,-6.329801,-13.136698,11.228470,-17.131301,-0.169401,-18.049998,-1.366236,-9.723565,-14.744902,-5.247301,-0.574675,1.305862,1.846165,-0.267172,-0.310804,-1.201685,1.352176,0.608425,1.574715,0.808725,1.00,1
2458,11092,0.378275,3.914797,-5.726872,6.094141,1.698875,-2.807314,-0.591118,-0.123496,-2.530713,-5.153095,4.654088,-7.839539,1.371819,-9.634690,-0.739597,-0.663204,0.891935,0.978676,-2.005477,0.440439,0.149896,-0.601967,-0.613724,-0.403114,1.568445,0.521884,0.527938,0.411910,1.00,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7814,14073,-4.153014,8.204797,-15.031714,10.330100,-3.994426,-3.250013,-10.415698,4.620804,-5.711248,-11.797181,11.277921,-16.728339,0.241368,-17.721638,-0.387300,-10.322017,-13.959085,-5.030710,1.197266,1.412625,1.976988,0.256510,0.485908,-1.198821,-0.526567,0.634874,1.627209,0.723235,1.00,1
6822,7610,0.725646,2.300894,-5.329976,4.007683,-1.730411,-1.732193,-3.968593,1.063728,-0.486097,-4.624985,5.588724,-7.148243,1.680451,-6.210258,0.495282,-3.599540,-4.830324,-0.649090,2.250123,0.504646,0.589669,0.109541,0.601045,-0.364700,-1.843078,0.351909,0.594550,0.099372,1.00,1
6409,7535,0.026779,4.132464,-6.560600,6.348557,1.329666,-2.513479,-1.689102,0.303253,-3.139409,-6.045468,6.754625,-8.948179,0.702725,-10.733854,-1.379520,-1.638960,-1.746350,0.776744,-1.327357,0.587743,0.370509,-0.576752,-0.669605,-0.759908,1.605056,0.540675,0.737040,0.496699,1.00,1
3691,8169,0.857321,4.093912,-7.423894,7.380245,0.973366,-2.730762,-1.496497,0.543015,-2.351190,-3.944238,6.355078,-7.309748,0.748451,-9.057993,-0.648945,-1.073117,1.524501,1.831364,-0.089724,0.483303,0.375026,0.145400,0.240603,-0.234649,-1.004881,0.435832,0.618324,0.148469,1.00,1


In [6]:
# Train the synthetic model on our DataFrame

from gretel_synthetics.batch import DataFrameBatch
from pathlib import Path

!rm -Rf ./checkpoints/*

config_template = {
    "max_lines": 0,
    "max_line_len": 2048,
    "epochs": 10,
    "vocab_size": 20000,
    "gen_lines": 1000,
    "dp": False,
    "field_delimiter": ",",
    "overwrite": True,
    "checkpoint_dir": str(Path.cwd() / "checkpoints")
}

batcher = DataFrameBatch(df=training_set, batch_size=32, config=config_template)
batcher.create_training_data()
batcher.train_all_batches()

2020-07-23 19:55:35,328 : MainThread : INFO : Creating directory structure for batch jobs...
2020-07-23 19:55:35,331 : MainThread : INFO : Generating training DF and CSV for batch 0
2020-07-23 19:55:35,595 : MainThread : INFO : Loading training data from /content/checkpoints/batch_0/train.csv
2020-07-23 19:55:35,613 : MainThread : INFO : Storing annotations to training_data.txt
2020-07-23 19:55:35,627 : MainThread : INFO : Dataset size: 7050 lines, 2392250 characters
2020-07-23 19:55:35,628 : MainThread : INFO : Training SentencePiece tokenizer
2020-07-23 19:55:36,630 : MainThread : INFO : Loading tokenizer from: m.model
2020-07-23 19:55:36,639 : MainThread : INFO : Tokenizer model vocabulary size: 6904 tokens
2020-07-23 19:55:36,646 : MainThread : INFO : Mapping first line of training data

'8154<d>1.139394<d>0.460056<d>0.887066<d>1.788764<d>-0.103472<d>0.052514<d>-0.217782<d>0.017468<d>0.347879<d>0.294<d>2.703087<d>-1.425727<d>2.218396<d>1.733381<d>-0.690611<d>0.546343<d>0.280375<d>-

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


2020-07-23 19:57:10,976 : MainThread : INFO : Saving model history to model_history.csv
2020-07-23 19:57:10,979 : MainThread : INFO : Saving model history to model_params.json
2020-07-23 19:57:10,984 : MainThread : INFO : Saving model to /content/checkpoints/batch_0/synthetic-10


In [7]:
# Now generate synthetic records.
# Custom record validator
def validate_record(line):
    rec = line.split(",")
    rows = training_set.shape[1]
    if len(rec) == rows:
        assert int(rec[0])
        assert int(rec[30])
        assert [float(rec[x]) for x in range(1,30)]
    else:
        raise Exception(f'record is {len(rec)} columns, not {rows} as expected')


batcher.set_batch_validator(0, validator=validate_record)
status = batcher.generate_all_batch_lines(max_invalid=5000)

HBox(children=(FloatProgress(value=0.0, description='Valid record count ', max=1000.0, style=ProgressStyle(des…

HBox(children=(FloatProgress(value=0.0, description='Invalid record count ', max=5000.0, style=ProgressStyle(d…

2020-07-23 19:57:11,082 : MainThread : INFO : Latest checkpoint: /content/checkpoints/batch_0/synthetic-10
2020-07-23 19:57:11,083 : MainThread : INFO : Loading SentencePiece tokenizer


Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (1, None, 256)            1767424   
_________________________________________________________________
dropout_3 (Dropout)          (1, None, 256)            0         
_________________________________________________________________
lstm_2 (LSTM)                (1, None, 256)            525312    
_________________________________________________________________
dropout_4 (Dropout)          (1, None, 256)            0         
_________________________________________________________________
lstm_3 (LSTM)                (1, None, 256)            525312    
_________________________________________________________________
dropout_5 (Dropout)          (1, None, 256)            0         
_________________________________________________________________
dense_1 (Dense)              (1, None, 6904)          

In [8]:
# Convert generated records to a dataframe,
# only keep synthetic "Fraudulent" records generated by the model (Class == 1)

df_synthetic = batcher.batches_to_df()
df_synthetic = df_synthetic[df_synthetic['Class'] == 1]
df_synthetic

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15,V16,V17,V18,V19,V20,V21,V22,V23,V24,V25,V26,V27,V28,Amount,Class
0,14152,-4.710529,8.636214,-15.496222,10.313349,-4.351341,-3.322689,-10.788373,5.060381,-5.689311,-11.712187,11.152491,-16.558197,0.302645,-17.475921,-0.412393,-10.222203,-13.799148,-5.008585,1.162026,1.434240,1.990545,0.223785,0.554408,-1.204042,-0.450685,0.641836,1.605958,0.721644,1.0,1
1,8408,-1.813280,4.917851,-5.926130,5.701500,1.204393,-3.035138,-1.713402,0.561257,-3.796354,-7.454841,7.388055,-10.475229,-0.379315,-11.736729,-2.086989,-2.442354,-3.535524,0.130360,-2.071450,0.576656,0.615642,-0.406427,-0.737018,-0.279642,1.106766,0.323885,0.894767,0.569519,1.0,1
2,8090,-1.783229,3.402794,-3.822742,2.625368,-1.976415,-2.731689,-3.430559,1.413204,-0.776941,-6.199882,4.366713,-8.243262,0.345761,-6.590550,0.265576,-3.028452,-4.214486,-1.213608,-0.265422,0.364089,0.454032,-0.577526,0.045967,0.461700,0.044146,0.305704,0.530981,0.243746,1.0,1
3,8090,-1.783229,3.402794,-3.822742,2.625368,-1.976415,-2.731689,-3.430559,1.413204,-0.776941,-6.199882,4.366713,-8.243262,0.345761,-6.590550,0.265576,-3.028452,-4.214486,-1.213608,-0.265422,0.364089,0.454032,-0.577526,0.045967,0.461700,0.044146,0.305704,0.530981,0.243746,1.0,1
4,8757,-1.863756,3.442644,-4.468260,2.805336,-2.118412,-2.332285,-4.261237,1.701682,-1.439396,-6.999907,6.316210,-8.670818,0.316024,-7.417712,-0.436537,-3.652802,-6.293145,-1.243248,0.364810,0.360924,0.667927,-0.516242,-0.012218,0.070614,0.058504,0.304883,0.418012,0.208858,1.0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,7526,0.008430,4.137837,-6.240697,6.675732,0.768307,-3.353060,-1.631735,0.154612,-2.795892,-6.187891,5.664395,-9.854485,-0.306167,-10.691196,-0.638498,-2.041974,-1.129056,0.116453,-1.934666,0.488378,0.364514,-0.608097,-0.539528,0.128940,1.488481,0.507963,0.735822,0.513574,1.0,1
996,8757,-1.863756,3.442644,-4.468260,2.805336,-2.118412,-2.332285,-4.261237,1.701682,-1.439396,-6.999907,6.316210,-8.670818,0.316024,-7.417712,-0.436537,-3.652802,-6.293145,-1.243248,0.364810,0.360924,0.667927,-0.516242,-0.012218,0.070614,0.058504,0.304883,0.418012,0.208858,1.0,1
997,8757,-1.863756,3.442644,-4.468260,2.805336,-2.118412,-2.332285,-4.261237,1.701682,-1.439396,-6.999907,6.316210,-8.670818,0.316024,-7.417712,-0.436537,-3.652802,-6.293145,-1.243248,0.364810,0.360924,0.667927,-0.516242,-0.012218,0.070614,0.058504,0.304883,0.418012,0.208858,1.0,1
998,14152,-4.710529,8.636214,-15.496222,10.313349,-4.351341,-3.322689,-10.788373,5.060381,-5.689311,-11.712187,11.152491,-16.558197,0.302645,-17.475921,-0.412393,-10.222203,-13.799148,-5.008585,1.162026,1.434240,1.990545,0.223785,0.554408,-1.204042,-0.450685,0.641836,1.605958,0.721644,1.0,1


In [9]:
# Save synthetic dataframe to disk or S3
#df_synthetic.to_csv('s3://[YOUR BUCKET HERE]/datasets/creditcard_synthetic.csv', index=False)
df_synthetic.to_csv('creditcard_synthetic.csv', index=False)