In [1]:
import csv
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = AutoModel.from_pretrained("bert-base-chinese")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [None]:
# 读取文件前5行内容
file_path = '../WeiboData/weibo_predict_data.txt'
contents = []
with open(file_path, 'r', encoding='utf-8') as f:
    for i, line in enumerate(f):
        parts = line.strip().split('\t')
        print(parts)
        uid, mid, time, content = parts[0], parts[1], parts[2], parts[3]
        contents.append(content)
        if len(contents) >= 5:
            break

['c01014739c046cd31d6f1b4fb71b440f', '0cd5ef13eb11ed0070f7625b14136ec9', '2015-08-19 22:44:55', 'Xah Emacs Tutorial http://t.cn/zWoY9IZ']
['fa5aed172c062c61e196eac61038a03b', '7cce78a4ad39a91ec1f595bcc7fb5eba', '2015-08-01 14:06:31', '卖水果老人因没住处夜宿酒店门口 被车碾死 http://t.cn/RL0Hw8J （分享自@凤凰新闻客户端）']
['77fc723c196a45203e70f4d359c96946', 'a3494d8cf475a92739a2ffd421640ddf', '2015-08-04 10:51:38', '不要学习没有用的理论？ 不是：要学习，但要知道这个理论，为什么没有用？真实有用的理论是什么。 像需求定律就是一个没有什么用的理论，没有用的原因是它建立在“意图需求”，建立在完全竞争、完全信息这种现实真空的条件下， 无法对“真实需求”作出解释，如中国的房价、少林寺的香的价格']
['e4097b07f34366399b623b94f174f60c', '6b89aea5aa7af093dde0894156c49dd3', '2015-08-16 14:59:19', '[幸运之星] -恭喜！您的新浪微博账号已被系统确认为“新浪五周年”活动二等奖幸运用户。请登陆:http://t.cn/RLrKzhO查收。[礼物]   @苏友朋xiaoyu @王玮Qten[吐]rZp8w']
['d43f7557c303b84070b13aa4eeeb21d3', '0bdeff19392e15737775abab46dc5437', '2015-08-04 22:30:46', '【Lennart Poettering宣布首届Systemd会议】受争议Linux初始化系统和服务管理器Systemd的创始人Lennart Poettering宣布了首届Systemd会议。systemd.conf将于11月5日到7日在 ... http://t.cn/RLjj4UX']


In [8]:
contents

['Xah Emacs Tutorial http://t.cn/zWoY9IZ',
 '卖水果老人因没住处夜宿酒店门口 被车碾死 http://t.cn/RL0Hw8J （分享自@凤凰新闻客户端）',
 '不要学习没有用的理论？ 不是：要学习，但要知道这个理论，为什么没有用？真实有用的理论是什么。 像需求定律就是一个没有什么用的理论，没有用的原因是它建立在“意图需求”，建立在完全竞争、完全信息这种现实真空的条件下， 无法对“真实需求”作出解释，如中国的房价、少林寺的香的价格',
 '[幸运之星] -恭喜！您的新浪微博账号已被系统确认为“新浪五周年”活动二等奖幸运用户。请登陆:http://t.cn/RLrKzhO查收。[礼物]   @苏友朋xiaoyu @王玮Qten[吐]rZp8w',
 '【Lennart Poettering宣布首届Systemd会议】受争议Linux初始化系统和服务管理器Systemd的创始人Lennart Poettering宣布了首届Systemd会议。systemd.conf将于11月5日到7日在 ... http://t.cn/RLjj4UX']

In [None]:
# 文件路径
input_file = '../WeiboData/weibo_predict_data.txt'
output_file = '../features/weibo_predict_bert_features.csv'
# 一些参数
batch_size = 32
max_length = 512


In [10]:
# 使用batch加速
uids, mids, contents = [], [], []
batch = []

def write_batch(writer, uids, mids, contents):
    inputs = tokenizer(contents, return_tensors="pt", padding=True,
                       truncation=True, max_length=max_length)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        cls_vectors = outputs.last_hidden_state[:, 0, :]  # shape: (B, 768)

    cls_vectors = cls_vectors.cpu().numpy()
    for uid, mid, vec in zip(uids, mids, cls_vectors):
        vec_str = ' '.join(map(str, vec))
        writer.writerow([uid, mid, vec_str])


In [16]:
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['uid', 'mid', 'content_feature'])

    with open(input_file, 'r', encoding='utf-8') as f:
        for line in tqdm(f, desc="Processing"):
            parts = line.strip().split('\t')
            uid, mid, content = parts[0], parts[1], parts[3]
            uids.append(uid)
            mids.append(mid)
            contents.append(content)

            if len(contents) >= batch_size:
                write_batch(writer, uids, mids, contents)
                uids, mids, contents = [], [], []

        # 处理剩余的
        if contents:
            write_batch(writer, uids, mids, contents)

Processing: 178297it [13:40, 217.25it/s]


- predict: 178,298
- train: 1,229,619