# Preparing data

In [None]:
%pip install -q pytorch-crf seqeval

In [None]:
!git clone https://github.com/bborisggg/chinese_word_segmentation.git

In [None]:
import os
import sys
sys.path.insert(1, '/kaggle/working/chinese_word_segmentation/')
os.chdir('/kaggle/working/chinese_word_segmentation/')

In [None]:
import codecs
import argparse
import pickle
import warnings
import collections
from utils import get_processing_word, read_pretrained_embeddings, is_dataset_tag, make_sure_path_exists, to_id_list
from convert_corpus import convert_corpus
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments,BertConfig
import numpy as np
from tqdm import tqdm
from torchcrf import CRF
from copy import deepcopy

In [None]:
convert_corpus()

In [None]:
Instance = collections.namedtuple("Instance", ["sentence", "tags"])

UNK_TAG = "<UNK>"
NONE_TAG = "<NONE>"
START_TAG = "<START>"
END_TAG = "<STOP>"
PADDING_CHAR = "<*>"

In [None]:
def read_file(filename, w2i, t2i, c2i, max_iter=sys.maxsize, processing_word=get_processing_word(lowercase=False)):
    """
    Read in a dataset and turn it into a list of instances.
    Modifies the w2i, t2is and c2i dicts, adding new words/attributes/tags/chars 
    as it s
    ees them.
    """
    instances = []
    vocab_counter = collections.Counter()
    niter = 0
    with codecs.open(filename, "r", "utf-8") as f:
        words, tags = [], []
        for line in f:
            line = line.strip()
            if len(line) == 0 or line.startswith("-DOCSTART-"):
                if len(words) != 0:
                    niter += 1
                    if max_iter is not None and niter > max_iter:
                        break
                    instances.append(Instance(words, tags))
                    words, tags = [], []
            else:
                word, tag = line.split()
                word = processing_word(word)
                vocab_counter[word] += 1
                if word not in w2i:
                    w2i[word] = len(w2i)
                if tag not in t2i:
                    t2i[tag] = len(t2i)
                if is_dataset_tag(word):
                    if word not in c2i:
                        c2i[word] = len(c2i)
                else:
                    for c in word:
                        if c not in c2i:
                            c2i[c] = len(c2i)
                words.append(w2i[word])
                tags.append(t2i[tag])
    return instances, vocab_counter

In [None]:
options = {'training_data':'./data/ctb/bmes/train-all.txt',
          'dev_data':'./data/ctb/bmes/dev.txt',
          'test_data':'./data/ctb/bmes/test.txt',
          'output':'dataset/ctb/dataset.pkl'}

In [None]:
w2i = {}  # mapping from word to index
t2i = {}  # mapping from tag to index
c2i = {}

print('Making training dataset')
training_instances, training_vocab = read_file(options['training_data'], w2i, t2i, c2i)
print('Making dev dataset')
dev_instances, dev_vocab = read_file(options['dev_data'], w2i, t2i, c2i)
print('Making test dataset')
test_instances, test_vocab = read_file(options['test_data'], w2i, t2i, c2i)

# Add special tokens / tags / chars to dicts
w2i[UNK_TAG] = len(w2i)
t2i[START_TAG] = len(t2i)
t2i[END_TAG] = len(t2i)
c2i[UNK_TAG] = len(c2i)


i2w = to_id_list(w2i)  # Inverse mapping
i2t = to_id_list(t2i)
i2c = to_id_list(c2i)

In [None]:
sum([len(el.sentence) for el in test_instances])

In [None]:
max([len(i.sentence) for i in training_instances])

In [None]:
MAX_LENGTH=128

In [None]:
num_labels = len(i2t)

# Create a custom Dataset
class WordSegmentationDataset(Dataset):
    def __init__(self, data, max_length=MAX_LENGTH, padding_tag = -100):
        self.data = data
        self.max_length = max_length
        self.padding_tag = padding_tag

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        input_ids = self.data[idx].sentence
        labels = self.data[idx].tags
        length = len(input_ids)
        # Padding
        padding_length = self.max_length - len(input_ids)
        if padding_length > 0:
            input_ids = input_ids + ([0] * padding_length)
            labels = labels + ([self.padding_tag] * padding_length)  # -100 will be ignored in loss calculation
            attention_mask = [1] * (self.max_length -padding_length) + ([0] * padding_length)
        else:
            input_ids = input_ids[:self.max_length]
            labels = labels[:self.max_length]
            attention_mask = [1] * self.max_length
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'length': length
        }

In [None]:
train_dataset = WordSegmentationDataset(training_instances)
val_dataset = WordSegmentationDataset(dev_instances)

# BERT model

In [None]:
config = BertConfig(
    vocab_size=len(i2c)+1,  # +1 for padding token
    hidden_size=256,
    num_hidden_layers=4,
    num_attention_heads=8,
    intermediate_size=512,
    max_position_embeddings=MAX_LENGTH,
    num_labels=num_labels,
    pad_token_id=0
)
# Initialize the model
model = BertForTokenClassification(config)

In [None]:
# Define compute_metrics function for evaluation

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    
    true_labels = []
    true_predictions = []

    for prediction, label in zip(predictions, labels):
        temp_labels = []
        temp_preds = []

        for pred, lbl in zip(prediction, label):
            if lbl != -100:
                temp_labels.append(i2t[lbl])
                temp_preds.append(i2t[pred])

        true_labels.append(temp_labels)
        true_predictions.append(temp_preds)
    
    return {
        "accuracy": accuracy_score(true_labels, true_predictions),
        "precision": precision_score(true_labels, true_predictions, average='macro'),
        "recall": recall_score(true_labels, true_predictions, average='macro'),
        "f1": f1_score(true_labels, true_predictions, average='macro'),
    }


In [None]:
os.environ["WANDB_DISABLED"] = "true"
warnings.filterwarnings("ignore")

In [None]:
# Training arguments
training_args = TrainingArguments(
    report_to="none",
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    logging_steps=500,
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    load_best_model_at_end=True,
    metric_for_best_model='f1',
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

# Start Training
trainer.train()

In [None]:
bert_model = model

# Bi-LSTM

In [None]:
# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Define the BiLSTM model
class BiLSTMTagger(nn.Module):
       def __init__(self, vocab_size, tagset_size, embedding_dim=256, hidden_dim=256, pad_token_id=0):
           super(BiLSTMTagger, self).__init__()
           self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_token_id)
           self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, 
                               bidirectional=True, batch_first=True)
           self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
       
       def forward(self, input_ids, attention_mask, lengths):
           embeds = self.embedding(input_ids)
           packed_input = nn.utils.rnn.pack_padded_sequence(embeds, lengths, 
                                                            batch_first=True, enforce_sorted=False)
           packed_output, _ = self.lstm(packed_input)
           # Use total_length to ensure consistent sequence length
           lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True, 
                                                          total_length=input_ids.size(1))
           tag_space = self.hidden2tag(lstm_out)
           return tag_space

In [None]:
# Initialize the model
vocab_size = len(i2c) + 1  # +1 for padding token
tagset_size = num_labels
pad_token_id = -100

In [None]:
model = BiLSTMTagger(vocab_size, tagset_size, pad_token_id=pad_token_id)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
num_epochs = 10
best_f1 = 0
best_state = None
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        lengths = batch['length']

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask, lengths)

        # Reshape outputs and labels for computing loss
        outputs = outputs.view(-1, tagset_size)
        labels = labels.view(-1)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}")
    
    # Evaluation
    model.eval()
    true_labels = []
    true_predictions = []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            lengths = batch['lengths']
            
            outputs = model(input_ids, attention_mask, lengths)
            
            # Get predictions
            _, preds = torch.max(outputs, dim=2)
            
            preds = preds.cpu().numpy()
            labels = labels.cpu().numpy()
            
            for pred, label, mask in zip(preds, labels, attention_mask.cpu().numpy()):
                temp_labels = []
                temp_preds = []
                
                for i in range(len(pred)):
                    if mask[i]:
                        temp_labels.append(i2t[label[i]])
                        temp_preds.append(i2t[pred[i]])
                true_labels.append(temp_labels)
                true_predictions.append(temp_preds)
    acc = accuracy_score(true_labels, true_predictions)
    prec = precision_score(true_labels, true_predictions, average='macro')
    rec = recall_score(true_labels, true_predictions, average='macro')
    f1 = f1_score(true_labels, true_predictions, average='macro')
    print(f"Validation Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1-score: {f1:.4f}\n")
    if f1 > best_f1:
        best_state = deepcopy(model.state_dict())
        best_f1 = f1

In [None]:
model.load_state_dict(best_state)  
bilstm_model = model

# BERT-BiLSTM-CRF

In [None]:
# We redifine loss calculation in order to correct for padding
train_dataset = WordSegmentationDataset(training_instances, padding_tag=1)
val_dataset = WordSegmentationDataset(dev_instances, padding_tag=1)

# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class BertBiLSTMCRF(nn.Module):
    def __init__(self, num_labels, vocab_size=30522,hidden_size=256,num_hidden_layers=4,
                 num_attention_heads=8,intermediate_size=512,max_position_embeddings=MAX_LENGTH,
                 hidden_dim_lstm=256, num_lstm_layers=1, pad_token=0):
        super().__init__()
        config = BertConfig(
            vocab_size=vocab_size, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads, intermediate_size=intermediate_size,
            pad_token_id=pad_token
        )
        self.bert = BertModel(config)
        self.lstm = nn.LSTM(
            input_size=hidden_size, hidden_size=hidden_dim_lstm // 2,
            num_layers=num_lstm_layers, bidirectional=True, batch_first=True
        )
        self.fc = nn.Linear(hidden_dim_lstm, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        lstm_out, _ = self.lstm(outputs.last_hidden_state)
        emissions = self.fc(lstm_out)
        if labels is not None:
            loss = -self.crf(emissions, labels, mask=attention_mask.bool(), reduction='mean')
            return loss
        else:
            return self.crf.decode(emissions, mask=attention_mask.bool())

In [None]:
from transformers import BertModel, BertConfig
from tqdm import tqdm

model = BertBiLSTMCRF(num_labels=tagset_size, vocab_size=len(i2c)+1, pad_token=1)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)
num_epochs=10
best_f1 = 0
best_state = None
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        loss = model(input_ids, attention_mask, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}")


    model.eval()
    true_labels = []
    true_predictions = []
    for batch in tqdm(val_loader):
        with torch.no_grad():
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels']
            
            preds = model(input_ids, attention_mask)
            
            
            for pred, label, mask in zip(preds, labels, attention_mask.cpu().numpy()):
                temp_labels = []
                temp_preds = []
                
                for i in range(len(pred)):
                    if mask[i]:
                        temp_labels.append(i2t[label[i]])
                        temp_preds.append(i2t[pred[i]])
                true_labels.append(temp_labels)
                true_predictions.append(temp_preds)
    acc = accuracy_score(true_labels, true_predictions)
    prec = precision_score(true_labels, true_predictions, average='macro')
    rec = recall_score(true_labels, true_predictions, average='macro')
    f1 = f1_score(true_labels, true_predictions, average='macro')
    print(f"Validation Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1-score: {f1:.4f}\n")
    if f1 > best_f1:
        best_state = deepcopy(model.state_dict())
        best_f1 = f1

In [None]:
model.load_state_dict(best_state)  
bilstmbertcrf_model = model

# Saving models

In [None]:
trainer.save_model("bert_model")

In [None]:
torch.save(bilstmbertcrf_model.state_dict(), "bilstmbertcrf_model.pth")

In [None]:
torch.save(bilstm_model.state_dict(), "bilstm_model.pth")