In [1]:
from datasets import load_dataset
from typing import List, Dict
import torch
dataset = load_dataset("parquet", data_files={
    "train": "/kaggle/input/dataset/train-00000-of-00001-baac38b53532b0da.parquet",
    "test": "/kaggle/input/dataset/test-00000-of-00001-1019821dbb200a34.parquet"
})

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [2]:
!pip install git+https://github.com/kmkurn/pytorch-crf.git --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pytorch-crf (setup.py) ... [?25l[?25hdone


In [3]:
print(dataset["train"][0])

{'tokens': ['Selegiline', '-', 'induced', 'postural', 'hypotension', 'in', 'Parkinson', "'", 's', 'disease', ':', 'a', 'longitudinal', 'study', 'on', 'the', 'effects', 'of', 'drug', 'withdrawal', '.'], 'tags': [0, 0, 0, 1, 2, 0, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'sentence_id': 'BC5CDR-0'}


In [4]:
from torch.utils.data import DataLoader, Dataset
from collections import Counter
def build_vocab(dataset: List[Dict], min_freq: int = 1):
    token_counter = Counter()
    for item in dataset:
        token_counter.update(item["tokens"])
    vocab = {"<PAD>": 0, "<UNK>": 1}
    for token, count in token_counter.items():
        if count >= min_freq:
            vocab[token] = len(vocab)
    return vocab
tag2idx = {"O": 0, "B-Disease": 1, "I-Disease": 2}
idx2tag = {v: k for k, v in tag2idx.items()}

class NERDataset(Dataset):
    def __init__(self, data: List[Dict], vocab: Dict[str, int], tag2idx: Dict[str, int], max_len: int = 128):
        self.data = data
        self.vocab = vocab
        self.tag2idx = tag2idx
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens = self.data[idx]["tokens"][:self.max_len]
        tags = self.data[idx]["tags"][:self.max_len]

        input_ids = [self.vocab.get(token, self.vocab["<UNK>"]) for token in tokens]
        tag_ids = tags
        pad_len = self.max_len - len(input_ids)

        return {
            "input_ids": torch.tensor(input_ids + [self.vocab["<PAD>"]]*pad_len, dtype=torch.long),
            "tags": torch.tensor(tag_ids + [0]*pad_len, dtype=torch.long),
            "attention_mask": torch.tensor([1]*len(input_ids) + [0]*pad_len, dtype=torch.long)
        }

In [5]:
vocab = build_vocab(dataset["train"], min_freq=1)

train_dataset = NERDataset(dataset["train"], vocab, tag2idx)
test_dataset = NERDataset(dataset["test"], vocab, tag2idx)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [6]:
import torch
import torch.nn as nn
from torchcrf import CRF

class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tagset_size,hidden_dim, embedding_dim=100, pad_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.crf = CRF(tagset_size, batch_first=True)
    
    def forward(self, input_ids, tags=None, mask=None):
        embedded = self.embedding(input_ids)
        lstm_out, _ = self.lstm(embedded)
        emissions = self.hidden2tag(lstm_out)
        
        if tags is not None:
            loss = -self.crf(emissions, tags, mask=mask, reduction='mean')
            return loss
        else:
            predictions = self.crf.decode(emissions, mask=mask)
            return predictions

In [7]:
from sklearn.metrics import classification_report
from tqdm.notebook import tqdm
def train_step(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader, desc="Training"):
        batch = {k: v.to(device) for k, v in batch.items()}
        loss = model(
            input_ids=batch["input_ids"],
            tags=batch["tags"],
            mask=batch["attention_mask"].bool()
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

def eval_step(model, dataloader, device):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            mask = batch["attention_mask"].bool()
            loss = model(
                input_ids=batch["input_ids"],
                tags=batch["tags"],
                mask=mask
            )
            total_loss += loss.item()
            predictions = model(
                input_ids=batch["input_ids"],
                tags=None,
                mask=mask
            )
            for pred_seq, true_seq, mask_seq in zip(predictions, batch["tags"], mask):
                true_seq = true_seq[mask_seq].cpu().tolist()
                all_preds.extend(pred_seq)
                all_labels.extend(true_seq)

    report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0)
    return total_loss / len(dataloader), report


In [8]:
from torch import nn
import torch.optim as optim
from transformers import get_scheduler
from tqdm.notebook import tqdm

def train_model(model, train_loader, test_loader, device, epochs=3, lr=2e-5):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=epochs * len(train_loader)
    )

    for epoch in tqdm(range(epochs), desc="Epochs"):
        print(f"\nEpoch {epoch+1}/{epochs}")

        train_loss = train_step(model, train_loader, optimizer, device)
        test_loss, test_report = eval_step(model, test_loader, device)

        scheduler.step()

        print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}")
        print(f"F1-score: {test_report['macro avg']['f1-score']:.4f}")
        print(f"Recall:    {test_report['macro avg']['recall']:.4f}")
        print(f"Precision: {test_report['macro avg']['precision']:.4f}")

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BiLSTM_CRF(
    vocab_size=len(vocab),
    tagset_size=len(tag2idx),
    hidden_dim=300,
    embedding_dim=100,
    pad_idx=vocab["<PAD>"],
)
train_model(model, train_dataloader, test_dataloader, device, epochs=30, lr=1e-3)

Epochs:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 1/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 4.9025 | Test Loss: 2.9311
F1-score: 0.7005
Recall:    0.6190
Precision: 0.8575

Epoch 2/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 2.0666 | Test Loss: 2.1014
F1-score: 0.8020
Recall:    0.7725
Precision: 0.8371

Epoch 3/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 1.1820 | Test Loss: 1.8276
F1-score: 0.8258
Recall:    0.8136
Precision: 0.8389

Epoch 4/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.6933 | Test Loss: 1.9813
F1-score: 0.8172
Recall:    0.8476
Precision: 0.7906

Epoch 5/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.3742 | Test Loss: 1.8392
F1-score: 0.8305
Recall:    0.8192
Precision: 0.8424

Epoch 6/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.1907 | Test Loss: 1.9641
F1-score: 0.8335
Recall:    0.8208
Precision: 0.8470

Epoch 7/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0937 | Test Loss: 2.3250
F1-score: 0.8271
Recall:    0.8286
Precision: 0.8258

Epoch 8/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0456 | Test Loss: 2.6385
F1-score: 0.8178
Recall:    0.8535
Precision: 0.7871

Epoch 9/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0241 | Test Loss: 2.6149
F1-score: 0.8313
Recall:    0.8188
Precision: 0.8446

Epoch 10/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0143 | Test Loss: 2.8434
F1-score: 0.8266
Recall:    0.8348
Precision: 0.8189

Epoch 11/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0686 | Test Loss: 2.7168
F1-score: 0.8139
Recall:    0.8425
Precision: 0.7887

Epoch 12/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0950 | Test Loss: 2.7201
F1-score: 0.8202
Recall:    0.8303
Precision: 0.8108

Epoch 13/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0270 | Test Loss: 2.6784
F1-score: 0.8319
Recall:    0.8234
Precision: 0.8408

Epoch 14/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0237 | Test Loss: 2.6615
F1-score: 0.8288
Recall:    0.8295
Precision: 0.8283

Epoch 15/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0151 | Test Loss: 2.7223
F1-score: 0.8373
Recall:    0.8268
Precision: 0.8483

Epoch 16/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0081 | Test Loss: 2.8422
F1-score: 0.8326
Recall:    0.8380
Precision: 0.8273

Epoch 17/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0036 | Test Loss: 2.9291
F1-score: 0.8370
Recall:    0.8284
Precision: 0.8459

Epoch 18/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0021 | Test Loss: 2.9913
F1-score: 0.8355
Recall:    0.8311
Precision: 0.8400

Epoch 19/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0015 | Test Loss: 3.0848
F1-score: 0.8345
Recall:    0.8349
Precision: 0.8341

Epoch 20/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0016 | Test Loss: 3.1564
F1-score: 0.8358
Recall:    0.8277
Precision: 0.8442

Epoch 21/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0016 | Test Loss: 3.1867
F1-score: 0.8349
Recall:    0.8301
Precision: 0.8398

Epoch 22/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0008 | Test Loss: 3.2756
F1-score: 0.8354
Recall:    0.8315
Precision: 0.8394

Epoch 23/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0011 | Test Loss: 3.2692
F1-score: 0.8343
Recall:    0.8319
Precision: 0.8368

Epoch 24/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0013 | Test Loss: 3.3226
F1-score: 0.8358
Recall:    0.8248
Precision: 0.8474

Epoch 25/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0009 | Test Loss: 3.3340
F1-score: 0.8343
Recall:    0.8298
Precision: 0.8388

Epoch 26/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0008 | Test Loss: 3.3469
F1-score: 0.8335
Recall:    0.8324
Precision: 0.8347

Epoch 27/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0007 | Test Loss: 3.4209
F1-score: 0.8340
Recall:    0.8337
Precision: 0.8344

Epoch 28/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0004 | Test Loss: 3.5181
F1-score: 0.8319
Recall:    0.8306
Precision: 0.8331

Epoch 29/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.0006 | Test Loss: 3.5757
F1-score: 0.8297
Recall:    0.8349
Precision: 0.8246

Epoch 30/30


Training:   0%|          | 0/484 [00:00<?, ?it/s]

Train Loss: 0.1506 | Test Loss: 2.4532
F1-score: 0.8284
Recall:    0.8326
Precision: 0.8250
