In [1]:
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertConfig, BertForTokenClassification


In [2]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

cuda


In [3]:
token_df = pd.read_csv('/home/chudeo/coding-evidence-extraction-main/33k_sentence.csv')

In [4]:
token_df.count()

sentence_id    625510
words          625262
labels         625510
dtype: int64

In [5]:
#checking for null values
token_df.isnull().sum()

sentence_id      0
words          248
labels           0
dtype: int64

In [6]:
data = token_df.fillna(method='ffill')
data.head()

  data = token_df.fillna(method='ffill')


Unnamed: 0,sentence_id,words,labels
0,0,Baseline,O
1,0,artifact,O
2,0,.,O
3,1,Probable,O
4,1,sinus,B


In [7]:
# let's create a new column called "sentence" which groups the words by sentence
data['sentence'] = data[['sentence_id','words','labels']].groupby(['sentence_id'])['words'].transform(lambda x: ' '.join(x))
# let's also create a new column called "word_labels" which groups the tags by sentence
data['word_labels'] = data[['sentence_id','words','labels']].groupby(['sentence_id'])['labels'].transform(lambda x: ','.join(x))
data.head()

Unnamed: 0,sentence_id,words,labels,sentence,word_labels
0,0,Baseline,O,Baseline artifact .,"O,O,O"
1,0,artifact,O,Baseline artifact .,"O,O,O"
2,0,.,O,Baseline artifact .,"O,O,O"
3,1,Probable,O,Probable sinus tachycardia with atrial prematu...,"O,B,I,O,O,O,O,O"
4,1,sinus,B,Probable sinus tachycardia with atrial prematu...,"O,B,I,O,O,O,O,O"


In [8]:

label2id = {k: v for v, k in enumerate(data.labels.unique())}
id2label = {v: k for v, k in enumerate(data.labels.unique())}
label2id

{'O': 0, 'B': 1, 'I': 2}

In [9]:
data = data[["sentence", "word_labels"]].drop_duplicates().reset_index(drop=True)
data.head()

Unnamed: 0,sentence,word_labels
0,Baseline artifact .,"O,O,O"
1,Probable sinus tachycardia with atrial prematu...,"O,B,I,O,O,O,O,O"
2,Low limb lead voltage .,"O,O,O,O,O"
3,Leftward axis .,"O,O,O"
4,Late R wave progression .,"O,O,O,O,O"


In [10]:
len(data)


11262

In [11]:
data.iloc[1].sentence

'Probable sinus tachycardia with atrial premature beats .'

In [12]:
data.iloc[1].word_labels

'O,B,I,O,O,O,O,O'

##### **Sampled sentences**

In [13]:
from sklearn.utils import resample


# Separate majority and minority classes
majority_class = data[data['word_labels'].apply(lambda x: all(label == 'O' for label in x.split(',')))]
minority_class = data[~data['word_labels'].apply(lambda x: all(label == 'O' for label in x.split(',')))]

# Downsample majority class
majority_downsampled = resample(majority_class, 
                                 replace=False,    # sample without replacement
                                 n_samples=len(minority_class),  # match minority class size
                                 random_state=42)  # reproducible results

# Combine minority class with downsampled majority class
balanced_data = pd.concat([majority_downsampled, minority_class])



In [14]:
len(balanced_data)

4378

In [15]:
data = balanced_data

In [None]:
# Extract words and labels from DataFrame
split_data = {
    "words": [word for sent in data["sentence"].str.split() for word in sent],
    "labels": [label.split(',') for label in data["word_labels"]]
}

In [None]:
# Split data into sentences and labels
sentences = split_data["words"]
labels = split_data["labels"]

In [17]:
import os
from typing import List, Tuple

import torch
from torch import nn
from transformers import AutoModel


In [18]:

LOG_INF = 10e5


class BertCrf(nn.Module):
    def __init__(
        self,
        num_labels: int,
        bert_name: str,
        dropout: float = 0.2,
        use_crf: bool = True,
    ):
        super().__init__()
        self.num_labels = num_labels
        self.use_crf = use_crf
        self.cross_entropy = nn.CrossEntropyLoss()

        self.bert = AutoModel.from_pretrained(bert_name)

        self.dropout = nn.Dropout(dropout)
        self.hidden2label = nn.Linear(self.bert.config.hidden_size, num_labels)

        self.start_transitions = nn.Parameter(torch.empty(num_labels))
        self.end_transitions = nn.Parameter(torch.empty(num_labels))
        self.transitions = nn.Parameter(torch.empty(num_labels, num_labels))

        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

    def _compute_log_denominator(self, features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        seq_len = features.shape[0]

        log_score_over_all_seq = self.start_transitions + features[0]

        for i in range(1, seq_len):
            next_log_score_over_all_seq = torch.logsumexp(
                log_score_over_all_seq.unsqueeze(2) + self.transitions + features[i].unsqueeze(1),
                dim=1,
            )
            log_score_over_all_seq = torch.where(
                mask[i].unsqueeze(1),
                next_log_score_over_all_seq,
                log_score_over_all_seq,
            )
        log_score_over_all_seq += self.end_transitions
        return torch.logsumexp(log_score_over_all_seq, dim=1)

    def _compute_log_numerator(self, features: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        seq_len, bs, _ = features.shape

        score_over_seq = self.start_transitions[labels[0]] + features[0, torch.arange(bs), labels[0]]

        for i in range(1, seq_len):
            score_over_seq += (
                self.transitions[labels[i - 1], labels[i]] + features[i, torch.arange(bs), labels[i]]
            ) * mask[i]
        seq_lens = mask.sum(dim=0) - 1
        last_tags = labels[seq_lens.long(), torch.arange(bs)]
        score_over_seq += self.end_transitions[last_tags]
        return score_over_seq

    def get_bert_features(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        hidden = self.bert(input_ids, attention_mask=attention_mask)["last_hidden_state"]
        hidden = self.dropout(hidden)
        return self.hidden2label(hidden), hidden

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
    ) -> torch.Tensor:
        features, _ = self.get_bert_features(input_ids=input_ids, attention_mask=attention_mask)
        attention_mask = attention_mask.bool()

        if self.use_crf:
            features = torch.swapaxes(features, 0, 1)
            attention_mask = torch.swapaxes(attention_mask, 0, 1)
            labels = torch.swapaxes(labels, 0, 1)

            log_numerator = self._compute_log_numerator(features=features, labels=labels, mask=attention_mask)
            log_denominator = self._compute_log_denominator(features=features, mask=attention_mask)

            return torch.mean(log_denominator - log_numerator)
        else:
            return self.cross_entropy(
                features.flatten(end_dim=1),
                torch.where(attention_mask.bool(), labels, -100).flatten(end_dim=1),
            )

    def _viterbi_decode(self, features: torch.Tensor, mask: torch.Tensor) -> List[List[int]]:
        seq_len, bs, _ = features.shape

        log_score_over_all_seq = self.start_transitions + features[0]

        backpointers = torch.empty_like(features)

        for i in range(1, seq_len):
            next_log_score_over_all_seq = (
                log_score_over_all_seq.unsqueeze(2) + self.transitions + features[i].unsqueeze(1)
            )

            next_log_score_over_all_seq, indices = next_log_score_over_all_seq.max(dim=1)

            log_score_over_all_seq = torch.where(
                mask[i].unsqueeze(1),
                next_log_score_over_all_seq,
                log_score_over_all_seq,
            )
            backpointers[i] = indices

        backpointers = backpointers[1:].int()

        log_score_over_all_seq += self.end_transitions
        seq_lens = mask.sum(dim=0) - 1

        best_paths = []
        for seq_ind in range(bs):
            best_label_id = torch.argmax(log_score_over_all_seq[seq_ind]).item()
            best_path = [best_label_id]

            for backpointer in reversed(backpointers[: seq_lens[seq_ind]]):
                best_path.append(backpointer[seq_ind][best_path[-1]].item())

            best_path.reverse()
            best_paths.append(best_path)

        return best_paths

    def decode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> List[List[int]]:
        features, _ = self.get_bert_features(input_ids=input_ids, attention_mask=attention_mask)
        attention_mask = attention_mask.bool()

        if self.use_crf:
            features = torch.swapaxes(features, 0, 1)
            mask = torch.swapaxes(attention_mask, 0, 1)
            return self._viterbi_decode(features=features, mask=mask)
        else:
            labels = torch.argmax(features, dim=2)
            predictions = []
            for i in range(len(labels)):
                predictions.append(labels[i][attention_mask[i]].tolist())
            return predictions

    def save_to(self, path: str):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(self.state_dict(), path)

    def load_from(self, path: str):
        self.load_state_dict(torch.load(path))


In [26]:
import json
from typing import Dict, List, Optional

import torch
from IPython import display
from matplotlib import pyplot as plt
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
from sklearn.metrics import f1_score
from tqdm import tqdm


In [21]:
def dict_to_device(
    dict: Dict[str, torch.Tensor],
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    for key, value in dict.items():
        dict[key] = value.to(device)
    return dict


def draw_plots(loss_history: List[float], f1: List[float]):
    display.clear_output(wait=True)

    f, (ax1, ax2) = plt.subplots(2)
    f.set_figwidth(15)
    f.set_figheight(10)

    ax1.set_title("training loss")
    ax2.set_title("f1 micro")

    ax1.plot(loss_history)
    ax2.plot(f1)

    plt.show()

    if len(loss_history) > 0:
        print(f"Current loss: {loss_history[-1]}")
    if len(f1) > 0:
        print(f"Current f1: {f1[-1]}")

In [23]:
class NerDataset(Dataset):
    def __init__(self, tokenized_texts_path: str):
        self.tokenized_texts = []
        with open(tokenized_texts_path, "r") as tokenized_texts_file:
            for line in tokenized_texts_file:
                self.tokenized_texts.append(json.loads(line))

    def __len__(self):
        return len(self.tokenized_texts)

    def __getitem__(self, index) -> Dict[str, List]:
        return self.tokenized_texts[index]

    def collate_function(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        input_ids = [torch.tensor(item["input_ids"]) for item in batch]
        labels = [torch.tensor(item["labels"]) for item in batch]
        attention_mask = [torch.ones(len(item)) for item in input_ids]

        return {
            "input_ids": pad_sequence(input_ids, batch_first=True),
            "labels": pad_sequence(labels, batch_first=True),
            "attention_mask": pad_sequence(attention_mask, batch_first=True),
        }


In [27]:
def train_ner(
    num_labels: int,
    bert_name: str,
    train_tokenized_texts_path: str,
    test_tokenized_texts_path: str,
    dropout: float,
    batch_size: int,
    epochs: int,
    log_every: int,
    lr_bert: float,
    lr_new_layers: float,
    use_crf: bool = True,
    save_to: Optional[str] = None,
    device="cuda" if torch.cuda.is_available() else "cpu",
):
    model = BertCrf(num_labels, bert_name, dropout=dropout, use_crf=use_crf)
    model = model.to(device)
    model.train()

    train_dataset = NerDataset(train_tokenized_texts_path)
    train_data_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=train_dataset.collate_function,
    )

    test_dataset = NerDataset(test_tokenized_texts_path)
    test_data_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=test_dataset.collate_function,
    )

    optimizer = Adam(
        [
            {"params": model.start_transitions},
            {"params": model.end_transitions},
            {"params": model.hidden2label.parameters()},
            {"params": model.transitions},
            {"params": model.bert.parameters(), "lr": lr_bert},
        ],
        lr=lr_new_layers,
    )

    loss_history = []
    f1 = []

    step = 0
    for epoch in range(1, epochs + 1):
        for batch in tqdm(train_data_loader):
            step += 1

            optimizer.zero_grad()

            batch = dict_to_device(batch, device)

            loss = model(**batch)

            loss.backward()
            optimizer.step()

            loss_history.append(loss.item())

            if step % log_every == 0:
                model.eval()
                predictions = []
                ground_truth = []
                with torch.no_grad():
                    for batch in test_data_loader:
                        labels = batch["labels"]
                        del batch["labels"]
                        batch = dict_to_device(batch)

                        prediction = model.decode(**batch)

                        flatten_prediction = [item for sublist in prediction for item in sublist]
                        flatten_labels = torch.masked_select(labels, batch["attention_mask"].bool()).tolist()

                        predictions.extend(flatten_prediction)
                        ground_truth.extend(flatten_labels)
                f1_micro = f1_score(ground_truth, predictions, average="micro")
                f1.append(f1_micro)
                model.train()

            draw_plots(loss_history, f1)
            print(f"Epoch {epoch}/{epochs}")
    if save_to is not None:
        model.save_to(save_to)