In [20]:
!pip install ydata-synthetic==1.1.0




In [21]:
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
import pandas as pd
import io
import os
import requests
import numpy as np
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()

#Load data and define the data processor parameters
df = pd.read_csv('https://raw.githubusercontent.com/mahayasa/various-sampling-churn-prediction/main/data/mobile-churn.csv')
#filling missing value with mean
df = df.fillna(df.mean())
df=df.drop(['year','month','user_account_id'],axis=1)

df_numerical_features = df.select_dtypes(exclude='object')
df_categorical_features = df.select_dtypes(include='object')

num_cols = list(df_numerical_features)
cat_cols = list(df_categorical_features)


In [22]:
train_data = df.loc[ df['churn']==1 ].copy()

In [23]:


#Defining the training parameters
noise_dim = 128
dim = 128
batch_size = 500

log_step = 100
epochs = 100
learning_rate = [5e-4, 3e-3]
beta_1 = 0.5
beta_2 = 0.9
models_dir = '../cache'

gan_args = ModelParameters(batch_size=batch_size,
                           lr=learning_rate,
                           betas=(beta_1, beta_2),
                           noise_dim=noise_dim,
                           layers_dim=dim)

train_args = TrainParameters(epochs=epochs,
                             sample_interval=log_step)

synth = RegularSynthesizer(modelname='wgangp', model_parameters=gan_args, n_critic=2)
synth.fit(train_data, train_args, num_cols, cat_cols)

synth.save('adult_wgangp_model.pkl')



  1%|          | 1/100 [00:18<30:29, 18.47s/it]

Epoch: 0 | disc_loss: 0.04222998768091202 | gen_loss: 0.051663581281900406


  2%|▏         | 2/100 [00:24<18:26, 11.29s/it]

Epoch: 1 | disc_loss: 10.701451301574707 | gen_loss: -0.09393060207366943


  3%|▎         | 3/100 [00:28<12:33,  7.76s/it]

Epoch: 2 | disc_loss: 0.6659924387931824 | gen_loss: -0.03932418301701546


  4%|▍         | 4/100 [00:31<09:46,  6.11s/it]

Epoch: 3 | disc_loss: 0.5305032134056091 | gen_loss: -0.03453649580478668


  5%|▌         | 5/100 [00:36<08:44,  5.52s/it]

Epoch: 4 | disc_loss: 0.019331417977809906 | gen_loss: -0.02535809762775898


  6%|▌         | 6/100 [00:42<09:13,  5.89s/it]

Epoch: 5 | disc_loss: 0.5329607725143433 | gen_loss: 0.01659471169114113


  7%|▋         | 7/100 [00:46<08:08,  5.25s/it]

Epoch: 6 | disc_loss: 0.09803629666566849 | gen_loss: 0.018322298303246498


  8%|▊         | 8/100 [00:50<07:13,  4.71s/it]

Epoch: 7 | disc_loss: 0.7291606664657593 | gen_loss: 0.00613913731649518


  9%|▉         | 9/100 [00:54<06:42,  4.42s/it]

Epoch: 8 | disc_loss: 0.09260205179452896 | gen_loss: -0.10873964428901672


 10%|█         | 10/100 [01:01<07:47,  5.19s/it]

Epoch: 9 | disc_loss: 0.1428181529045105 | gen_loss: 0.007778870407491922


 11%|█         | 11/100 [01:05<07:23,  4.98s/it]

Epoch: 10 | disc_loss: 3.066908597946167 | gen_loss: 0.00828565750271082


 12%|█▏        | 12/100 [01:09<06:41,  4.56s/it]

Epoch: 11 | disc_loss: 0.0495523065328598 | gen_loss: -0.12251190096139908


 13%|█▎        | 13/100 [01:12<06:10,  4.26s/it]

Epoch: 12 | disc_loss: 0.3640748858451843 | gen_loss: 0.016781862825155258


 14%|█▍        | 14/100 [01:19<06:57,  4.86s/it]

Epoch: 13 | disc_loss: 0.3377641439437866 | gen_loss: -0.0050474717281758785


 15%|█▌        | 15/100 [01:24<07:00,  4.95s/it]

Epoch: 14 | disc_loss: 0.02836805209517479 | gen_loss: -0.00970156118273735


 16%|█▌        | 16/100 [01:27<06:20,  4.53s/it]

Epoch: 15 | disc_loss: 0.12131357192993164 | gen_loss: -0.030027253553271294


 17%|█▋        | 17/100 [01:31<05:53,  4.25s/it]

Epoch: 16 | disc_loss: 0.39386725425720215 | gen_loss: 0.00486451480537653


 18%|█▊        | 18/100 [01:36<06:21,  4.65s/it]

Epoch: 17 | disc_loss: 0.09345854818820953 | gen_loss: 0.0037411837838590145


 19%|█▉        | 19/100 [01:42<06:45,  5.01s/it]

Epoch: 18 | disc_loss: 0.2546987533569336 | gen_loss: 0.03698151931166649


 20%|██        | 20/100 [01:46<06:05,  4.57s/it]

Epoch: 19 | disc_loss: 1.7540315389633179 | gen_loss: -0.01199072040617466


 21%|██        | 21/100 [01:51<06:18,  4.79s/it]

Epoch: 20 | disc_loss: 0.035447411239147186 | gen_loss: 0.011835521087050438


 22%|██▏       | 22/100 [02:00<07:40,  5.90s/it]

Epoch: 21 | disc_loss: 0.8302493691444397 | gen_loss: 0.0089015057310462


 23%|██▎       | 23/100 [02:06<07:37,  5.95s/it]

Epoch: 22 | disc_loss: 0.1265329271554947 | gen_loss: 0.006148923188447952


 24%|██▍       | 24/100 [02:09<06:37,  5.23s/it]

Epoch: 23 | disc_loss: 0.2705945670604706 | gen_loss: 0.04426192119717598


 25%|██▌       | 25/100 [02:13<05:55,  4.74s/it]

Epoch: 24 | disc_loss: 0.10411501675844193 | gen_loss: 0.010893119499087334


 26%|██▌       | 26/100 [02:18<05:50,  4.74s/it]

Epoch: 25 | disc_loss: -0.0001646876335144043 | gen_loss: 0.01709415763616562


 27%|██▋       | 27/100 [02:24<06:25,  5.29s/it]

Epoch: 26 | disc_loss: 0.15110567212104797 | gen_loss: 0.05479712039232254


 28%|██▊       | 28/100 [02:28<05:44,  4.79s/it]

Epoch: 27 | disc_loss: 1.023840069770813 | gen_loss: 0.02327628806233406


 29%|██▉       | 29/100 [02:31<05:14,  4.42s/it]

Epoch: 28 | disc_loss: 1.357872724533081 | gen_loss: 0.006842650938779116


 30%|███       | 30/100 [02:35<05:02,  4.33s/it]

Epoch: 29 | disc_loss: 0.07182635366916656 | gen_loss: -0.007747280411422253


 31%|███       | 31/100 [02:42<05:45,  5.01s/it]

Epoch: 30 | disc_loss: 0.7223649621009827 | gen_loss: 0.024072939530014992


 32%|███▏      | 32/100 [02:46<05:25,  4.79s/it]

Epoch: 31 | disc_loss: 0.0013353079557418823 | gen_loss: 0.020099354907870293


 33%|███▎      | 33/100 [02:50<04:55,  4.42s/it]

Epoch: 32 | disc_loss: 0.043597858399152756 | gen_loss: -0.026953300461173058


 34%|███▍      | 34/100 [02:53<04:34,  4.16s/it]

Epoch: 33 | disc_loss: 0.14575040340423584 | gen_loss: 0.045586276799440384


 35%|███▌      | 35/100 [03:00<05:15,  4.85s/it]

Epoch: 34 | disc_loss: 0.9078197479248047 | gen_loss: 0.045705895870923996


 36%|███▌      | 36/100 [03:06<05:34,  5.22s/it]

Epoch: 35 | disc_loss: 0.3312055468559265 | gen_loss: -0.032862596213817596


 37%|███▋      | 37/100 [03:10<05:05,  4.85s/it]

Epoch: 36 | disc_loss: 0.12296368926763535 | gen_loss: -0.05694831162691116


 38%|███▊      | 38/100 [03:14<04:37,  4.48s/it]

Epoch: 37 | disc_loss: 0.0018267761915922165 | gen_loss: 0.04898812249302864


 39%|███▉      | 39/100 [03:20<05:12,  5.11s/it]

Epoch: 38 | disc_loss: -0.033626988530159 | gen_loss: 0.059933651238679886


 40%|████      | 40/100 [03:25<04:59,  4.98s/it]

Epoch: 39 | disc_loss: 0.00864245556294918 | gen_loss: -0.07693919539451599


 41%|████      | 41/100 [03:28<04:28,  4.56s/it]

Epoch: 40 | disc_loss: 0.16897907853126526 | gen_loss: 0.09236659109592438


 42%|████▏     | 42/100 [03:32<04:07,  4.26s/it]

Epoch: 41 | disc_loss: -0.04144293814897537 | gen_loss: 0.12478187680244446


 43%|████▎     | 43/100 [03:38<04:36,  4.85s/it]

Epoch: 42 | disc_loss: -0.0003387089818716049 | gen_loss: 0.05130815878510475


 44%|████▍     | 44/100 [03:44<04:40,  5.01s/it]

Epoch: 43 | disc_loss: 0.05962885916233063 | gen_loss: 0.08124895393848419


 45%|████▌     | 45/100 [03:50<05:01,  5.48s/it]

Epoch: 44 | disc_loss: -0.010607998818159103 | gen_loss: 0.07069558650255203


 46%|████▌     | 46/100 [03:57<05:22,  5.97s/it]

Epoch: 45 | disc_loss: 0.10526838153600693 | gen_loss: 0.09775341302156448


 47%|████▋     | 47/100 [04:03<05:14,  5.93s/it]

Epoch: 46 | disc_loss: 0.014945922419428825 | gen_loss: 0.031948503106832504


 48%|████▊     | 48/100 [04:07<04:31,  5.23s/it]

Epoch: 47 | disc_loss: 0.7949094772338867 | gen_loss: 0.1268710196018219


 49%|████▉     | 49/100 [04:10<04:01,  4.73s/it]

Epoch: 48 | disc_loss: 0.07438699901103973 | gen_loss: 0.07521533966064453


 50%|█████     | 50/100 [04:15<03:59,  4.79s/it]

Epoch: 49 | disc_loss: 0.1688183695077896 | gen_loss: 0.0524221733212471


 51%|█████     | 51/100 [04:22<04:18,  5.28s/it]

Epoch: 50 | disc_loss: 0.07170755416154861 | gen_loss: 0.10440120100975037


 52%|█████▏    | 52/100 [04:25<03:49,  4.77s/it]

Epoch: 51 | disc_loss: 0.08586933463811874 | gen_loss: 0.08297252655029297


 53%|█████▎    | 53/100 [04:29<03:26,  4.40s/it]

Epoch: 52 | disc_loss: -0.01404098141938448 | gen_loss: 0.06771798431873322


 54%|█████▍    | 54/100 [04:34<03:28,  4.54s/it]

Epoch: 53 | disc_loss: -0.021054089069366455 | gen_loss: 0.06616050750017166


 55%|█████▌    | 55/100 [04:40<03:51,  5.15s/it]

Epoch: 54 | disc_loss: 0.29449260234832764 | gen_loss: 0.023491661995649338


 56%|█████▌    | 56/100 [04:44<03:28,  4.73s/it]

Epoch: 55 | disc_loss: 0.09139303863048553 | gen_loss: 0.03711170330643654


 57%|█████▋    | 57/100 [04:48<03:08,  4.39s/it]

Epoch: 56 | disc_loss: 0.05232944339513779 | gen_loss: 0.06142429634928703


 58%|█████▊    | 58/100 [04:52<02:59,  4.29s/it]

Epoch: 57 | disc_loss: -0.010441907681524754 | gen_loss: 0.018718264997005463


 59%|█████▉    | 59/100 [04:58<03:24,  4.98s/it]

Epoch: 58 | disc_loss: 0.08355944603681564 | gen_loss: 0.01910744234919548


 60%|██████    | 60/100 [05:02<03:10,  4.75s/it]

Epoch: 59 | disc_loss: 0.6631407141685486 | gen_loss: 0.023774871602654457


 61%|██████    | 61/100 [05:06<02:51,  4.41s/it]

Epoch: 60 | disc_loss: 1.1417481899261475 | gen_loss: 0.02805725857615471


 62%|██████▏   | 62/100 [05:10<02:38,  4.16s/it]

Epoch: 61 | disc_loss: -0.01678720861673355 | gen_loss: 0.05322450399398804


 63%|██████▎   | 63/100 [05:16<03:00,  4.87s/it]

Epoch: 62 | disc_loss: 0.30283495783805847 | gen_loss: 0.027180438861250877


 64%|██████▍   | 64/100 [05:21<02:54,  4.85s/it]

Epoch: 63 | disc_loss: 0.11060860753059387 | gen_loss: 0.023431651294231415


 65%|██████▌   | 65/100 [05:24<02:36,  4.47s/it]

Epoch: 64 | disc_loss: 0.08240535855293274 | gen_loss: 0.006494541186839342


 66%|██████▌   | 66/100 [05:28<02:22,  4.19s/it]

Epoch: 65 | disc_loss: 0.03296396881341934 | gen_loss: -0.10308728367090225


 67%|██████▋   | 67/100 [05:34<02:35,  4.72s/it]

Epoch: 66 | disc_loss: 0.21708287298679352 | gen_loss: -0.13572800159454346


 68%|██████▊   | 68/100 [05:40<02:38,  4.95s/it]

Epoch: 67 | disc_loss: 0.016651568934321404 | gen_loss: 0.04942023754119873


 69%|██████▉   | 69/100 [05:43<02:20,  4.54s/it]

Epoch: 68 | disc_loss: -0.03906470537185669 | gen_loss: 0.053722500801086426


 70%|███████   | 70/100 [05:47<02:08,  4.27s/it]

Epoch: 69 | disc_loss: -0.0019920766353607178 | gen_loss: 0.07263438403606415


 71%|███████   | 71/100 [05:52<02:12,  4.57s/it]

Epoch: 70 | disc_loss: 0.12177841365337372 | gen_loss: 0.051899366080760956


 72%|███████▏  | 72/100 [05:58<02:19,  5.00s/it]

Epoch: 71 | disc_loss: 1.3439832925796509 | gen_loss: 0.05882921814918518


 73%|███████▎  | 73/100 [06:02<02:03,  4.57s/it]

Epoch: 72 | disc_loss: 0.016121579334139824 | gen_loss: 0.017661312595009804


 74%|███████▍  | 74/100 [06:08<02:12,  5.08s/it]

Epoch: 73 | disc_loss: 1.797581672668457 | gen_loss: 0.030834412202239037


 75%|███████▌  | 75/100 [06:15<02:21,  5.67s/it]

Epoch: 74 | disc_loss: -0.01204914040863514 | gen_loss: 0.010493924841284752


 76%|███████▌  | 76/100 [06:21<02:17,  5.74s/it]

Epoch: 75 | disc_loss: 0.26364827156066895 | gen_loss: 0.058126404881477356


 77%|███████▋  | 77/100 [06:24<01:57,  5.10s/it]

Epoch: 76 | disc_loss: 0.04875367507338524 | gen_loss: 0.08208627998828888


 78%|███████▊  | 78/100 [06:28<01:42,  4.64s/it]

Epoch: 77 | disc_loss: -0.001018250361084938 | gen_loss: 0.0772184506058693


 79%|███████▉  | 79/100 [06:33<01:39,  4.73s/it]

Epoch: 78 | disc_loss: -0.019305303692817688 | gen_loss: 0.08437075465917587


 80%|████████  | 80/100 [06:39<01:44,  5.24s/it]

Epoch: 79 | disc_loss: -0.00283743254840374 | gen_loss: 0.06827837228775024


 81%|████████  | 81/100 [06:43<01:30,  4.74s/it]

Epoch: 80 | disc_loss: 0.030863050371408463 | gen_loss: 0.07036975026130676


 82%|████████▏ | 82/100 [06:47<01:19,  4.41s/it]

Epoch: 81 | disc_loss: 0.01946014165878296 | gen_loss: 0.034700870513916016


 83%|████████▎ | 83/100 [06:51<01:14,  4.40s/it]

Epoch: 82 | disc_loss: 0.19154199957847595 | gen_loss: 0.023294448852539062


 84%|████████▍ | 84/100 [06:58<01:21,  5.10s/it]

Epoch: 83 | disc_loss: 0.017039582133293152 | gen_loss: 0.015029357746243477


 85%|████████▌ | 85/100 [07:02<01:11,  4.76s/it]

Epoch: 84 | disc_loss: 0.48255428671836853 | gen_loss: 0.03764244168996811


 86%|████████▌ | 86/100 [07:05<01:01,  4.41s/it]

Epoch: 85 | disc_loss: 0.06391225755214691 | gen_loss: 0.04219461604952812


 87%|████████▋ | 87/100 [07:09<00:55,  4.24s/it]

Epoch: 86 | disc_loss: 0.19353923201560974 | gen_loss: -0.027790004387497902


 88%|████████▊ | 88/100 [07:16<00:59,  4.95s/it]

Epoch: 87 | disc_loss: -0.004985250998288393 | gen_loss: 0.02480296976864338


 89%|████████▉ | 89/100 [07:20<00:53,  4.83s/it]

Epoch: 88 | disc_loss: 0.10746642202138901 | gen_loss: 0.041117288172245026


 90%|█████████ | 90/100 [07:24<00:44,  4.46s/it]

Epoch: 89 | disc_loss: 0.0556815043091774 | gen_loss: 0.03648800402879715


 91%|█████████ | 91/100 [07:27<00:37,  4.19s/it]

Epoch: 90 | disc_loss: 0.12254973500967026 | gen_loss: 0.030030306428670883


 92%|█████████▏| 92/100 [07:34<00:38,  4.84s/it]

Epoch: 91 | disc_loss: 0.11880950629711151 | gen_loss: 0.03800388053059578


 93%|█████████▎| 93/100 [07:39<00:34,  4.90s/it]

Epoch: 92 | disc_loss: 0.30507418513298035 | gen_loss: 0.05053644999861717


 94%|█████████▍| 94/100 [07:42<00:27,  4.51s/it]

Epoch: 93 | disc_loss: 0.03834468871355057 | gen_loss: 0.03819749504327774


 95%|█████████▌| 95/100 [07:46<00:21,  4.24s/it]

Epoch: 94 | disc_loss: 0.002075091004371643 | gen_loss: 0.04406016319990158


 96%|█████████▌| 96/100 [07:52<00:18,  4.70s/it]

Epoch: 95 | disc_loss: 0.03771727532148361 | gen_loss: -0.0015200810739770532


 97%|█████████▋| 97/100 [07:57<00:14,  4.94s/it]

Epoch: 96 | disc_loss: -0.008465536870062351 | gen_loss: -0.019655562937259674


 98%|█████████▊| 98/100 [08:02<00:09,  4.94s/it]

Epoch: 97 | disc_loss: 0.38708868622779846 | gen_loss: -0.13478681445121765


 99%|█████████▉| 99/100 [08:09<00:05,  5.51s/it]

Epoch: 98 | disc_loss: 0.060236405581235886 | gen_loss: -0.08612337708473206


100%|██████████| 100/100 [08:16<00:00,  4.96s/it]

Epoch: 99 | disc_loss: 0.0180684644728899 | gen_loss: 0.0048463596031069756





In [24]:
#########################################################
#    Loading and sampling from a trained synthesizer    #
#########################################################
synth = RegularSynthesizer.load('adult_wgangp_model.pkl')
synth_data = synth.sample(38000)

Synthetic data generation: 100%|██████████| 77/77 [00:00<00:00, 94.39it/s]


In [25]:
synth_data[synth_data['churn']==1]

Unnamed: 0,user_lifetime,user_intake,user_no_outgoing_activity_in_days,user_account_balance_last,user_spendings,user_has_outgoing_calls,user_has_outgoing_sms,user_use_gprs,user_does_reload,reloads_inactive_days,...,last_100_calls_outgoing_duration,last_100_calls_outgoing_to_onnet_duration,last_100_calls_outgoing_to_offnet_duration,last_100_calls_outgoing_to_abroad_duration,last_100_sms_outgoing_count,last_100_sms_outgoing_to_onnet_count,last_100_sms_outgoing_to_offnet_count,last_100_sms_outgoing_to_abroad_count,last_100_gprs_usage,churn
2,6518,0,182,8.709260,-22.963715,0,0,0,0,350,...,141.651535,-18.718710,87.907143,-15.841368,-33,-42,-258,7,-34.904579,1
6,2970,0,145,12.333520,-1.041637,0,0,0,0,444,...,81.338493,-38.994225,129.246521,-5.862953,-108,-17,-360,32,-65.884804,1
9,7270,0,219,7.319665,-12.019320,0,0,0,0,377,...,144.639084,-25.582415,149.867935,-24.654305,23,-35,-328,46,-70.366508,1
10,3821,0,233,14.314448,4.186330,0,0,0,0,539,...,74.776917,-41.422882,156.227737,-17.509182,-288,8,-458,50,-79.814507,1
14,5933,0,194,8.818342,-10.091899,0,0,0,0,410,...,67.116982,-17.934464,196.635132,-37.047180,9,-35,-250,39,-56.644238,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
38491,6222,0,233,13.073122,-10.093932,0,0,0,0,384,...,100.903358,-23.744555,125.060020,-20.725622,-168,-17,-531,18,-29.317562,1
38492,1071,0,139,13.415490,5.788873,0,0,0,0,479,...,-84.989594,-43.058472,79.640549,-6.802803,-437,15,-349,39,-79.180756,1
38493,2114,0,120,11.050632,2.616423,0,0,0,0,412,...,25.912218,-30.549259,114.815948,-11.859159,-322,-10,-126,33,-77.665016,1
38494,8350,0,249,2.834426,-21.275934,0,0,0,0,276,...,247.317657,-5.156565,79.180382,-30.424553,67,-30,-397,-4,-74.148140,1


In [26]:
#concat original data and gan data
data_concat = pd.concat([df, synth_data])
# combine data churn and not churn
data=pd.concat([df, data_concat])

In [27]:
data

Unnamed: 0,user_lifetime,user_intake,user_no_outgoing_activity_in_days,user_account_balance_last,user_spendings,user_has_outgoing_calls,user_has_outgoing_sms,user_use_gprs,user_does_reload,reloads_inactive_days,...,last_100_calls_outgoing_duration,last_100_calls_outgoing_to_onnet_duration,last_100_calls_outgoing_to_offnet_duration,last_100_calls_outgoing_to_abroad_duration,last_100_sms_outgoing_count,last_100_sms_outgoing_to_onnet_count,last_100_sms_outgoing_to_offnet_count,last_100_sms_outgoing_to_abroad_count,last_100_gprs_usage,churn
0,1000,0,1,0.050000,0.000000,1,1,0,0,66,...,75.270000,0.000000,63.430000,0.000000,210,1,84,0,0.000000,0
1,1000,0,25,28.310000,3.450000,1,0,0,0,1276,...,13.380000,11.180000,2.000000,11.180000,0,0,0,0,0.000000,0
2,1005,0,8,15.620000,1.970000,1,0,0,0,1276,...,30.000000,0.000000,0.000000,10.450000,0,0,0,0,0.000000,0
3,1013,0,11,5.620000,0.000000,1,0,0,0,1276,...,0.000000,0.000000,0.000000,0.000000,0,0,0,0,0.000000,1
4,1032,0,2,5.860000,0.150000,1,0,0,1,17,...,2.580000,0.000000,1.000000,0.000000,0,0,0,0,0.000000,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
38495,4039,0,194,14.831997,-0.631287,0,0,0,0,503,...,131.384308,-38.928391,60.383308,-6.202411,-163,4,-439,22,-77.923729,0
38496,3614,0,123,9.457420,1.139228,0,0,0,0,418,...,-4.257732,-15.151215,103.011719,-34.145947,-258,-32,-449,57,-115.211189,0
38497,5080,0,216,14.709961,-2.131133,0,0,0,0,543,...,130.212753,-30.515043,253.926071,-30.014433,-381,7,-168,63,-54.951759,1
38498,1533,0,95,15.540998,9.002070,0,0,0,0,478,...,-74.228493,-45.623802,132.119141,-1.750406,-498,2,-357,49,-96.746460,0


In [28]:
# -*- coding: utf-8 -*-
"""
Created on Mon Jun  1 17:31:40 2020

@author: manav

Modifed on 23 AUG 2022

by mahayasa adiputra
"""

import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import f1_score
from sklearn.metrics import auc
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix , accuracy_score
from sklearn.metrics import roc_curve, roc_auc_score
import sklearn.metrics as mt
from imblearn.under_sampling import EditedNearestNeighbours
from imblearn.under_sampling import TomekLinks
from imblearn.under_sampling import NeighbourhoodCleaningRule
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import KNeighborsClassifier
from numpy import mean
from numpy import std
from sklearn.metrics import make_scorer
from imblearn.metrics import specificity_score
import math
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import GradientBoostingClassifier
import time

start1=time.time()

X=data.drop(['churn'],axis=1)
y=data["churn"]

enn = EditedNearestNeighbours(n_neighbors=3)
X, y = enn.fit_resample(X, y)
#ncr = NeighbourhoodCleaningRule(n_neighbors=5, kind_sel='all')
#X, y = ncr.fit_resample(X, y)
#tomek_links = TomekLinks()
#X, y = tomek_links.fit_resample(X, y)

#kfold cross validation
cv = KFold(n_splits=5, random_state=1, shuffle=True)
end1 = time.time()
print("The time of execution of preprocess:",
      (end1-start1), "s")

The time of execution of preprocess: 232.7059919834137 s


In [29]:
startknn=time.time()
classifier=KNeighborsClassifier()
score=cross_val_score(classifier, X, y, scoring='f1', cv=cv, n_jobs=-1)
print('===============KNN Performance====================')

print('F1 score: %.3f' % (mean(score)))
print('STD F1 Score: %.3f' % (std(score)))

rc=cross_val_score(classifier, X, y, scoring='recall', cv=cv, n_jobs=-1)
print('Recall: %.3f' % (mean(rc)))


# Define the specificity scorer
scorer = make_scorer(specificity_score)

# Calculate the cross-validated specificity score
sp = cross_val_score(classifier, X, y, cv=5, scoring=scorer)
print('Specitifity: %.3f' % (mean(sp)))

auc=cross_val_score(classifier, X, y, scoring='roc_auc', cv=cv, n_jobs=-1)
print('AUC ROC: %.3f' % (mean(auc)))

gmean=rc*sp
gmean=mean(gmean)
sqrtg = math.sqrt(gmean)
print('G-Mean: %.3f' % (sqrtg))
print('======================================================')

endknn = time.time()
print("The time of execution of knn:",
      (endknn-startknn), "s")

F1 score: 0.891
STD F1 Score: 0.004
Recall: 0.834
Specitifity: 0.827
AUC ROC: 0.972
G-Mean: 0.830
The time of execution of knn: 798.3735768795013 s


In [30]:
start3=time.time()

classifier = GradientBoostingClassifier()

score=cross_val_score(classifier, X, y, scoring='f1', cv=cv, n_jobs=-1)
print('===============GBM Performance====================')

print('F1 score: %.3f' % (mean(score)))
print('STD F1 Score: %.3f' % (std(score)))

rc=cross_val_score(classifier, X, y, scoring='recall', cv=cv, n_jobs=-1)
print('Recall: %.3f' % (mean(rc)))


# Define the specificity scorer
scorer = make_scorer(specificity_score)

# Calculate the cross-validated specificity score
sp = cross_val_score(classifier, X, y, cv=5, scoring=scorer)
print('Specitifity: %.3f' % (mean(sp)))

auc=cross_val_score(classifier, X, y, scoring='roc_auc', cv=cv, n_jobs=-1)
print('AUC ROC: %.3f' % (mean(auc)))


gmean=rc*sp
gmean=mean(gmean)
sqrtg = math.sqrt(gmean)
print('G-Mean: %.3f' % (sqrtg))
print('======================================================')

end3 = time.time()
print("The time of execution of gbm:",
      (end3-start3), "s")





F1 score: 0.868
STD F1 Score: 0.003
Recall: 0.827
Specitifity: 0.809
AUC ROC: 0.962
G-Mean: 0.817
The time of execution of gbm: 2404.5905771255493 s


In [31]:
startdt=time.time()
classifier=DecisionTreeClassifier()
score=cross_val_score(classifier, X, y, scoring='f1', cv=cv, n_jobs=-1)
print('===============DT Performance====================')

print('F1 score: %.3f' % (mean(score)))
print('STD F1 Score: %.3f' % (std(score)))

rc=cross_val_score(classifier, X, y, scoring='recall', cv=cv, n_jobs=-1)
print('Recall: %.3f' % (mean(rc)))


# Define the specificity scorer
scorer = make_scorer(specificity_score)

# Calculate the cross-validated specificity score
sp = cross_val_score(classifier, X, y, cv=5, scoring=scorer)
print('Specitifity: %.3f' % (mean(sp)))

auc=cross_val_score(classifier, X, y, scoring='roc_auc', cv=cv, n_jobs=-1)
print('AUC ROC: %.3f' % (mean(auc)))

gmean=rc*sp
gmean=mean(gmean)
sqrtg = math.sqrt(gmean)
print('G-Mean: %.3f' % (sqrtg))
print('======================================================')

enddt = time.time()
print("The time of execution of dt:",
      (enddt-startdt), "s")

F1 score: 0.933
STD F1 Score: 0.003
Recall: 0.933
Specitifity: 0.843
AUC ROC: 0.955
G-Mean: 0.887
The time of execution of dt: 175.36278295516968 s


In [32]:
start5 = time.time()
classifier=GaussianNB()
score=cross_val_score(classifier, X, y, scoring='f1', cv=cv, n_jobs=-1)
print('===============NB Performance====================')

print('F1 score: %.3f' % (mean(score)))
print('STD F1 Score: %.3f' % (std(score)))

rc=cross_val_score(classifier, X, y, scoring='recall', cv=cv, n_jobs=-1)
print('Recall: %.3f' % (mean(rc)))


# Define the specificity scorer
scorer = make_scorer(specificity_score)

# Calculate the cross-validated specificity score
sp = cross_val_score(classifier, X, y, cv=5, scoring=scorer)
print('Specitifity: %.3f' % (mean(sp)))

auc=cross_val_score(classifier, X, y, scoring='roc_auc', cv=cv, n_jobs=-1)
print('AUC ROC: %.3f' % (mean(auc)))

gmean=rc*sp
gmean=mean(gmean)
sqrtg = math.sqrt(gmean)
print('G-Mean: %.3f' % (sqrtg))
print('======================================================')

end5 = time.time()
print("The time of execution of NB:",
      (end5-start5), "s")

F1 score: 0.760
STD F1 Score: 0.004
Recall: 0.886
Specitifity: 0.728
AUC ROC: 0.922
G-Mean: 0.802
The time of execution of NB: 6.657418966293335 s


In [None]:


start2=time.time()
#Training XGBoost
#classifier = XGBClassifier(eta=0.3, max_depth = 4, gamma=0, min_child_weight=1)
classifier = XGBClassifier()

score=cross_val_score(classifier, X, y, scoring='f1', cv=cv, n_jobs=-1)
print('===============XGBoost Performance====================')

print('F1 score: %.3f' % (mean(score)))
print('STD F1 Score: %.3f' % (std(score)))

rc=cross_val_score(classifier, X, y, scoring='recall', cv=cv, n_jobs=-1)
print('Recall: %.3f' % (mean(rc)))



# Define the specificity scorer
scorer = make_scorer(specificity_score)

# Calculate the cross-validated specificity score
sp = cross_val_score(classifier, X, y, cv=5, scoring=scorer)
print('Specitifity: %.3f' % (mean(sp)))

auc=cross_val_score(classifier, X, y, scoring='roc_auc', cv=cv, n_jobs=-1)
print('AUC ROC: %.3f' % (mean(auc)))


gmean=rc*sp
gmean=mean(gmean)
sqrtg = math.sqrt(gmean)
print('G-Mean: %.3f' % (sqrtg))
print('======================================================')

end2 = time.time()
print("The time of execution of XGBOOST:",
      (end2-start2), "s")

F1 score: 0.887
STD F1 Score: 0.003
Recall: 0.846
Specitifity: 0.815
AUC ROC: 0.977
G-Mean: 0.830
The time of execution of XGBOOST: 1354.87202835083 s


In [None]:
start3=time.time()

from sklearn.ensemble import RandomForestClassifier

classifier = RandomForestClassifier()

score=cross_val_score(classifier, X, y, scoring='f1', cv=cv, n_jobs=-1)
print('===============Random Forest Performance====================')

print('F1 score: %.3f' % (mean(score)))
print('STD F1 Score: %.3f' % (std(score)))

rc=cross_val_score(classifier, X, y, scoring='recall', cv=cv, n_jobs=-1)
print('Recall: %.3f' % (mean(rc)))

from sklearn.metrics import make_scorer
from imblearn.metrics import specificity_score

# Define the specificity scorer
scorer = make_scorer(specificity_score)

# Calculate the cross-validated specificity score
sp = cross_val_score(classifier, X, y, cv=5, scoring=scorer)
print('Specitifity: %.3f' % (mean(sp)))

auc=cross_val_score(classifier, X, y, scoring='roc_auc', cv=cv, n_jobs=-1)
print('AUC ROC: %.3f' % (mean(auc)))

import math

gmean=rc*sp
gmean=mean(gmean)
sqrtg = math.sqrt(gmean)
print('G-Mean: %.3f' % (sqrtg))
print('======================================================')

end3 = time.time()
print("The time of execution of random forest:",
      (end3-start3), "s")

F1 score: 0.960
STD F1 Score: 0.002
Recall: 0.943
Specitifity: 0.841
AUC ROC: 0.994
G-Mean: 0.891
The time of execution of random forest: 1154.6475343704224 s
