In [1]:
! export http_proxy=http://10.12.44.139:7890
! export https_proxy=http://10.12.44.139:7890


In [2]:
import torch
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader

# 示例数据
texts = [
    "I love programming in Python",
    "Python is a great language",
    "Programming is fun"
]
labels = [1, 1, 0]  # 假设 1 表示正面评价，0 表示负面评价

# 加载预训练的分词器
tokenizer = AutoTokenizer.from_pretrained(
    "../.pretrained_models/bert-base-cased",
    clean_up_tokenization_spaces=True
)


# 定义数据集类
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=10):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # 使用分词器进行分词和编码
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        return input_ids, attention_mask, torch.tensor(label, dtype=torch.float)


# 创建数据集
dataset = TextDataset(texts, labels, tokenizer)

# 创建数据加载器
batch_size = 2
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 打印数据集中的一个批次
for batch in data_loader:
    input_ids, attention_mask, targets = batch
    print("Input IDs:", input_ids)
    print("Attention Mask:", attention_mask)
    print("Targets:", targets)
    break

# 打印分词器的词汇表大小
print("Vocabulary size:", len(tokenizer))


Input IDs: tensor([[  101, 21076,  1110,  4106,   102,     0,     0,     0,     0,     0],
        [  101, 23334,  1110,   170,  1632,  1846,   102,     0,     0,     0]])
Attention Mask: tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
Targets: tensor([0., 1.])
Vocabulary size: 28996
