# Multiclass LSTM classification model

In this notebook, we've replicated the multiclass LSTM model for the classification of various DGA categories a domain name may belong to; based on the Endgame paper:

"Predicting Domain Generation Algorithms with Long Short-Term Memory Networks"
http://arxiv.org/abs/1611.00791v1


In [43]:
import pandas as pd
import pickle
import json
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

from keras.models import Sequential, model_from_json
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM
from keras.preprocessing import sequence
from keras.preprocessing import text
from keras.utils import to_categorical

from tensorflow.python.client import device_lib

In [22]:
# Confirm gpu is being picked up for acceleration
print(device_lib.list_local_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 15068813791508331409
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 2572488704
locality {
  bus_id: 1
  links {
  }
}
incarnation: 9104953581547433324
physical_device_desc: "device: 0, name: GeForce GTX 780M, pci bus id: 0000:01:00.0, compute capability: 3.0"
]


In [23]:
# Path and file variables for saving model information
path_dir = '.\\saved_models\\'
name_encoder      = path_dir + 'multiclass_tokenizer.pkl'
name_model        = path_dir + 'multiclass_LSTM.json'
name_weights      = path_dir + 'multiclass_LSTM.h5'
name_categories   = path_dir + 'multiclass_categories.pkl'
name_m_report     = path_dir + 'multiclass_metrics_report'
name_c_report     = path_dir + 'multiclass_class_report'

# format and report dump switches
dump_pred_results    = 0x03                                 # bitmask switches to dump prediction results: 
                                                            #       0x01: prediction metrics
                                                            #       0x02: domains' mis-classifications table
format_m_report       = 'json'                              # 'json' format only, csv doesn't fit correctly here
format_c_report       = ['json', 'csv']                     # atleast 1 of: 'csv' or 'json'

# Category names
name_nonDGA = 'non-DGA'

In [24]:
# Read DGA and Cisco high confidence data
dga_df = pd.read_csv('..\\data\\2018_0923\\dga-feed-high.csv', header=None, skiprows=15)
cisco_df = pd.read_csv('..\\data\\2018_0923\\top-1m.csv', header=None)

In [25]:
""" Display head/tail/sample of the DGA and/or nonDGA data frames """
def display_df(dga_df_=None, cisco_df_=None, sample='head', seed=21):
    
    if dga_df_ is not None:
        display("DGA feed sample: {}".format( dga_df_.shape) )
        if sample=='head':
            display(dga_df_.head())
        elif sample=='tail':
            display(dga_df_.tail())
        elif 'sample' in sample:
            cnt = int(sample.strip('sample'))
            display(dga_df_.sample(n=cnt, random_state=seed))
            
    if cisco_df_ is not None:
        display("Cisco feed sample: {}".format( cisco_df_.shape) )
        if sample=='head':
            display(cisco_df_.head())
        elif sample=='tail':
            display(cisco_df_.tail())
        elif 'sample' in sample:
            cnt = int(sample.strip('sample'))
            display(cisco_df_.sample(n=cnt, random_state=seed))
            

In [26]:
display_df(dga_df, cisco_df)

'DGA feed sample: (381953, 4)'

Unnamed: 0,0,1,2,3
0,plvklpgwivery.com,Domain used by Cryptolocker - Flashback DGA fo...,2018-06-23,http://osint.bambenekconsulting.com/manual/cl.txt
1,dnuxdhcgblsgy.net,Domain used by Cryptolocker - Flashback DGA fo...,2018-06-23,http://osint.bambenekconsulting.com/manual/cl.txt
2,qjlullhfkiowp.biz,Domain used by Cryptolocker - Flashback DGA fo...,2018-06-23,http://osint.bambenekconsulting.com/manual/cl.txt
3,elkidddodxdly.ru,Domain used by Cryptolocker - Flashback DGA fo...,2018-06-23,http://osint.bambenekconsulting.com/manual/cl.txt
4,rnbfwuprlwfor.org,Domain used by Cryptolocker - Flashback DGA fo...,2018-06-23,http://osint.bambenekconsulting.com/manual/cl.txt


'Cisco feed sample: (1000000, 2)'

Unnamed: 0,0,1
0,1,netflix.com
1,2,api-global.netflix.com
2,3,prod.netflix.com
3,4,push.prod.netflix.com
4,5,google.com


In [27]:
# Remove unused columns, add output label 'dga'

dga_df_slim =   dga_df.drop(columns=range(2,dga_df.shape[1]), inplace=False)
dga_df_slim.columns = ['domain', 'dga']

cisco_df_slim = cisco_df.drop(columns=[0], inplace=False)
cisco_df_slim.columns = ['domain']
cisco_df_slim['dga'] = name_nonDGA

display_df(dga_df_slim, None, 'sample5')

'DGA feed sample: (381953, 2)'

Unnamed: 0,domain,dga
57569,a94421b9c998056fb42456ad25ea55bfb9.hk,Domain used by dyre DGA for 26 Jun 2018
171315,wjb92vcerh.net,Domain used by shiotob/urlzone/bebloh DGA - no...
267753,vhnvvwx.net,Domain used by pykspa (varying date seeds)
250789,csyyhnyiwejluy.su,Domain used by ranbyus (uses previous 31 days ...
154897,ypuyuvscckuc.pw,Domain used by tinba DGA for 25 Jun 2018


In [28]:
SUFFIXES = [' DGA', ' (', ' -']

""" Extract the DGA categories from the description string """
def strip_cat(input_str_row, lstrip_str="Domain used by ", rtrunc_str=SUFFIXES, verbose=False):
    if verbose:
        print('-'*50, '\nInput:    ', input_str_row['dga'])
    str1 = input_str_row['dga'].replace(lstrip_str, '')
    if verbose:
        print('Lstrip:   ', str1)
    str2 = str1
    for i in rtrunc_str:
        idx = str2.find(i)
        if idx != -1:
            str2 = str2[0:idx]
            if verbose:
                print('Trimmed:  ', str2)
            break
    return str2

In [29]:
# Trim description down to the DGA category names

verbosity = False

dga_df_slim['dga'] = dga_df_slim.apply(lambda row: strip_cat(row, verbose=verbosity), axis=1)

display_df(dga_df_slim, None, 'sample5')


'DGA feed sample: (381953, 2)'

Unnamed: 0,domain,dga
57569,a94421b9c998056fb42456ad25ea55bfb9.hk,dyre
171315,wjb92vcerh.net,shiotob/urlzone/bebloh
267753,vhnvvwx.net,pykspa
250789,csyyhnyiwejluy.su,ranbyus
154897,ypuyuvscckuc.pw,tinba


In [30]:
# Extract unique DGA categories

categories = list(dga_df_slim['dga'].unique())
print("Categories of DGA domains: {}\n".format(len(categories)))
print(categories)
categories.append('nonDGA')
print("\nTotal output classes will be: {}\n".format(len(categories)))
num_categories = len(categories)

Categories of DGA domains: 43

['Cryptolocker - Flashback', 'Post Tovar GOZ', 'geodo', 'dyre', 'corebot', 'symmi', 'padcrypt', 'locky', 'tinba', 'pushdo', 'P2P Gameover Zeus', 'shiotob/urlzone/bebloh', 'hesperbot', 'cryptowall', 'ramnit', 'dircrypt', 'ranbyus', 'pykspa', 'murofet', 'Volatile Cedar / Explosive', 'beebone', 'bedep', 'fobber', 'necurs', 'qakbot', 'tempedreve', 'ramdo', 'kraken', 'bamital', 'vawtrak', 'sisron', 'chinad', 'gozi', 'sphinx', 'proslikefan', 'vidro', 'madmax', 'dromedan', 'g01', 'pandabanker', 'mirai', 'unknownjs', 'unknowndropper']

Total output classes will be: 44



In [31]:
# Check skewness in the dataset with respect to DGA categories

counts = dga_df_slim['dga'].value_counts().to_frame()
print("\nMost frequent categories:", end='')
display(counts.head())
print("\nLeast frequent categories:", end='')
display(counts.tail())


Most frequent categories:

Unnamed: 0,dga
tinba,66688
Post Tovar GOZ,66000
ramnit,56174
necurs,43008
murofet,28520



Least frequent categories:

Unnamed: 0,dga
gozi,24
mirai,3
dromedan,2
madmax,1
g01,1


The above data shows heavy skew and is highly imbalanced for more than half the lower DGA categories in the sorted list. This might affect the training and prediction of the model for these categories.

In [32]:
THRESHOLD_COUNT = 50                     # lower count limit beyond which categories are merged together
MERGED_CAT_STR = 'mergedDGA'             # name of merged category

MERGED_CAT_LIST = []
""" Trim down the categories that have sparse data, and merge them into one """
def trim_categories(input_row, threshold=THRESHOLD_COUNT):
    if input_row['dga'] < threshold:
        newcat = MERGED_CAT_STR
        MERGED_CAT_LIST.append(input_row.name)
    else:
        newcat = input_row.name
    return newcat


In [33]:
counts['newCat'] = counts.apply(trim_categories, axis=1)
print("Merged categories: ", MERGED_CAT_LIST)
display(counts)

Merged categories:  ['pandabanker', 'gozi', 'mirai', 'dromedan', 'madmax', 'g01']


Unnamed: 0,dga,newCat
tinba,66688,tinba
Post Tovar GOZ,66000,Post Tovar GOZ
ramnit,56174,ramnit
necurs,43008,necurs
murofet,28520,murofet
ranbyus,26040,ranbyus
qakbot,20000,qakbot
pykspa,14215,pykspa
shiotob/urlzone/bebloh,12521,shiotob/urlzone/bebloh
kraken,8988,kraken


In [34]:
# Update DGA frame with new categories

def update_categories(input_row):
    if input_row['dga'] in MERGED_CAT_LIST:
        return MERGED_CAT_STR
    else:
        return input_row['dga']

dga_df_slim['dga'] = dga_df_slim.apply(update_categories, axis=1)


In [35]:
display(dga_df_slim['dga'].unique())

array(['Cryptolocker - Flashback', 'Post Tovar GOZ', 'geodo', 'dyre',
       'corebot', 'symmi', 'padcrypt', 'locky', 'tinba', 'pushdo',
       'P2P Gameover Zeus', 'shiotob/urlzone/bebloh', 'hesperbot',
       'cryptowall', 'ramnit', 'dircrypt', 'ranbyus', 'pykspa', 'murofet',
       'Volatile Cedar / Explosive', 'beebone', 'bedep', 'fobber',
       'necurs', 'qakbot', 'tempedreve', 'ramdo', 'kraken', 'bamital',
       'vawtrak', 'sisron', 'chinad', 'mergedDGA', 'sphinx',
       'proslikefan', 'vidro', 'unknownjs', 'unknowndropper'],
      dtype=object)

In [36]:
# Combine DGA/nonDGA dataframes, and factorize categories (mapping to integer indices)

unified_df = pd.concat([cisco_df_slim, dga_df_slim], ignore_index=True)
temp_df = unified_df.copy()
unified_df['catIndex'], labels = pd.factorize(unified_df['dga'], sort=True)
num_categories = len(labels)

print("Final number of categories to be used in model: ", num_categories)
with open(name_categories, 'wb') as catEnc:
    pickle.dump(labels, catEnc, protocol=pickle.HIGHEST_PROTOCOL)


Final number of categories to be used in model:  39


In [37]:
# Separate input sequences (domains) and output labels (DGA 0/1), and do train/test split

X = unified_df['domain']
Y = unified_df['catIndex']
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2,random_state=23)
Y_train_binarized = to_categorical(Y_train, num_classes=num_categories)

In [38]:
# Multiclass LSTM model

TRAIN_MODEL = False                                          # Load saved model otherwise
max_features = 1000                                          # length of vocabulary
batch_size = 128                                             # input batch size
num_epochs = 5                                               # epochs to train
num_labels = num_categories                                  # final number of output classes, after potentially merging DGA categories 
    
if TRAIN_MODEL == False:
    file = open(name_model, 'r')
    model_load = file.read()
    file.close()
    model = model_from_json(model_load)
    model.load_weights(name_weights)
    with open(name_encoder, 'rb') as tokenEnc:
        encoder = pickle.load(tokenEnc)
    with open(name_categories, 'rb') as catEnc:
        labels = pickle.load(catEnc)

    print('MODEL TRAINING SKIPPED.\nSAVED MODEL IS NOW LOADED!')

else:                                                        # train the model
    # encode string characters to integers
    encoder = text.Tokenizer(num_words=500, char_level=True)
    encoder.fit_on_texts(X_train)                            # build character indices
    X_train_tz = encoder.texts_to_sequences(X_train)
    
    # Model definition - this is the core model from Endgame
    model=Sequential()
    model.add(Embedding(max_features, 128, input_length=75))
    model.add(LSTM(128))
    model.add(Dropout(0.5))
    model.add(Dense(num_labels))
    model.add(Activation('softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
    
    # Pad sequence where sequences are case insensitive characters encoded to
    # integers from 0 to number of valid characters55
    X_train_pad=sequence.pad_sequences(X_train_tz, maxlen=75)
    
    # Train where Y_train is 0-1
    model.fit(X_train_pad, Y_train_binarized, batch_size=batch_size, epochs=num_epochs)

MODEL TRAINING SKIPPED.
SAVED MODEL IS NOW LOADED!


In [39]:
# Validation on test dataset

X_test_pad = sequence.pad_sequences(encoder.texts_to_sequences(X_test), maxlen=75)
Y_pred = model.predict_classes(X_test_pad)
Y_pred_prob = model.predict(X_test_pad)

In [40]:
# Inspect a few prediction probabilities

pred_table = X_test.to_frame()
pred_table.columns = ['domain']
pred_table['trueClass'] = [labels[i] for i in Y_test]
pred_table['predClass'] = [labels[i] for i in Y_pred]
pred_table['predProb'] = [Y_pred_prob[idx][Y_pred[idx]] for idx in range(0, Y_pred.shape[0]) ]

print('\nCorrectly predicted Domains:')
display(pred_table[pred_table['trueClass'] == pred_table['predClass'] ].head(10) )

print('\nMis-predicted Domains:')
display(pred_table[pred_table['trueClass'] != pred_table['predClass'] ].head(10) )

pred_table_FP = pred_table[(pred_table['trueClass'] == name_nonDGA) & (pred_table['predClass'] != name_nonDGA) ]
pred_FP_frac = pred_table_FP.shape[0]/pred_table.shape[0]
print('\nPercentage of False Positives (i.e. nonDGA domains classified as DGA): {:6.4f} %'
      .format(100*pred_FP_frac))

pred_table_FN = pred_table[(pred_table['trueClass'] != name_nonDGA) & (pred_table['predClass'] == name_nonDGA) ]
pred_FN_frac = pred_table_FN.shape[0]/pred_table.shape[0]
print('\nPercentage of False Negatives (i.e. DGA domains classified as nonDGA): {:6.4f} %'
      .format(100*pred_FN_frac))


Correctly predicted Domains:


Unnamed: 0,domain,trueClass,predClass,predProb
124546,ns47.domaincontrol.com,non-DGA,non-DGA,1.0
660921,britishlibrary.typepad.co.uk,non-DGA,non-DGA,1.0
446456,a538.casalemedia.com,non-DGA,non-DGA,1.0
600919,ign-ar8de21s8pinm-8d3d0d118-4d8d69dgoogleplayd...,non-DGA,non-DGA,1.0
1186650,gbggekvj.eu,ramnit,ramnit,0.831845
115543,ewr-66.ewr-rtb1.rfihub.com,non-DGA,non-DGA,1.0
1360357,vsagkcaahpxrfbmqljnnxutj.com,qakbot,qakbot,0.5613
464912,static.bladeandsoul.com,non-DGA,non-DGA,1.0
1097547,dlpyniywfxxp.com,tinba,tinba,0.873484
606453,trans11212.addressy.com,non-DGA,non-DGA,1.0



Mis-predicted Domains:


Unnamed: 0,domain,trueClass,predClass,predProb
1359727,uxamuoylbidlktngprh.com,qakbot,ramnit,0.982235
90390,6htb5ck86hk8i9.com,non-DGA,shiotob/urlzone/bebloh,0.590963
610853,wvxlsagkeuye.ir,non-DGA,necurs,0.914297
1259769,gsilnc.net,pykspa,non-DGA,0.870627
1158840,ibcmcycuemsvstbepeybarwpbey.info,P2P Gameover Zeus,murofet,0.676546
1290765,slsykrrahowsxw.net,murofet,ranbyus,0.783558
1342033,yapyerh.su,necurs,non-DGA,0.79247
1030207,vouoqkmqvcrgpb.com,Cryptolocker - Flashback,ramnit,0.547136
1370861,nyxaeltggspy.com,kraken,ramnit,0.896213
1299994,pfqbnmiymhjrjijchb.com,bedep,ramnit,0.983375



Percentage of False Positives (i.e. nonDGA domains classified as DGA): 0.6444 %

Percentage of False Negatives (i.e. DGA domains classified as nonDGA): 0.4577 %


In [41]:
acc = accuracy_score(Y_test, Y_pred)
print("Model's overall accuracy = {:8.3f} %\n".format(acc*100))
metrics_report = classification_report(Y_test, Y_pred, target_names=labels)
print(metrics_report)

Model's overall accuracy =   96.027 %

                            precision    recall  f1-score   support

  Cryptolocker - Flashback       0.50      0.51      0.50      1223
         P2P Gameover Zeus       0.50      0.00      0.01       393
            Post Tovar GOZ       1.00      1.00      1.00     12996
Volatile Cedar / Explosive       0.99      0.97      0.98       204
                   bamital       1.00      1.00      1.00        47
                     bedep       0.00      0.00      0.00        42
                   beebone       0.00      0.00      0.00        36
                    chinad       0.97      0.92      0.95       317
                   corebot       0.97      0.97      0.97        62
                cryptowall       0.00      0.00      0.00        23
                  dircrypt       0.00      0.00      0.00       139
                      dyre       1.00      1.00      1.00      1593
                    fobber       0.40      0.02      0.04        99
        

  'precision', 'predicted', average, warn_for)


In [44]:
# dump classification metrics and FP/PF domains

# accuracy, precision, recall, f1, false positive, false negative
if dump_pred_results & 0x01:
    metrics_report = classification_report(Y_test, Y_pred, target_names=labels, output_dict=True)
    metrics_report['accuracy'] = acc
    metrics_report['false positives'] = pred_FP_frac
    metrics_report['false negatives'] = pred_FN_frac
    
    if format_m_report == 'json':
        fileName = name_m_report + '.' + format_m_report
        with open(fileName, 'w') as filePath:
            json.dump(metrics_report, fp=filePath)
    
# False Positives and False Negatives
if dump_pred_results & 0x02:
    pred_table_FP.insert(0, 'type', 'FP')
    pred_table_FN.insert(0, 'type', 'FN')
    
    for extn in format_c_report:
        filePath = name_c_report + '.' + extn
        if extn == 'csv':
            pred_table_FP.to_csv(filePath, mode='w', index=False, header=True)
            pred_table_FN.to_csv(filePath, mode='a', index=False, header=False)
        elif extn == 'json':
            pred_table_FP.append(pred_table_FN)
            pred_table_FP.to_json(filePath, orient='table', index=False)



  'precision', 'predicted', average, warn_for)


In [21]:
# Save model and weights
if TRAIN_MODEL == True:
    model_save = model.to_json()
    with open(name_model, 'w') as file:
        file.write(model_save)
    model.save_weights(name_weights)
    with open(name_encoder, 'wb') as tokenEnc:
        pickle.dump(encoder, tokenEnc, protocol=pickle.HIGHEST_PROTOCOL)
    print('MODEL SAVED TO DISK!')
else:
    print('MODEL AREADY SAVED TO DISK.')

MODEL AREADY SAVED TO DISK.


## Look ahead and next steps:
__1__ Look closer at the misclassified domains. Any particular DGA category stands out? What do we need to improve? 

__2__ Improving classification accuracy - more balanced dataset especially for the multiclass classification.

__3__ Learning from scratch takes significant time. Need to implement model update in batches of new domain dataset.

__4__ We may get a dynamic dataset with more than 60 categories for example, and code need to step in and trim down the categories to an upper limit say 50 at max. Or this could be implemented as dropping off the categories with inadequate data available, say less than 1000 available domain names.

__5__ Use predict() and argmax() to reduce inference time, instead of using both predict() and predict_classes()