In [14]:
from transformers import BertForTokenClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import Dataset
import pandas as pd
import numpy as np
import torch
import torch
from torch import nn
import ast
from transformers import BertModel
from sklearn.metrics import accuracy_score, f1_score


In [15]:
class DocNerBERT(nn.Module):
    def __init__(self, num_doc_labels, num_token_labels):
        super().__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)

        self.doc_classifier = nn.Linear(self.bert.config.hidden_size, num_doc_labels)
        self.token_classifier = nn.Linear(self.bert.config.hidden_size, num_token_labels)

        self.loss = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, doc_labels=None, token_labels=None):
        outputs = self.bert(input_ids, attention_mask = attention_mask)

        # Document classification
        doc_output = outputs.pooler_output
        doc_output = self.dropout(doc_output)
        doc_logits = self.doc_classifier(doc_output)
        
        # Token classification
        token_output = outputs.last_hidden_state
        token_output = self.dropout(token_output)
        token_logits = self.token_classifier(token_output)

        if doc_labels is not None and token_labels is not None:
            doc_loss = self.loss(doc_logits, doc_labels)
            token_loss = self.loss(token_logits.view(-1, token_logits.shape[-1]), token_labels.view(-1))
            loss = doc_loss + token_loss

        return {
            'loss': loss,
            'doc_logits': doc_logits,
            'token_logits': token_logits
        }


In [16]:
def prepare_data(df):
    texts = df['tokens'].apply(lambda x: ' '.join(ast.literal_eval(x))).tolist()
    doc_labels = df['sentence_label'].tolist()
    token_labels = df['ner_tags'].apply(ast.literal_eval).tolist()
    
    return {
        "text": texts,
        "doc_labels": doc_labels,
        "token_labels": token_labels
    }

def tokenize_and_align_labels(examples, tokenizer, max_length=128):
    tokenized_inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length)
    
    labels = []
    for i, label in enumerate(examples["token_labels"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx] if word_idx < len(label) else -100)
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    
    tokenized_inputs["token_labels"] = labels
    tokenized_inputs["doc_labels"] = examples["doc_labels"]
    return tokenized_inputs

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    doc_logits, token_logits = logits
    doc_labels, token_labels = labels

    doc_predictions = np.argmax(doc_logits, axis=-1)
    token_predictions = np.argmax(token_logits, axis=-1)

    # Compute document classification accuracy
    doc_correct = (doc_predictions == doc_labels).sum()
    doc_total = len(doc_labels)
    doc_accuracy = doc_correct / doc_total

    # Compute token classification accuracy (ignoring padding tokens)
    token_correct = ((token_predictions == token_labels) & (token_labels != -100)).sum()
    token_total = (token_labels != -100).sum()
    token_accuracy = token_correct / token_total

    return {
        "doc_accuracy": doc_accuracy,
        "token_accuracy": token_accuracy
    }


In [17]:
# Read CSV file
df = pd.read_csv('train_augmented.csv')

# Prepare data
data = prepare_data(df)


In [18]:

# Initialize tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')


In [19]:

# Create a Dataset
dataset = Dataset.from_dict(data)
tokenized_dataset = dataset.map(
    lambda examples: tokenize_and_align_labels(examples, tokenizer),
    batched=True,
    remove_columns=dataset.column_names
)

Map: 100%|██████████| 1000/1000 [00:00<00:00, 4919.91 examples/s]


In [20]:


# Initialize model
num_doc_labels = len(df['sentence_label'].unique())
num_token_labels = max(max(ast.literal_eval(x)) for x in df['ner_tags']) + 1
model = DocNerBERT(num_doc_labels, num_token_labels)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
)

In [21]:


# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

# Save the model
trainer.save_model("./docnerbert_model")

  3%|▎         | 10/375 [01:18<44:00,  7.23s/it] 

{'loss': 3.8495, 'grad_norm': 12.97952651977539, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.08}


  5%|▌         | 20/375 [02:27<40:58,  6.92s/it]

{'loss': 3.7017, 'grad_norm': 11.461528778076172, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.16}


  8%|▊         | 30/375 [03:37<39:56,  6.95s/it]

{'loss': 3.5959, 'grad_norm': 14.218560218811035, 'learning_rate': 3e-06, 'epoch': 0.24}


 11%|█         | 40/375 [04:46<38:21,  6.87s/it]

{'loss': 3.3548, 'grad_norm': 12.133085250854492, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.32}


 13%|█▎        | 50/375 [05:55<37:47,  6.98s/it]

{'loss': 3.0544, 'grad_norm': 10.131518363952637, 'learning_rate': 5e-06, 'epoch': 0.4}


 16%|█▌        | 60/375 [07:05<36:28,  6.95s/it]

{'loss': 2.7357, 'grad_norm': 9.325644493103027, 'learning_rate': 6e-06, 'epoch': 0.48}


 19%|█▊        | 70/375 [08:15<35:14,  6.93s/it]

{'loss': 2.4257, 'grad_norm': 10.599355697631836, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.56}


 21%|██▏       | 80/375 [09:25<34:21,  6.99s/it]

{'loss': 2.2537, 'grad_norm': 7.2992072105407715, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.64}


 24%|██▍       | 90/375 [10:41<33:18,  7.01s/it]

{'loss': 2.1955, 'grad_norm': 9.399588584899902, 'learning_rate': 9e-06, 'epoch': 0.72}


 27%|██▋       | 100/375 [11:53<33:06,  7.22s/it]

{'loss': 2.1767, 'grad_norm': 9.410879135131836, 'learning_rate': 1e-05, 'epoch': 0.8}


 29%|██▉       | 110/375 [13:15<33:09,  7.51s/it]

{'loss': 1.8361, 'grad_norm': 9.77276611328125, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.88}


 32%|███▏      | 120/375 [14:26<30:43,  7.23s/it]

{'loss': 1.802, 'grad_norm': 11.23577880859375, 'learning_rate': 1.2e-05, 'epoch': 0.96}


 35%|███▍      | 130/375 [15:39<29:37,  7.25s/it]

{'loss': 1.8478, 'grad_norm': 8.7840576171875, 'learning_rate': 1.3000000000000001e-05, 'epoch': 1.04}


 37%|███▋      | 140/375 [16:49<27:11,  6.94s/it]

{'loss': 1.7572, 'grad_norm': 6.267681121826172, 'learning_rate': 1.4000000000000001e-05, 'epoch': 1.12}


 40%|████      | 150/375 [17:57<25:35,  6.82s/it]

{'loss': 1.4731, 'grad_norm': 5.59513521194458, 'learning_rate': 1.5e-05, 'epoch': 1.2}


 43%|████▎     | 160/375 [19:05<24:27,  6.83s/it]

{'loss': 1.4271, 'grad_norm': 5.680927276611328, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.28}


 45%|████▌     | 170/375 [20:16<23:37,  6.91s/it]

{'loss': 1.4987, 'grad_norm': 5.352529525756836, 'learning_rate': 1.7000000000000003e-05, 'epoch': 1.36}


 48%|████▊     | 180/375 [21:25<22:23,  6.89s/it]

{'loss': 1.3541, 'grad_norm': 10.9909086227417, 'learning_rate': 1.8e-05, 'epoch': 1.44}


 51%|█████     | 190/375 [22:33<21:00,  6.81s/it]

{'loss': 1.3008, 'grad_norm': 11.67233943939209, 'learning_rate': 1.9e-05, 'epoch': 1.52}


 53%|█████▎    | 200/375 [23:46<21:23,  7.33s/it]

{'loss': 1.3435, 'grad_norm': 7.6928935050964355, 'learning_rate': 2e-05, 'epoch': 1.6}


 56%|█████▌    | 210/375 [24:57<19:20,  7.03s/it]

{'loss': 1.4332, 'grad_norm': 9.19049072265625, 'learning_rate': 2.1e-05, 'epoch': 1.68}


 59%|█████▊    | 220/375 [26:08<17:59,  6.97s/it]

{'loss': 1.2855, 'grad_norm': 13.92219352722168, 'learning_rate': 2.2000000000000003e-05, 'epoch': 1.76}


 61%|██████▏   | 230/375 [27:17<16:47,  6.95s/it]

{'loss': 1.187, 'grad_norm': 4.876608371734619, 'learning_rate': 2.3000000000000003e-05, 'epoch': 1.84}


 64%|██████▍   | 240/375 [28:26<15:37,  6.94s/it]

{'loss': 1.1155, 'grad_norm': 12.742841720581055, 'learning_rate': 2.4e-05, 'epoch': 1.92}


 67%|██████▋   | 250/375 [29:36<14:23,  6.91s/it]

{'loss': 1.2688, 'grad_norm': 7.570932865142822, 'learning_rate': 2.5e-05, 'epoch': 2.0}


 69%|██████▉   | 260/375 [30:45<13:14,  6.91s/it]

{'loss': 1.131, 'grad_norm': 21.518108367919922, 'learning_rate': 2.6000000000000002e-05, 'epoch': 2.08}


 72%|███████▏  | 270/375 [31:55<12:17,  7.02s/it]

{'loss': 0.9778, 'grad_norm': 6.184206485748291, 'learning_rate': 2.7000000000000002e-05, 'epoch': 2.16}


 75%|███████▍  | 280/375 [33:04<11:00,  6.96s/it]

{'loss': 1.1431, 'grad_norm': 11.963001251220703, 'learning_rate': 2.8000000000000003e-05, 'epoch': 2.24}


 77%|███████▋  | 290/375 [34:14<09:49,  6.94s/it]

{'loss': 1.1364, 'grad_norm': 15.883962631225586, 'learning_rate': 2.9e-05, 'epoch': 2.32}


 80%|████████  | 300/375 [35:25<09:19,  7.45s/it]

{'loss': 0.8667, 'grad_norm': 13.561271667480469, 'learning_rate': 3e-05, 'epoch': 2.4}


 83%|████████▎ | 310/375 [37:02<08:34,  7.91s/it]

{'loss': 0.9513, 'grad_norm': 12.236526489257812, 'learning_rate': 3.1e-05, 'epoch': 2.48}


 85%|████████▌ | 320/375 [38:12<06:26,  7.02s/it]

{'loss': 0.8017, 'grad_norm': 1.843321442604065, 'learning_rate': 3.2000000000000005e-05, 'epoch': 2.56}


 88%|████████▊ | 330/375 [39:21<05:10,  6.90s/it]

{'loss': 0.9124, 'grad_norm': 3.238100528717041, 'learning_rate': 3.3e-05, 'epoch': 2.64}


 91%|█████████ | 340/375 [40:30<04:01,  6.89s/it]

{'loss': 0.8481, 'grad_norm': 50.50942611694336, 'learning_rate': 3.4000000000000007e-05, 'epoch': 2.72}


 93%|█████████▎| 350/375 [41:40<02:55,  7.03s/it]

{'loss': 0.901, 'grad_norm': 5.574615478515625, 'learning_rate': 3.5e-05, 'epoch': 2.8}


 96%|█████████▌| 360/375 [42:49<01:43,  6.88s/it]

{'loss': 1.1017, 'grad_norm': 16.127687454223633, 'learning_rate': 3.6e-05, 'epoch': 2.88}


 99%|█████████▊| 370/375 [43:59<00:34,  6.91s/it]

{'loss': 0.9666, 'grad_norm': 17.190067291259766, 'learning_rate': 3.7e-05, 'epoch': 2.96}


100%|██████████| 375/375 [44:33<00:00,  7.13s/it]


{'train_runtime': 2673.9341, 'train_samples_per_second': 1.122, 'train_steps_per_second': 0.14, 'train_loss': 1.6900064697265624, 'epoch': 3.0}
