In [1]:
import torch
import random
import numpy as np

SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [2]:
from torchtext.datasets import SequenceTaggingDataset

class AICup2020(SequenceTaggingDataset):
    @classmethod
    def splits(cls, fields, path ='aidai/',
               root='aidai/', train="train3.txt",
               validation="dev3.txt",
               test="test3.txt", **kwargs):
        return super(AICup2020, cls).splits(path = path,
            fields=fields, root=root, train=train, validation=validation,
            test=test, **kwargs)

In [3]:
from torchtext import data

TEXT = data.Field(batch_first = True)
CHAR = data.NestedField(data.Field(tokenize=list, batch_first = True))
PHI_TAGS = data.Field(unk_token = None, batch_first = True)
fields = ((("text", "char"), (TEXT, CHAR)), ("labels", PHI_TAGS))

In [4]:
train_data, valid_data, test_data = AICup2020.splits(fields)
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")
print(f"Number of testing examples: {len(test_data)}")

Number of training examples: 15496
Number of validation examples: 2804
Number of testing examples: 4922


In [5]:
import torch

MIN_FREQ = 2
TEXT.build_vocab(train_data, min_freq = MIN_FREQ)

MIN_CHAR_FREQ = 5
CHAR.build_vocab(train_data, min_freq = MIN_CHAR_FREQ)

PHI_TAGS.build_vocab(train_data)

In [6]:
print(f"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}")
print(f"Unique tokens in PHI_TAG vocabulary: {len(PHI_TAGS.vocab)}")
print(f"Unique tokens in CHAR vocabulary: {len(CHAR.vocab)}")

print(TEXT.vocab.freqs.most_common(5))
print(PHI_TAGS.vocab.freqs.most_common(5))
print(CHAR.vocab.freqs.most_common(5))

Unique tokens in TEXT vocabulary: 3124
Unique tokens in PHI_TAG vocabulary: 21
Unique tokens in CHAR vocabulary: 1182
[('：', 14155), ('。', 12271), ('，', 11703), ('民眾', 6239), ('醫師', 5812)]
[('O', 180578), ('I-time', 1482), ('B-time', 1434), ('B-med_exam', 220), ('B-name', 169)]
[('：', 14359), ('。', 12496), ('，', 11824), ('師', 7434), ('是', 6706)]


In [7]:
def tag_percentage(tag_counts):
    
    total_count = sum([count for tag, count in tag_counts])
    
    tag_counts_percentages = [(tag, count, count/total_count) for tag, count in tag_counts]
        
    return tag_counts_percentages

print(f"{'Tag':<16s}\t{'Count':8s}\t{'Percentage':5s}\n")

for tag, count, percent in tag_percentage(PHI_TAGS.vocab.freqs.most_common()):
    print(f"{tag:<16s}\t{count:>8}\t{percent:>10.3f}")

Tag             	Count   	Percentage

O               	  180578	     0.979
I-time          	    1482	     0.008
B-time          	    1434	     0.008
B-med_exam      	     220	     0.001
B-name          	     169	     0.001
B-location      	     161	     0.001
I-name          	      91	     0.000
B-money         	      78	     0.000
I-location      	      30	     0.000
I-money         	      30	     0.000
B-family        	      25	     0.000
I-med_exam      	      25	     0.000
B-contact       	      19	     0.000
B-profession    	      13	     0.000
B-ID            	       8	     0.000
I-ID            	       5	     0.000
B-clinical_event	       5	     0.000
B-education     	       3	     0.000
I-family        	       3	     0.000
B-organization  	       1	     0.000


In [8]:
BATCH_SIZE = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, device = device)

In [170]:
for i in train_iterator:
    print(i.labels[3])
    break
# PHI_TAGS.vocab.itos

tensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')


In [171]:
PHI_TAGS.vocab.itos

['<pad>',
 'O',
 'I-time',
 'B-time',
 'B-med_exam',
 'B-name',
 'B-location',
 'I-name',
 'B-money',
 'I-location',
 'I-money',
 'B-family',
 'I-med_exam',
 'B-contact',
 'B-profession',
 'B-ID',
 'B-clinical_event',
 'I-ID',
 'B-education',
 'I-family',
 'B-organization']

In [194]:
from torch import nn
import torch.nn.functional as F
from torchcrf import CRF

class CharWordEmbLSTMCRF(nn.Module):
    def __init__(self, input_dim, embedding_dim,                  
                 hidden_dim, output_dim, 
                 dropout,
                 text_pad_idx,

                 char_emb_dim,
                 char_input_dim,
                 char_cnn_filter_num,
                 char_cnn_kernel_size,
                 cnn_dropout,
                 char_pad_idx,
                 tag_pad_idx
                 ):
        super().__init__()        
        self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = text_pad_idx)        
        self.char_emb_dim = char_emb_dim
        self.char_emb = nn.Embedding(char_input_dim, char_emb_dim, padding_idx=char_pad_idx)
        
        self.my_pad = nn.ConstantPad1d(1, char_pad_idx)
        self.char_cnn = nn.Conv1d(
            char_emb_dim,
            out_channels=char_emb_dim * char_cnn_filter_num,
            kernel_size=char_cnn_kernel_size,
            groups=char_emb_dim  # different 1d conv for each embedding dim
        )
        
        self.dropout = nn.Dropout(dropout)
        self.cnn_dropout = nn.Dropout(cnn_dropout)
        self.rnn = nn.LSTM(embedding_dim + (char_emb_dim * char_cnn_filter_num), hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        # new adding part
        self.tag_pad_idx = tag_pad_idx
        self.crf = CRF(num_tags=output_dim, batch_first = True)
        
    def forward(self, words, chars, tags=None):
#         print(tags)
        embedded = self.embedding(words)
        embedded = self.dropout(embedded) # [batch size, seq len, embed dim]
        chars = self.my_pad(chars)
        char_emb = self.dropout(self.char_emb(chars)) # [batch size, sent len, word len, embed dim]
        batch_size, sent_len, word_len, char_emb_dim = char_emb.shape
        char_cnn_max_out = torch.zeros(batch_size, sent_len, self.char_cnn.out_channels).to(device)
        for sent_i in range(sent_len):
            sent_char_emb = char_emb[:, sent_i, :, :]  # get the character field of sent i            
            sent_char_emb_p = sent_char_emb.permute(0, 2, 1)  # the channel (char emb dim) has to be the last dimension
            char_cnn_sent_out = self.char_cnn(sent_char_emb_p)
            char_cnn_max_out[:, sent_i, :], _ = torch.max(char_cnn_sent_out, dim=2)  # max pooling over the word length dimension
        char_cnn = self.cnn_dropout(char_cnn_max_out)
        word_features = torch.cat((embedded, char_cnn), dim=2)
        output, hidden = self.rnn(word_features) 
        fc_out = self.fc(self.dropout(output))
        print(fc_out)
        if tags is not None:
            print(tags)
            mask = tags != self.tag_pad_idx
            crf_out = self.crf.decode(fc_out, mask=mask)
            crf_loss = -self.crf(fc_out, tags=tags, mask=mask)
        else:
            crf_out = self.crf.decode(fc_out)
            crf_loss = None
        return crf_out, crf_loss


    ### BEGIN MODIFIED SECTION: CRF OUTPUT ###
    def init_crf_transitions(self, tag_names, imp_value=-100):
        num_tags = len(tag_names)
        for i in range(num_tags):
            tag_name = tag_names[i]
            if tag_name[0] in ("I") or tag_name == "<pad>":
                torch.nn.init.constant_(self.crf.start_transitions[i], imp_value)
            if tag_name[0] in ("B", "I"):
                torch.nn.init.constant_(self.crf.end_transitions[i], imp_value)
        tag_is = {}
        for tag_position in ("B", "I", "O"):
            tag_is[tag_position] = [i for i, tag in enumerate(tag_names) if tag[0] == tag_position]
        
        impossible_transitions_position = {
            #"B": "BOUP",
            #"I": "BOUP",
            "O": "I",        
        }
        for from_tag, to_tag_list in impossible_transitions_position.items():
            to_tags = list(to_tag_list)
            for from_tag_i in tag_is[from_tag]:
                for to_tag in to_tags:
                    for to_tag_i in tag_is[to_tag]:
                        torch.nn.init.constant_(
                            self.crf.transitions[from_tag_i, to_tag_i], imp_value
                        )
        impossible_transitions_tags = {
            "B": "I",
            "I": "I"
        }
        for from_tag, to_tag_list in impossible_transitions_tags.items():
            to_tags = list(to_tag_list)
            for from_tag_i in tag_is[from_tag]:
                for to_tag in to_tags:
                    for to_tag_i in tag_is[to_tag]:
                        if tag_names[from_tag_i].split("-")[1] != tag_names[to_tag_i].split("-")[1]:
                            torch.nn.init.constant_(
                                self.crf.transitions[from_tag_i, to_tag_i], imp_value
                            )
    ### END MODIFIED SECTION: CRF OUTPUT ###

In [195]:
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 300
HIDDEN_DIM = 256
OUTPUT_DIM = len(PHI_TAGS.vocab)
DROPOUT = 0

CHAR_EMBEDDING_DIM = 100
CHAR_INPUT_DIM = len(CHAR.vocab)
FILTER_NUM = 5
KERNEL_SIZE = 2

TEXT_PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
CHAR_PAD_IDX = CHAR.vocab.stoi[CHAR.pad_token]
TAG_PAD_IDX = PHI_TAGS.vocab.stoi[PHI_TAGS.pad_token]

model = CharWordEmbLSTMCRF(
    INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM,
    OUTPUT_DIM, DROPOUT, TEXT_PAD_IDX,
    CHAR_EMBEDDING_DIM,
    CHAR_INPUT_DIM,
    FILTER_NUM,
    KERNEL_SIZE,
    DROPOUT,    
    CHAR_PAD_IDX,
    TAG_PAD_IDX
)

model

CharWordEmbLSTMCRF(
  (embedding): Embedding(3124, 300, padding_idx=1)
  (char_emb): Embedding(1182, 100, padding_idx=1)
  (my_pad): ConstantPad1d(padding=(1, 1), value=1)
  (char_cnn): Conv1d(100, 500, kernel_size=(2,), stride=(1,), groups=100)
  (dropout): Dropout(p=0, inplace=False)
  (cnn_dropout): Dropout(p=0, inplace=False)
  (rnn): LSTM(800, 256)
  (fc): Linear(in_features=256, out_features=21, bias=True)
  (crf): CRF(num_tags=21)
)

In [196]:
# CRF transitions initialization for impossible transitions
model.init_crf_transitions(tag_names=PHI_TAGS.vocab.itos)
# print_crf_transitions(PHI_TAGS.vocab, model.crf)  

In [197]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 2,146,172 trainable parameters


In [198]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters())
model = model.to(device)

In [199]:
from seqeval.metrics import f1_score

def calculate_f_score(preds, ys, tag_pad_idx):
    # ys' shape: [batch size, max length of the seq]
    y_pred = []
    for y in preds:
        s = []        
        for p in y:
            if p != tag_pad_idx:
                s.append(PHI_TAGS.vocab.itos[p])
            else:
                s.append('O')
        y_pred.append(s)
    
    y_gold = []
    for y in ys:
        s = []        
        for g in y:
            if g != tag_pad_idx:
                s.append(PHI_TAGS.vocab.itos[g])
            #else:
            #    s.append('O')
        y_gold.append(s)

    f_score = f1_score(y_gold, y_pred, zero_division = 0)
    return f_score

In [200]:
def train(model, iterator, optimizer, tag_pad_idx):    
    epoch_loss = 0
    epoch_f_score = 0    
    model.train()
    
    for batch in iterator:
        
        text = batch.text
        tags = batch.labels
        char = batch.char
        
        optimizer.zero_grad()
        
        #text = [batch size, sent len]
        #char = [batch size, sent len, word len]
        
        ### BEGIN MODIFIED SECTION: CRF OUTPUT ###
        pred_tags_list, batch_loss = model(text, char, tags)
        
        f_score = calculate_f_score(pred_tags_list, tags, tag_pad_idx)
        
        batch_loss.backward()        
        optimizer.step()
        
        epoch_loss += batch_loss.item()
        epoch_f_score += f_score
        break
        
    return epoch_loss / len(iterator), epoch_f_score / len(iterator)

In [201]:
def evaluate(model, iterator, tag_pad_idx):    
    epoch_loss = 0
    epoch_f_score = 0    
    model.eval()
    
    with torch.no_grad():    
        for batch in iterator:
            text = batch.text
            tags = batch.labels
            char = batch.char

            # MODIFIED PARTS
            predictions, batch_loss = model(text, char, tags)
            # END

            f_score = calculate_f_score(predictions, tags, tag_pad_idx)
            
            epoch_loss += batch_loss.item()
            epoch_f_score += f_score
        
    return epoch_loss / len(iterator), epoch_f_score / len(iterator)

In [202]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [203]:
def save_checkpoint(epoch, model, optimizer, val_f1, word_map, char_map, tag_map, lm_vocab_size, is_best, is_best_loss, fname):
    state = {'epoch': epoch,
             'f1': val_f1,
             'model': model,
             'optimizer': optimizer,
             'word_map': word_map,
             'glove_word_map': glove_word_map,
             'tag_map': tag_map,
             'char_map': char_map,
             'lm_vocab_size': lm_vocab_size,
             'epochs_since_improvement': epochs_since_improvement,
             'elmo_embedder': elmo_embedder,
             'train_loader': train_loader,
             'umls_map': umls_map,
             'pos_map': pos_map,
             'val_loader' : val_loader,
             'vb_decoder' : vb_decoder,
             'crf_criterion' : crf_criterion,
             'lm_criterion' : lm_criterion,
             'start_epoch' : epoch,
             'rev_tag_map': rev_tag_map}
    
    if not os.path.exists(OUTPUT_FOLDER):
        os.mkdir(OUTPUT_FOLDER)
    filename = os.path.join(OUTPUT_FOLDER, f'{fname}.tar')
    if epoch % SAVE_RATE == 0 or epoch == epochs:
        print('Saving checkpoint...')
        torch.save(state, filename)
    # If checkpoint is the best so far, create a copy to avoid being overwritten by a subsequent worse checkpoint
    if is_best:
        print('Saving best F-score checkpoint...')
        filename = os.path.join(OUTPUT_FOLDER, f'BEST_{fname}.tar')
        torch.save(state, filename)
    elif is_best_loss:
        print('Saving best loss checkpoint...')
        filename = os.path.join(OUTPUT_FOLDER, f'BEST_L_{fname}.tar')
        torch.save(state, filename)

In [204]:
import time

N_EPOCHS = 1

best_valid_loss = float('inf')


BATCH_SIZE = 16

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, device = device)

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_f = train(model, train_iterator, optimizer, TAG_PAD_IDX)
#     valid_loss, valid_f = evaluate(model, valid_iterator, TAG_PAD_IDX)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
#     if valid_loss < best_valid_loss:
#         best_valid_loss = valid_loss
#         torch.save(model.state_dict(), 'ai_cup-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train F-score: {train_f*100:.2f}%')
#     print(f'\t Val. Loss: {valid_loss:.3f} |  Val. F-score: {valid_f*100:.2f}%')

tensor([[[ 0.0221,  0.0516,  0.1097,  ...,  0.0451, -0.0285,  0.0193],
         [ 0.0267,  0.0340, -0.0666,  ...,  0.1129, -0.1310,  0.1072],
         [-0.0522,  0.0818,  0.0996,  ...,  0.0932, -0.1045,  0.1519],
         ...,
         [-0.0194, -0.0088,  0.0496,  ...,  0.0785,  0.0221, -0.0118],
         [-0.0194, -0.0088,  0.0496,  ...,  0.0785,  0.0221, -0.0118],
         [-0.0194, -0.0088,  0.0496,  ...,  0.0785,  0.0221, -0.0118]],

        [[-0.0528,  0.0090,  0.1162,  ..., -0.0009,  0.1785,  0.0559],
         [ 0.0043,  0.0340, -0.0952,  ...,  0.1467, -0.1867,  0.1472],
         [-0.1470,  0.0040,  0.0361,  ...,  0.1401, -0.0278,  0.0942],
         ...,
         [-0.0550, -0.0179,  0.0786,  ...,  0.0861,  0.0206, -0.0234],
         [-0.0550, -0.0179,  0.0786,  ...,  0.0861,  0.0206, -0.0234],
         [-0.0550, -0.0179,  0.0786,  ...,  0.0861,  0.0206, -0.0234]],

        [[-0.0545, -0.0122,  0.1040,  ..., -0.0345,  0.2568,  0.0940],
         [-0.0101,  0.0334, -0.1111,  ...,  0

In [35]:
def save_model(epoch, model, fname):
    state = {'epoch': epoch, 'model': model}
    
    filename = f'{fname}.tar'
    torch.save(state, filename)
save_model(20, model, 'e20')    

In [134]:
a = []
for i in PHI_TAGS.vocab.itos:
    if i != 'O' and i != '<pad>' and i[2:] not in a:
        a.append(i[2:])
    elif i == 'O' or i == '<pad>':
        a.append(i)
a

['<pad>',
 'O',
 'time',
 'med_exam',
 'name',
 'location',
 'money',
 'family',
 'contact',
 'profession',
 'ID',
 'clinical_event',
 'education',
 'organization']

In [143]:
from torchcrf import CRF
### BEGIN MODIFIED SECTION: CRF OUTPUT ###
def init_crf_transitions(tag_names, imp_value=-100):
    crf = CRF(num_tags=len(PHI_TAGS.vocab))
    num_tags = len(tag_names)
    for i in range(num_tags):
        tag_name = tag_names[i]
        if tag_name[0] in ("I") or tag_name == "<pad>":
            torch.nn.init.constant_(crf.start_transitions[i], imp_value)
    tag_is = {}
    for tag_position in ("B", "I", "O"):
        tag_is[tag_position] = [i for i, tag in enumerate(tag_names) if tag[0] == tag_position]
    
    impossible_transitions_position = {
        "O": "I",        
    }
    for from_tag, to_tag_list in impossible_transitions_position.items():
        to_tags = list(to_tag_list)
        for from_tag_i in tag_is[from_tag]:
            for to_tag in to_tags:
                for to_tag_i in tag_is[to_tag]:
                    torch.nn.init.constant_(
                        crf.transitions[from_tag_i, to_tag_i], imp_value
                    )
    # init impossible B and I transitions to different entity types
    impossible_transitions_tags = {
        "B": "I",
        "I": "I"
    }
    for from_tag, to_tag_list in impossible_transitions_tags.items():
        to_tags = list(to_tag_list)
        for from_tag_i in tag_is[from_tag]:
            for to_tag in to_tags:
                for to_tag_i in tag_is[to_tag]:
                    if tag_names[from_tag_i].split("-")[1] != tag_names[to_tag_i].split("-")[1]:
                        torch.nn.init.constant_(
                            crf.transitions[from_tag_i, to_tag_i], imp_value
                        )
    return crf

def print_crf_transitions(tag_vocab, crf):
#     tags = tag_vocab.itos
#     max_len_tag = max([len(tag) for tag in tags])
#     print("Start and end tag transitions:")
#     print(f"{'TAG'.ljust(max_len_tag)}\tSTART\tEND")
#     for tag, start_prob, end_prob in zip(tags, crf.start_transitions.tolist(), crf.end_transitions.tolist()):
#         print(f"{tag.ljust(max_len_tag)}\t{round(start_prob, 2)}\t{round(end_prob, 2)}")
    print("\nBetween tags transitions:")
    print(tag_vocab.itos)
    persons_i = [i for i, tag in enumerate(tag_vocab.itos) if "time" in tag or 'money' in tag or 'family' in tag or 'contact' in tag or 'profession' in tag or 'ID' in tag\
                 or 'clinical_event' in tag or 'education' in tag or 'organization' in tag or "med_exam" in tag or "name" in tag or 'location' in tag or tag == "O"]
    print(persons_i)
    max_len_tag = max([len(tag) for tag in tag_vocab.itos])
    print(max_len_tag)
    print(tag_vocab.itos[max_len_tag])
#     transitions = crf.transitions
#     to_tags = "TO".rjust(max_len_tag) + "\t" + "\t".join([tag.ljust(max_len_tag) for tag in tags if "time" in tag or 'money' in tag or 'family' in tag or 'contact' in tag or 'profession' in tag or 'ID' in tag\
#                  or 'clinical_event' in tag or 'education' in tag or 'organization' in tag or "med_exam" in tag or "name" in tag or 'location' in tag or tag == "O"])
#     print(to_tags)
#     print("FROM")
#     for from_tag_i, from_tag_probs in enumerate(transitions[persons_i]):
#         to_tag_str = f"{tags[persons_i[from_tag_i]].ljust(max_len_tag)}"
#         for to_tag_prob in from_tag_probs[persons_i]:
#             to_tag_str += f"\t{str(round(to_tag_prob.item(), 2)).ljust(max_len_tag)}"
#         print(to_tag_str)


# crf = CRF(num_tags=len(PHI_TAGS.vocab))
# print_crf_transitions(PHI_TAGS.vocab, crf)  

crf = init_crf_transitions(tag_names=PHI_TAGS.vocab.itos)
print_crf_transitions(PHI_TAGS.vocab, crf)  


Between tags transitions:
['<pad>', 'O', 'I-time', 'B-time', 'B-med_exam', 'B-name', 'B-location', 'I-name', 'B-money', 'I-location', 'I-money', 'B-family', 'I-med_exam', 'B-contact', 'B-profession', 'B-ID', 'B-clinical_event', 'I-ID', 'B-education', 'I-family', 'B-organization']
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
16
B-clinical_event


In [23]:
def tag_sentence(model, device, sentence, text_field, tag_field):
    model.eval()
    
    tokens = [token for token in sentence.split()]
        
    numericalized_tokens = [text_field.vocab.stoi[t] for t in tokens]

    numericalized_chars = []
    char_pad_id = CHAR_PAD_IDX
    
    max_word_len = max([len(token) for token in tokens])

    for token in tokens:
        numericalized_chars.append(
            [CHAR.vocab.stoi[char] for char in token] + [char_pad_id for _ in range(max_word_len - len(token))]
            )

    token_tensor = torch.LongTensor(numericalized_tokens)
    char_tensor = torch.LongTensor(numericalized_chars)
    token_tensor = token_tensor.unsqueeze(0).to(device)
    char_tensor = char_tensor.unsqueeze(0).to(device)
    
    ## MODIFIED PARTS
    # ignore the loss
    predictions, _ = model(token_tensor, char_tensor)    
    predicted_tags = [tag_field.vocab.itos[t] for t in predictions[0]]
    
    return tokens, predicted_tags

In [None]:
sent = '四月 跟 五月 那 之間 ， 因為 他 開刀 也 都 是 在 高雄 開 。'
#t, p = 
for t, p in zip(*tag_sentence(model, device, sent, TEXT, PHI_TAGS)):
    print(f'{t}\t{p}')

四月	B-time
跟	O
五月	B-time
那	O
之間	O
，	O
因為	O
他	O
開刀	O
也	O
都	O
是	O
在	O
高雄	B-location
開	O
。	O


In [None]:
sent = '1月 跟 五月 那 之間 ， 因為 他 開刀 也 都 是 在 高醫 開 。'
#t, p = 
for t, p in zip(*tag_sentence(model, device, sent, TEXT, PHI_TAGS)):
    print(f'{t}\t{p}')

1月	B-time
跟	O
五月	B-time
那	O
之間	O
，	O
因為	O
他	O
開刀	O
也	O
都	O
是	O
在	O
高醫	O
開	O
。	O


In [None]:
sent = '民眾 ： 身高 170 。 醫師 ： 170 這樣 ， 體重 有 增加 嗎 ?'
#t, p = 
for t, p in zip(*tag_sentence(model, device, sent, TEXT, PHI_TAGS)):
    print(f'{t}\t{p}')

民眾	O
：	O
身高	O
170	B-med_exam
。	O
醫師	O
：	O
170	B-med_exam
這樣	O
，	O
體重	O
有	O
增加	O
嗎	O
?	O
