In [1]:
import torch
import lightning as L

from torch.utils.data import Dataset, DataLoader

from nltk.corpus.reader.conll import ConllCorpusReader

In [2]:
with open("data/train.txt", "r") as f:
    file = f.readlines()

file[:10]

['-DOCSTART- -X- -X- O\n',
 '\n',
 'EU NNP B-NP B-ORG\n',
 'rejects VBZ B-VP O\n',
 'German JJ B-NP B-MISC\n',
 'call NN I-NP O\n',
 'to TO B-VP O\n',
 'boycott VB I-VP O\n',
 'British JJ B-NP B-MISC\n',
 'lamb NN I-NP O\n']

In [3]:
root = "data/"
file_id = "train.txt"

reader = ConllCorpusReader(root, file_id, columntypes=("words", "pos", "chunk", "ne"))

In [4]:
iob = reader.iob_words()

words = [word[0] for word in iob]

In [5]:
# get words
reader.words()

['EU', 'rejects', 'German', 'call', 'to', 'boycott', ...]

In [6]:
# get sentences
reader.sents()

[[], ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], ...]

In [7]:
# get tagged words
reader.tagged_words()

[('EU', 'NNP'), ('rejects', 'VBZ'), ('German', 'JJ'), ...]

In [8]:
# get tagged by sentences
reader.tagged_sents()

[[], [('EU', 'NNP'), ('rejects', 'VBZ'), ('German', 'JJ'), ('call', 'NN'), ('to', 'TO'), ('boycott', 'VB'), ('British', 'JJ'), ('lamb', 'NN'), ('.', '.')], ...]

In [9]:
# get chunk words
reader.chunked_words()

[Tree('NP', [('EU', 'NNP')]), Tree('VP', [('rejects', 'VBZ')]), ...]

In [10]:
# get chunk sentences
reader.chunked_sents()

[Tree('S', []), Tree('S', [Tree('NP', [('EU', 'NNP')]), Tree('VP', [('rejects', 'VBZ')]), Tree('NP', [('German', 'JJ'), ('call', 'NN')]), Tree('VP', [('to', 'TO'), ('boycott', 'VB')]), Tree('NP', [('British', 'JJ'), ('lamb', 'NN')]), ('.', '.')]), ...]

In [11]:
grid = reader._grids()
grid

[[], [['EU', 'NNP', 'B-NP', 'B-ORG'], ['rejects', 'VBZ', 'B-VP', 'O'], ['German', 'JJ', 'B-NP', 'B-MISC'], ['call', 'NN', 'I-NP', 'O'], ['to', 'TO', 'B-VP', 'O'], ['boycott', 'VB', 'I-VP', 'O'], ['British', 'JJ', 'B-NP', 'B-MISC'], ['lamb', 'NN', 'I-NP', 'O'], ['.', '.', 'O', 'O']], ...]

In [12]:
sentences = list(grid)[1:]
sentences[0]

[['EU', 'NNP', 'B-NP', 'B-ORG'],
 ['rejects', 'VBZ', 'B-VP', 'O'],
 ['German', 'JJ', 'B-NP', 'B-MISC'],
 ['call', 'NN', 'I-NP', 'O'],
 ['to', 'TO', 'B-VP', 'O'],
 ['boycott', 'VB', 'I-VP', 'O'],
 ['British', 'JJ', 'B-NP', 'B-MISC'],
 ['lamb', 'NN', 'I-NP', 'O'],
 ['.', '.', 'O', 'O']]

In [13]:
all_words = reader.words()

len(all_words)

203621

In [14]:
words = list(set(all_words))

# add for padding
words.append("<pad>")

n_words = len(words)
len(words)

23624

In [15]:
# number of sentences
len(sentences)

14986

In [16]:
len([tag[3] for sentence in sentences for tag in sentence])

203621

In [17]:
all_tags_in_sentence = [[c[3] for c in sentence] for sentence in sentences]
all_tags = []

print(all_tags_in_sentence)
for sentence in all_tags_in_sentence:
    for c in sentence:
        all_tags.append(c)

print(set(all_tags))
print(len(all_tags))

[['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'], ['B-PER', 'I-PER'], ['B-LOC', 'O'], ['O', 'B-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['B-LOC', 'O', 'O', 'O', 'O', 'B-ORG', 'I-ORG', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'I-ORG', 'O'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'], ['B-PER', 'O', 

In [18]:
tags = set(all_tags)

n_tags = len(tags)
print(tags)

{'I-MISC', 'I-PER', 'B-PER', 'B-MISC', 'I-LOC', 'I-ORG', 'B-ORG', 'B-LOC', 'O'}


In [19]:
all_pos_in_sentence = [[c[1] for c in sentence] for sentence in sentences]
all_pos = []

for sentence in all_pos_in_sentence:
    for c in sentence:
        all_pos.append(c)

print((all_pos[:5]))
print(len(all_pos))

['NNP', 'VBZ', 'JJ', 'NN', 'TO']
203621


In [20]:
pos = set(all_pos)
print(pos)

n_pos = len(pos)

{',', 'RP', 'CC', 'RBS', 'LS', '(', 'FW', 'PRP', 'JJ', 'NN|SYM', 'NNP', 'NN', 'DT', 'MD', 'IN', '"', 'NNS', 'CD', 'RB', 'VBZ', 'WRB', 'NNPS', 'VBD', 'RBR', 'VB', 'VBG', ')', 'UH', 'JJR', 'WDT', 'SYM', 'VBN', 'TO', 'JJS', '.', 'WP$', 'EX', 'PRP$', "''", ':', 'WP', 'VBP', 'POS', 'PDT', '$'}


In [21]:
all_chunk_in_sentence = [[c[2] for c in sentence] for sentence in sentences]
all_chunk = []

for sentence in all_chunk_in_sentence:
    for c in sentence:
        all_chunk.append(c)

print((all_chunk[:5]))
print(len(all_chunk))

['B-NP', 'B-VP', 'B-NP', 'I-NP', 'B-VP']
203621


In [22]:
chunks = set(all_chunk)

n_chunks = len(chunks); n_chunks

20

In [23]:
sentences[0]

[['EU', 'NNP', 'B-NP', 'B-ORG'],
 ['rejects', 'VBZ', 'B-VP', 'O'],
 ['German', 'JJ', 'B-NP', 'B-MISC'],
 ['call', 'NN', 'I-NP', 'O'],
 ['to', 'TO', 'B-VP', 'O'],
 ['boycott', 'VB', 'I-VP', 'O'],
 ['British', 'JJ', 'B-NP', 'B-MISC'],
 ['lamb', 'NN', 'I-NP', 'O'],
 ['.', '.', 'O', 'O']]

In [24]:
[w[0] for w in sentences[0]]

['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']

# Dataset

In [25]:
class ConllDataset(Dataset):
    def __init__(
        self, 
        root_dir: str, 
        file_id: str, 
        max_len: int,
        token_vocab: dict=None,
        pos_vocab: dict=None,
        chunk_vocab: dict=None,
        tags_vocab: dict=None
    ):

        self.root_dir = root_dir
        self.file_id = file_id
        self.max_len = max_len

        self.token_vocab = token_vocab
        self.pos_vocab = pos_vocab,
        self.chunk_vocab = chunk_vocab,
        self.tags_vocab = tags_vocab

        self._load_data()

    def _load_data(self):
        self.reader = ConllCorpusReader(
            self.root_dir, 
            self.file_id, 
            columntypes=("words", "pos", "chunk", "ne")
        )

        self.iob = self.reader.iob_words()

        words = [word[0] for word in self.iob]
        self.words = list(set(words))

        pos = [pos[1] for pos in self.iob]
        self.pos = list(set(pos))

        chunks = [chunk[2] for chunk in self.iob]
        self.chunks = list(set(chunks))

        grid = reader._grids()
        self.sentences = list(grid)[1:]

        tags = [tag[3] for sentence in sentences for tag in sentence]
        self.tags = list(set(tags))

    def setup(self, stage="train"):
        if stage=="train":
            self.token_vocab = {w: i+2 for i, w in enumerate(self.words)}
            self.token_vocab["<pad>"] = 0
            self.token_vocab["<unk>"] = 1

            self.pos_vocab = {w: i+1 for i, w in enumerate(self.pos)}
            self.pos_vocab["<pad>"] = 0

            self.chunk_vocab = {w: i+1 for i, w in enumerate(self.chunks)}
            self.chunk_vocab["<pad>"] = 0

            self.tags_vocab = {w: i+1 for i, w in enumerate(self.tags)}
            self.tags_vocab["<pad>"] = 0

    def _pad_sequence(self, sequence, pad_idx):
        sequence = sequence[:self.max_len]
        return sequence + [pad_idx] * (self.max_len - len(sequence))

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        
        token = [self.token_vocab.get(w[0], self.token_vocab["<unk>"]) for w in sentence]
        pos = [self.pos_vocab.get(w[1]) for w in sentence]
        chunk = [self.chunk_vocab.get(w[2]) for w in sentence]
        tags = [self.tags_vocab.get(w[3]) for w in sentence]

        token = self._pad_sequence(token, self.token_vocab["<pad>"])
        pos = self._pad_sequence(pos, self.pos_vocab["<pad>"])
        chunk = self._pad_sequence(chunk, self.chunk_vocab["<pad>"])
        tags = self._pad_sequence(tags, self.tags_vocab["<pad>"])

        return (
            torch.tensor(token, dtype=torch.long),
            torch.tensor(pos, dtype=torch.long),
            torch.tensor(chunk, dtype=torch.long),
            torch.tensor(tags, dtype=torch.long)
        )


In [26]:
data = ConllDataset(root_dir="data", file_id="train.txt", max_len=65)
data.setup()

In [27]:
data[1]

(tensor([11589,  1125,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0]),
 tensor([11, 11,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 tensor([17,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,

In [28]:
data.pos_vocab

{',': 1,
 'RP': 2,
 'CC': 3,
 'RBS': 4,
 'LS': 5,
 '(': 6,
 'FW': 7,
 'PRP': 8,
 'JJ': 9,
 'NN|SYM': 10,
 'NNP': 11,
 'NN': 12,
 'DT': 13,
 'MD': 14,
 'IN': 15,
 '"': 16,
 'NNS': 17,
 'CD': 18,
 'RB': 19,
 'VBZ': 20,
 'WRB': 21,
 'NNPS': 22,
 'VBD': 23,
 'RBR': 24,
 'VB': 25,
 'VBG': 26,
 ')': 27,
 'UH': 28,
 'JJR': 29,
 'WDT': 30,
 'SYM': 31,
 'VBN': 32,
 'TO': 33,
 'JJS': 34,
 '.': 35,
 'WP$': 36,
 'EX': 37,
 'PRP$': 38,
 "''": 39,
 ':': 40,
 'WP': 41,
 'VBP': 42,
 'POS': 43,
 'PDT': 44,
 '$': 45,
 '<pad>': 0}

In [29]:
data.chunk_vocab

{'I-ADJP': 1,
 'I-CONJP': 2,
 'B-LST': 3,
 'I-VP': 4,
 'B-ADVP': 5,
 'O': 6,
 'I-NP': 7,
 'B-ADJP': 8,
 'B-VP': 9,
 'B-CONJP': 10,
 'I-INTJ': 11,
 'I-LST': 12,
 'B-PP': 13,
 'I-PP': 14,
 'B-INTJ': 15,
 'B-PRT': 16,
 'B-NP': 17,
 'I-ADVP': 18,
 'B-SBAR': 19,
 'I-SBAR': 20,
 '<pad>': 0}

In [30]:
data.tags_vocab

{'I-MISC': 1,
 'I-PER': 2,
 'B-PER': 3,
 'B-MISC': 4,
 'I-LOC': 5,
 'I-ORG': 6,
 'B-ORG': 7,
 'B-LOC': 8,
 'O': 9,
 '<pad>': 0}

In [31]:
class ConllDataModule(L.LightningModule):
    def __init__(self, root_dir: str, max_len: int, batch_size: int):
        super().__init__()

        self.root_dir = root_dir
        self.max_len = max_len
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train = ConllDataset(
            root_dir = self.root_dir,
            file_id = "train.txt", 
            max_len = self.max_len
        )
        self.train.setup()

        self.val = ConllDataset(
            root_dir = self.root_dir,
            file_id = "valid.txt", 
            max_len = self.max_len,
            token_vocab = self.train.token_vocab,
            pos_vocab = self.train.pos_vocab,
            chunk_vocab = self.train.chunk_vocab,
            tags_vocab = self.train.tags_vocab
        )

        self.test = ConllDataset(
            root_dir = self.root_dir,
            file_id = "test.txt", 
            max_len = self.max_len,
            token_vocab = self.train.token_vocab,
            pos_vocab = self.train.pos_vocab,
            chunk_vocab = self.train.chunk_vocab,
            tags_vocab = self.train.tags_vocab
        )

    def train_dataloader(self):
        return DataLoader(
            self.train, 
            batch_size=self.batch_size, 
            num_workers=11, 
            shuffle=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val, 
            batch_size=32, 
            num_workers=11
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, 
            batch_size=32, 
            num_workers=11
        )

In [32]:
datamodule = ConllDataModule(root_dir="data", max_len=65, batch_size=32)
datamodule.setup()

In [33]:
train = datamodule.train_dataloader()

token, pos, chunk, tag = next(iter(train))

In [34]:
token.shape

torch.Size([32, 65])

In [35]:
pos.shape

torch.Size([32, 65])

In [36]:
chunk.shape

torch.Size([32, 65])

In [37]:
tag.shape

torch.Size([32, 65])

# Model

In [38]:
n_words

23624

In [39]:
n_pos

45

In [40]:
n_chunks

20

In [41]:
n_tags

9

In [42]:
import torch
import torch.nn as nn

class NERModel(nn.Module):
    def __init__(self, n_words, n_pos, n_chunks, n_tags):
        super(NERModel, self).__init__()
        self.word_embedding = nn.Embedding(n_words + 2, 20, padding_idx=0)
        self.pos_embedding = nn.Embedding(n_pos + 2, 20, padding_idx=0)
        self.chunk_embedding = nn.Embedding(n_chunks + 2, 20, padding_idx=0)

        self.spatial_dropout = nn.Dropout(0.3)
        self.lstm = nn.LSTM(
            input_size=60,  # 20 (word) + 20 (pos) + 20 (chunk)
            hidden_size=50,
            num_layers=1,
            bidirectional=True,
            batch_first=True,
            dropout=0.6,
        )
        self.output_layer = nn.Linear(50 * 2, n_tags + 1)  # Bidirectional LSTM doubles hidden size
        self.sigmoid = nn.Sigmoid()

    def forward(self, word_input, pos_input, chunk_input):
        # Embeddings
        word_emb = self.word_embedding(word_input)
        pos_emb = self.pos_embedding(pos_input)
        chunk_emb = self.chunk_embedding(chunk_input)

        # Concatenate embeddings along the last dimension
        x = torch.cat((word_emb, pos_emb, chunk_emb), dim=-1)

        # Apply spatial dropout (dropout on the entire embedding sequence)
        x = self.spatial_dropout(x)

        # LSTM
        x, _ = self.lstm(x)

        # Output layer with time-distributed dense equivalent
        x = self.output_layer(x)

        # Apply sigmoid activation for each time step
        x = self.sigmoid(x)

        return x


In [43]:
model = NERModel(n_words, n_pos, n_chunks, n_tags)



In [45]:
pred = model(token, pos, chunk)

pred.shape

torch.Size([32, 65, 10])

In [46]:
tag.shape

torch.Size([32, 65])

In [None]:
tag.view(32, 65, 1)

tensor([[[9],
         [9],
         [9],
         ...,
         [0],
         [0],
         [0]],

        [[9],
         [9],
         [9],
         ...,
         [0],
         [0],
         [0]],

        [[9],
         [9],
         [9],
         ...,
         [0],
         [0],
         [0]],

        ...,

        [[9],
         [9],
         [9],
         ...,
         [0],
         [0],
         [0]],

        [[9],
         [9],
         [9],
         ...,
         [0],
         [0],
         [0]],

        [[3],
         [9],
         [9],
         ...,
         [0],
         [0],
         [0]]])