# 对中文新闻题目进行分类

本示例使用头条客户端抓取的新闻题目分类，演示如何用Amazon Sagemaker内置算法BlazingText对新闻标题进行分类。

原数据集下载地址：https://github.com/skdjfla/toutiao-text-classfication-dataset

In [1]:
import sagemaker
from sagemaker import get_execution_role
import json
import boto3
from random import shuffle

sess = sagemaker.Session()

role = get_execution_role()
print(role) # This is the role that SageMaker would use to leverage AWS resources (S3, CloudWatch) on your behalf

bucket = 'ray-ai-ml' # Replace with your own bucket name if needed
print(bucket)
prefix = 'classification/blazingtext/supervised/toutiao' #Replace with the prefix under which you want to store the data if needed

arn:aws-cn:iam::876820548815:role/service-role/AmazonSageMaker-ExecutionRole-20200520T151303
ray-ai-ml


In [2]:
!pip install -i https://pypi.tuna.tsinghua.edu.cn/simple jieba

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
[33mYou are using pip version 10.0.1, however version 20.2b1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [3]:
import jieba
import re
import os

In [4]:
index_to_label = {} 
with open("classes.txt") as f:
    for i,label in enumerate(f.readlines()):
        ll = label.strip().split(',')
        index_to_label[ll[0]] = ll[1]
        print(ll)
print(index_to_label)

['100', '民生']
['101', '文化']
['102', '娱乐']
['103', '体育']
['104', '财经']
['106', '房产']
['107', '汽车']
['108', '教育']
['109', '科技']
['110', '军事']
['112', '旅游']
['113', '国际']
['114', '证券']
['115', '农业']
['116', '电竞']
{'100': '民生', '101': '文化', '102': '娱乐', '103': '体育', '104': '财经', '106': '房产', '107': '汽车', '108': '教育', '109': '科技', '110': '军事', '112': '旅游', '113': '国际', '114': '证券', '115': '农业', '116': '电竞'}


### 使用分词

In [5]:
# define stopwords
def get_stopwords():  
    #加载停用词表  
    stopword_set = set()  
    with open("zhstopwords.txt",'r',encoding="utf-8") as stopwords:  
        for stopword in stopwords:  
            stopword_set.add(stopword.strip("\n"))  
    return stopword_set

# Parse chinese with stopwords
def parse_zh_words(read_file_path):
    
    #get stopwords
    stopword_set = get_stopwords()
#     i = 0
#     for x in stopword_set:
#         if i > 10:
#             break
#         print(x)
#         i += 1
    
    with open(read_file_path) as f:
        lines = f.readlines()
    
    labels = []
    for line in lines:
        label = []
        line = line.split('_!_')
        label_code = index_to_label[line[1]]
        label.append('__label__' + label_code)
        line[3] = re.sub(r"[\s+\.\!\/_,$%^*()?;；:-【】+\"\']+|[+——一！，;:：。？、~@#￥%……&*（）]+", "", line[3])
        words = jieba.cut(line[3],cut_all=False)
    
        filled_words = set()
        for word in words:
            if word not in stopword_set:
                filled_words.add(word)

        label.extend(filled_words)
        #print(label)
        labels.append(label)
    
    shuffle(labels)
    return labels
    
file  = 'toutiao_cat_data.txt'#'toutiao_cat_data.txt'
labels = parse_zh_words(file)
    
shuffle(labels)
print(labels[0:5])

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.871 seconds.
Prefix dict has been built successfully.


[['__label__军事', '929', '波音', '俄', '大客机', '中', '787', '年', '对标', '宽体', '首飞', '2025'], ['__label__汽车', '谍照', '路试', '新款', '道奇'], ['__label__教育', '语文', '挂', '年', '142', '力荐', '保管', '董卿', '成语', '墙上', '打印', '背熟', '1500', '孩子'], ['__label__科技', '跨境', '运动', '下波', '造富', '商会', '电'], ['__label__体育', '常规赛', '火箭', '难', '惨不忍睹', '复苏', '戈登', '对比', '数据', '季后赛']]


In [7]:
prefix = 'classification/blazingtext/supervised/toutiao'

In [8]:
t_train_data = labels[0:int(len(labels)*0.8)]
t_validation_data = labels[int(len(labels)*0.8):]

In [9]:
t_train_data[0:13]

[['__label__国际', '时', '不好', '意义', '驱逐', '对方', '外交官', '两国关系'],
 ['__label__财经', '生物科技', '股有', '股中'],
 ['__label__财经', '茅台镇', '喝', '茅台酒', '老百姓'],
 ['__label__房产', '买房子', '去', '莱山区', '更好', '开发区', '烟台'],
 ['__label__电竞', '遇挫', '模仿', '抄袭', '游戏', '频频', '自主', '腾讯', '开发'],
 ['__label__农业',
  '有人',
  '农村',
  '封杀',
  '惨',
  '成线',
  '走出',
  '明星',
  '霸屏',
  '大',
  '却',
  '小花',
  '多年'],
 ['__label__汽车', '利用', '农村', '喇叭', '活动', '交警', '大', '教育', '交通安全', '临漳'],
 ['__label__汽车',
  '几招',
  '车身',
  '抖动',
  '汽车',
  '决不能',
  '忽视',
  '放出',
  '危险',
  '教',
  '找到',
  '信号',
  '这是',
  '原因'],
 ['__label__体育', '地上', '球员', '在场', '足球', '唾沫', '都', '高强度', '运动', '吐', '运动员'],
 ['__label__文化', '考述', '毛氏', '刻本', '汲古阁', '特色', '价值'],
 ['__label__电竞', '游', '总', '换部', '玩手', '像样', '不爽', '手机'],
 ['__label__军事', '时', '美国', '投', '没用', '直', '原子弹', '越南战争', '胜利'],
 ['__label__房产', '不涨', '房价', '涨工资', '涨']]

In [10]:
import csv
t_train_file = 'tt.train'
t_validation_file = 'tt.validation'

with open(t_train_file, 'w') as csvoutfile:
    csv_writer = csv.writer(csvoutfile, delimiter=' ', lineterminator='\n')
    csv_writer.writerows(t_train_data)
    
with open(t_validation_file, 'w') as csvoutfile:
    csv_writer = csv.writer(csvoutfile, delimiter=' ', lineterminator='\n')
    csv_writer.writerows(t_validation_data)

In [11]:
%%time

t_train_channel = prefix + '/train'
t_validation_channel = prefix + '/validation'

sess.upload_data(path='tt.train', bucket=bucket, key_prefix=t_train_channel)
sess.upload_data(path='tt.validation', bucket=bucket, key_prefix=t_validation_channel)

s3_train_data = 's3://{}/{}'.format(bucket, t_train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, t_validation_channel)

s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

CPU times: user 294 ms, sys: 67.2 ms, total: 361 ms
Wall time: 681 ms


In [12]:
s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

In [13]:
region_name = boto3.Session().region_name
container = sagemaker.amazon.amazon_estimator.get_image_uri(region_name, "blazingtext", "latest")
print('Using SageMaker BlazingText container: {} ({})'.format(container, region_name))

Using SageMaker BlazingText container: 387376663083.dkr.ecr.cn-northwest-1.amazonaws.com.cn/blazingtext:latest (cn-northwest-1)


In [14]:
t_bt_model = sagemaker.estimator.Estimator(container,
                                         role, 
                                         train_instance_count=1, 
                                         train_instance_type='ml.c4.4xlarge',
                                         train_volume_size = 30,
                                         train_max_run = 360000,
                                         input_mode= 'File',
                                         output_path=s3_output_location,
                                         sagemaker_session=sess)
t_bt_model.set_hyperparameters(mode="supervised",
                            epochs=10,
                            min_count=2,
                            learning_rate=0.05,
                            vector_dim=10,
                            early_stopping=True,
                            patience=4,
                            min_epochs=5,
                            word_ngrams=2)

In [15]:
t_train_data = sagemaker.inputs.s3_input(s3_train_data, distribution='FullyReplicated', 
                        content_type='text/plain', s3_data_type='S3Prefix')
t_validation_data = sagemaker.inputs.s3_input(s3_validation_data, distribution='FullyReplicated', 
                             content_type='text/plain', s3_data_type='S3Prefix')
t_data_channels = {'train': t_train_data, 'validation': t_validation_data}

In [16]:
t_bt_model.fit(inputs=t_data_channels, logs=True)

2020-06-19 07:30:34 Starting - Starting the training job...
2020-06-19 07:30:38 Starting - Launching requested ML instances......
2020-06-19 07:31:43 Starting - Preparing the instances for training......
2020-06-19 07:32:56 Downloading - Downloading input data...
2020-06-19 07:33:31 Training - Training image download completed. Training in progress..[34mArguments: train[0m
[34m[06/19/2020 07:33:31 INFO 140138826430272] nvidia-smi took: 0.0251858234406 secs to identify 0 gpus[0m
[34m[06/19/2020 07:33:31 INFO 140138826430272] Running single machine CPU BlazingText training using supervised mode.[0m
[34m[06/19/2020 07:33:31 INFO 140138826430272] Processing /opt/ml/input/data/train/tt.train . File size: 20 MB[0m
[34m[06/19/2020 07:33:31 INFO 140138826430272] Processing /opt/ml/input/data/validation/tt.validation . File size: 5 MB[0m
[34mRead 2M words[0m
[34mNumber of words:  79788[0m
[34mLoading validation data from /opt/ml/input/data/validation/tt.validation[0m
[34mLoaded

In [17]:
t_text_classifier = t_bt_model.deploy(initial_instance_count = 1,instance_type = 'ml.m4.2xlarge')

-------------!

In [24]:
sentences = "亚马逊云计算Q1营收过百亿"#"宝马推出新车型/亚马逊云计算Q1营收过百亿/美国航空母舰开往伊朗波斯湾/北京迎来黄金周小高峰/C罗纳尔多还是梅西/邓超加入春晚/新冠肺炎全球蔓延急需疫苗/谁是最可爱的人"

# using the same nltk tokenizer that we used during data preparation for training
tokenized_sentences = [' '.join(jieba.cut(sentences,cut_all=False))]

payload = {"instances" : tokenized_sentences}

# payload = {"instances" : tokenized_sentences,
#           "configuration": {"k": 2}}

t_response = t_text_classifier.predict(json.dumps(payload))

t_predictions = json.loads(t_response)
print(json.dumps(t_predictions, indent=2))
t_predictions[0]['label']

[
  {
    "prob": [
      0.9997926354408264
    ],
    "label": [
      "__label__\u79d1\u6280"
    ]
  }
]


['__label__科技']

In [25]:
# Tears down the SageMaker endpoint and endpoint configuration
t_text_classifier.delete_endpoint()