## pytorch-lighting版本的训练方式

torch version: 1.7.1+cu101
pytorch_lightning version: 1.2.7

In [1]:
import pytorch_lightning as pl
pl.__version__

'1.2.7'

In [3]:
# config.py
import transformers

class config:
    max_len = 32 # 即pad_size
    N_EPOCHS = 3
    LEARNING_RATE = 3e-6
    BERT_PATH = './pretrain/bert-base-chinese'
    TOKENIZER = transformers.BertTokenizer.from_pretrained(BERT_PATH)
    num_labels = 10
    save_dir = './lightningloss_test'

In [4]:
# dataset.py
import torch
def tokenizering(tokenizer, input, max_len, return_type=None):
    return tokenizer.encode_plus(
        text = input,
        max_length = max_len,
        padding = 'max_length',
        truncation = 'only_first',
        return_tensors = return_type
    )

#类似from torch.utils.data import dataset 里面的dataset实现
class Dataset:
    def __init__(self, data):
        self.data = data
        self.tokenizer = config.TOKENIZER
        self.max_len = config.max_len# pad size
    def __len__(self):
        return len(self.data)
    def __getitem__(self, item):
        item = self.data[item]
        text = item[0]
        target = item[1]
        inputs = tokenizering(self.tokenizer, text, self.max_len, return_type=None)
        ids = inputs['input_ids']
        masks = inputs['attention_mask']
        token_type_ids = inputs['token_type_ids']
        return {'input_ids': torch.tensor(ids, dtype=torch.long),
                'token_type_ids':torch.tensor(token_type_ids, dtype=torch.long),
               'attention_mask':torch.tensor(masks, dtype=torch.long),
               'labels':torch.tensor(int(target), dtype=torch.long)} # 因为BertForSequenceClassification要求的forward输入是这样的，所以处理成这种字典格式；其他的模型不一定都是这样

In [5]:
with open('THUCNews/data/train.txt', 'r', encoding='UTF-8') as f:
    data = [(x.strip().split('\t')) for x in f.readlines() if x.strip()]
train_dataset = Dataset(data)
with open('THUCNews/data/dev.txt', 'r', encoding='UTF-8') as f:
    data = [(x.strip().split('\t')) for x in f.readlines() if x.strip()]
val_dataset = Dataset(data)

In [6]:
# model.py
from transformers import BertForSequenceClassification
from pytorch_lightning import LightningModule
from transformers import AdamW

class LightningModel(LightningModule):
    def __init__(self, BASE_MODEL_PATH, num_labels):
        super(LightningModel, self).__init__()
        self.model = BertForSequenceClassification.from_pretrained(BASE_MODEL_PATH, num_labels=num_labels)
        
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)
    
    def configure_optimizers(self):
        return AdamW(self.model.parameters(), lr = config.LEARNING_RATE)
    
    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss, logits = outputs.loss, outputs.logits
         # acc
        targets = batch['labels']
        train_acc = torch.tensor(torch.eq(logits.argmax(dim=1), targets).sum().float().item() / targets.size(0))
        
        self.log(
            'train_loss',
            loss,
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=True
        )
        return {'loss': loss, 'acc':train_acc}
    
    def training_epoch_end(self,training_step_outputs):
        loss, acc = self.calculate_metrics(training_step_outputs)
        print(f'Epoch: {self.current_epoch:2}')
        print(f' Train_loss: {loss:.3f}  | Train_acc: {acc*100:.2f}%')
    
    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss, logits = outputs.loss, outputs.logits
         # acc
        targets = batch['labels']
        val_acc = torch.tensor(torch.eq(logits.argmax(dim=1), targets).sum().float().item() / targets.size(0))
        
        self.log(
            'val_loss',
            loss,
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=True
        )
        return {'loss': loss, 'acc':val_acc}
    
    def validation_epoch_end(self,validation_step_outputs):
        loss, acc = self.calculate_metrics(validation_step_outputs)
        print(f' valid_loss: {loss:.3f} | valid_acc: {acc*100:.2f}%')
    
    def calculate_metrics(self, step_outputs):
        loss = torch.mean(torch.stack([x['loss'] for x in step_outputs]))
        acc = torch.mean(torch.stack([x['acc'] for x in step_outputs]))
        return loss, acc

In [7]:
# train.py
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from collections import OrderedDict

def train(base_model_path, save_directory, train_dataset, val_dataset, batch_size, lr, epochs, num_labels):
    train_dataloaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_dataloaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    LTmodel = LightningModel(base_model_path, num_labels)
#     checkpoint_callback = ModelCheckpoint(monitor="val_acc") 
    logger = TensorBoardLogger(
        save_dir=save_directory,
        name=base_model_path.split('/')[-1]
    )
    trainer = Trainer(
        logger = logger,
#         callbacks=[checkpoint_callback],
        min_epochs = 1,
        max_epochs = epochs,
        gpus = [0]
    )
    print(trainer.logger.log_dir)
    trainer.fit(
        LTmodel,
        train_dataloader = train_dataloaders,
        val_dataloaders = val_dataloaders
    )
    
train(config.BERT_PATH, config.save_dir, train_dataset, val_dataset, 128, config.LEARNING_RATE, config.N_EPOCHS, config.num_labels)

Some weights of the model checkpoint at ./pretrain/bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model ch

./lightningloss_test\bert-base-chinese\version_0



  | Name  | Type                          | Params
--------------------------------------------------------
0 | model | BertForSequenceClassification | 102 M 
--------------------------------------------------------
102 M     Trainable params
0         Non-trainable params
102 M     Total params
409.101   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

 valid_loss: 2.492 | valid_acc: 7.81%


Training: 0it [00:00, ?it/s]

