In [1]:
import pandas as pd
import numpy as np
import random
import re
import json
import warnings
warnings.filterwarnings("ignore")
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from numpy import dot
from numpy.linalg import norm

In [2]:
def get_cos_similarity(sms, templates):
    '''计算一条新sms 与 每个template 相似度'''
    def cos_sim(a, b):
        return dot(a, b) / (norm(a) * norm(b))
    return [cos_sim(i, sms) for i in templates]


def tfIdfVector(corpus):
    '''corpus is a list of sentences:
    ['This is an example', 'hello world', ...]
    '''
    vectorizer = CountVectorizer()
    transformer = TfidfTransformer()
    x = vectorizer.fit_transform(corpus)
    tfidf = transformer.fit_transform(x)
    return tfidf.toarray()


class SmsRuleClf:
    '''模块功能: 给我一个(组)sms， 能够对其正确分类'''
    def __init__(self, labeled_templates_df):
        self.labeled_templates_df = labeled_templates_df
        self.corpus, self.labels = self._get_template_corpus_labels()
        
    def _get_template_corpus_labels(self):
        corpus, labels = self.labeled_templates_df.sms.tolist(), self.labeled_templates_df.label.tolist()
        return corpus, labels

    def predict(self, sms):
        '''
        Input: ['This is an example', 'hello world', ...]
        Output: [cls1, cls2, ...]
        '''
        if isinstance(sms, list) and isinstance(sms[0], str):
            template_corpus, template_labels = self.corpus, self.labels
            num_sms = len(sms)
            for single in sms:
                template_corpus.append(single.lower())
            all_tfidf = tfIdfVector(template_corpus)
            template_tfidf = all_tfidf[:-num_sms]
            instances_tfidf = all_tfidf[-num_sms:]
            
            result = []
            for idx, single_sms in enumerate(instances_tfidf):  
                cos_score = get_cos_similarity(single_sms, template_tfidf)
                max_score = np.max(cos_score)
                label = template_labels[np.argmax(cos_score)]
                result.append([sms[idx], label, max_score])
            return result
        else:
            raise Exception('''sms type not allowed: should be with type: ['This is an example', 'hello world', ...]''')

In [3]:
gsm_templates_df = pd.read_csv('gsm_templates_df.csv')
example = gsm_templates_df.sample(n=5).sms.tolist()
clf = SmsRuleClf(gsm_templates_df)
clf.predict(example)