# Data Preparation

In [1]:
%pip install -q pytorch-crf seqeval

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
Note: you may need to restart the kernel to use updated packages.


In [2]:
!git clone https://github.com/bborisggg/chinese_word_segmentation.git

Cloning into 'chinese_word_segmentation'...
remote: Enumerating objects: 121, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (54/54), done.[K
remote: Total 121 (delta 20), reused 0 (delta 0), pack-reused 60 (from 2)[K
Receiving objects: 100% (121/121), 81.30 MiB | 26.62 MiB/s, done.
Resolving deltas: 100% (24/24), done.
Updating files: 100% (36/36), done.


In [3]:
import os
import sys
sys.path.insert(1, '/kaggle/working/chinese_word_segmentation/')
os.chdir('/kaggle/working/chinese_word_segmentation/')

In [4]:
import codecs
import argparse
import pickle
import warnings
import collections
from utils import get_processing_word, read_pretrained_embeddings, is_dataset_tag, make_sure_path_exists, to_id_list
from convert_corpus import convert_corpus
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments,BertConfig
import numpy as np
from tqdm import tqdm
from torchcrf import CRF
from copy import deepcopy

In [5]:
convert_corpus()

Converting sighan2005 Simplified Chinese corpus
Converting sighan bakeoff 2005 corpus: pku
Converting sighan bakeoff 2005 corpus: msr
Converting sighan bakeoff 2005 corpus: as
Converting sighan bakeoff 2005 corpus: cityu
Combining sighan2005 corpus to one joint Simplified Chinese corpus
Converting extra 6 corpora
Converting corpus sxu
Converting corpus ctb
Converting corpus zx
Converting corpus cnc
Converting corpus udc
Converting corpus wtb
Combining those 10 corpora to one joint corpus


In [6]:
Instance = collections.namedtuple("Instance", ["sentence", "tags"])

UNK_TAG = "<UNK>"
NONE_TAG = "<NONE>"
START_TAG = "<START>"
END_TAG = "<STOP>"
PADDING_CHAR = "<*>"

In [7]:
def read_file(filename, w2i, t2i, c2i, max_iter=sys.maxsize, processing_word=get_processing_word(lowercase=False)):
    """
    Read in a dataset and turn it into a list of instances.
    Modifies the w2i, t2is and c2i dicts, adding new words/attributes/tags/chars 
    as it s
    ees them.
    """
    instances = []
    vocab_counter = collections.Counter()
    niter = 0
    with codecs.open(filename, "r", "utf-8") as f:
        words, tags = [], []
        for line in f:
            line = line.strip()
            if len(line) == 0 or line.startswith("-DOCSTART-"):
                if len(words) != 0:
                    niter += 1
                    if max_iter is not None and niter > max_iter:
                        break
                    instances.append(Instance(words, tags))
                    words, tags = [], []
            else:
                word, tag = line.split()
                word = processing_word(word)
                vocab_counter[word] += 1
                if word not in w2i:
                    w2i[word] = len(w2i)
                if tag not in t2i:
                    t2i[tag] = len(t2i)
                if is_dataset_tag(word):
                    if word not in c2i:
                        c2i[word] = len(c2i)
                else:
                    for c in word:
                        if c not in c2i:
                            c2i[c] = len(c2i)
                words.append(w2i[word])
                tags.append(t2i[tag])
    return instances, vocab_counter

In [8]:
options = {'training_data':'./data/ctb/bmes/train-all.txt',
          'dev_data':'./data/ctb/bmes/dev.txt',
          'test_data':'./data/ctb/bmes/test.txt',
          'output':'dataset/ctb/dataset.pkl'}

In [9]:
w2i = {}  # mapping from word to index
t2i = {}  # mapping from tag to index
c2i = {}

print('Making training dataset')
training_instances, training_vocab = read_file(options['training_data'], w2i, t2i, c2i)
print('Making dev dataset')
dev_instances, dev_vocab = read_file(options['dev_data'], w2i, t2i, c2i)
print('Making test dataset')
test_instances, test_vocab = read_file(options['test_data'], w2i, t2i, c2i)

# Add special tokens / tags / chars to dicts
w2i[UNK_TAG] = len(w2i)
t2i[START_TAG] = len(t2i)
t2i[END_TAG] = len(t2i)
c2i[UNK_TAG] = len(c2i)


i2w = to_id_list(w2i)  # Inverse mapping
i2t = to_id_list(t2i)
i2c = to_id_list(c2i)

Making training dataset
Making dev dataset
Making test dataset


In [17]:
MAX_LENGTH=128
num_labels = len(i2t)

# Create a custom Dataset
class WordSegmentationDataset(Dataset):
    def __init__(self, data, max_length=MAX_LENGTH, padding_tag = -100):
        self.data = data
        self.max_length = max_length
        self.padding_tag = padding_tag

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        input_ids = self.data[idx].sentence
        labels = self.data[idx].tags
        length = len(input_ids)
        # Padding
        padding_length = self.max_length - len(input_ids)
        if padding_length > 0:
            input_ids = input_ids + ([0] * padding_length)
            labels = labels + ([self.padding_tag] * padding_length)  # -100 will be ignored in loss calculation
            attention_mask = [1] * (self.max_length -padding_length) + ([0] * padding_length)
        else:
            input_ids = input_ids[:self.max_length]
            labels = labels[:self.max_length]
            attention_mask = [1] * self.max_length
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'length': length
        }

# BERT model

In [18]:
# Define compute_metrics function for evaluation

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    
    true_labels = []
    true_predictions = []

    for prediction, label in zip(predictions, labels):
        temp_labels = []
        temp_preds = []

        for pred, lbl in zip(prediction, label):
            if lbl != -100:
                temp_labels.append(i2t[lbl])
                temp_preds.append(i2t[pred])

        true_labels.append(temp_labels)
        true_predictions.append(temp_preds)
    
    return {
        "accuracy": accuracy_score(true_labels, true_predictions),
        "precision": precision_score(true_labels, true_predictions, average='macro'),
        "recall": recall_score(true_labels, true_predictions, average='macro'),
        "f1": f1_score(true_labels, true_predictions, average='macro'),
    }


# Bi-LSTM

In [22]:
class BiLSTMTagger(nn.Module):
       def __init__(self, vocab_size, tagset_size, embedding_dim=256, hidden_dim=256, pad_token_id=0):
           super(BiLSTMTagger, self).__init__()
           self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_token_id)
           self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, 
                               bidirectional=True, batch_first=True)
           self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
       
       def forward(self, input_ids, attention_mask, lengths):
           embeds = self.embedding(input_ids)
           packed_input = nn.utils.rnn.pack_padded_sequence(embeds, lengths, 
                                                            batch_first=True, enforce_sorted=False)
           packed_output, _ = self.lstm(packed_input)
           # Use total_length to ensure consistent sequence length
           lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True, 
                                                          total_length=input_ids.size(1))
           tag_space = self.hidden2tag(lstm_out)
           return tag_space

In [23]:
# Initialize the model
vocab_size = len(i2c) + 1  # +1 for padding token
tagset_size = num_labels
pad_token_id = -100

In [24]:
bilstm_model = BiLSTMTagger(vocab_size, tagset_size, pad_token_id=pad_token_id)
bilstm_model.load_state_dict(torch.load('./models/bilstm_model/bilstm_model.pth', weights_only=True))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bilstm_model.to(device)

BiLSTMTagger(
  (embedding): Embedding(4317, 256, padding_idx=4217)
  (lstm): LSTM(256, 128, batch_first=True, bidirectional=True)
  (hidden2tag): Linear(in_features=256, out_features=6, bias=True)
)

# BERT-BiLSTM-CRF

In [27]:
from transformers import BertModel, BertConfig
class BertBiLSTMCRF(nn.Module):
    def __init__(self, num_labels, vocab_size=30522,hidden_size=256,num_hidden_layers=4,
                 num_attention_heads=8,intermediate_size=512,max_position_embeddings=MAX_LENGTH,
                 hidden_dim_lstm=256, num_lstm_layers=1, pad_token=0):
        super().__init__()
        config = BertConfig(
            vocab_size=vocab_size, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads, intermediate_size=intermediate_size,
            pad_token_id=pad_token
        )
        self.bert = BertModel(config)
        self.lstm = nn.LSTM(
            input_size=hidden_size, hidden_size=hidden_dim_lstm // 2,
            num_layers=num_lstm_layers, bidirectional=True, batch_first=True
        )
        self.fc = nn.Linear(hidden_dim_lstm, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        lstm_out, _ = self.lstm(outputs.last_hidden_state)
        emissions = self.fc(lstm_out)
        if labels is not None:
            loss = -self.crf(emissions, labels, mask=attention_mask.bool(), reduction='mean')
            return loss
        else:
            return self.crf.decode(emissions, mask=attention_mask.bool())

In [43]:
from transformers.models.bert.modeling_bert import BertEmbeddings
from torch.nn.modules.sparse import Embedding
torch.serialization.add_safe_globals([Embedding])

In [47]:
#bilstmbertcrf_model = BertBiLSTMCRF(num_labels=tagset_size, vocab_size=len(i2c)+1, pad_token=1)
#bilstmbertcrf_model.load_state_dict(torch.load('./models/bilstmbertcrf_model/bilstmbertcrf_model.pth', weights_only=False))
bilstmbertcrf_model =torch.load('./models/bilstmbertcrf_model/bilstmbertcrf_model.pth', weights_only=False)
bilstmbertcrf_model.to(device)

BertBiLSTMCRF(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(4317, 256, padding_idx=0)
      (position_embeddings): Embedding(512, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_af

# Evaluating on test data

In [49]:
bert_model = BertForTokenClassification.from_pretrained('./models/bert_model')

test_args = TrainingArguments(
     report_to="none",
    output_dir='./results',
    do_train = False,
    do_predict = True,
    per_device_eval_batch_size = 32,   
    dataloader_drop_last = False    
)

# init trainer
trainer = Trainer(
              model = bert_model, 
              args = test_args, 
              compute_metrics = compute_metrics)

test_dataset = WordSegmentationDataset(test_instances)
test_results = trainer.predict(test_dataset)
test_results.metrics





{'test_loss': 0.22861121594905853,
 'test_accuracy': 0.9237090427939691,
 'test_precision': 0.9047110611595017,
 'test_recall': 0.9136792720531205,
 'test_f1': 0.9091730512207518,
 'test_runtime': 3.1533,
 'test_samples_per_second': 2194.543,
 'test_steps_per_second': 34.567}

In [52]:
test_dataset = WordSegmentationDataset(test_instances)
batch_size = 16
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
bilstm_model.eval()
true_labels = []
true_predictions = []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        lengths = batch['length']

        
        outputs = bilstm_model(input_ids, attention_mask, lengths)
        
        # Get predictions
        _, preds = torch.max(outputs, dim=2)
        
        preds = preds.cpu().numpy()
        labels = labels.cpu().numpy()
        
        for pred, label, mask in zip(preds, labels, attention_mask.cpu().numpy()):
            temp_labels = []
            temp_preds = []
            
            for i in range(len(pred)):
                if mask[i]:
                    temp_labels.append(i2t[label[i]])
                    temp_preds.append(i2t[pred[i]])
            true_labels.append(temp_labels)
            true_predictions.append(temp_preds)
acc = accuracy_score(true_labels, true_predictions)
prec = precision_score(true_labels, true_predictions, average='macro')
rec = recall_score(true_labels, true_predictions, average='macro')
f1 = f1_score(true_labels, true_predictions, average='macro')
print(f"Validation Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1-score: {f1:.4f}\n")

Validation Accuracy: 0.9452, Precision: 0.9334, Recall: 0.9374, F1-score: 0.9354



In [54]:
test_dataset = WordSegmentationDataset(test_instances, padding_tag=1)
batch_size = 16
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
bilstmbertcrf_model.eval()
true_labels = []
true_predictions = []
for batch in tqdm(test_loader):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels']
        
        preds = bilstmbertcrf_model(input_ids, attention_mask)
        
        
        for pred, label, mask in zip(preds, labels, attention_mask.cpu().numpy()):
            temp_labels = []
            temp_preds = []
            
            for i in range(len(pred)):
                if mask[i]:
                    temp_labels.append(i2t[label[i]])
                    temp_preds.append(i2t[pred[i]])
            true_labels.append(temp_labels)
            true_predictions.append(temp_preds)
acc = accuracy_score(true_labels, true_predictions)
prec = precision_score(true_labels, true_predictions, average='macro')
rec = recall_score(true_labels, true_predictions, average='macro')
f1 = f1_score(true_labels, true_predictions, average='macro')
print(f"Validation Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1-score: {f1:.4f}\n")


100%|██████████| 433/433 [00:08<00:00, 49.52it/s]


Validation Accuracy: 0.9569, Precision: 0.9498, Recall: 0.9506, F1-score: 0.9502



# DEMO

In [101]:
sentence1 = '我喜欢编写自然语言处理机器学习任务'

In [102]:
def transform(sentence):
    input_ids = [c2i[el] for el in sentence]
    length = len(input_ids)
    padding_length = MAX_LENGTH - len(input_ids)
    input_ids = input_ids + ([0] * padding_length)
    attention_mask = [1] * (MAX_LENGTH -padding_length) + ([0] * padding_length)
    preds = bilstmbertcrf_model(torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device), torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device))
    ans = ""
    for i in range(len(preds[0])):
        ans+= sentence[i]
        if preds[0][i]==1 or preds[0][i]==2:
            ans += " "
    return ans

In [103]:
transform(sentence1)

'我 喜欢 编写 自然 语言 处理 机器 学习 任务 '