<a href="https://colab.research.google.com/github/lyzno1/lightning_example/blob/main/fabric.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.6-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.4.1-py3-none-any.whl.metadata (20 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<4.0,>=2.1.0->lightning)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch<4.0,>=2.1.0->lightning)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch<4.0,>=2.1.0->lightning)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecti

In [3]:
import pandas as pd

data = pd.read_excel('usingtest.xlsx')

print(data.head())
for i, value in enumerate(data['label']):
    if value != 1.0 and value != 0.0:
        print(f"in row {i+1}")
print("yes")
assert all(data['label'].isin([0.0, 1.0])), "label contains values other than 0.0 and 1.0"
data = data[['内容', 'label']]

    序号       昵称 性别  省份                                                 内容  \
0    1  潇潇diana  女  北京  一个 妈妈 一天 心路历程 吃饭 篇 牛奶 有无 三聚氰胺 超标 会 不会 喝成 大头 面包...   
1    2   睡不饱的任镳  男  上海   发现 现在 媒体 微博后 关注度 会 大幅度 增加 快速 传播 影响 很大 有个 缺点 不...   
2    4    狙击手蝈蝈  男  广东  铁证如山 日军 性 暴行 受害者 两姐妹 证言 公布 救 其他人 时年 14 岁 彭 仁寿 ...   
3    7  邵井子1314  男  其他  疫苗 事件 转基因 事件 只不过 比较 两个 造假 事件 没收 转发 键 疫苗 事件 国产 ...   
4  181  时尚老太80后  女  其他   转基因 日前 农业部 回应 表示 转基因 谣言 已经 影响 转基因 健康 发展 实际上 科...   

            认证   编写日期  label  
0          未认证   3分钟前      1  
1  东方汇金期货研究员任镳   3分钟前      1  
2          未认证  13分钟前      1  
3       头条文章作者   3分钟前      1  
4          未认证  23分钟前      1  
yes


In [5]:
import torch
from torch.utils.data import Dataset, DataLoader

class WeiboDataset(Dataset):
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        comment = self.data.iloc[idx]['内容']
        label = self.data.iloc[idx]['label']
        encoding = self.tokenizer.encode_plus(
            comment,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )#
        return {
            'comment_text': comment,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [21]:
import torch.nn as nn
from transformers import BertForSequenceClassification

class Classifer(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.bert = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=config.num_labels).train()

    def forward(self, input_ids, attention_mask, labels=None):
        return self.bert(input_ids, attention_mask=attention_mask, labels=labels)

In [34]:
def validate(config, fabric, model, val_loader):
    model.eval()
    val_losses = []
    for batch in val_loader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        with torch.no_grad():
            outputs = model(input_ids, attention_mask, labels=labels)
            loss = outputs.loss
            val_losses.append(loss.item())

    avg_val_loss = sum(val_losses) / len(val_losses)
    fabric.print(f"Validation Loss\t{avg_val_loss}")

def train(config, fabric, model, train_loader, val_loader, optimizer, scheduler=None):
    for epoch in range(config.num_epochs):
        model.train()
        train_losses = []
        all_steps = len(train_loader)
        fabric.print(f"Epoch\t{epoch+1}\tTotal Steps\t{all_steps}")

        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['labels']
            outputs = model(input_ids, attention_mask, labels=labels)
            loss = outputs.loss
            fabric.backward(loss)
            optimizer.step()

            if scheduler is not None:
                scheduler.step()
            if optimizer.param_groups[0]['lr'] != scheduler.get_last_lr()[0]:
                fabric.print(f"Learning Rate\t{optimizer.param_groups[0]['lr']}")

            train_losses.append(loss.item())
            avg_train_loss = sum(train_losses) / len(train_losses)
            fabric.print(f"Epoch[{epoch+1}/{config.num_epochs}]\tStep\t{i}\tAverage Train Loss\t{avg_train_loss}")

        if epoch % config.validate_every_n_epoch == 0:
            validate(config, fabric, model, val_loader)

def test(config, fabric, model, test_loader):
    from sklearn.metrics import f1_score

    model.eval()
    test_losses = []
    all_f1_scores = []

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['labels']
            outputs = model(input_ids, attention_mask, labels=labels)
            preds = torch.argmax(outputs.logits, dim=1)
            loss = outputs.loss
            f1 = f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='binary')

            test_losses.append(loss.item())
            all_f1_scores.append(f1)

    avg_test_loss = sum(test_losses) / len(test_losses)
    avg_f1_score = sum(all_f1_scores) / len(all_f1_scores)
    fabric.print(f"Test Loss\t{avg_test_loss}")
    fabric.print(f"Test F1 Score\t{avg_f1_score}")


In [33]:
from lightning.fabric import Fabric

def main(config):
  fabric = Fabric()
  # fabric.launch() # if multi gpu
  fabric.seed_everything(config.seed)

  from sklearn.model_selection import train_test_split
  train_data, test_data = train_test_split(data, test_size=0.1, random_state=config.seed)
  train_data, val_data = train_test_split(train_data, test_size=0.2, random_state=config.seed)

  train_dataset = WeiboDataset(train_data, config.tokenizer, config.max_len)
  val_dataset = WeiboDataset(val_data, config.tokenizer, config.max_len)
  test_dataset = WeiboDataset(test_data, config.tokenizer, config.max_len)

  train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
  val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
  test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

  model = Classifer(config)
  optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
  model, optimizer = fabric.setup(model, optimizer)
  train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader)

  train(config, fabric, model, train_loader, val_loader, optimizer, scheduler)
  test(config, fabric, model, test_loader)


In [35]:
from dataclasses import dataclass
from transformers import BertTokenizer

@dataclass
class Config():
  # seed
  seed = 42

  # bert
  num_labels = 2
  max_len = 256
  tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")

  # loader
  batch_size = 32

  # scheduler
  step_size = 10
  gamma = 0.9

  # optimizer
  lr = 1e-3

  # training loop
  num_epochs = 2
  validate_every_n_epoch = 1


In [36]:
main(config=Config())

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch	1	Total Steps	70
Epoch[1/2]	Step	0	Average Train Loss	0.7799716591835022
Epoch[1/2]	Step	1	Average Train Loss	0.6695522964000702
Epoch[1/2]	Step	2	Average Train Loss	1.2686284184455872
Epoch[1/2]	Step	3	Average Train Loss	1.0832367837429047
Epoch[1/2]	Step	4	Average Train Loss	0.976738166809082
Epoch[1/2]	Step	5	Average Train Loss	0.9190701643625895
Epoch[1/2]	Step	6	Average Train Loss	0.8630381396838597
Epoch[1/2]	Step	7	Average Train Loss	0.9318525269627571
Epoch[1/2]	Step	8	Average Train Loss	0.8766976098219553
Epoch[1/2]	Step	9	Average Train Loss	0.8668433934450149
Epoch[1/2]	Step	10	Average Train Loss	0.8489433174783533
Epoch[1/2]	Step	11	Average Train Loss	0.8148540457089742
Epoch[1/2]	Step	12	Average Train Loss	0.8003633847603431
Epoch[1/2]	Step	13	Average Train Loss	0.7771714500018528
Epoch[1/2]	Step	14	Average Train Loss	0.7594929416974385
Epoch[1/2]	Step	15	Average Train Loss	0.7434717677533627
Epoch[1/2]	Step	16	Average Train Loss	0.7362517293761758
Epoch[1/2]	Step	17	