# 对于2018年CIC2rd的反馈数据进行文本分类

## install dependency and initial session and S3 bucket

In [364]:
import sagemaker
from sagemaker import get_execution_role
import json
import boto3
from random import shuffle

sess = sagemaker.Session()

role = get_execution_role()
print(role) # This is the role that SageMaker would use to leverage AWS resources (S3, CloudWatch) on your behalf

bucket = 'ray-ai-ml-bjs' #sess.default_bucket() # Replace with your own bucket name if needed
print(bucket)
prefix = 'classification/blazingtext/bmwticket' 

arn:aws-cn:iam::876820548815:role/Sagemaker-Bootcamp-SageMakerExecutionRole-Z3VF78G260T1
ray-ai-ml-bjs


In [365]:
!pip install -i https://pypi.tuna.tsinghua.edu.cn/simple jieba --upgrade
!pip install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle-tiny

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already up-to-date: jieba in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (0.42.1)
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


## Prepare classification category
- All return '\n' convert to `_` for example `LSC\n远程指令` convert to `LSC_远程指令`
- Convert alphabet to Uppercase

In [366]:
import jieba
import re
import pandas as pd
import os
import csv

In [71]:
index_to_label = {} #1001:APP显示问题
label_to_index = {} #APP显示问题:1001
with open("app-comments-classes.txt") as f:
    for i,label in enumerate(f.readlines()):
        ll = label.strip().split(',')
        index_to_label[ll[0]] = ll[1].upper()
        label_to_index[ll[1].upper()] = ll[0]
print(label_to_index)

{'APP显示问题': '1001', 'APP版本问题': '1002', 'APP登陆问题': '1003', 'APP车辆信息显示问题': '1004', 'BON': '1005', 'CARLIFE': '1006', 'CARPLAY': '1007', 'ETC': '1008', 'GPS': '1009', 'IBA用户手册': '1010', 'LSC': '1011', 'LSC_远程指令': '1012', 'PIN码': '1013', 'POI': '1014', 'QQ音乐': '1015', 'RSU': '1016', 'RTTI': '1017', 'SIM卡相关': '1019', 'WIFI': '1020', '个性化设置': '1021', '二手车预激活问题': '1022', '会员': '1023', '其他': '1024', '即时充电': '1025', '喜马拉雅': '1026', '地图自动更新服务': '1027', '天猫精灵': '1028', '实时路况功能': '1029', '意见建议': '1030', '数字钥匙': '1031', '智慧停车': '1032', '更新BMW服务': '1033', '满意': '1034', '节日祝福': '1035', '行程摘要': '1036', '行车摘要': '1037', '车机端登录': '1038', '远程3D视图': '1039', '远程指令': '1040', '远程指令_LSC': '1041', '远程服务': '1042', '远程软件升级': '1043', '违章代缴': '1044', '预开通激活': '1045', '预激活开通': '1046'}


## Prepare the input data
- parse the 'category detail' and 'comment'
- convert the 'category detail' as format `1001_!_APP显示问题_!_问题描述`
- Convert alphabet to Uppercase

In [361]:
excel_data_df = pd.read_excel('20200916_in-app-comments_EE-CN-12.xlsx', sheet_name='in-app comment', usecols=['category detail', 'comment'], index_col=None)
#print(excel_data_df[0:5])

raw_app_comments_file = 'raw-app-comments-0916.csv'
excel_data_df.to_csv(raw_app_comments_file, index=False)
sentiment_input = []
input_with_category = []
input_without_category = []
data_offset_counts = {}

csv_data_df = pd.read_csv(raw_app_comments_file, delimiter=',', index_col=None)
print(csv_data_df[0:10])

for index, data in csv_data_df.iterrows():
    category_str = data['category detail']
    comment_str = data['comment']
    if category_str=='' or pd.isnull(category_str):
        input_without_category.append(str(comment_str))
    else:
        if category_str=="远程指令\nLSC" or category_str=="LSC\n远程指令":
            category_str = '_'.join(category_str.split('\n'))
            #print(category_str, comment_str)
        category_str = category_str.upper()
        category_index = label_to_index[category_str]
        input_with_category.append(str(category_index) + '_!_' + str(category_str) + '_!_' + str(comment_str))
        data_offset_counts[category_str] = data_offset_counts.get(category_str, 0) + 1
    sentiment_input.append(comment_str)

print(input_with_category[0:3],'\n')
print(input_without_category[0:3],'\n')
print('Excel Sheet to CSV done - csv_data_df: ',len(csv_data_df), ' , input_with_category: ', len(input_with_category), ' , input_without_category: ', len(input_without_category), '\n')

print('data offset counts: ')
items = list(data_offset_counts.items())
items.sort(key=lambda x: x[1], reverse=True)
for i in range(30):
    word, count = items[i]
    print("{:<10}{:>7}".format(word, count))


print('\n','sentiment_input: ', len(sentiment_input), 'preview 10 items')
print(sentiment_input[0:10])

  category detail                                            comment
0             LSC                                      我的互联不更新了是什么原因
1             LSC                                           定位不准，乱定位
2            WIFI                                   开通互联驾驶后，手机怎么连接车辆
3           预激活开通                                     我怎么是开通不了互联驾驶功能
4            远程服务                      开通云端互联后车辆无法执行通风指令 且车辆里程等信息无显示
5             LSC                                          系统好几天不更新了
6             LSC                                            无法定位车辆！
7             LSC  你好，为什么在互联驾驶这个功能中定位不了我车了，显示说我在车上的GPS关闭了，但是我检查过了...
8           个性化设置                                        云端互联欢迎词修改不了
9          远程3D视图                                         远程3D视图看不了了
['1011_!_LSC_!_我的互联不更新了是什么原因', '1011_!_LSC_!_定位不准，乱定位', '1020_!_WIFI_!_开通互联驾驶后，手机怎么连接车辆'] 

['已经激活云端互联。但是不能连接车', '汽车公里数不和车同步', '突然之间导航没有任何声音了'] 

Excel Sheet to CSV done - csv_data_df:  4171  , input_with_category:  3502  , input_without_cat

In [5]:
session = boto3.Session(profile_name='global', region_name='us-east-1')
comprehend_client = session.client('comprehend')
sample_list = ['喜马拉雅FM可以像QQ音乐一样直接在车上应用么？还是只能在CARPLAY上使用', '你好，为什么在互联驾驶这个功能中定位不了我车了，显示说我在车上的GPS关闭了，但是我检查过了，也开着啊什么问题呢', '我喜欢使用互联驾驶这个功能中路线规划']
for s_input in sample_list:
    sentiment_response = comprehend_client.detect_sentiment(Text=s_input, LanguageCode='zh')
    sentiment_str = sentiment_response['Sentiment']
    sentiment_score = json.dumps(sentiment_response['SentimentScore'])
    print(s_input + ' - sentiment: ' + sentiment_str + ' - sentiment_score: ' + sentiment_score + '\n')


喜马拉雅FM可以像QQ音乐一样直接在车上应用么？还是只能在CARPLAY上使用 - sentiment: NEUTRAL - sentiment_score: {"Positive": 0.010348147712647915, "Negative": 0.0909975990653038, "Neutral": 0.8986495733261108, "Mixed": 4.764520781463943e-06}

你好，为什么在互联驾驶这个功能中定位不了我车了，显示说我在车上的GPS关闭了，但是我检查过了，也开着啊什么问题呢 - sentiment: NEGATIVE - sentiment_score: {"Positive": 0.0004973431350663304, "Negative": 0.9937689900398254, "Neutral": 0.005685559939593077, "Mixed": 4.810316386283375e-05}

我喜欢使用互联驾驶这个功能中路线规划 - sentiment: POSITIVE - sentiment_score: {"Positive": 0.9950544834136963, "Negative": 5.2080289606237784e-05, "Neutral": 0.004892547149211168, "Mixed": 8.196818157557573e-07}



In [6]:
# batch_detect_sentiment 
sentiment_result_file = 'sentiment_result.csv'

if os.path.isfile(sentiment_result_file):
    os.remove(sentiment_result_file)

with open(sentiment_result_file, 'a+') as sentiment_f:
    sentiment_f.write('comment' + '|' + 'sentiment' + '|' + 'sentiment_score' + '\n')

sentiment_chunks = [sentiment_input[x:x+25] for x in range(0, len(sentiment_input), 25)]
for chunk in sentiment_chunks:
    sentiment_response = comprehend_client.batch_detect_sentiment(TextList=chunk, LanguageCode='zh')
    #print(sentiment_response['ResultList'])
    with open(sentiment_result_file, 'a+') as sentiment_f:
        s_input_index = 0
        for s_input in chunk:
          sentiment_str = sentiment_response['ResultList'][s_input_index]['Sentiment']
          sentiment_score = json.dumps(sentiment_response['ResultList'][s_input_index]['SentimentScore'])
          sentiment_f.write(s_input + '|' + sentiment_str + '|' + sentiment_score + '\n')
          s_input_index +=1
sentiment_f.close

<function TextIOWrapper.close()>

In [7]:
sentiment_data_df = pd.read_csv(sentiment_result_file, delimiter='|', index_col=None)
print(sentiment_data_df.head())

print('Analysis sentiment done for total item: ', len(sentiment_data_df))

                         comment sentiment  \
0                  我的互联不更新了是什么原因  NEGATIVE   
1                       定位不准，乱定位  NEGATIVE   
2               开通互联驾驶后，手机怎么连接车辆   NEUTRAL   
3                 我怎么是开通不了互联驾驶功能  NEGATIVE   
4  开通云端互联后车辆无法执行通风指令 且车辆里程等信息无显示  NEGATIVE   

                                     sentiment_score  
0  {"Positive": 0.00021503579046111554, "Negative...  
1  {"Positive": 0.00015880828141234815, "Negative...  
2  {"Positive": 0.0008933115750551224, "Negative"...  
3  {"Positive": 0.0003960980975534767, "Negative"...  
4  {"Positive": 0.00016022840281948447, "Negative...  
Analysis sentiment done for total item:  4172


In [73]:
import logging,os,jieba
#!wget https://cdc-code.s3.cn-north-1.amazonaws.com.cn/chineseStopWords.txt
def get_stopwords(StopWordFileName):
    logging.basicConfig(format='%(asctime)s:%(levelname)s:%(message)s',level=logging.INFO)  
      #加载停用词表 
    stopword_set = set()
    with open(StopWordFileName,'r',encoding="utf-8") as stopwords:
        for stopword in stopwords: 
            stopword_set.add(stopword.strip("\n"))  
    return stopword_set
    
def clear_timestamp(mystr):
    patterns = [r"\w{3} \w{3} \d{2} \d{1,2}:\d{1,2}:\d{1,2} \d{4}\s*",    #sun aug 19 13:02:10 2018
        r"\w{3}, \d{2} \w{3} \d{4} \d{1,2}:\d{1,2}:\d{1,2} \w{2}\s*",     #Sun, 19 Aug 2018 13:02:08 ET
        r"\d{4}-\d{1,2}-\d{1,2} \d{1,2}:\d{1,2}:\s*",                       #2018-11-01 09:35:
        r"\d{4}/\d{1,2}/\d{1,2}\s*",                                    #2018/9/1
        r"\d{1,2}/\d{1,2}/\d{4}\s*",                                    #9/1/2018
        r"\d{4}.\d{1,2}.\d{0,2}\s*",                                    #2018.9.1
        r"\d{1,2}.\d{1,2}.\d{4}\s*",                                    #9.1.2018
        r"\d{4}-\d{1,2}-\d{1,2} \d{1,2}:\d{1,2}:\d{4}/\d{1,2}/\d{1,2}\s*",      #2018-11-01 11:18:2018/10/31
        r"\d{1,2}:\d{1,2}:\d{4}/\d{1,2}/\d{1,2}\s*",      #2018-11-01 11:18:2018/10/31 21:09:08
        r"\d{1,2}:\d{1,2}:\d{1,2}\s*(AM|PM|am|pm)\s*",        #4:00:58 PM
        r"\d{1,2}:\d{1,2}:\d{1,2}\s*",                                     #21:09:08
        r"\d{1,2}:\d{1,2}\s*",                                     #21:09:08
        r"(\d{4})年(\d{1,2})月(\d{1,2})日\s*",                 #2018年10月5日
        r"(\d{2,4})年\s*",                 #2018年
        r"(\d{4})年(\d{1,2})月\s*",                 #2018年10月
        r":\s*([\da-zA_Z]+\/)+([a-zA-Z0-9\.]+)"                     #URL
        ]

    s = mystr

    for p in patterns:
        s = re.sub(p,'', s)

    s = s.strip()
    return s

def clear_email_phone_colon(mystr):
    patterns = [r"1[0-9]{10}", #mobile
        r"(\(0\d{2}\) \d{8})|(\(0\d{3}\) \d{7})|(\(0\d{3}\)-\d{8}$)|(\(0\d{2}\)\d{8})|(\(0\d{3}\)\d{7})|(\(0\d{3}\)\d{8}$)",                       #phone
        r"(0\d{2}-\d{8})|(0\d{3}-\d{7})|(0\d{3}-\d{8}$)|(\d{8})",                       #phone
        r"\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*"    #email
        ]

    s = mystr

    for p in patterns:
        s = re.sub(p,'', s)

    s = s.strip()
    return s

# Debug code
file  = 'stop-words-test.txt'
with open(file) as f:
    lines = f.readlines()
    
labels = []
for line in lines:
    s = clear_email_phone_colon(line)
    print(s)

1005_!_RS_!_客户罗先生（车主黄先生的家人）致电，表示其i豪华型车辆，之前已经接到互联驾驶通知协议手机号更改成功的通知，但目前使用新协议手机号登陆云端互联APP后，首页不显示远程控制选项了（更改协议手机号之前可以正常使用），针对此问题，烦请互联驾驶部门跟进处理。 签订互联驾驶协议时登记的车主姓名：黄先生 协议手机号码：（原协议手机号） VIN：LBV5S3104HSN87123 登陆云端互联密码：990308ljq 手机型号&版本：IOS 最新 云端互联APP版本：最新 联系人：罗先生 联系电话：
1002_!_Carplay_!_ya_gao 2018-11-01 09:35: 客户杨先生通过在线客户平台反馈，表示在2018/9/1在淮安润宝行店内购买的 530Li 尊享型 豪华套装车辆，客户表示互联驾驶已经开通但是还是没有Apple Carplay的选项，已建议客户发送邮件，针对客户问题烦请互联驾驶人员跟进处理。 联系人/协议手机号码： 车主/联系人：杨先生 车架号：LBVKY5109JSP87501
1002_!_Carplay_!_ya_gao 2018-11-01 10:36:2018/11/1 7:53:14客户李先生（车主刘力先生的朋友）致电，表示其（2018.09.29）在（唐山中宝）购买了（525LI M），（10.09）登陆BMW云端互联APP，遇到（CARPLAY无显示、无法使用）问题，已建议客户发送截图至互联驾驶邮箱，针对此问题烦请跟进处理。未转接。协议手机号码：+86


### 使用分词

In [74]:
!pip install --upgrade pip
!pip install datetime

Requirement already up-to-date: pip in /home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages (20.2.3)


In [81]:
import jieba.posseg as pseg

def with_stop_words_paddle_pseg(mystr, word_counts):
    #get stopwords
    stopwords = get_stopwords('chinesestopwords_test.txt')
    #启用停用词过滤
    no_phone = clear_email_phone_colon(mystr)
    no_timestamp = clear_timestamp(no_phone)
    fenci = re.sub(r"[\s+\.\!\/_,$%^*()?;；:-【】+\"\']+|[+——一！，;:：。？、~@#￥%……&*（）]+", "", no_timestamp)
    words = pseg.cut(fenci,use_paddle=True) #paddle模式+posseg
    filled_words = set()
    for word, flag in words:
        #print('raw %s %s' % (word, flag))
        # 人名, 地名，机构，方位名词, 量词, 代词, 时间, 副词
        if (flag == 'nr' or flag == 'ns' or flag == 'PER' 
            or flag == 'LOC' or flag == 'ORG' or flag == 'f' or flag == 'r' or flag == 'q'
            or flag == 't' or flag == 'TIME' or flag == 'd'):
            #print('%s %s' % (word, flag))
            continue
        else:
            if word not in stopwords:          #不在停用词表中
                if len(word) == 1:
                    continue
                else:
                    word_counts[word] = word_counts.get(word, 0) + 1
                filled_words.add(word)
    return filled_words


def without_stop_words_paddle_pseg(mystr, word_counts):
    #get stopwords
    stopwords = get_stopwords('chinesestopwords_test.txt')
    #启用停用词过滤
    no_phone = clear_email_phone_colon(mystr)
    no_timestamp = clear_timestamp(no_phone)
    fenci = re.sub(r"[\s+\.\!\/_,$%^*()?;；:-【】+\"\']+|[+——一！，;:：。？、~@#￥%……&*（）]+", "", no_timestamp)
    words = pseg.cut(fenci,use_paddle=True) #paddle模式+posseg
    filled_words = set()
    for word, flag in words:
        #print('raw %s %s' % (word, flag))
        # 人名, 地名，机构，方位名词, 量词, 代词, 时间, 副词
        if (flag == 'nr' or flag == 'ns' or flag == 'PER' 
            or flag == 'LOC' or flag == 'ORG' or flag == 'f' or flag == 'r' or flag == 'q'
            or flag == 't' or flag == 'TIME' or flag == 'd'):
            #print('%s %s' % (word, flag))
            continue
        else:
            if len(word) == 1:
                continue
            else:
                word_counts[word] = word_counts.get(word, 0) + 1
        filled_words.add(word)
    return filled_words

def without_stop_words_cut(mystr, word_counts): 
    #get stopwords
    stopwords = get_stopwords('chinesestopwords_test.txt')
    #启用停用词过滤
    no_phone = clear_email_phone_colon(mystr)
    no_timestamp = clear_timestamp(no_phone)
    fenci = re.sub(r"[\s+\.\!\/_,$%^*()?;；:-【】+\"\']+|[+——一！，;:：。？、~@#￥%……&*（）]+", "", no_timestamp)
    words = jieba.cut(fenci,cut_all=False)
    filled_words = set()
    for word in words:
        #print('raw %s %s' % (word, flag))
        if len(word) == 1:
            continue
        else:
            word_counts[word] = word_counts.get(word, 0) + 1
        filled_words.add(word)
    return filled_words


def with_stop_words_cut(mystr, word_counts): 
    #get stopwords
    stopwords = get_stopwords('chinesestopwords_test.txt')
    #启用停用词过滤
    no_phone = clear_email_phone_colon(mystr)
    no_timestamp = clear_timestamp(no_phone)
    fenci = re.sub(r"[\s+\.\!\/_,$%^*()?;；:-【】+\"\']+|[+——一！，;:：。？、~@#￥%……&*（）]+", "", no_timestamp)
    words = jieba.cut(fenci,cut_all=False)
    filled_words = set()
    for word in words:
        #print('raw %s %s' % (word, flag))
        # 人名, 地名，机构，方位名词, 量词, 代词, 时间, 副词
        if word not in stopwords:          #不在停用词表中
            if len(word) == 1:
                continue
            else:
                word_counts[word] = word_counts.get(word, 0) + 1
            filled_words.add(word)
    return filled_words


def without_stop_words_paddle(mystr, word_counts): 
    #get stopwords
    stopwords = get_stopwords('chinesestopwords_test.txt')
    #启用停用词过滤
    no_phone = clear_email_phone_colon(mystr)
    no_timestamp = clear_timestamp(no_phone)
    fenci = re.sub(r"[\s+\.\!\/_,$%^*()?;；:-【】+\"\']+|[+——一！，;:：。？、~@#￥%……&*（）]+", "", no_timestamp)
    words = jieba.cut(fenci,use_paddle=True) #paddle模式
    filled_words = set()
    for word in words:
        #print('raw %s %s' % (word, flag))
        if len(word) == 1:
            continue
        else:
            word_counts[word] = word_counts.get(word, 0) + 1
        filled_words.add(word)
    return filled_words


def with_stop_words_cut_paddle(mystr, word_counts):   
    #get stopwords
    stopwords = get_stopwords('chinesestopwords_test.txt')
    #启用停用词过滤
    no_phone = clear_email_phone_colon(mystr)
    no_timestamp = clear_timestamp(no_phone)
    fenci = re.sub(r"[\s+\.\!\/_,$%^*()?;；:-【】+\"\']+|[+——一！，;:：。？、~@#￥%……&*（）]+", "", no_timestamp)
    words = jieba.cut(fenci,cut_all=False,use_paddle=True) #paddle模式
    filled_words = set()
    for word in words:
        #print('raw %s %s' % (word, flag))
        # 人名, 地名，机构，方位名词, 量词, 代词, 时间, 副词
        if word not in stopwords:          #不在停用词表中
            if len(word) == 1:
                continue
            else:
                word_counts[word] = word_counts.get(word, 0) + 1
            filled_words.add(word)
    return filled_words

In [337]:
import datetime

#单机并行分词
#jieba.enable_parallel(8)
#paddle模式，精确匹配需要关闭parallel
jieba.disable_parallel()
#启动paddle模式
jieba.enable_paddle()

begin = datetime.datetime.now()
    
labels = []
counts = {}                 #计数{word，frequency}
for line in input_with_category:
    label = []
    line = line.split('_!_')
    if line[0].strip('"').isdigit() and index_to_label[line[0].strip('"')]:
        label_code = index_to_label[line[0].strip('"')]
        label.append('__label__' + label_code)
        
        words_after_jieba = with_stop_words_cut_paddle(line[2], counts)
        
        label.extend(words_after_jieba)
        #print(label)
        labels.append(label)

shuffle(labels)

end = datetime.datetime.now()
print('data processing time %d' %(end - begin).seconds)
print(labels[0:5])

items = list(counts.items())
items.sort(key=lambda x: x[1], reverse=True)
for i in range(30):
    word, count = items[i]
    print("{:<10}{:>7}".format(word, count))

Paddle enabled successfully......
2020-10-13 05:22:14,144:DEBUG:Paddle enabled successfully......


data processing time 9
[['__label__其他', '手机', '自动', '云端', '默认', '启动', '打开', '关闭', '互联'], ['__label__喜马拉雅', '只能', '喜马拉雅', '音乐样', '车上'], ['__label__LSC', '不到', '查询', '状态', '更新', '车辆', '互联'], ['__label__LSC', '云端', '更新', '位置', '互联'], ['__label__远程服务', '解锁']]
更新           1350
车辆           1185
定位            825
互联            630
信息            626
显示            597
远程            501
云端            401
功能            396
位置            274
手机            250
状态            237
驾驶            160
数据            149
成功            142
解锁            137
软件            133
车门            131
打开            131
系统            122
连接            121
车子            116
发送            105
实时            104
刷新             98
通风             95
解决             87
准确             86
启动             86
里程             85


In [338]:
prefix = 'classification/blazingtext/bmwticket'

In [339]:
t_train_data = labels[0:int(len(labels)*0.8)]
t_validation_data = labels[int(len(labels)*0.8):]

In [340]:
#t_train_data[0:13]

In [341]:
import csv
t_train_file = 'tt.train'
t_validation_file = 'tt.validation'

with open(t_train_file, 'w') as csvoutfile:
    csv_writer = csv.writer(csvoutfile, delimiter=' ', lineterminator='\n')
    csv_writer.writerows(t_train_data)
    
with open(t_validation_file, 'w') as csvoutfile:
    csv_writer = csv.writer(csvoutfile, delimiter=' ', lineterminator='\n')
    csv_writer.writerows(t_validation_data)

In [342]:
%%time

t_train_channel = prefix + '/train'
t_validation_channel = prefix + '/validation'

sess.upload_data(path='tt.train', bucket=bucket, key_prefix=t_train_channel)
sess.upload_data(path='tt.validation', bucket=bucket, key_prefix=t_validation_channel)

s3_train_data = 's3://{}/{}'.format(bucket, t_train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, t_validation_channel)

s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

CPU times: user 46.5 ms, sys: 0 ns, total: 46.5 ms
Wall time: 187 ms


In [343]:
s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

In [344]:
region_name = boto3.Session().region_name
container = sagemaker.amazon.amazon_estimator.get_image_uri(region_name, "blazingtext", "latest")
print('Using SageMaker BlazingText container: {} ({})'.format(container, region_name))



Using SageMaker BlazingText container: 390948362332.dkr.ecr.cn-north-1.amazonaws.com.cn/blazingtext:latest (cn-north-1)


In [345]:
t_bt_model = sagemaker.estimator.Estimator(container,
                                         role, 
                                         train_instance_count=1, 
                                         train_instance_type='ml.c4.4xlarge',
                                         train_volume_size = 120,
                                         train_max_run = 360000,
                                         input_mode= 'File',
                                         output_path=s3_output_location,
                                         sagemaker_session=sess)
t_bt_model.set_hyperparameters(
    mode="supervised",
    epochs=20,
    min_count=2,
    learning_rate=0.1,
    vector_dim=10,
    early_stopping=True,
    patience=4,
    min_epochs=5,
    word_ngrams=1,
    embedding=200)



In [346]:
t_train_data = sagemaker.inputs.s3_input(s3_train_data, distribution='FullyReplicated', 
                        content_type='text/plain', s3_data_type='S3Prefix')
t_validation_data = sagemaker.inputs.s3_input(s3_validation_data, distribution='FullyReplicated', 
                             content_type='text/plain', s3_data_type='S3Prefix')
t_data_channels = {'train': t_train_data, 'validation': t_validation_data}



In [347]:
t_bt_model.fit(inputs=t_data_channels, logs=True)

2020-10-13 05:58:20,165:INFO:Creating training-job with name: blazingtext-2020-10-13-05-58-20-165


2020-10-13 05:58:20 Starting - Starting the training job...
2020-10-13 05:58:23 Starting - Launching requested ML instances......
2020-10-13 05:59:24 Starting - Preparing the instances for training......
2020-10-13 06:00:48 Downloading - Downloading input data
2020-10-13 06:00:48 Training - Downloading the training image..[34mArguments: train[0m
[34m[10/13/2020 06:01:02 INFO 140380736489280] nvidia-smi took: 0.0252020359039 secs to identify 0 gpus[0m
[34m[10/13/2020 06:01:02 INFO 140380736489280] Running single machine CPU BlazingText training using supervised mode.[0m
[34m[10/13/2020 06:01:02 INFO 140380736489280] Processing /opt/ml/input/data/train/tt.train . File size: 0 MB[0m
[34m[10/13/2020 06:01:02 INFO 140380736489280] Processing /opt/ml/input/data/validation/tt.validation . File size: 0 MB[0m
[34mRead 0M words[0m
[34mNumber of words:  773[0m
[34mLoading validation data from /opt/ml/input/data/validation/tt.validation[0m
[34mLoaded validation data.[0m
[34m----

In [348]:
t_text_classifier = t_bt_model.deploy(initial_instance_count = 1,instance_type = 'ml.c5.large')

2020-10-13 06:03:27,193:INFO:Creating model with name: blazingtext-2020-10-13-05-58-20-165
2020-10-13 06:03:27,733:INFO:Creating endpoint with name blazingtext-2020-10-13-05-58-20-165


-------------!

In [360]:
predict_category_result_file = 'predict_category_result.csv'
low_confidence = [] # confidence<confidence_thredhold
confidence_thredhold = 0.8

if os.path.isfile(predict_category_result_file):
    os.remove(predict_category_result_file)

with open(predict_category_result_file, 'a+') as predict_f:
    predict_f.write('category' + '|' + 'comment' + '|' + 'confidence' + '\n')

# #单机并行分词
# jieba.enable_parallel(8)
# #paddle模式，精确匹配需要关闭parallel
# #jieba.disable_parallel()
# #启动paddle模式
# jieba.enable_paddle()

for sentences in input_without_category:
    counts = {} 
    # using the same nltk tokenizer that we used during data preparation for training
    tokenized_sentences = [' '.join(with_stop_words_cut(sentences, counts))]
    #print('sentences: ', sentences, ' , jieba: ', tokenized_sentences)

    #payload = {"instances" : tokenized_sentences, "configuration": {"k": 2}}
    payload = {"instances" : tokenized_sentences}

    t_response = t_text_classifier.predict(json.dumps(payload))

    t_predictions = json.loads(t_response)
    #print(json.dumps(t_predictions, indent=2))
    predict_prob = t_predictions[0]['prob'][0]
    predict_category = t_predictions[0]['label'][0]
    if predict_prob < confidence_thredhold:
        low_confidence.append(sentences + '|' + predict_category + '|' + str(predict_prob) + '\n')
    
    with open(predict_category_result_file, 'a+') as predict_f:
        predict_f.write(sentences + '|' + predict_category + '|' + str(predict_prob) + '\n')

predict_f.close
print('total predict: ', len(input_without_category), ' , > ', confidence_thredhold, ' confidence: ', len(input_without_category) - len(low_confidence))

total predict:  669  ,more than  0.8  confidence:  372


In [362]:
t_text_classifier.delete_endpoint()

2020-10-13 07:28:39,347:INFO:Deleting endpoint configuration with name: blazingtext-2020-10-13-05-58-20-165
2020-10-13 07:28:39,424:INFO:Deleting endpoint with name: blazingtext-2020-10-13-05-58-20-165
