<a href="https://colab.research.google.com/github/doraemonidol/bert-crf/blob/master/punktuation_ner_2k.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [2]:
# import sys
# sys.path.append('/content/gdrive/MyDrive/Colab/punktuation-ner')

In [4]:
%pip install forgebox==0.4.18.5 pytorch_lightning

Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import sys
import numpy as np
import pandas as pd

In [None]:
def load_data_for_split(file_path):
    texts, labels = [], []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            text = line.replace(' ', '')
            label = []
            i = 0
            while (i < len(text)):
                if text[i] in ['，', '。', '？', '！']:
                    label.append('1')
                else:
                    label.append('0')
                i += 1
            text_list = list(text)
            # label[-1] = 2
            # print(text, label)
            # break
            texts.append(text_list)
            labels.append(''.join(label))

            if (len(text_list) != len(label)):
                print('Error:', text, label)
    return texts, labels

In [None]:
from sklearn.model_selection import train_test_split

def recursive_split(data, n_splits, idx=0):
    print(f'Splitting data_{idx} with {len(data)} samples')
    if idx >= n_splits:
        with open(f'data/split6/data_{idx - n_splits}.txt', 'w', encoding='utf-8') as file:
            for text in data['sentence'].to_list():
                file.write(''.join(text) + '\n')
        return
    train_data, test_data = train_test_split(data, test_size=0.5, random_state=42, stratify=data['label'])
    # print(f'Lenght of train_data: {len(train_data)}')
    # print(f'Lenght of test_data: {len(test_data)}')
    recursive_split(train_data, n_splits, idx * 2)
    recursive_split(test_data, n_splits, idx * 2 + 1)

In [None]:
sentences, labels = load_data_for_split('data/train_large_2.txt')

data = pd.DataFrame({'sentence': sentences, 'label': labels})

In [None]:
label_counts = data['label'].value_counts()
valid_labels = label_counts[label_counts >= 4096].index
filtered_data = data[data['label'].isin(valid_labels)]

In [None]:
print(f'Original data with {len(data)} samples')
print(f'Filtered data with {len(filtered_data)} samples')

In [None]:
recursive_split(filtered_data, 4096, 1)

# Punctuation NER

In [32]:
# Forgebox Imports
from forgebox.imports import *
from forgebox.category import Category
import pytorch_lightning as pl
from transformers import AutoTokenizer, BertForTokenClassification, BertModel
from transformers import pipeline
from typing import List
import re
from torch.utils.data import DataLoader, Dataset

In [33]:
# DATA = r'/content/gdrive/MyDrive/Colab/punktuation-ner/data'
DATA = r'C:\Users\Nhat Hung\Documents\GitHub\bert-crf\data\split6'

## Read Metadata

In [34]:
# From DATA, list all the files in the directory ending with .txt
LABELS = [f.name for f in Path(DATA).rglob("*.txt")]
print(LABELS)

['data_0.txt', 'data_1.txt', 'data_10.txt', 'data_100.txt', 'data_1000.txt', 'data_1001.txt', 'data_1002.txt', 'data_1003.txt', 'data_1004.txt', 'data_1005.txt', 'data_1006.txt', 'data_1007.txt', 'data_1008.txt', 'data_1009.txt', 'data_101.txt', 'data_1010.txt', 'data_1011.txt', 'data_1012.txt', 'data_1013.txt', 'data_1014.txt', 'data_1015.txt', 'data_1016.txt', 'data_1017.txt', 'data_1018.txt', 'data_1019.txt', 'data_102.txt', 'data_1020.txt', 'data_1021.txt', 'data_1022.txt', 'data_1023.txt', 'data_1024.txt', 'data_1025.txt', 'data_1026.txt', 'data_1027.txt', 'data_1028.txt', 'data_1029.txt', 'data_103.txt', 'data_1030.txt', 'data_1031.txt', 'data_1032.txt', 'data_1033.txt', 'data_1034.txt', 'data_1035.txt', 'data_1036.txt', 'data_1037.txt', 'data_1038.txt', 'data_1039.txt', 'data_104.txt', 'data_1040.txt', 'data_1041.txt', 'data_1042.txt', 'data_1043.txt', 'data_1044.txt', 'data_1045.txt', 'data_1046.txt', 'data_1047.txt', 'data_1048.txt', 'data_1049.txt', 'data_105.txt', 'data_1050

In [14]:
DATA_SIZE = 0
for label in LABELS:
    DATA_SIZE += len(open(DATA + '/' + label, 'r').readlines())
    break
print(DATA_SIZE)

1024


In [18]:
punkt_regex = r'[^\w\s]'

def position_of_all_punctuation(x):
    return [m.start() for m in re.finditer(punkt_regex, x)]

# simplify the punctuation
eng_punkt_to_cn_dict = {
    ".": "。",
    ",": "，",
    ":": "：",
    ";": "；",
    "?": "？",
    "!": "！",
    "“": "\"",
    "”": "\"",
    "‘": "\'",
    "’": "\'",
    "「": "（",
    "」": "）",
    "『": "\"",
    "』": "\"",
    "（": "（",
    "）": "）",
    "《": "【",
    "》": "】",
    "［": "【",
    "］": "】",
    }

def translate_eng_punkt_to_cn(char):
    if char == "O":
        return char
    if char in eng_punkt_to_cn_dict.values():
        return char
    result = eng_punkt_to_cn_dict.get(char)
    if result is None:
        return "。"
    return result

def punct_ner_pair(sentence):
    positions = position_of_all_punctuation(sentence)
    x = re.sub(punkt_regex, '', sentence)
    y = list("O"*len(x))

    for i, p in enumerate(positions):
        y[p-i-1] = sentence[p]
    p_df = pd.DataFrame({"x":list(x), "y":y})
    p_df["y"] = p_df["y"].apply(translate_eng_punkt_to_cn)
    return p_df

In [19]:
ALL_LABELS = ["O",]+list(eng_punkt_to_cn_dict.values())

In [20]:
cates = Category(ALL_LABELS)

In [28]:
import random


class PunctDataset(Dataset):
    def __init__(
        self,
        data_dir: Path,
        filelist: List[str],
        num_threads: int = 8,
        length: int = 1000,
        size: int = 540
    ):
        """
        Args:
            - filelist: list of file names
            - The dataset will open ```num_threads``` files, and hold
                in memory simoultaneously.
            - num_threads: number of threads to read files,
            - length: number of sentences per batch
            - size: number of characters per sentence
        """
        self.data_dir = Path(data_dir)
        self.filelist = filelist
        self.num_threads = num_threads
        self.length = length
        # open file strings, index is mod of num_threads
        self.current_files = dict(enumerate([""]*length))
        self.string_index = dict(enumerate([0]*length))
        self.to_open_idx = 0
        self.size = size
        self.get_counter = 0
        self.return_string = False

    def __len__(self):
        return self.length

    def __repr__(self):
        return f"PunctDataset: {len(self)}, on {len(self.filelist)} files"

    def new_file(self, idx_mod):
        filename = self.filelist[self.to_open_idx]
        with open(self.data_dir/filename, "r", encoding="utf-8") as f:
            self.current_files[idx_mod] = f.read()

        self.to_open_idx += 1

        # reset to open article file index
        if self.to_open_idx >= len(self.filelist):
            self.to_open_idx = 0

        # reset string_index within new article file
        self.string_index[idx_mod] = 0

        # if self.to_open_idx % 500 == 0:
        #     print(f"went through files:\t{self.to_open_idx}")

    def __getitem__(self, idx):
        idx_mod = self.get_counter % self. num_threads

        if self.string_index[idx_mod] >= len(self.current_files[idx_mod]):
            self.new_file(idx_mod)
        string_idx = self.string_index[idx_mod]

        # slicing a sentence
        sentence = self.current_files[idx_mod][string_idx:string_idx+self.size]

        # move the string_index within current article file
        self.string_index[idx_mod] += self.size

        # move the get_counter
        self.get_counter += 1
        p_df = punct_ner_pair(sentence)
        return list(p_df.x), list(p_df.y)

    def align_offsets(
        self,
        inputs,
        text_labels: List[List[str]],
        words: List[List[str]]
    ):
        """
        inputs: output if tokenizer
        text_labels: labels in form of list of list of strings
        words: words in form of list of list of strings
        """
        labels = torch.zeros_like(inputs.input_ids).long()
        labels -= 100
        text_lables_array = np.empty(labels.shape, dtype=object)
        words_array = np.empty(labels.shape, dtype=object)
        max_len = inputs.input_ids.shape[1]

        # print("Input_ids: ", inputs.input_ids)

        # print("Max Len: ", max_len)

        # print("Text Labels: ", text_labels)

        for row_id, input_ids in enumerate(inputs.input_ids):
            word_pos = inputs.word_ids(row_id)
            # print("Word Pos: ", word_pos)
            for idx, pos in enumerate(word_pos):
                # print("index: ", idx)
                if pos is None:
                    # print("Pos is None")
                    continue
                labels[row_id, idx] = self.cates.c2i[text_labels[row_id][pos]]
                if self.return_string:
                    text_lables_array[row_id,
                                        idx] = text_labels[row_id][pos]
                    words_array[row_id, idx] = words[row_id][pos]

        inputs['labels'] = labels

        if self.return_string:
            inputs['text_labels'] = text_lables_array.tolist()
            inputs['word'] = words_array.tolist()

        # for input_id in inputs['input_ids']:
        #     print("InPuT_iD: ", input_id)
        #     print("Word: ", self.tokenizer.convert_ids_to_tokens(input_id))
        #     print("Word Pos: ", inputs.word_ids(0))
        # for label in inputs['labels']:
        #     print("Label: ", label)

        return inputs

    def collate_fn(self, data):
        """
        data: list of tuple
        """
        words, text_labels = zip(*data)

        inputs = self.tokenizer(
            list(words),
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=self.max_len,
            is_split_into_words=True,
            return_offsets_mapping=True,
            add_special_tokens=False,
        )

        # print("Original words", words)

        return self.align_offsets(inputs, text_labels, words)

    def dataloaders(self, tokenizer, cates, max_len: int = 512, batch_size: int = 32):
        self.tokenizer = tokenizer
        self.cates = cates
        self.max_len = max_len
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=self.collate_fn,
        )

    def split(self, ratio: float = 0.9):

        self.length = (self.length // DATA_SIZE + (1 if self.length % DATA_SIZE != 0 else 0)) * DATA_SIZE
        fileCount = self.length // DATA_SIZE

        print(fileCount)

        self.filelist = self.filelist[:fileCount]

        filelist = self.filelist
        random.shuffle(filelist)
        train_filelist = filelist[:int(len(filelist)*ratio)]
        valid_filelist = filelist[int(len(filelist)*ratio):]

        print(f"Train files: {' '.join(train_filelist)}")
        print(f"Valid files: {' '.join(valid_filelist)}")

        valid_dataset = PunctDataset(
            self.data_dir,
            valid_filelist,
            num_threads=self.num_threads,
            length=int(self.length * (1 - ratio)),
            size=self.size,
        )
        train_dataset = PunctDataset(
            self.data_dir,
            train_filelist,
            num_threads=self.num_threads,
            length=int(self.length * ratio),
            size=self.size,
        )
        return train_dataset, valid_dataset

Create dataset object

* Length is the length of the epoch
* Size: is the sequence length
* num_threads: num of files that is opening at the same time

In [29]:
ds = PunctDataset(DATA + '/split', LABELS, num_threads=1, length=10240, size=512)
train_ds, valid_ds = ds.split(0.8)

10
Train files: data_101.txt data_0.txt data_1.txt data_100.txt data_103.txt data_106.txt data_10.txt data_102.txt
Valid files: data_104.txt data_105.txt


### lightning data module

In [30]:
class PunctDataModule(pl.LightningDataModule):
    def __init__(self, train_ds, valid_ds, tokenizer, cates,
    max_len=512, batch_size=32):
        super().__init__()
        self.train_ds, self.valid_ds = train_ds, valid_ds
        self.tokenizer = tokenizer
        self.cates = cates
        self.max_len = max_len
        self.batch_size = batch_size

    def split_data(self):

        return train_ds, valid_ds

    def train_dataloader(self):
        return self.train_ds.dataloaders(
            self.tokenizer,
            self.cates,
            self.max_len,
            self.batch_size,
        )

    def val_dataloader(self):
        return self.valid_ds.dataloaders(
            self.tokenizer,
            self.cates,
            self.max_len,
            self.batch_size*4)

## Load Pretrained

In [31]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

In [16]:
from forgebox.thunder.callbacks import DataFrameMetricsCallback
from forgebox.hf.train import NERModule

In [17]:


# Define BERT-CRF Model using PyTorch Lightning
class BERT_CRF(pl.LightningModule):
    def __init__(self, model, num_labels, learning_rate=1e-5):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])

        self.model = model
        self.crf = CRF(num_labels, batch_first=True)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.model.config.hidden_size, num_labels)
        self.lr = learning_rate

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        sequence_output = output.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.fc(sequence_output)
        if labels is not None:
            loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean')
            return {"loss": loss, "logits": logits}
        else:
            return {"logits": logits}

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        loss, logits = self(input_ids, attention_mask, labels)
        self.log('train_loss', loss)
        self.log('train_acc', self.crf.accuracy(logits, labels, attention_mask.byte()))
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        loss, logits = self(input_ids, attention_mask, labels)
        self.log('val_loss', loss)
        self.log('val_acc', self.crf.accuracy(logits, labels, attention_mask.byte()))
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer


Load pretrained model with proper num of categories

In [18]:
model = BertForTokenClassification.from_pretrained("bert-base-chinese", num_labels=len(cates),)

model.safetensors:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
print(train_ds.length)
print(valid_ds.length)

3200
799


In [20]:
data_module = PunctDataModule(train_ds, valid_ds, tokenizer, cates,
                              batch_size=32,)

### Run data pipeline

In [21]:
inputs = next(iter(data_module.val_dataloader()))
print(inputs.input_ids)

tensor([[3189, 4212, 7198,  ...,    0,    0,    0],
        [6843, 4242, 1898,  ...,    0,    0,    0],
        [7412, 5782, 2255,  ..., 7491,    0,    0],
        ...,
        [3717,  756, 5682,  ..., 7414,    0,    0],
        [6816,  881,  782,  ...,    0,    0,    0],
        [4080, 3958, 3726,  ...,    0,    0,    0]])


In [22]:
inputs.input_ids.shape

torch.Size([128, 439])

In [23]:
inputs.labels.shape

torch.Size([128, 439])

## NER tranining module

In [24]:
module = NERModule(model)

In [25]:
save_callback = pl.callbacks.ModelCheckpoint(
    dirpath=f"{DATA}/ckpoint",
    save_top_k=1,
    verbose=True,
    monitor='val_loss',
    mode='min',
)
df_show = DataFrameMetricsCallback()

Reset the configure_optimizers function

In [26]:
def configure_optimizers(self):
        # discriminative learning rate
    param_groups = [
            {'params': self.model.bert.parameters(), 'lr': 5e-6},
            {'params': self.model.classifier.parameters(), 'lr': 1e-3},
        ]
    optimizer = torch.optim.Adam(param_groups, lr=1e-3)
    return optimizer

NERModule.configure_optimizers = configure_optimizers

In [27]:
import torch
print(torch.cuda.is_available())

False


Trainer

In [28]:
trainer = pl.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=30,
    callbacks=[df_show, save_callback],
    )

MisconfigurationException: No supported gpu backend found!

In [None]:
trainer.fit(module, datamodule=data_module)

## Load the best model

In [None]:
module = NERModule.load_from_checkpoint(save_callback.best_model_path, model=model)

In [None]:
module.model.config.id2label = dict(enumerate(cates.i2c))
module.model.config.label2id = cates.c2i.dict

In [None]:
from transformers import pipeline

In [None]:
module.model = module.model.eval()
module.model = module.model.cpu()

In [None]:
# prompt: Store the model

torch.save(module.model, '/content/gdrive/MyDrive/Colab/punktuation-ner/best_ckpoint/punct_model_2.pth')


In [None]:
ner = pipeline("ner",module.model,tokenizer=tokenizer)

In [None]:
def mark_sentence(x: str):
    outputs = ner(x)
    print(outputs)
    print("hello")
    x_list = list(x)
    for i, output in enumerate(outputs):
        x_list.insert(output['end']+i, output['entity'])
    return "".join(x_list)

In [None]:
mark_sentence("洛水神龜單應兆天數九地數九九九八十一數數數混成三大道道合元始天尊一成有感")

In [None]:
mark_sentence("""郡邑置夫子庙于学以嵗时释奠盖自唐贞观以来未之或改我宋有天下因其制而损益之姑苏当浙右要区规模尤大更建炎戎马荡然无遗虽修学宫于荆榛瓦砾之余独殿宇未遑议也每春秋展礼于斋庐已则置不问殆为阙典今寳文阁直学士括苍梁公来牧之明年实绍兴十有一禩也二月上丁修祀既毕乃愓然自咎揖诸生而告之曰天子不以汝嘉为不肖俾再守兹土顾治民事神皆守之职惟是夫子之祀教化所基尤宜严且谨而拜跪荐祭之地卑陋乃尔其何以掲防妥灵汝嘉不敢避其责曩常去此弥年若有所负尚安得以罢輭自恕复累后人乎他日或克就绪愿与诸君落之于是谋之僚吏搜故府得遗材千枚取赢资以给其费鸠工庀役各举其任嵗月讫工民不与知像设礼器百用具修至于堂室廊序门牖垣墙皆一新之""")

In [None]:
# import pandas as pd
# df = pd.read_csv('test.txt', sep='\t')

# text_list = df['input'].tolist()[0:]
# predicted = []
# for text in text_list:
#     predicted.append(mark_sentence(text))

# # predicted save to csv
# df2 = pd.DataFrame(predicted)
# # output to csv but ignore header
# df2.to_csv('predicted.txt', index=False, header=False)
# df2 = pd.read_csv('predicted.txt', sep='\t', header=None)
# df2.columns = ['predicted']

# punctuation = ['，', '。', '！', '？', '；', '：', '、', '「', '」', '『', '』', '（', '）', '〔', '〕', '【', '】', '《', '》', '〈', '〉', '﹏', '＿', '～', '—', '…', '‥', '﹑', '﹔', '﹖', '﹪', '﹙', '﹚', '﹛', '﹜', '﹟', '﹠', '﹡', '﹢', '﹣', '﹤', '﹥', '﹦', '﹨', '﹩', '﹪', '﹫', '＃', '＄', '％', '＆', '＊', '＋', '－', '／', '＜', '＝', '＞', '＠', '＾', '＿', '｀', '｜', '～', '∕', '∥']

# predicted_parsed = []
# for text in df2['predicted']:
#     i = 0
#     line = []
#     while i < len(text) - 1:
#         if text[i + 1] in punctuation:
#             line.append(1)
#             i += 1
#         else:
#             line.append(0)
#         i += 1
#     if line[-1] != 1:
#         line.append(1)
#     predicted_parsed.append(' - '.join([str(x) for x in line]))

# df3 = pd.DataFrame(predicted_parsed)
# df3.to_csv('predicted_parsed.txt', index=False, header=False)