In [None]:
import jsonlines
import pandas as pd
import nltk
import numpy as np
from nltk.stem.snowball import SnowballStemmer
from nltk.probability import FreqDist
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier,AdaBoostClassifier, VotingClassifier
from sklearn.metrics import confusion_matrix
import json
from nltk.sentiment.vader import SentimentIntensityAnalyzer
from sklearn.linear_model import LogisticRegression
import math
from sklearn.metrics import accuracy_score
from sklearn import tree
import sklearn
from nltk.translate.ribes_score import position_of_ngram
from sklearn.model_selection import StratifiedKFold
from sklearn.feature_selection import RFECV
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import roc_auc_score
from sklearn import metrics
import pickle
from sklearn.naive_bayes import GaussianNB
from sklearn.externals import joblib
from sklearn.model_selection import StratifiedShuffleSplit
from nltk.data import load
from sklearn.model_selection import learning_curve
from sklearn.model_selection import ShuffleSplit
import matplotlib.pyplot as plt
%matplotlib inline  
from readability_score.calculators.fleschkincaid import *
# from readability_score.calculators.dalechall import *
tagdict = load('help/tagsets/upenn_tagset.pickle')

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
def readData(instance_path, truth_path):
    instances = []
    truths = []
    
    with jsonlines.open(instance_path) as reader:
        for obj in reader:
            obj['postText'] = obj['postText'][0]
            instances.append(obj)
    
    with jsonlines.open(truth_path) as reader:
        for obj in reader:
            truths.append(obj)
    
    instance = pd.DataFrame.from_dict(instances)
    label = pd.DataFrame.from_dict(truths)
    data = pd.merge(instance, label, on='id')
    
    return data, instance, label

In [None]:
def analyze(data):
    allPostText = data['postText'].get_values()
    allTruthClass = data['truthClass'].get_values()

    labels = np.array(allTruthClass)
    labels[labels == 'no-clickbait'] = 0
    labels[labels == 'clickbait'] = 1
    labels = np.array(labels, dtype='int64')

    POSs = np.array(list(map(generatePosSequence, allPostText)))

    texts = list(map(cleaning, allPostText))
    texts = np.array(texts)
    negTexts = texts[labels == 0]
    posTexts = texts[labels == 1]

    sentTokens = np.array(list(map(generateTokens, texts)))
    negSentTokens = np.array(list(map(generateTokens, negTexts)))
    posSentTokens = np.array(list(map(generateTokens, posTexts)))

    tokensDist, tokens = generateFdist(sentTokens)
    posTokensDist, posTokens = generateFdist(posSentTokens)
    negTokensDist, negTokens = generateFdist(negSentTokens)

    return data['id'], allPostText, labels, POSs, posTokens, negTokens, tokensDist, sentTokens, tokens

In [None]:
def generatePosSequence(s):
#     tokens = nltk.word_tokenize(s)
    tokens = nltk.wordpunct_tokenize(s)
    tags = nltk.pos_tag(tokens)
    return tags

def generateTokens(s):
    stemmer = SnowballStemmer("english")
    wnl = nltk.WordNetLemmatizer()
    tokens = nltk.word_tokenize(s)
#     tokens = [stemmer.stem(wnl.lemmatize(token)) for token in tokens]
    tokens = [stemmer.stem(token) for token in tokens]
    return tokens

def generateFdist(sentTokens):
    fdist = FreqDist()
    tokens = []
    for sentTokens in sentTokens:
        for token in sentTokens:
            fdist[token] += 1
            if token not in tokens:
                tokens.append(token)
    return fdist, tokens

def cleaning(s):
    return s

def findPosPattern(grammar, POSs, patternName):
    rt = []
    cp = nltk.RegexpParser(grammar)
    for sentPOSs in POSs:
#         print('Parsing:{}'.format(sentPOSs))
        tree = cp.parse(sentPOSs)
        for subtree in tree.subtrees():
            if subtree.label() == patternName:
                print(sentPOSs)
                print('==>{}'.format(subtree))
                rt.append(subtree)
            
    return np.array(rt)

def isPosPattern(grammar, sentPOSs, patternName):
    cp = nltk.RegexpParser(grammar)
    tree = cp.parse(sentPOSs)
    for subtree in tree.subtrees():
        if subtree.label() == patternName:
            return True
    return False

def countPosPattern(grammar, sentPOSs, patternName):
    cp = nltk.RegexpParser(grammar)
    tree = cp.parse(sentPOSs)
    count = 0
    for subtree in tree.subtrees():
        if subtree.label() == patternName:
            count += 1
    return count

def findPosPatternIncludeWord(grammar, sentPOSs, patternName, words):
    rt = []
    cp = nltk.RegexpParser(grammar)

    tree = cp.parse(sentPOSs)
    for subtree in tree.subtrees():
        if subtree.label() == patternName:
#                 subtree.pprint
            check = 0
            for POS in subtree.pos():
                for word in words:
                    if word in POS[0]:
                        check += 1
            if (check > 0):
                rt.append(subtree)
                        
    return np.array(rt)

In [None]:
def generatePosUnigramFeatures(ids, POSs, patterns, patternNames):
    df = pd.DataFrame()
    df['id'] = ids
    
    feat = {}
    for patternName in patternNames:
        feat[patternName] = []
        
    for sentPOSs in POSs:
        for i in range(len(patterns)):
            patternName = patternNames[i]
            grammar = r"""{}:{}""".format(patternName, "{<" + patterns[i] + ">}")
            count = countPosPattern(grammar, sentPOSs, patternName)
            feat[patternName].append(count)
            
    for patternName in patternNames:
        df[patternName] = feat[patternName]
        
    return df

In [None]:
def generatePosNgramsFeatures(ids, POSs, m, n):
    df = pd.DataFrame()
    df['id'] = ids
    
    patterns = []
    patternNames = []
    patternsFreq = findTopPosNgrams(POSs, m, n)
    for k,v in patternsFreq:
        patterns.append(list(k))
        patternNames.append('-'.join(list(k)))
        
    feat = {}
    for patternName in patternNames:
        feat[patternName] = []
        
    for sentPOSs in POSs:
        for i in range(len(patterns)):
            patternName = patternNames[i]
            pattern = patterns[i]
            patternx = ["<" + p + ">" for p in pattern]
            grammar = r"""{}:{}""".format(patternName, "{" + "".join(patternx) + "}")
            count = countPosPattern(grammar, sentPOSs, patternName)
            feat[patternName].append(count)
            
    print(len(feat))
    for patternName in patternNames:
        df["POS_" + str(n) + "_gram_" + patternName] = feat[patternName]
        
    return df

In [None]:
# Find top k POS N-grams: n = 2 or n = 3
def findTopPosNgrams(POSs, m, n):  
    fdistPOS = FreqDist()
    for sentPOSs in POSs:
        onlyPOSs = [k[1] for k in sentPOSs]
        
        if (n == 2):
            onlyPOSs = nltk.bigrams(onlyPOSs)
        elif (n == 3):
            onlyPOSs = nltk.trigrams(onlyPOSs)
            
        for onlyPOS in onlyPOSs:
            if ':' not in onlyPOS and '#' not in onlyPOS and '@' not in onlyPOS and '?' not in onlyPOS:
                fdistPOS[onlyPOS] += 1
            
    POSngrams = fdistPOS.most_common(m)
    return POSngrams

In [None]:
def ngramFreqClass(allPostText, n):        
    sentTokens = list(map(generateTokens, allPostText))
    
    fdist = FreqDist()
    for sent in sentTokens:
        if (n == 2):
            ngrams = nltk.bigrams(sent)
        elif (n == 3):
            ngrams = nltk.trigrams(sent)
        elif (n == 1):
            ngrams = sent
            
        for ngram in ngrams:
            fdist[ngram] += 1
            
    return fdist

def ngramFreq(allPostText):        
    sentTokens = list(map(generateTokens, allPostText))
    
    fdist = FreqDist()
    for sent in sentTokens:   
        bigrams = nltk.bigrams(sent)
        for bigram in bigrams:
            fdist[bigram] += 1
            
    return fdist

def generateNgramFeatures(ids, allPostText, sentTokens, k, n):
#     fdistPos = ngramFreq(allPostText, labels, 0)
#     fdistNeg = ngramFreq(allPostText, labels, 1)
#     fdistAll = ngramFreq(allPostText)
    df = pd.DataFrame()
    df['id'] = ids
    
    fdistAll = ngramFreqClass(allPostText, n)
    ngrams = fdistAll.most_common(k)
    for (ngram, freq) in ngrams:
        feat = []
        ngramS = ' '.join(list(ngram))
        for sentToken in sentTokens:
            feat.append(int(isNgramExist(ngramS, sentToken)))
        df[str(n) + '-gram_' + '_'.join(list(ngram))] = feat
        
    return df

def findNgram(ngram, sentToken):
    tokens = generateTokens(ngram)
    pos = position_of_ngram(tuple(tokens), sentToken)
    return pos

def isNgramExist(ngram, sentToken):
    pos = findNgram(ngram, sentToken)
    return pos is not None

In [None]:
def countAverageWordLength(allPostText):
    feat = []
    for text in allPostText:
        tokens = nltk.word_tokenize(text)
        avg = sum (map(len, tokens))/len(tokens) if len(tokens) > 0 else 0
        feat.append(avg)
    return feat

def lengthLongestWord(allPostText):
    feat = []
    for text in allPostText:
        tokens = nltk.word_tokenize(text) 
        if len(tokens) > 0:
            length = len(sorted(tokens, key=lambda x: len(x))[-1])
        else:
            length = 0
        feat.append(length)
    return feat

In [11]:
def generateReadabilityDF(ids, allPostText):
    feat = []
    df = pd.DataFrame()

    for text in allPostText:
        fk=FleschKincaid(text)
        feat.append(fk.min_age)
    
    df['readability_min_age'] = feat
    
    return df

In [12]:
def generateFeatureDF(ids, POSs, allPostText, sentTokens):
    df = pd.DataFrame()
    df['id'] = ids
    
    # Sentiment score feature
    print('Generating SENTIMENT SCORE FEATURE')
    feat_Sentiment_HIGH = []
    feat_Sentiment = []
    sid = SentimentIntensityAnalyzer()
    for text in allPostText:
        feat_Sentiment_HIGH.append(int(math.fabs(sid.polarity_scores(text)['compound']) > 0.8))
        feat_Sentiment.append(math.fabs(sid.polarity_scores(text)['compound']))
        
    df['sentiment_score_high'] = feat_Sentiment_HIGH
    df['sentiment_score'] = feat_Sentiment
    # END of sentiment popularity score feature
    
    # IF POS Patterns Exist Feature
    print('Generating IF POS PATTERN FEATURES')
    patterns = {'EXIST_POS_NUMBER_NP_THAT': r"""CHUNK: {<CD><JJ.*>?<N.*><WDT><VB.*|VB>}""",
                'EXIST_POS_NUMBER_NP_VB': r"""CHUNK: {<CD><JJ.*>?<N.*><PRP.*><VB.*|VB>}"""}
    
    feat_EXIST_POS = {}
    for key in patterns:
        feat_EXIST_POS[key] = []
        
    for sentPOSs in POSs:
        for key in patterns:
            check = isPosPattern(patterns[key], sentPOSs, 'CHUNK')
            feat_EXIST_POS[key].append(int(check))
            
    for key in patterns:
        df[key] = feat_EXIST_POS[key]
    # End of IF POS Pattern Exist Features
    
    # COUNT POS Patterns
    print('Generating COUNT OF POS PATTERN FEATURES')
    patterns = {'POS_pattern_COUNT_NUM_SHORTTEN': r"""CHUNK: {<''><VBP|MD>}""",
                'POS_pattern_COUNT_DT': r"""CHUNK: {<DT>}""",
                'POS_pattern_COUNT_WRB': r"""CHUNK: {<WRB>}""", 
                'POS_pattern_COUNT_PRP_Dollar': r"""CHUNK: {<PRP$>}""",
                'POS_pattern_COUNT_MD': r"""CHUNK: {<MD>}""", 
                'POS_pattern_COUNT_WDT': r"""CHUNK: {<WDT>}""",
                'POS_pattern_COUNT_PRP': r"""CHUNK: {<PRP>}""", 
                'POS_pattern_COUNT_RB': r"""CHUNK: {<RB>}""", 
                'POS_pattern_COUNT_WRB': r"""CHUNK: {<WRB>}""",
                'POS_pattern_COUNT_WP': r"""CHUNK: {<WP>}"""}
    
    feat_COUNT_POS = {}
    for key in patterns:
        feat_COUNT_POS[key] = []
        
    for sentPOSs in POSs:
        for key in patterns:
            count = countPosPattern(patterns[key], sentPOSs, 'CHUNK')
            feat_COUNT_POS[key].append(count)
            
    for key in patterns:
        df[key] = feat_COUNT_POS[key]
    ## End COUNT POS Patterns Features
    
    # COUNT OF POS Pattern with Conditions
    print('Generating COUNT OF POS Pattern with Conditions')
    feat_POS_pattern4 = []
    for sentPOSs in POSs:
        found = findPosPatternIncludeWord(r"""CHUNK: {<DT><NN.*>}""", sentPOSs, 'CHUNK', 
                                          ['this','these','This','These'])
        count = len(found)
        feat_POS_pattern4.append(count)
    df['POS_pattern_COUNT_this-these_NN'] = feat_POS_pattern4
    # End of Count of POS pattern with Conditions
    
    # IF DICT Exist Feature
    print('Generating TOKEN DICTIONARY FEATURE')
    ngramsToCheck = ['@', 'http', '?', '#', '!', '! ?', '. . .', '* * *', '! !', '! ! !']
    ngramsNames = ['AT', 'WEB', 'QM', 'OC', "EX",
                 'EX-QM', 'TRIPLE-DOT', 'TRIPLE-AS', 'DOUBLE-EX', 'TRIPLE-EX']
    for i in range(len(ngramsToCheck)):
        ngramToCheck = ngramsToCheck[i]
        feat_exist_ = []
        for sentToken in sentTokens:
            if (isNgramExist(ngramToCheck, sentToken)):
                feat_exist_.append(1)
            else:
                feat_exist_.append(0)
        df['CONTAINS_' + ngramsNames[i]] = feat_exist_
     # END OF IF DICT Exist Feature
    
    print('Generating LENGTH/NUMERIC FEATURES')
    # number of tokens
    feat_NUM_TOKENS = []
    for sentPOSs in POSs:
        feat_NUM_TOKENS.append(len(sentPOSs))
    df['NUM_TOKENS'] = feat_NUM_TOKENS
    
    # average word length
    averageWordLengthFeat = np.array(countAverageWordLength(allPostText))
    df['AVG_WORD_LENGTH'] = averageWordLengthFeat
    
    # length of longest word
    lengthLongest = lengthLongestWord(allPostText)
    df['LEN_LONGEST_WORD'] = lengthLongest

    return df

#### Loading & Spliting Data

In [13]:
trainData, trainInstance, trainLabel = readData('/Users/lethai/Downloads/clickbait17-train-170331/instances.jsonl',
               '/Users/lethai/Downloads/clickbait17-train-170331/truth.jsonl')
validData, validInstance, validLabel = readData('/Users/lethai/Downloads/clickbait17-validation-170630/instances.jsonl',
                     '/Users/lethai/Downloads/clickbait17-validation-170630/truth.jsonl')
data_df = pd.concat([trainData, validData])
instance_df = pd.concat([trainInstance, validInstance])
label_df = pd.concat([trainLabel, validLabel])

In [14]:
# data = data_df['postText'].get_values()
data= data_df
label = label_df['truthClass'].get_values()
label[label == 'clickbait'] = 1
label[label == 'no-clickbait'] = 0

In [15]:
# Spliting the dataset
X_train = pd.DataFrame()
X_test = pd.DataFrame()
y_train = []
y_test = []

# Splitting 30% for testing
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=0)
for train_index, test_index in sss.split(data, label):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = data.iloc[train_index,:], data.iloc[test_index,:]
    y_train, y_test = label[train_index], label[test_index]

# Splitting 10% for validation, remaining is training
# sss = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=0)
# for train_index, valid_index in sss.split(X_train, y_train):
#     print("TRAIN:", train_index, "TEST:", valid_index)
#     X_train, X_valid = X_train.iloc[train_index,:], X_train.iloc[valid_index,:]
#     y_train, y_valid = label[train_index], label[valid_index]

('TRAIN:', array([16334,  8421, 20783, ...,  1168, 13390, 10493]), 'TEST:', array([ 7181,  8799, 18085, ..., 18412, 17895, 13068]))


In [16]:
X_train.shape, X_test.shape, data.shape

((15397, 14), (6600, 14), (21997, 14))

#### Feature Engineering

#### Training

In [17]:
ids, allPostText, labels, POSs, posTokens, negTokens, tokensDist, sentTokens, tokens = analyze(X_train)
print("Labels distribution:{}".format(np.bincount(labels)))

Labels distribution:[11579  3818]


In [18]:
tags = tagdict.keys()
tags.remove('(')
tags.remove(')')
tags.remove(':')
df_pos_unigram = generatePosUnigramFeatures(ids, POSs, tags, tags)
df_pos_unigram.to_csv('clickbait_pos_unigram_train',sep='\t', encoding='utf-8')
df_pos_unigram = pd.DataFrame.from_csv('clickbait_pos_unigram_train', sep='\t', encoding='utf-8')











In [None]:
df_pos_bigram_50 = generatePosNgramsFeatures(ids, POSs, 50, 2)
df_pos_bigram_50.to_csv('clickbait_pos_bigram_top_50_train',sep='\t', encoding='utf-8')
df_pos_bigram_50 = pd.DataFrame.from_csv('clickbait_pos_bigram_top_50_train', sep='\t', encoding='utf-8')

In [None]:
df_pos_trigram_50 = generatePosNgramsFeatures(ids, POSs, 50, 3)
df_pos_trigram_50.to_csv('clickbait_pos_trigram_top_50_train',sep='\t', encoding='utf-8')
df_pos_trigram_50 = pd.DataFrame.from_csv('clickbait_pos_trigram_top_50_train', sep='\t', encoding='utf-8')

In [None]:
df_unigram_50 = generateNgramFeatures(ids, allPostText, sentTokens, 50, 1)
df_unigram_50.to_csv('clickbait_unigram_top_50_train',sep='\t', encoding='utf-8')
df_unigram_50 = pd.DataFrame.from_csv('clickbait_unigram_top_50_train', sep='\t', encoding='utf-8')

In [None]:
df_bigram_50 = generateNgramFeatures(ids, allPostText, sentTokens, 50, 2)
df_bigram_50.to_csv('clickbait_bigram_top_50_train', sep='\t', encoding='utf-8')
df_bigram_50 = pd.DataFrame.from_csv('clickbait_bigram_top_50_train', sep='\t', encoding='utf-8')

In [None]:
df_trigram_50 = generateNgramFeatures(ids, allPostText, sentTokens, 50, 3)
df_trigram_50.to_csv('clickbait_trigram_top_50_train',sep='\t', encoding='utf-8')
df_trigram_50 = pd.DataFrame.from_csv('clickbait_trigram_top_50_train', sep='\t', encoding='utf-8')

In [19]:
X_df_features = generateFeatureDF(ids, POSs, allPostText, sentTokens)

Generating SENTIMENT SCORE FEATURE
Generating IF POS PATTERN FEATURES
Generating COUNT OF POS PATTERN FEATURES


Generating COUNT OF POS Pattern with Conditions
Generating TOKEN DICTIONARY FEATURE
Generating LENGTH/NUMERIC FEATURES


In [20]:
X_df = pd.concat([X_df_features, df_pos_unigram], axis = 1)

In [None]:
X = X_df.drop(['id'], axis = 1).get_values()
y = labels
X_df.shape, X.shape

((15397, 71), (15397, 69))

#### Generating Testing Features

In [None]:
ids_test, allPostText_test, labels_test, POSs_test, posTokens_test, negTokens_test, tokensDist_test, sentTokens_test, tokens_test = analyze(X_test)
print("Labels distribution:{}".format(np.bincount(labels_test)))
allPostText_test

In [None]:
tags = tagdict.keys()
tags.remove('(')
tags.remove(')')
tags.remove(':')
df_pos_unigram_test = generatePosUnigramFeatures(ids_test, POSs_test, tags, tags)
df_pos_unigram_test.to_csv('clickbait_pos_unigram_test',sep='\t', encoding='utf-8')
df_pos_unigram_test = pd.DataFrame.from_csv('clickbait_pos_unigram_test', sep='\t', encoding='utf-8')

In [None]:
X_test_df_features = generateFeatureDF(ids_test, POSs_test, allPostText_test, sentTokens_test)

In [None]:
# Add bigram_50 and trigram_50 features from training_set
X_test_df = pd.concat([X_test_df_features, df_pos_unigram_test], axis = 1)
# X_test_df = X_test_df_features

In [None]:
Xtest = X_test_df.drop(['id'], axis = 1).get_values()
ytest = labels_test
X_test_df.shape, Xtest.shape

#### Generating Validation Features

In [None]:
# To be done

#### Training & Validation

In [None]:
def evaluateModel(model, X, y):
    pred = model.predict_proba(X)
    print(classification_report(y, np.argmax(pred, axis = 1)))
    print('Accuracy:{}'.format(accuracy_score(np.argmax(pred, axis = 1), y)))
    print('AUC:{}'.format(roc_auc_score(y, pred[:,1])))

##### AdaBoost

In [None]:
clfAB = AdaBoostClassifier()
clfAB.fit(X,y)
evaluateModel(clfAB, X, y)

In [None]:
evaluateModel(clfAB, Xtest, ytest)

##### MultiPerceptrons

In [None]:
clfMP = MLPClassifier(alpha=0.1)
clfMP.fit(X,y)
evaluateModel(clfMP, X, y)

In [None]:
evaluateModel(clfMP, Xtest, ytest)

##### RandomForest 

In [None]:
clfRF = RandomForestClassifier(n_estimators=200, max_features = 10)
clfRF.fit(X, y)
evaluateModel(clfRF, X, y)

In [None]:
evaluateModel(clfRF, Xtest, ytest)

##### Random Forest Weight Balanced

In [None]:
clfRFweight = RandomForestClassifier(n_estimators=300, max_features = 20, class_weight='balanced')
clfRFweight.fit(X, y)
evaluateModel(clfRFweight, X, y)

In [None]:
evaluateModel(clfRFweight, Xtest, ytest)

##### Logistics Regression

In [None]:
clfLR = LogisticRegression(C=1., solver='lbfgs')
clfLR.fit(X, y)
evaluateModel(clfLR, X, y)

In [None]:
evaluateModel(clfLR, Xtest, ytest)

##### Decision Tree

In [None]:
clfDT = sklearn.tree.DecisionTreeClassifier()
clfDT.fit(X, y)
evaluateModel(clfDT, X, y)

In [None]:
evaluateModel(clfDT, Xtest, ytest)

##### Majority Voting

In [None]:
clf1 = RandomForestClassifier(n_estimators=300, max_features = 20, class_weight='balanced')
clf2 = LogisticRegression(C=1., solver='lbfgs')
clf3 = AdaBoostClassifier()
eclf1 = VotingClassifier(estimators=[('RF', clf1), ('LR', clf2), ('AB', clf3)], voting='hard')
eclf1.fit(X, y)
evaluateModel(eclf1, X, y)

#### Plot Learning Curve

In [None]:
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
                        n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):
    plt.figure()
    plt.title(title)
    if ylim is not None:
        plt.ylim(*ylim)
    plt.xlabel("Training examples")
    plt.ylabel("Score")
    train_sizes, train_scores, test_scores = learning_curve(
        estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std = np.std(train_scores, axis=1)
    test_scores_mean = np.mean(test_scores, axis=1)
    test_scores_std = np.std(test_scores, axis=1)
    plt.grid()

    plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
                     train_scores_mean + train_scores_std, alpha=0.1,
                     color="r")
    plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
                     test_scores_mean + test_scores_std, alpha=0.1, color="g")
    plt.plot(train_sizes, train_scores_mean, 'o-', color="r",
             label="Training score")
    plt.plot(train_sizes, test_scores_mean, 'o-', color="g",
             label="Cross-validation score")

    plt.legend(loc="best")
    return plt

In [None]:
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
clf = RandomForestClassifier(n_estimators=100, max_features=15, class_weight='balanced')
plot_learning_curve(clf, 'Learning Curve', X, y, (0.7, 1.01), cv=cv, n_jobs=4)

In [None]:
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
clf = MLPClassifier(alpha=0.1)
plot_learning_curve(clf, 'Learning Curve', X, y, (0.7, 1.01), cv=cv, n_jobs=4)