In [1]:
import os
import sys
import torch
from transformers import BertTokenizerFast
import lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))  
pl.seed_everything(42, workers=True)
torch.set_float32_matmul_precision(precision="high")

Seed set to 42


In [3]:
args = {
    "pretrain": '/home/zhulin/pretrain/bert_pretrain_uncased/',
    "model": "./SingleChannelPredictor.pt",
    "dataset": "/home/zhulin/datasets/cdatasets.test.5.csv"
}


In [11]:
### load model
from core.predictor import SingleChannelPredictor
tokenizer = BertTokenizerFast.from_pretrained(args["pretrain"], use_fast=True)
predictor = torch.jit.load(args["model"])

In [5]:
### load datasets
import numpy as np
import pandas as pd
import datatable as dt

data = dt.fread(args["dataset"], fill=True, max_nrows=1024).to_pandas()

In [9]:
@torch.no_grad()
def interface(tokenizer, predictor, data, batchsize):
    n = len(data)
    for l in range(0, n, batchsize):
        padded_sent_seq = tokenizer(data.iloc[l:l+batchsize]["channel"].to_list(), padding=True, truncation=True, max_length=2048, return_tensors="pt")
        pred = predictor(padded_sent_seq["input_ids"].cuda(), padded_sent_seq["attention_mask"].cuda())


In [12]:
from loguru import logger
from tqdm import tqdm

predictor.cuda().eval()
# 预热, GPU 平时可能为了节能而处于休眠状态, 因此需要预热
logger.info('[+] warm up ...\n')
with torch.no_grad():
    for _ in range(10):
        # _ = predictor(dummy_input)
        interface(tokenizer, predictor, data, 8)
torch.cuda.synchronize()

# 设置用于测量时间的 cuda Event, 这是PyTorch 官方推荐的接口,理论上应该最靠谱
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# 初始化一个时间容器
timings = np.zeros((100, 1))

logger.info('testing ...\n')
with torch.no_grad():
    for rep in tqdm(range(100)):
        starter.record()
        interface(tokenizer, predictor, data, 8)
        ender.record()
        torch.cuda.synchronize() # 等待GPU任务完成
        curr_time = starter.elapsed_time(ender) # 从 starter 到 ender 之间用时,单位为毫秒
        timings[rep] = curr_time

avg = timings.sum()/100
logger.info('\navg={}\n'.format(avg))

[32m2024-09-12 16:24:55.718[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1m[+] warm up ...
[0m
[32m2024-09-12 16:25:16.175[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mtesting ...
[0m
100%|██████████| 100/100 [03:16<00:00,  1.97s/it]
[32m2024-09-12 16:28:32.746[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1m
avg=1964.0056689453124
[0m


In [13]:
from torch.profiler import profile, record_function, ProfilerActivity

logger.info('[+] warm up ...\n')
with torch.no_grad():
    for _ in range(10):
        # _ = predictor(dummy_input)
        interface(tokenizer, predictor, data, 8)
torch.cuda.synchronize()

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        interface(tokenizer, predictor, data, 8)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

[32m2024-09-12 16:31:29.442[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1m[+] warm up ...
[0m
STAGE:2024-09-12 16:31:49 38310:38310 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-09-12 16:31:53 38310:38310 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-09-12 16:31:53 38310:38310 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        41.39%        1.388s       100.00%        3.353s        3.353s             1  
                                                forward         9.68%     324.656ms        54.26%        1.819s       1.776ms          1024  
                                           aten::linear         2.23%      74.924ms        14.30%     479.448ms      66.887us          7168  
                                            aten::addmm         7.36%     246.771ms         9.70%     325.220ms      45.371us          7168  
      