In [1]:
import torch
import torch.nn as nn
import numpy as np
from tqdm.auto import tqdm

MAX_LEN = 256

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRAIN_PATH = 'data/span_detection_datasets_split_word_IOB/train.jsonl'
DEV_PATH = 'data/span_detection_datasets_split_word_IOB/dev.jsonl'
TEST_PATH = 'data/span_detection_datasets_split_word_IOB/test.jsonl'

# Load material

In [3]:
# function read jsonl file as dataframe
import pandas as pd
import json

def read_jsonl_to_dataframe(file_path):
    data = []

    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            try:
                json_obj = json.loads(line)
                data.append(json_obj)
            except json.JSONDecodeError as e:
                print(f"Skipping invalid JSON: {e}")

    df = pd.DataFrame(data)

    return df

In [4]:
import json

# load embedding
# embedding_maxtrix = np.load('embedding/embedding_matrix.npy')

# load vocab
# with open('data/vocab.txt', 'r') as f:
#     vocab = f.read().split('\n')

# load tag_to_id
with open('data/tag_to_id.json', 'r') as f:
    tag_to_id = json.load((f))

# load train and dev data

train_data = read_jsonl_to_dataframe(TRAIN_PATH)
dev_data = read_jsonl_to_dataframe(DEV_PATH)


train_sentences = list(train_data.text.apply(lambda x: " ".join(x)))
dev_sentences = list(dev_data.text.apply(lambda x: " ".join(x)))

train_labels = list(train_data.labels)
dev_labels = list(dev_data.labels)

# Module Data

In [5]:
AUTH_TOKEN = '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'
TOKENIZER_PATH = 'nguyenvulebinh/vi-mrc-large'

## Datasets

In this solution we use pretrained tokenizer from [HuggingFace](https://huggingface.co/nguyenvulebinh/vi-mrc-large)

In [6]:
# helper function
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True)

class SpanDetectionDataset(torch.utils.data.Dataset):
    def __init__(self, sentences, labels, tag_to_id, tokenizer, max_len=MAX_LEN):

        self.sentences = sentences
        self.labels = labels

        self.max_len = max_len

        # encode all sentences
        self.tokenizer = tokenizer
        self.encoded_sentences = self.tokenizer.batch_encode_plus(
            self.sentences,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
            max_length=self.max_len,
        )

        # tags to ids
        self.tag_to_id = tag_to_id
        self.encoded_labels = self.convert_labels_to_ids()


    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, index):
        return {
            'input_ids': self.encoded_sentences['input_ids'][index],
            'attention_mask': self.encoded_sentences['attention_mask'][index],
            'labels': self.encoded_labels[index]
        }
    
        def convert_labels_to_ids(self):
            
            # initialize encoded_labels has shape (num_sentences, max_len) with all is id of <PAD> tag
            PAD_ID = self.tag_to_id['<PAD>']
            encoded_labels = np.ones((len(self.labels), self.max_len)) * PAD_ID
            # loop through all labels of sentences and convert to ids
            for i, label in enumerate(self.labels):
                # loop through all labels of sentence
                for j, tag in enumerate(label):
                    encoded_labels[i][j] = self.tag_to_id[tag]
            
            # convert to tensor
            encoded_labels = torch.tensor(encoded_labels, dtype=torch.long, device=device)
            print(encoded_labels)
                    
            return encoded_labels
        

train_dataset = SpanDetectionDataset(train_sentences, train_labels, tag_to_id, tokenizer)
dev_dataset = SpanDetectionDataset(dev_sentences, dev_labels, tag_to_id, tokenizer)

In [8]:
# create data loader tensorflow
BATCH_SIZE = 2

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Model

## Embedding model

In [None]:
# import fasttext

# # Load the pre-trained model
# embedding_model = fasttext.load_model('pretrained-weights/cc.vi.300.bin')

# vocabulary = tokenizer.get_vocabulary()
# vector_dim = embedding_model.get_dimension()

# embedding_matrix = np.zeros((len(vocabulary), vector_dim))
# for i, word in enumerate(vocabulary):
#         embedding_matrix[i] = embedding_model.get_word_vector(word)

# embedding_matrix_file = 'embedding/embedding_matrix.npy'

# np.save(embedding_matrix_file, embedding_matrix)

In [None]:
# load embedding
# embedding_maxtrix = np.load('embedding/embedding_matrix.npy')

In [None]:
# torch.manual_seed(1)

# def argmax(vec):
#     # return the argmax as a python int
#     _, idx = torch.max(vec, 1)
#     return idx.item()


# def prepare_sequence(seq, to_ix):
#     idxs = [to_ix[w] for w in seq]
#     return torch.tensor(idxs, dtype=torch.long)


# # Compute log sum exp in a numerically stable way for the forward algorithm
# def log_sum_exp(vec):
#     max_score = vec[0, argmax(vec)]
#     max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
#     return max_score + \
#         torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

# class BiLSTM_CRF(nn.Module):
#     def __init__(self, vocab_size, tag_to_id, batch_size, embedding_matrix=None, embedding_dim=None, hidden_dim=200, units='lstm', droput=0.2, recurrent_dropout=0.2, max_len=MAX_LEN):
#         super(BiLSTM_CRF, self).__init__()

#         self.embedding_dim = embedding_dim
#         self.hidden_dim = hidden_dim
#         self.vocab_size = vocab_size
#         self.tag_to_ix = tag_to_id
#         self.tagset_size = len(tag_to_id)

#         self.max_len = max_len
#         self.batch_size = batch_size

#         # check embedding matrix and embedding dimension
#         if embedding_matrix is None and embedding_dim is None:
#             raise ValueError('You must provide either embedding matrix or embedding dimension')
#         if embedding_matrix is not None and embedding_dim is not None:
#             raise ValueError('You must provide either embedding matrix or embedding dimension, not both')
        
#         if embedding_matrix is None:
#             self.word_embeds = nn.Embedding(vocab_size, embedding_dim)

#         if embedding_matrix is not None:
#             self.word_embeds = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))

#         num_layers = 1
#         if units == 'lstm':
#             self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, dropout=recurrent_dropout, num_layers=num_layers)
#         elif units == 'gru':
#             self.lstm = nn.GRU(embedding_dim, hidden_dim // 2, bidirectional=True, dropout=recurrent_dropout, num_layers=num_layers)
#         elif units == 'rnn':
#             self.lstm = nn.RNN(embedding_dim, hidden_dim // 2, bidirectional=True, dropout=recurrent_dropout, num_layers=num_layers)
#         else:
#             raise ValueError('Invalid unit type, must be one of "lstm", "gru", "rnn"')
        
#         self.hidden2tag = nn.Linear(hidden_dim, len(self.tag_to_ix))

#         # self.dropout = nn.Dropout(droput)

#         self.transitions = nn.Parameter(
#             torch.randn(self.tagset_size, self.tagset_size).to(device)
#             )

#         self.transitions.data[tag_to_id[START_TAG], :] = -10000
#         self.transitions.data[:, tag_to_id[STOP_TAG]] = -10000

#         self.hidden = self.init_hidden()

#     def init_hidden(self):
#         return (torch.randn(2, self.max_len, self.hidden_dim // 2).to(device),
#                 torch.randn(2, self.max_len, self.hidden_dim // 2).to(device))
    
#     def _forward_alg(self, feats):
#         # Do the forward algorithm to compute the partition function
#         init_alphas = torch.full((1, self.tagset_size), -10000.).to(device)
#         # START_TAG has all of the score.
#         init_alphas[0][self.tag_to_ix[START_TAG]] = 0.

#         # Wrap in a variable so that we will get automatic backprop
#         forward_var = init_alphas

#         # Iterate through the sentence
#         for feat in feats:
#             alphas_t = []  # The forward tensors at this timestep
#             for next_tag in range(self.tagset_size):
#                 # broadcast the emission score: it is the same regardless of
#                 # the previous tag
#                 emit_score = feat[next_tag].view(
#                     1, -1).expand(1, self.tagset_size)
#                 # the ith entry of trans_score is the score of transitioning to
#                 # next_tag from i
#                 trans_score = self.transitions[next_tag].view(1, -1)
#                 # The ith entry of next_tag_var is the value for the
#                 # edge (i -> next_tag) before we do log-sum-exp
#                 next_tag_var = forward_var + trans_score + emit_score
#                 # The forward variable for this tag is log-sum-exp of all the
#                 # scores.
#                 alphas_t.append(log_sum_exp(next_tag_var).view(1))
#             forward_var = torch.cat(alphas_t).view(1, -1)
#         terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
#         alpha = log_sum_exp(terminal_var)
#         return alpha
    
#     def _get_lstm_features(self, sentence):
#         self.hidden = self.init_hidden()
#         # embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
#         embeds = self.word_embeds(sentence)

#         # print(embeds.shape)
#         # print(self.hidden[0].shape, self.hidden[1].shape)
        
#         lstm_out, self.hidden = self.lstm(embeds, self.hidden)

#         # print(lstm_out.shape, self.hidden[0].shape, self.hidden[1].shape)
        
#         # lstm_out = lstm_out.view(len(sentence), self.hidden_dim)

#         lstm_feats = self.hidden2tag(lstm_out)

#         # lstm_feats = self.dropout(lstm_feats)
#         return lstm_feats
    
#     def _score_sentence(self, feats, tags):
#         # Gives the score of a provided tag sequence
#         score = torch.zeros(1, device=device)
#         tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long, device=device), tags])
#         for i, feat in enumerate(feats):
#             score = score + \
#                 self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
#         score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
#         return score

#     def _viterbi_decode(self, feats):
#         backpointers = []

#         # Initialize the viterbi variables in log space
#         init_vvars = torch.full((1, self.tagset_size), -10000.).to(device)
#         init_vvars[0][self.tag_to_ix[START_TAG]] = 0

#         # forward_var at step i holds the viterbi variables for step i-1
#         forward_var = init_vvars
#         for feat in feats:
#             bptrs_t = []  # holds the backpointers for this step
#             viterbivars_t = []  # holds the viterbi variables for this step

#             for next_tag in range(self.tagset_size):
#                 # next_tag_var[i] holds the viterbi variable for tag i at the
#                 # previous step, plus the score of transitioning
#                 # from tag i to next_tag.
#                 # We don't include the emission scores here because the max
#                 # does not depend on them (we add them in below)
#                 next_tag_var = forward_var + self.transitions[next_tag]
#                 best_tag_id = argmax(next_tag_var)
#                 bptrs_t.append(best_tag_id)
#                 viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
#             # Now add in the emission scores, and assign forward_var to the set
#             # of viterbi variables we just computed
#             forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
#             backpointers.append(bptrs_t)

#         # Transition to STOP_TAG
#         terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
#         best_tag_id = argmax(terminal_var)
#         path_score = terminal_var[0][best_tag_id]

#         # Follow the back pointers to decode the best path.
#         best_path = [best_tag_id]
#         for bptrs_t in reversed(backpointers):
#             best_tag_id = bptrs_t[best_tag_id]
#             best_path.append(best_tag_id)
#         # Pop off the start tag (we dont want to return that to the caller)
#         start = best_path.pop()
#         assert start == self.tag_to_ix[START_TAG]  # Sanity check
#         best_path.reverse()
#         return path_score, best_path

#     def neg_log_likelihood(self, batch_sentence, tags):
#         batch_feats = self._get_lstm_features(batch_sentence) #[batch_size, max_len, hidden_dim/2]

#         batch_forward_score = torch.zeros(1, device=device)
#         batch_gold_score = torch.zeros(1, device=device)
        
#         for feats, tag in zip(batch_feats, tags):
#             forward_score = self._forward_alg(feats) # this function get input is [max_len, hidden_dim/2]
#             gold_score = self._score_sentence(feats, tag)

#             batch_forward_score += forward_score
#             batch_gold_score += gold_score

#         return batch_forward_score - batch_gold_score
    
#     def forward(self, sentence):  # dont confuse this with _forward_alg above.
#         # Get the emission scores from the BiLSTM
#         lstm_feats = self._get_lstm_features(sentence)

#         # Find the best path, given the features.
#         batch_score, batch_tag_seq = torch.zeros(self.batch_size, device=device), torch.zeros(self.batch_size, device=device)
        
#         for feats in lstm_feats:
#             score, tag_seq = self._viterbi_decode(feats)
#             batch_score += score
#             batch_tag_seq += tag_seq

#         return batch_score, batch_tag_seq


## Span detection model

In [9]:
import numpy as np
import matplotlib.pyplot as plt

START_TAG = "<START>"
STOP_TAG = "<STOP>"

In [10]:
def log_sum_exp(x):
    """calculate log(sum(exp(x))) = max(x) + log(sum(exp(x - max(x))))
    """
    max_score = x.max(-1)[0]
    return max_score + (x - max_score.unsqueeze(-1)).exp().sum(-1).log()


IMPOSSIBLE = -1e4


class CRF(nn.Module):
    """General CRF module.
    The CRF module contain a inner Linear Layer which transform the input from features space to tag space.

    :param in_features: number of features for the input
    :param num_tag: number of tags. DO NOT include START, STOP tags, they are included internal.
    """

    def __init__(self, in_features, num_tags):
        super(CRF, self).__init__()

        self.num_tags = num_tags + 2
        self.start_idx = self.num_tags - 2
        self.stop_idx = self.num_tags - 1

        self.fc = nn.Linear(in_features, self.num_tags)

        # transition factor, Tij mean transition from j to i
        self.transitions = nn.Parameter(torch.randn(self.num_tags, self.num_tags), requires_grad=True)
        self.transitions.data[self.start_idx, :] = IMPOSSIBLE
        self.transitions.data[:, self.stop_idx] = IMPOSSIBLE

    def forward(self, features, masks):
        """decode tags

        :param features: [B, L, C], batch of unary scores
        :param masks: [B, L] masks
        :return: (best_score, best_paths)
            best_score: [B]
            best_paths: [B, L]
        """
        features = self.fc(features)
        return self.__viterbi_decode(features, masks[:, :features.size(1)].float())

    def loss(self, features, ys, masks):
        """negative log likelihood loss
        B: batch size, L: sequence length, D: dimension

        :param features: [B, L, D]
        :param ys: tags, [B, L]
        :param masks: masks for padding, [B, L]
        :return: loss
        """
        features = self.fc(features)

        L = features.size(1)
        masks_ = masks[:, :L].float()

        forward_score = self.__forward_algorithm(features, masks_)
        gold_score = self.__score_sentence(features, ys[:, :L].long(), masks_)
        loss = (forward_score - gold_score).mean()
        return loss

    def __score_sentence(self, features, tags, masks):
        """Gives the score of a provided tag sequence

        :param features: [B, L, C]
        :param tags: [B, L]
        :param masks: [B, L]
        :return: [B] score in the log space
        """
        B, L, C = features.shape

        # emission score
        emit_scores = features.gather(dim=2, index=tags.unsqueeze(-1)).squeeze(-1)

        # transition score
        start_tag = torch.full((B, 1), self.start_idx, dtype=torch.long, device=tags.device)
        tags = torch.cat([start_tag, tags], dim=1)  # [B, L+1]
        trans_scores = self.transitions[tags[:, 1:], tags[:, :-1]]

        # last transition score to STOP tag
        last_tag = tags.gather(dim=1, index=masks.sum(1).long().unsqueeze(1)).squeeze(1)  # [B]
        last_score = self.transitions[self.stop_idx, last_tag]

        score = ((trans_scores + emit_scores) * masks).sum(1) + last_score
        return score

    def __viterbi_decode(self, features, masks):
        """decode to tags using viterbi algorithm

        :param features: [B, L, C], batch of unary scores
        :param masks: [B, L] masks
        :return: (best_score, best_paths)
            best_score: [B]
            best_paths: [B, L]
        """
        B, L, C = features.shape

        bps = torch.zeros(B, L, C, dtype=torch.long, device=features.device)  # back pointers

        # Initialize the viterbi variables in log space
        max_score = torch.full((B, C), IMPOSSIBLE, device=features.device)  # [B, C]
        max_score[:, self.start_idx] = 0

        for t in range(L):
            mask_t = masks[:, t].unsqueeze(1)  # [B, 1]
            emit_score_t = features[:, t]  # [B, C]

            # [B, 1, C] + [C, C]
            acc_score_t = max_score.unsqueeze(1) + self.transitions  # [B, C, C]
            acc_score_t, bps[:, t, :] = acc_score_t.max(dim=-1)
            acc_score_t += emit_score_t
            max_score = acc_score_t * mask_t + max_score * (1 - mask_t)  # max_score or acc_score_t

        # Transition to STOP_TAG
        max_score += self.transitions[self.stop_idx]
        best_score, best_tag = max_score.max(dim=-1)

        # Follow the back pointers to decode the best path.
        best_paths = []
        bps = bps.cpu().numpy()
        for b in range(B):
            best_tag_b = best_tag[b].item()
            seq_len = int(masks[b, :].sum().item())

            best_path = [best_tag_b]
            for bps_t in reversed(bps[b, :seq_len]):
                best_tag_b = bps_t[best_tag_b]
                best_path.append(best_tag_b)
            # drop the last tag and reverse the left
            best_paths.append(best_path[-2::-1])

        return best_score, best_paths

    def __forward_algorithm(self, features, masks):
        """calculate the partition function with forward algorithm.
        TRICK: log_sum_exp([x1, x2, x3, x4, ...]) = log_sum_exp([log_sum_exp([x1, x2]), log_sum_exp([x3, x4]), ...])

        :param features: features. [B, L, C]
        :param masks: [B, L] masks
        :return:    [B], score in the log space
        """
        B, L, C = features.shape

        scores = torch.full((B, C), IMPOSSIBLE, device=features.device)  # [B, C]
        scores[:, self.start_idx] = 0.
        trans = self.transitions.unsqueeze(0)  # [1, C, C]

        # Iterate through the sentence
        for t in range(L):
            emit_score_t = features[:, t].unsqueeze(2)  # [B, C, 1]
            score_t = scores.unsqueeze(1) + trans + emit_score_t  # [B, 1, C] + [1, C, C] + [B, C, 1] => [B, C, C]
            score_t = log_sum_exp(score_t)  # [B, C]

            mask_t = masks[:, t].unsqueeze(1)  # [B, 1]
            scores = score_t * mask_t + scores * (1 - mask_t)
        scores = log_sum_exp(scores + self.transitions[self.stop_idx])
        return scores

In [16]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tag_to_id, batch_size, embedding_matrix=None, embedding_dim=None, hidden_dim=200, units='lstm', recurrent_dropout=0.2, max_len=MAX_LEN):
        super(BiLSTM_CRF, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_id
        self.tagset_size = len(tag_to_id)

        self.max_len = max_len
        self.batch_size = batch_size

        # check embedding matrix and embedding dimension
        if embedding_matrix is None and embedding_dim is None:
            raise ValueError('You must provide either embedding matrix or embedding dimension')
        if embedding_matrix is not None and embedding_dim is not None:
            raise ValueError('You must provide either embedding matrix or embedding dimension, not both')
        
        if embedding_matrix is None:
            self.word_embeds = nn.Embedding(vocab_size, embedding_dim)

        if embedding_matrix is not None:
            self.word_embeds = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))

        num_layers = 1
        if units == 'lstm':
            UNIT = nn.LSTM
        elif units == 'gru':
            UNIT = nn.GRU
        elif units == 'rnn':
            UNIT = nn.RNN
        else:
            raise ValueError('Invalid unit type, must be one of "lstm", "gru", "rnn"')
        
        self.lstm = UNIT(embedding_dim, hidden_dim // 2, bidirectional=True, dropout=recurrent_dropout, num_layers=num_layers)
        
        # self.hidden2tag = nn.Linear(hidden_dim, len(self.tag_to_ix))

        self.crf = CRF(hidden_dim, self.tagset_size)
    
    def __build_features(self, sentences):
        masks = sentences.gt(0)
        embeds = self.word_embeds(sentences.long())

        seq_length = masks.sum(1)
        sorted_seq_length, perm_idx = seq_length.sort(descending=True)
        embeds = embeds[perm_idx, :]

        pack_sequence = pack_padded_sequence(embeds, lengths=sorted_seq_length.to('cpu'), batch_first=True)
        packed_output, _ = self.lstm(pack_sequence)
        lstm_out, _ = pad_packed_sequence(packed_output, batch_first=True)
        _, unperm_idx = perm_idx.sort()
        lstm_out = lstm_out[unperm_idx, :]

        return lstm_out, masks
    
    def loss(self, xs, tags):
        features, masks = self.__build_features(xs)
        loss = self.crf.loss(features, tags, masks=masks)
        return loss

    def forward(self, xs):
        features, masks = self.__build_features(xs)
        scores, tag_seq = self.crf(features, masks)
        return scores, tag_seq

## Train model

In [17]:
def train_model(dataloader, epochs=20, lr=0.01, weight_decay=0.01, early_stopping=5):
    train_loader, dev_loader = dataloader

    model = BiLSTM_CRF(vocab_size=len(tokenizer.get_vocab()), tag_to_id=tag_to_id, batch_size=BATCH_SIZE, embedding_dim=300)
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

    best_loss = np.inf
    best_model = None
    early_stopping_counter = 0

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()

            loss = model.loss(input_ids, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
        
        train_loss /= len(train_loader)

        model.eval()
        dev_loss = 0
        for batch in dev_loader:
            input_ids = batch['input_ids'].to(device)

            labels = batch['labels'].to(device)

            loss = model.loss(input_ids, labels)

            dev_loss += loss.item()
        
        dev_loss /= len(dev_loader)

        print(f'Epoch {epoch + 1}/{epochs}, train_loss: {train_loss}, dev_loss: {dev_loss}')

        # calculate F1 score

        # early stopping
        if dev_loss < best_loss:
            best_loss = dev_loss
            best_model = model
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1
        
        if early_stopping_counter >= early_stopping:
            print('Early stopping')
            break
        
        scheduler.step()

    return best_model, train_loss, dev_loss

model, train_loss, dev_loss = train_model((train_dataloader, dev_dataloader), epochs=20, lr=0.001, weight_decay=0.01, early_stopping=5)



KeyboardInterrupt: 

# Plot results

In [None]:
# save model
model.save_pretrained('model')

In [None]:
# load test data
test_data = read_jsonl_to_dataframe(TEST_PATH)
test_data

In [None]:
idx = 0
sentence = test_data.text[idx]
label = test_data.labels[idx]

def end_to_end_predict(sentence):
    # tokenize sentence
    tokenized_sentence = tokenizer(sentence, return_tensors='pt', padding='max_length', truncation=True, max_length=MAX_LEN)
    # predict
    scores, tag_seq = model(tokenized_sentence['input_ids'].to(device))
    # convert tag_seq to tag
    tag_seq = tag_seq[0].cpu().numpy()
    tag_seq = [id_to_tag[tag] for tag in tag_seq]
    # convert to list of (word, tag)
    words = tokenizer.convert_ids_to_tokens(tokenized_sentence['input_ids'][0])
    words = words[1:-1]
    return list(zip(words, tag_seq))

end_to_end_predict(sentence)

# End