In [1]:
import torch
from datasets import load_from_disk

In [5]:
# 定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.datasets = load_from_disk('../data/ChnSentiCorp')
        self.dataset = self.datasets[split]
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']
        return text, label
    
dataset = Dataset('train')
dataset

<__main__.Dataset at 0x2cf4f74fb08>

In [6]:
dataset.dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 9600
})

In [7]:
len(dataset)

9600

In [8]:
dataset[0]

('选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般',
 1)

In [9]:
from transformers import BertTokenizer

# 加载字典和分词工具
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer

BertTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [10]:
# 重写collate_fn
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    
    # 编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents, 
                               truncation=True,
                               padding='max_length',
                               return_tensors='pt',
                               return_length='True')
    # 编码之后的数字
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)
    
    return input_ids, attention_mask, token_type_ids, labels

In [11]:
# 数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True
)

for i,(input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

In [12]:
print(len(loader))

600


In [15]:
attention_mask

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])

In [16]:
print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels)

torch.Size([16, 512]) torch.Size([16, 512]) torch.Size([16, 512]) tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0])


In [17]:
from transformers import BertModel

# 加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')

# 固定bert的参数
for param in pretrained.parameters():
    param.requires_grad_(False)
    

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/412M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exact

In [18]:
# 模型试算
out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

In [22]:
out.last_hidden_state[:, 0].shape

torch.Size([16, 768])

In [24]:
# 定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self, pretrained):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)
        self.pretrained = pretrained
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
            
        out = self.fc(out.last_hidden_state[:, 0])
        out = out.softmax(dim=1)
        return out
    
model = Model(pretrained)
model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).shape

torch.Size([16, 2])

In [28]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [33]:
model.device

AttributeError: 'Model' object has no attribute 'device'

In [38]:
from transformers import AdamW

# 训练
optimizer = AdamW(model.parameters(), lr=5e-4)

criterion = torch.nn.CrossEntropyLoss()

model = Model(pretrained)
model.train()
model.to(device)
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    labels = labels.to(device)
    
    out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)
        print(i, loss.item(), accuracy)
        
    if i == 600:
        break

0 0.6827474236488342 0.5625
5 0.662400484085083 0.625
10 0.6865236759185791 0.5625
15 0.7083545923233032 0.375
20 0.6344900131225586 0.75
25 0.658557653427124 0.625
30 0.6382225155830383 0.8125
35 0.715989351272583 0.375
40 0.7118325233459473 0.5
45 0.6845795512199402 0.625
50 0.685475766658783 0.5
55 0.6963281035423279 0.4375
60 0.6833534240722656 0.5
65 0.6950926780700684 0.4375
70 0.72159743309021 0.5
75 0.6438045501708984 0.75
80 0.7265025973320007 0.3125
85 0.7005784511566162 0.5
90 0.72879958152771 0.4375
95 0.7319801449775696 0.3125
100 0.6244087219238281 0.8125
105 0.6907938718795776 0.5625
110 0.728121817111969 0.4375
115 0.6624838709831238 0.5625
120 0.6880730986595154 0.5
125 0.7138139009475708 0.5
130 0.6859719157218933 0.625
135 0.6897076964378357 0.5
140 0.703993558883667 0.4375
145 0.6741662621498108 0.625
150 0.6896845102310181 0.5625
155 0.6807734966278076 0.5
160 0.6992517709732056 0.375
165 0.6861604452133179 0.5625
170 0.6964953541755676 0.5
175 0.7046523690223694 0

In [40]:
# 测试
def test():
    model.eval()
    correct = 0
    total = 0
    
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                             batch_size=32,
                                             collate_fn=collate_fn,
                                             shuffle=True,
                                             drop_last=True)
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)
        if i == 5:
            break
        print(i)
        with torch.no_grad():
            out = model(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids)
        out = out.argmax(dim=1)
        correct += (out==labels).sum().item()
        total += len(labels)
    print(correct / total)
test()

0
1
2
3
4
0.58125
