In [1]:
import json
import torch
from pandas import read_parquet
from transformers import BertModel, BertTokenizerFast
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "data/mBERT/fine"

tokenizer = BertTokenizerFast.from_pretrained(model_path)

In [4]:
train_data = read_parquet("data/merge/train.parquet")
dev_data = read_parquet("data/merge/dev.parquet")
test_data = read_parquet("data/merge/test.parquet")

with open("data/merge/tags_2_idx.json", "r") as f:
    tags2idx = json.load(f)

with open("data/merge/idx_2_tags.json", "r") as f:
    idx2tags = json.load(f)

with open("data/merge/chars2idx.json", "r") as f:
    chars2idx = json.load(f)

In [5]:
sentences_train = train_data["tokens"].values.tolist()
tags_train = train_data["ner_tags"].values.tolist()

sentences_dev = dev_data["tokens"].values.tolist()
tags_dev = dev_data["ner_tags"].values.tolist()

sentences_test = test_data["tokens"].values.tolist()
tags_test = test_data["ner_tags"].values.tolist()

In [58]:
MAX_WORD_LEN = 30
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
def align_label(tokenized_input, tags, tags_2_idx, idx_2_tags, label_all_tokens=True): 
    # tokenized_input refers to the sequences after tokenized
    # tags refers to the original tags from dataset
    # False:只为每个拆分token的第一个子词提供一个标签。
    # True:在属于同一 token 的所有子词中提供相同的标签。
    word_ids = tokenized_input.word_ids()
    previous_word_idx = None
    label_ids = []   
    for word_idx in word_ids:
        if word_idx is None:
            label_ids.append(-100)                
        elif word_idx != previous_word_idx:
            try:
                label_ids.append(tags[word_idx])
            except:
                label_ids.append(-100) 
        else:
            label_ids.append(tags[word_idx] if label_all_tokens else -100)
        previous_word_idx = word_idx      
    return label_ids

def generate_tokenized_input(sentences_raw, tags_raw):
    sentences = []
    tags = []
    for i in range(len(sentences_raw)):
        tokenized_text = tokenizer(sentences_raw[i].tolist(), padding="max_length", max_length=512, truncation=True, return_tensors="pt", is_split_into_words=True)
        extended_tags = align_label(tokenized_text, tags_raw[i], tags2idx, idx2tags)
        sentences.append(tokenized_text)
        tags.append(extended_tags)
    return sentences, tags

def generate_tokenized_input_with_words(sentences_raw, tags_raw):
    sentences = []
    tags = []
    words = []
    chars = []
    for i in range(len(sentences_raw)):
        tokenized_text = tokenizer(sentences_raw[i].tolist(), padding="max_length", max_length=512, truncation=True, return_tensors="pt", is_split_into_words=True)
        extended_tags = align_label(tokenized_text, tags_raw[i], tags2idx, idx2tags)
        sentences.append(tokenized_text)
        tags.append(extended_tags)
        token_ids = tokenized_text["input_ids"][0]
        token_words = tokenizer.convert_ids_to_tokens(token_ids)
        words.append(token_words)
        char_ids = torch.zeros(512, MAX_WORD_LEN)
        for i in range(len(token_words)):
            for j in range(len(token_words[i])):
                char_ids[i][j] = chars2idx.get(token_words[i][j], chars2idx['<unk>'])
        chars.append(char_ids)
    return sentences, tags, words, chars

In [8]:
train_sentences, train_tags = generate_tokenized_input(sentences_train, tags_train)
dev_sentences, dev_tags = generate_tokenized_input(sentences_dev, tags_dev)
test_sentences, test_tags = generate_tokenized_input(sentences_test, tags_test)

In [None]:
# token_ids = train_sentences[0]["input_ids"][0]
# token_words = tokenizer.convert_ids_to_tokens(token_ids)
# print(token_words)
# char_ids = torch.zeros(512, MAX_WORD_LEN)
# for i in range(len(token_words)):
#     for j in range(len(token_words[i])):
#         char_ids[i][j] = chars2idx.get(token_words[i][j], chars2idx['<unk>'])

In [None]:
train_sentences1, train_tags1, train_words, train_chars = generate_tokenized_input_with_words(sentences_train, tags_train)

In [12]:
bert = BertModel.from_pretrained("data/mBert/fine").to(device)

  return self.fget.__get__(instance, owner)()


In [55]:
class MultilingualDataset(Dataset):
    def __init__(self, sentences, labels):
        self.sentences = sentences
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def get_text_tokenized(self, idx):
        return self.sentences[idx]

    def get_labels(self, idx):
        return torch.LongTensor(self.labels[idx])

    def __getitem__(self, idx):
        text_tokenized = self.get_text_tokenized(idx)
        labels = self.get_labels(idx)
        return text_tokenized, labels.unsqueeze(0)  # shap: [1, 512]
    
def collate_fn(batch):
    text_tokenized_seqs, labels_seqs = zip(*batch)
    B = len(labels_seqs)
    batch_input_ids = []
    batch_attention_masks = []
    batch_label_seqs = torch.concat(labels_seqs)
    for i in range(B):
        batch_input_ids.append(text_tokenized_seqs[i]["input_ids"])
        batch_attention_masks.append(text_tokenized_seqs[i]["attention_mask"])
    batch_input_ids = torch.concat(batch_input_ids)
    batch_attention_masks = torch.concat(batch_attention_masks)
    bert_output = bert(batch_input_ids, batch_attention_masks)
    bert_embeddings = bert_output["last_hidden_state"]
    return bert_embeddings, batch_label_seqs

In [48]:
train_dataset = MultilingualDataset(train_sentences, train_tags)
dev_dataset = MultilingualDataset(dev_sentences, dev_tags)
test_dataset = MultilingualDataset(test_sentences, test_tags)

In [14]:
output = bert(train_sentences[0]["input_ids"], attention_mask=train_sentences[0]["attention_mask"])

In [57]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn
)
for data in train_loader:
    bert_embeddings, labels = data
    print(bert_embeddings.shape, labels.shape)
    break

torch.Size([16, 512, 768]) torch.Size([16, 512])
