In [1]:
import pickle

with open("texts.pkl",'rb') as f:
    texts  = pickle.loads(f.read())
with open("all_label_bio.pkl",'rb') as f:
    labels  = pickle.loads(f.read())
print(len(texts), len(labels))

43026 43026


In [2]:
import torch  
from torch.utils.data import DataLoader, TensorDataset  
from transformers import BertTokenizer, BertForTokenClassification, AdamW, get_linear_schedule_with_warmup  
from tqdm import tqdm

dic = {"O":0, "B-LOC":1, "I-LOC":2}  
  
# initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('../bert-base-uncased')  
  
# turn texts into input_ids, attention_masks, label_ids
input_ids = []  
attention_masks = []  
label_ids = []  
  
max_seq_length = 100 
  
for text, label in tqdm(zip(texts, labels)):  
    encoded_text = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=max_seq_length)  
    input_ids.append(encoded_text['input_ids'].squeeze(0))  
    attention_masks.append(encoded_text['attention_mask'].squeeze(0))  
      
    # turn label into ids 
    encoded_label = [0]+[dic.get(l, 0) for l in label] 
    encoded_label += [0] * (max_seq_length - len(encoded_label)) 
    encoded_label = encoded_label[:100] 
    encoded_label = torch.tensor(encoded_label, dtype=torch.long)  
    label_ids.append(encoded_label)  
from torch.utils.data import random_split
# turn lists into tensors
input_ids = torch.stack(input_ids, dim=0)  
attention_masks = torch.stack(attention_masks, dim=0)  
label_ids = torch.stack(label_ids, dim=0) 

# create dataset
dataset = TensorDataset(input_ids, attention_masks, label_ids)

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
43026it [00:24, 1747.15it/s]


In [4]:
# Print the first few samples of the dataset
for i in range(5):
    input_id = dataset[i][0]
    attention_mask = dataset[i][1]
    label_id = dataset[i][2]
    
    print(f"Sample {i + 1}:")
    print("Input IDs:", input_id)
    print("Attention Mask:", attention_mask)
    print("Label IDs:", label_id)
    print()

Sample 1:
Input IDs: tensor([  101,  1045,  2444,  1999, 19372,  2237,  2803,  1010,  2012,  2028,
         2051,  1996,  3007,  1997,  2563,  1006,   999,  1007,  1998,  2073,
         2952,  5660,  2288,  2496,   999,  2009,  2036,  2018,  1996,  2922,
         5645,  4170,  1999,  2563,  1998,  2028,  1997,  1996,  2922, 15947,
         2127,  2332,  2888,  5296,  2009,  2091,  1999,  1996, 14883,  2015,
         1006,  1996,  2277,  2145,  3464,  1998,  2064,  2022,  2464,  2013,
         2115,  3332,  1007,  1012,  2009,  1005,  1055,  2747, 14996,  1037,
         2843,  1997, 15582,  2007,  3488,  2005,  1037,  3518,  2143,  2996,
         1010,  4435,  2047,  6023,  2803,  1998,  1037,  2047,  2276,  1012,
         2045,  2024,  2048,  9726,  2015,  2306,  3788,  3292,  1010,   102])
Attention Mask: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1,

In [3]:
from torch.utils.data import DataLoader, random_split

# Split the dataset into training and testing sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
print(len(train_dataset), len(test_dataset))

# Create data loaders
train_data_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


34420 8606


# BERT Trainning

In [9]:
import torch
from transformers import BertTokenizer, BertForTokenClassification, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.metrics import classification_report

# load the pre-trained model
model = BertForTokenClassification.from_pretrained('../bert-base-uncased', num_labels=3)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# define the optimizer
loss_fn = torch.nn.CrossEntropyLoss().to(device)
optimizer = AdamW(model.parameters(), lr=1e-4, eps=1e-8)

# define the scheduler
epochs = 3
total_steps = len(train_data_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)


Some weights of BertForTokenClassification were not initialized from the model checkpoint at ../bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
from torch.nn import functional as F  

def test():
    # Evaluation loop
    model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for batch in tqdm(test_data_loader):
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)

            outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
            batch_predictions = outputs.logits
            # catch the predicted ids
            predicted_ids = torch.argmax(F.log_softmax(batch_predictions, dim=2), dim=2) 

            loss = outputs.loss

            # Only keep predictions for non-padding tokens
            for i, mask in enumerate(b_input_mask):
                true_len = torch.sum(mask).item()
                all_labels.extend(b_labels[i, :true_len].cpu().numpy())
                all_preds.extend(predicted_ids[i, :true_len].cpu().numpy())
                # print(len(all_labels), len(all_preds))

    print(f'Test Loss: {loss.item():.4f}')
    # Compute evaluation metrics
    report = classification_report(all_labels, all_preds, digits=4)
    print(report)
    return loss.item()

best_loss = test()
print('test loss before train:', best_loss)

100%|██████████| 269/269 [02:29<00:00,  1.79it/s]


Test Loss: 1.2260
              precision    recall  f1-score   support

           0     0.7393    0.1290    0.2197    584755
           1     0.0270    0.3775    0.0504     32399
           2     0.0243    0.0692    0.0360     33958

    accuracy                         0.1383    651112
   macro avg     0.2636    0.1919    0.1020    651112
weighted avg     0.6666    0.1383    0.2017    651112

test loss before train: 1.2259918451309204


In [10]:
# Training loop
for epoch in range(epochs):
    model.train()
    # Training loop for one epoch
    for batch_idx, batch in enumerate(train_data_loader):
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        optimizer.zero_grad()
        outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        scheduler.step()

        if (batch_idx + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(train_data_loader)}], Loss: {loss.item():.4f}')

    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

    now_test_loss = test()
    # Check if current loss is the best
    if now_test_loss < best_loss:
        best_loss = now_test_loss
        # save the best model
        print('save model')
        torch.save(model.state_dict(), 'bert_ner_model.pt')
        
    model.train()

print('Best loss:', best_loss)


Epoch [1/3], Step [100/1076], Loss: 0.0667
Epoch [1/3], Step [200/1076], Loss: 0.0668
Epoch [1/3], Step [300/1076], Loss: 0.0349
Epoch [1/3], Step [400/1076], Loss: 0.0400
Epoch [1/3], Step [500/1076], Loss: 0.0257
Epoch [1/3], Step [600/1076], Loss: 0.0406
Epoch [1/3], Step [700/1076], Loss: 0.0647
Epoch [1/3], Step [800/1076], Loss: 0.0382
Epoch [1/3], Step [900/1076], Loss: 0.0345
Epoch [1/3], Step [1000/1076], Loss: 0.0309
Epoch [1/3], Loss: 0.0360


100%|██████████| 269/269 [02:35<00:00,  1.73it/s]


Test Loss: 0.0320
              precision    recall  f1-score   support

           0     0.9942    0.9894    0.9918    584755
           1     0.9231    0.9482    0.9355     32399
           2     0.8874    0.9387    0.9123     33958

    accuracy                         0.9847    651112
   macro avg     0.9349    0.9588    0.9465    651112
weighted avg     0.9851    0.9847    0.9849    651112

save model
Epoch [2/3], Step [100/1076], Loss: 0.0221
Epoch [2/3], Step [200/1076], Loss: 0.0520
Epoch [2/3], Step [300/1076], Loss: 0.0276
Epoch [2/3], Step [400/1076], Loss: 0.0262
Epoch [2/3], Step [500/1076], Loss: 0.0320
Epoch [2/3], Step [600/1076], Loss: 0.0290
Epoch [2/3], Step [700/1076], Loss: 0.0144
Epoch [2/3], Step [800/1076], Loss: 0.0284
Epoch [2/3], Step [900/1076], Loss: 0.0313
Epoch [2/3], Step [1000/1076], Loss: 0.0165
Epoch [2/3], Loss: 0.0376


100%|██████████| 269/269 [02:36<00:00,  1.71it/s]


Test Loss: 0.0317
              precision    recall  f1-score   support

           0     0.9941    0.9921    0.9931    584755
           1     0.9426    0.9483    0.9455     32399
           2     0.9129    0.9386    0.9256     33958

    accuracy                         0.9871    651112
   macro avg     0.9499    0.9597    0.9547    651112
weighted avg     0.9873    0.9871    0.9872    651112

save model
Epoch [3/3], Step [100/1076], Loss: 0.0182
Epoch [3/3], Step [200/1076], Loss: 0.0144
Epoch [3/3], Step [300/1076], Loss: 0.0161
Epoch [3/3], Step [400/1076], Loss: 0.0103
Epoch [3/3], Step [500/1076], Loss: 0.0166
Epoch [3/3], Step [600/1076], Loss: 0.0145
Epoch [3/3], Step [700/1076], Loss: 0.0251
Epoch [3/3], Step [800/1076], Loss: 0.0231
Epoch [3/3], Step [900/1076], Loss: 0.0096
Epoch [3/3], Step [1000/1076], Loss: 0.0195
Epoch [3/3], Loss: 0.0105


100%|██████████| 269/269 [02:58<00:00,  1.50it/s]


Test Loss: 0.0342
              precision    recall  f1-score   support

           0     0.9953    0.9915    0.9934    584755
           1     0.9409    0.9564    0.9486     32399
           2     0.9073    0.9518    0.9290     33958

    accuracy                         0.9877    651112
   macro avg     0.9478    0.9666    0.9570    651112
weighted avg     0.9880    0.9877    0.9878    651112

Best loss: 0.03171028941869736
