In [1]:
from transformers import AutoTokenizer

# 记载分词器
tokenizer = AutoTokenizer.from_pretrained('t5-small')
print(tokenizer)

T5TokenizerFast(name_or_path='t5-small', vocab_size=32100, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_i

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [2]:
tokenizer.batch_encode_plus(['Hello , this is one sentence', 'This is another sentence.'])

{'input_ids': [[8774, 3, 6, 48, 19, 80, 7142, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

In [3]:
with tokenizer.as_target_tokenizer():
    print(tokenizer.batch_encode_plus(['Hello , this is one sentence', 'This is another sentence.']))

{'input_ids': [[8774, 3, 6, 48, 19, 80, 7142, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}


  "`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your "


In [4]:
from datasets import load_dataset


In [5]:
# 有可能会遇到ConnectionError的问题. 
# 解决办法, 翻墙, 设置全局代理, 代理中执行下面代码
import os

# 注意要把端口号, 替换成你自己的vpn的端口号.
os.environ['http_proxy'] = 'http://127.0.0.1:10809'
os.environ['https_proxy'] = 'http://127.0.0.1:10809'

In [6]:
dataset = load_dataset('xsum')
# 从本地加载 load_from_disk(path)

Found cached dataset xsum (C:/Users/SupercoldZzz/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)


  0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

In [8]:
# 采样
dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(1000))
dataset['test'] = dataset['test'].shuffle(1).select(range(1000))

Loading cached shuffled indices for dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-f5f90b78d2c462d1.arrow
Loading cached shuffled indices for dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-6dd63dfa49190643.arrow
Loading cached shuffled indices for dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-fbae00ea393505b7.arrow


In [9]:
print(dataset['train'][0])

{'document': "Clay, who has agreed a two-year deal, made 39 appearances for Scottish Premiership club Motherwell last season after joining them in June 2016.\nThe 25-year-old had spent the two previous seasons with Grimsby, playing 74 National League games.\nClay is Leyton Orient's ninth signing since being relegated from League Two last season.", 'summary': 'National League side Leyton Orient have signed Motherwell midfielder Craig Clay on a free transfer.', 'id': '40635923'}


In [10]:
# 对数据进行预处理
def f(examples, tokenizer):
    data = tokenizer.batch_encode_plus(['summarize:' + i for i in examples['document']],
                                      max_length=1024,
                                      truncation=True)
    
    # 编码label
    data['labels'] = tokenizer.batch_encode_plus(examples['summary'], max_length=128, truncation=True)['input_ids']
    return data


In [11]:
dataset = dataset.map(f,
                     batched=True,
                     batch_size=1000,
                     num_proc=12,
                     remove_columns=['document', 'summary', 'id'],
                     fn_kwargs={'tokenizer': tokenizer})

 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-ce43cf79ccfa91c0.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-4e36ccc96175af80.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-28209578b8c68e07.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-db8f1e84dc9f75a2.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-30c872fc35ab4e84.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-96a14e0e8932224e.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-77e4490c5c3543f4.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-8a348ba83132080d.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-cd5b239fb87f2be1.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-a02ea666b951f4a4.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-7a25468867840b95.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-6fcfdb5e10ca7c01.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-1b69d2d326cca8ad.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-b8d7c6a4ff75fb4e.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-c1b15b6614fdb7a0.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-86885a042b296927.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-4911eb024b91844c.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-4225565a98db5d2a.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-403599fafa11774b.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-fdc5a7032a4c1ad5.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-a82038f51ff3394a.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-9dcd96be7bc1cd1d.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-bd79b3dc4204f7d6.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-5bd663b62f8b57cb.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-b383fb41e18670d1.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-f9c4668734531fcd.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-4c743f673c323d09.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-3d6b1f6c3356a985.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-4c5a9e8d16256c1e.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-dfd7d1ca31a59edc.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-28e2b14e837ce409.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-1863e355cc124d06.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-69da917f6f527548.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-f7defaf2c746f35f.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-fb9cc2f8befe389c.arrow


 

Loading cached processed dataset at C:\Users\SupercoldZzz\.cache\huggingface\datasets\xsum\default\1.2.0\082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71\cache-73a5fccf16531a29.arrow


In [12]:
dataset


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 20000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
})

In [13]:
print(dataset['train'][0])

{'input_ids': [21603, 10, 254, 5595, 6, 113, 65, 4686, 3, 9, 192, 18, 1201, 1154, 6, 263, 6352, 3179, 7, 21, 12580, 6552, 2009, 1886, 8007, 2091, 336, 774, 227, 6109, 135, 16, 1515, 4619, 37, 944, 18, 1201, 18, 1490, 141, 1869, 8, 192, 1767, 9385, 28, 23427, 7, 969, 6, 1556, 3, 4581, 868, 3815, 1031, 5, 20988, 19, 312, 21220, 3, 16495, 31, 7, 24651, 8097, 437, 271, 3, 60, 8791, 26, 45, 3815, 2759, 336, 774, 5, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [868, 3815, 596, 312, 21220, 3, 16495, 43, 3814, 8007, 2091, 2076, 1846, 49, 12870, 20988, 30, 3, 9, 339, 2025, 5, 1]}


In [14]:
import torch


def collate_fn(data):
    # 求最长的labels
    max_length = max([len(i['labels']) for i in data])
    
    # 把所有的label都pad到max_length
    for i in data:
        pads = [-100] * (max_length - len(i['labels']))
        i['labels'] = i['labels'] + pads
        
    
    # 把多个数据整合成一个tensor
    data = tokenizer.pad(
        encoded_inputs=data,
        padding=True,
        max_length=None,
        pad_to_multiple_of=None,
        return_tensors='pt'
    )
    
    data['decoder_input_ids'] = torch.zeros_like(data['labels'], dtype=torch.long)
    data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
    data['decoder_input_ids'][data['decoder_input_ids']== -100] = 0
    return data


In [15]:
data = [{
    'input_ids': [21603, 10, 37, 3719, 13],
    'attention_mask': [1, 1, 1, 1, 1],
    'labels': [10455, 120, 80]
}, {
    'input_ids': [21603, 10, 7086, 8408, 563],
    'attention_mask': [1, 1, 1, 1, 1],
    'labels': [301, 53, 4074, 1669]
}]
collate_fn(data)

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': tensor([[21603,    10,    37,  3719,    13],
        [21603,    10,  7086,  8408,   563]]), 'attention_mask': tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]]), 'labels': tensor([[10455,   120,    80,  -100],
        [  301,    53,  4074,  1669]]), 'decoder_input_ids': tensor([[    0, 10455,   120,    80],
        [    0,   301,    53,  4074]])}

In [16]:
# 数据加载器
loader = torch.utils.data.DataLoader(
    dataset = dataset['train'],
    batch_size=4,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True
)

for data in loader:
    break

In [17]:
for k, v in data.items():
    print(k, v.shape)

input_ids torch.Size([4, 809])
attention_mask torch.Size([4, 809])
labels torch.Size([4, 47])
decoder_input_ids torch.Size([4, 47])


In [18]:
len(loader)

5000

In [19]:
from transformers import AutoModelForSeq2SeqLM, T5Model

In [20]:
512 * 32100

16435200

In [21]:
# 定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = T5Model.from_pretrained('t5-small')
        # vocab_size 是32100
        # 我们打算用AutoModelForSeq2SeqLM的参数, 但是它的输出层参数是32128
        self.fc = torch.nn.Linear(512, 32128, bias=False)
        
        # 加载预训练权重
        parameters = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
        self.fc.load_state_dict(parameters.lm_head.state_dict())
        
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)
        
    def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
        logits = self.pretrained.encoder(input_ids=input_ids,
                                        attention_mask=attention_mask)
        logits = logits.last_hidden_state
        logits = self.pretrained.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=logits,
            encoder_attention_mask=attention_mask
        )
        logits = logits.last_hidden_state
        logits = logits * (512** -0.5)
        logits = self.fc(logits)
        loss = self.criterion(logits.reshape(-1, 32128), labels.reshape(-1))
        
        return {'loss': loss, 'logits': logits}
    
    
                                        

In [22]:
model = Model()

In [23]:
out = model(**data)

In [24]:
out['loss']

tensor(3.5750, grad_fn=<NllLossBackward0>)

In [25]:
out['logits'].shape

torch.Size([4, 47, 32128])

In [26]:
# 参数量
print(sum(i.numel() for i in model.parameters()))

76956160


In [27]:
# 测试
# 测试
def test(model):
    model.eval()
    
    loader_test = torch.utils.data.DataLoader(
        dataset=dataset['test'],
        batch_size=4,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True
    )
    
    for data in loader_test:
        break
        
    with torch.no_grad():
        out = model(**data)
        

    
    for i in range(4):
        input_ids = tokenizer.decode(data['input_ids'][i])
        pred = tokenizer.decode(out['logits'].argmax(dim=2)[i])
        label = tokenizer.decode(data['decoder_input_ids'][i])
        
        print('pred:', pred)
        print('label:', label)
        print()

In [28]:
test(model)

pred: Annakeepero Oinomar was thes in  uncle beat the  round.  the football mped out. 4ship   of the fifth capital  the:
label: <pad> Goal hero Rabin Omar made headlines when his club from the fourth tier of Scottish football dumped a Premiership side out of the Scottish Cup.</s>

pred: world Hamilton says startinghe was "actual  about about finishing starting start in both race One race    the the the the the the the the world world world world world world
label: <pad> Lewis Hamilton said he was "not worried" about his difficult start to the Formula 1 season.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

pred: scientistss have on's spaceetta probe to  will  comet'.P  have they have have been  of the they a structures formed formed  the
label: <pad> Scientists working on Europe's Rosetta probe, which is tracking Comet 67P, say they may have found evidence for how such icy objects were formed.

pred: football's also latest thing of football football football

In [29]:
from transformers.trainer_pt_utils import get_parameter_names
from transformers import AdamW
from transformers.optimization import get_scheduler

In [30]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [31]:
data

{'input_ids': tensor([[21603,    10,   667,  ...,     0,     0,     0],
        [21603,    10, 15743,  ...,  5719,   535,     1],
        [21603,    10,   634,  ...,     0,     0,     0],
        [21603,    10, 24607,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[   71,  1249, 17030,    18,  8861, 26737,   297,    13, 26238,    31,
             7, 13017,  1635,     7,   107,  2309,  2050,    65,   118,  3754,
             5,     1,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100],
        [17159,   115,    17,     7,    43,   118,  3279,    81,   823,  2789,
            31,     7, 17804,    91,    18,   858,    18,  5842,     7,   199,
           747,    19,     3,   179,    12,  2862,  2

In [32]:
# 训练代码
def train():
    parameter_names = get_parameter_names(model, [torch.nn.LayerNorm])
    # weight_decay 权重衰减
    parameter_names = [i for i in parameter_names if 'bias' not in i]
    
    parameter_names = [
        {
            'params': [p for i, p in model.named_parameters() if i in parameter_names],
            'weight_decay': 1e-2
        },
        {
            'params': [p for i, p in model.named_parameters() if i not in parameter_names],
            'weight_decay': 0.0
        }
    ]
    
    # 定义优化器
    optimizer = AdamW(parameter_names, betas=(0.9, 0.999), eps=1e-8, lr=2e-5)
    
    scheduler = get_scheduler(name='linear',
                             num_warmup_steps=0,
                             num_training_steps=len(loader),
                             optimizer=optimizer)
    
    model.to(device)
    model.train()
    
    for i, data in enumerate(loader):
        input_ids, attention_mask = data['input_ids'], data['attention_mask']
        labels, decoder_input_ids = data['labels'], data['decoder_input_ids']
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)
        decoder_input_ids = decoder_input_ids.to(device)
        
        out = model(input_ids=input_ids, 
                    attention_mask=attention_mask, 
                    labels=labels, 
                    decoder_input_ids=decoder_input_ids)
        loss = out['loss']
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        optimizer.zero_grad()
        model.zero_grad()
        
        if i % 50 == 0:
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            
            pred = tokenizer.decode(out['logits'].argmax(dim=2)[0])
            label = tokenizer.decode(data['decoder_input_ids'][0])
            
            print(i, loss.item(), lr)
            print('pred: ', pred)
            print('label: ', label)
            print()

In [33]:
train()



0 5.284295082092285 1.9996000000000003e-05
pred:  fsea cancer hospital,  cancer   to pedal in.   t   the world </s></s>   the the the the the the the  the
label:  <pad> A Swansea University engineer with terminal cancer is set to ride a baked bean-shaped bike around the UK.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

50 3.1944425106048584 1.9796e-05
pred:  six jury is thea  who of  in f crash crash claim was. was sak" thea jury heard heard </s>  the the the the Beth Beth  the jury jury the the seven six jury the  the the
label:  <pad> The case against a woman accused of involvement in a car insurance fraud ring is "weak", a court has heard.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

100 3.390977382659912 1.9596e-05
pred:    of thesa's Trade into the tax of of people work, been that are no that are  the-employed as to create tax </s></s>
label:  <pad> The head of Theresa May's inquiry into the way millions of peopl

1100 2.8634085655212402 1.5596e-05
pred:  ComJ giant UScast has  build a new9% stake of in Tokyo Studios''USJ) in its1.5 billionillion in£16mb)</s>
label:  <pad> US entertainment giant Comcast is to buy a 51% majority stake in Universal Studios Japan (USJ) for $1.5bn (£987m).

1150 2.5308940410614014 1.5396000000000003e-05
pred:  The annual event's worldice hockey event in in Du Du Hockey in Dumfries has generated aailed as a in " boost to Du region economy.</s>t the
label:  <pad> An international women's ice hockey competition held at the Ice Bowl in Dumfries has been hailed for bringing an economic boost to the local area.</s><pad>

1200 2.91882061958313 1.5196000000000002e-05
pred:  The mayor Mayor of Tower Hamlets has beenreowed to " the name from the election election election. hishe court Court ruling. him guilty of corruption fraud.</s>
label:  <pad> The former mayor of Tower Hamlets has vowed to clear his name in his first public speech since a High Court ruling found him guilt

2250 3.3192496299743652 1.0996e-05
pred:  Ss Cup Sea Sihull Moors have ita  debut to the at the Premier halftier of  S Irelands Sutton United-1.</s></s> A
label:  <pad> National League North winners Solihull Moors made a promising start to life in the fifth tier by beating southern champions Sutton 3-1.</s><pad>

2300 3.111926794052124 1.0796e-05
pred:  The new councilboroughdevelopment project in in in Ball  town in Ball ofrim has been approved by thelors.</s></s>  The The The The The The The The The The The The The The A The
label:  <pad> A major redevelopment project centred around an evangelical church in County Antrim has been approved by councillors.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

2350 2.9417884349823 1.0596000000000002e-05
pred:  Dony Toye, 33ry Gallvanagh and To To have been for off the new-leagueing squad team to theegal.</s></s> The The The To A Don
label:  <pad> Christy Toye, Rory Kavanagh and David Walsh have called

3350 3.1043970584869385 6.596e-06
pred:  Thetsman, Gillan has been: had have to play  in the thanthe-" roles. the  in villain in</s></s> Guardian Guardian Guardian The Guardian The The Marvel Guardian The Guardian Guardian
label:  <pad> Scots actress Karen Gillan has said she would like to be cast in more "bad guy" roles after playing her first villain.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

3400 2.8847649097442627 6.396e-06
pred:  Aer of photographer Drinkwater and</s>:  The The A The The A A The The A A A A A A A A A A A A
label:  <pad> Photographs by Marcus Drinkwater.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

3450 3.004556894302368 6.196000000000001e-06
pred:  Aillian Whyte has Joshua mentals over Anthony Joshua has him ands on  makes- in  fighte the title heavyweight title.</s>
label:  <pad> Dillian Whyte believes his amateur victory over Anthony Joshua left mental scars that will re

4450 2.81040358543396 2.1960000000000002e-06
pred:  Arankss havephrew chicken chicken into through the kitchen hatch ofs at the chicken's inpPrbyininshed chicken in Mondays..</s>
label:  <pad> Pranksters threw live chickens through the serving hatches of two McDonald's 'drive-thru' restaurants on Teesside.

4500 2.8710155487060547 1.996e-06
pred:  A passenger in been in a car witha  cars in  road4 in Birminghamwellleys, Birminghams,onan,, said said.</s>
label:  <pad> A woman has died after a collision involving two cars on the A44 in Powys, Dyfed-Powys Police has said.

4550 2.8274359703063965 1.7960000000000003e-06
pred:  The theae   the to of the  on on thet 67P, it have have how   thea dark space.</s>, The The The The
label:  <pad> When Philae first sent back images of its landing location on Comet 67P, researchers could see it was in a dark ditch.</s><pad><pad><pad><pad>

4600 3.1162917613983154 1.596e-06
pred:  The shoppers have beifys at who the taxi taxi Cup Cup. in Cardiff.  of