In [10]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [11]:
import torch
import torch.nn as nn
import transformers
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import re
from collections import defaultdict
from collections import Counter
from sklearn.model_selection import KFold

In [12]:
MAX_LEN = 500
TRAIN_BATCH_SIZE = 40
VALID_BATCH_SIZE = 40
EPOCHS = 100
BERT_MODEL = 'xlm-roberta-base'
#BERT_MODEL = 'klue/roberta-base'
TOKENIZER = transformers.AutoTokenizer.from_pretrained(BERT_MODEL, use_fast=True)

DATASET = "ETRI_syl" #MODU21 #ETRI # MODU19_syl #MODU21_syl #ETRI_syl
VALID_FILE = None
if DATASET == "NAVER":
    #TRAIN_FILE = "./data/bio_group1_test.conllu"
    #VALID_FILE = "./data/bio_group1_test.conllu"
    TRAIN_FILE = "./data/train_data.bio.converted"
elif DATASET == "KLUE":
    TRAIN_FILE = "./data/klue-ner-v1.1_train.converted"
    VALID_FILE = "./data/klue-ner-v1.1_dev.converted"
elif DATASET == "MODU19":
    TRAIN_FILE = "./data/NXNE2102008030.converted" #MODU19
elif DATASET == "MODU21":
    TRAIN_FILE = "./data/NXNE2102203310.converted" #MODU21
elif DATASET == "ETRI":
    TRAIN_FILE = "./data/EXOBRAIN_NE_CORPUS_10000.converted" #ETRI15
elif DATASET == "MODU19_syl":
    TRAIN_FILE = "./data/NXNE2102008030.corrected" #MODU19
elif DATASET == "MODU21_syl":
    TRAIN_FILE = "./data/NXNE2102203310.corrected" #MODU21
elif DATASET == "ETRI_syl":
    TRAIN_FILE = "./data/EXOBRAIN_NE_CORPUS_10000.corrected_conll" #ETRI15
else:
    print("Please set your DATASET")

In [13]:
DEVICE=0
ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10)

In [14]:
def normalize(word):
    return re.sub(r"\d", "0", word).lower()


def strong_normalize(word):
    w = ftfy.fix_text(word.lower())
    w = re.sub(r".+@.+", "*EMAIL*", w)
    w = re.sub(r"@\w+", "*AT*", w)
    w = re.sub(r"(https?://|www\.).*", "*url*", w)
    w = re.sub(r"([^\d])\1{2,}", r"\1\1", w)
    w = re.sub(r"([^\d][^\d])\1{2,}", r"\1\1", w)
    w = re.sub(r"``", '"', w)
    w = re.sub(r"''", '"', w)
    w = re.sub(r"\d", "0", w)
    return w


def buildVocab(graphs, cutoff=1):
    wordsCount = Counter()
    charsCount = Counter()
    uposCount = Counter()
    xposCount = Counter()
    relCount = Counter()
    featCount = Counter()
    langCount = Counter()

    for graph in graphs:
        wordsCount.update([node.norm for node in graph.nodes[1:]])
        for node in graph.nodes[1:]:
            charsCount.update(list(node.word))
            featCount.update(node.feats_set)
            #  charsCount.update(list(node.norm))
        uposCount.update([node.upos for node in graph.nodes[1:]])
        xposCount.update([node.xupos for node in graph.nodes[1:]])
        relCount.update([rel for rel in graph.rels[1:]])
        langCount.update([node.lang for node in graph.nodes[1:]])
        

    wordsCount = Counter({w: i for w, i in wordsCount.items() if i >= cutoff})
    print("Vocab containing {} words".format(len(wordsCount)))
    print("Charset containing {} chars".format(len(charsCount)))
    print("UPOS containing {} tags".format(len(uposCount)), uposCount)
    #print("XPOS containing {} tags".format(len(xposCount)), xposCount)
    print("Rels containing {} tags".format(len(relCount)), relCount)
    print("Feats containing {} tags".format(len(featCount)), featCount)
    print("lang containing {} tags".format(len(langCount)), langCount)

    ret = {
        "vocab": list(wordsCount.keys()),
        "wordfreq": wordsCount,
        "charset": list(charsCount.keys()),
        "charfreq": charsCount,
        "upos": list(uposCount.keys()),
        "xpos": list(xposCount.keys()),
        "rels": list(relCount.keys()),
        "feats": list(featCount.keys()),
        "lang": list(langCount.keys()),
    }

    return ret

def shuffled_stream(data):
    len_data = len(data)
    while True:
        for d in random.sample(data, len_data):
            yield d

def shuffled_balanced_stream(data):
    for ds in zip(*[shuffled_stream(s) for s in data]):
        ds = list(ds)
        random.shuffle(ds)
        for d in ds:
            yield d
            
            
def parse_dict(features):
    if features is None or features == "_":
        return {}

    ret = {}
    lst = features.split("|")
    for l in lst:
        k, v = l.split("=")
        ret[k] = v
    return ret


def parse_features(features):
    if features is None or features == "_":
        return set()

    return features.lower().split("|")


class Word:

    def __init__(self, word, upos, lemma=None, xpos=None, feats=None, misc=None, lang=None):
        self.word = word
        self.norm = normalize(word) #strong_normalize(word)
        self.lemma = lemma if lemma else "_"
        self.upos = upos
        self.xpos = xpos if xpos else "_"
        self.xupos = self.upos + "|" + self.xpos
        self.feats = feats if feats else "_"
        self.feats_set = parse_features(self.feats)
        self.misc = misc if misc else "_"
        self.lang = lang if lang else "_"

    def cleaned(self):
        return Word(self.word, "_")

    def clone(self):
        return Word(self.word, self.upos, self.lemma, self.xpos, self.feats, self.misc)

    def __repr__(self):
        return "{}_{}".format(self.word, self.upos)


class DependencyGraph(object):

    def __init__(self, words, tokens=None):
        #  Token is a tuple (start, end, form)
        if tokens is None:
            tokens = []
        self.nodes = np.array([Word("*root*", "*root*")] + list(words))
        self.tokens = tokens
        self.heads = np.array([-1] * len(self.nodes))
        self.rels = np.array(["_"] * len(self.nodes), dtype=object)

    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.nodes = self.nodes
        result.tokens = self.tokens
        result.heads = self.heads.copy()
        result.rels = self.rels.copy()
        return result

    def cleaned(self, node_level=True):
        if node_level:
            return DependencyGraph([node.cleaned() for node in self.nodes[1:]], self.tokens)
        else:
            return DependencyGraph([node.clone() for node in self.nodes[1:]], self.tokens)

    def attach(self, head, tail, rel):
        self.heads[tail] = head
        self.rels[tail] = rel

    def __repr__(self):
        return "\n".join(["{} ->({})  {} ({})".format(str(self.nodes[i]), self.rels[i], self.heads[i], self.nodes[self.heads[i]]) for i in range(len(self.nodes))])


def read_conll(filename, lang_code=None):
    
    print("read_conll with", lang_code)
    def get_word(columns):
        return Word(columns[FORM], columns[UPOS], lemma=columns[LEMMA], xpos=columns[XPOS], feats=columns[FEATS], misc=columns[MISC], lang=lang_code)

    def get_graph(graphs, words, tokens, edges):
        graph = DependencyGraph(words, tokens)
        for (h, d, r) in edges:
            graph.attach(h, d, r)
        graphs.append(graph)

    file = open(filename, "r", encoding="UTF-8")

    graphs = []
    words = []
    tokens = []
    edges = []

    num_sent = 0
    sentence_start = False
    while True:
        line = file.readline()
        if not line:
            if len(words) > 0:
                get_graph(graphs, words, tokens, edges)
                words, tokens, edges = [], [], []
            break
        line = line.rstrip("\r\n")

        # Handle sentence start boundaries
        if not sentence_start:
            # Skip comments
            if line.startswith("#"):
                continue
            # Start a new sentence
            sentence_start = True
        if not line:
            sentence_start = False
            if len(words) > 0:
                if (len(words) < 250):
                    get_graph(graphs, words, tokens, edges)
                words, tokens, edges = [], [], []
                num_sent += 1
            continue

        # Read next token/word
        columns = line.split("\t")

        # Skip empty nodes
        if "." in columns[ID]:
            continue

        # Handle multi-word tokens to save word(s)
        if "-" in columns[ID]:
            start, end = map(int, columns[ID].split("-"))
            tokens.append((start, end + 1, columns[FORM]))

            for _ in range(start, end + 1):
                word_line = file.readline().rstrip("\r\n")
                word_columns = word_line.split("\t")
                words.append(get_word(word_columns))
                if word_columns[HEAD].isdigit():
                    head = int(word_columns[HEAD])
                else:
                    head = -1
                edges.append((head, int(word_columns[ID]), word_columns[DEPREL].split(":")[0]))
        # Basic tokens/words
        else:
            words.append(get_word(columns))
            if columns[HEAD].isdigit():
                head = int(columns[HEAD])
            else:
                head = -1
            edges.append((head, int(columns[ID]), columns[DEPREL].split(":")[0]))

    file.close()

    return graphs

In [15]:
# 2. Data Loader
class CoNLLDataset:
    def __init__(self, graphs, tokenizer, max_len, fullvocab=None):
        self.conll_graphs = graphs
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        self._fullvocab = fullvocab if fullvocab else buildVocab(self.conll_graphs, cutoff=1)
            
        self._upos = {p: i for i, p in enumerate(self._fullvocab["upos"])}
        self._iupos = self._fullvocab["upos"]
        self._xpos = {p: i for i, p in enumerate(self._fullvocab["xpos"])}
        self._ixpos = self._fullvocab["xpos"]
        self._vocab = {w: i+3 for i, w in enumerate(self._fullvocab["vocab"])}
        self._wordfreq = self._fullvocab["wordfreq"]
        self._charset = {c: i+3 for i, c in enumerate(self._fullvocab["charset"])}
        self._charfreq = self._fullvocab["charfreq"]
        self._rels = {r: i for i, r in enumerate(self._fullvocab["rels"])}
        self._irels = self._fullvocab["rels"]
        self._feats = {f: i for i, f in enumerate(self._fullvocab["feats"])}
        self._langs = {r: i+2 for i, r in enumerate(self._fullvocab["lang"])}
        self._ilangs = self._fullvocab["lang"]
        
        #self._posRels = {r: i for i, r in enumerate(self._fullvocab["posRel"])}
        #self._iposRels = self._fullvocab["posRel"]
        
        self._vocab['*pad*'] = 0
        self._charset['*pad*'] = 0
        self._langs['*pad*'] = 0
        
        self._vocab['*root*'] = 1
        self._charset['*whitespace*'] = 1
        
        self._vocab['*unknown*'] = 2
        self._charset['*unknown*'] = 2
        
        
    
    def __len__(self):
        return len(self.conll_graphs)
        
        
    def __getitem__(self, item):
        
        graph = self.conll_graphs[item]
        word_list = [node.word for node in graph.nodes]
        upos_list = [node.upos for node in graph.nodes]
        feat_list = [node.feats for node in graph.nodes]
        
        encoded = self.tokenizer.encode_plus(' '.join(word_list[1:]),
                                             None,
                                             add_special_tokens=True,
                                             max_length = self.max_len,
                                             truncation=True,
                                             pad_to_max_length = True)
        
        ids, mask = encoded['input_ids'], encoded['attention_mask']
        
        bpe_head_mask = [0]; upos_ids = [-1]; feat_ids = [-1] # --> CLS token
        
        for word, upos, feat in zip(word_list[1:], upos_list[1:], feat_list[1:]):
            bpe_len = len(self.tokenizer.tokenize(word))
            head_mask = [1] + [0]*(bpe_len-1)
            bpe_head_mask.extend(head_mask)
            upos_mask = [self._upos.get(upos)] + [-1]*(bpe_len-1)
            upos_ids.extend(upos_mask)
            feat_mask = [self._feats.get(feat.lower(), 2)] + [-1]*(bpe_len-1)
            feat_ids.extend(feat_mask)
            
            #print("head_mask", head_mask)
        
        bpe_head_mask.append(0); upos_ids.append(-1); feat_ids.append(-1) # --> END token
        bpe_head_mask.extend([0] * (self.max_len - len(bpe_head_mask))) ## --> padding by max_len
        upos_ids.extend([-1] * (self.max_len - len(upos_ids))) ## --> padding by max_len
        feat_ids.extend([-1] * (self.max_len - len(feat_ids))) ## --> padding by max_len
        
        return {
                'ids': torch.tensor(ids, dtype=torch.long),
                'mask': torch.tensor(mask, dtype=torch.long),
                'bpe_head_mask': torch.tensor(bpe_head_mask, dtype=torch.long),
                'upos_ids': torch.tensor(upos_ids, dtype=torch.long),
                'feat_ids': torch.tensor(feat_ids, dtype=torch.long)
               }
    
    

  

In [16]:
def f1_score(total_pred, total_targ, noNER_idx):
    
    p = 0 # (retrived SB and real SB) / retrived SB  # The percentage of (the number of correct predictions) / (the number of predction that system predicts as B-SENT)
    r = 0
    f1= 0
    
    np_total_pred = np.array(total_pred)
    np_total_tag = np.array(total_targ)
    
    #Get noPad
    incidence_nopad = np.where(np_total_tag != -1) ## eliminate paddings
    np_total_pred_nopad = np_total_pred[incidence_nopad]
    np_total_tag_nopad = np_total_tag[incidence_nopad]
    
    
    #precision
    incidence_nopad_sb = np.where(np_total_pred_nopad != noNER_idx)
    np_total_pred_nopad_sb = np_total_pred_nopad[incidence_nopad_sb]
    np_total_tag_nopad_sb = np_total_tag_nopad[incidence_nopad_sb]
    
    count_active_tokens_p = len(np_total_pred_nopad_sb)
    count_correct_p = np.count_nonzero((np_total_pred_nopad_sb==np_total_tag_nopad_sb) == True)
    
    '''
    np_total_pred_incid = np_total_pred[incidence_p]
    print("np_total_pred_incid", np_total_pred_incid)
    ids_sb_pred_p = np.where(np_total_pred_incid==1)
    np_total_pred_p = np_total_pred_incid[ids_sb_pred_p]
    np_total_tag_p = np_total_tag[ids_sb_pred_p]
    
    print("ids_sb_pred_p", ids_sb_pred_p)
    print("np_total_pred_p", np_total_pred_p)
    print("np_total_tag_p", np_total_tag_p)
    
    count_active_tokens_p = len(np_total_pred_p)
    count_correct_p = np.count_nonzero((np_total_pred_p==np_total_tag_p) == True)
    '''
    
    print("count_correct_p", count_correct_p)
    print("count_active_tokens_p", count_active_tokens_p)
    
    p = count_correct_p/count_active_tokens_p
    print("precision:", p)

    
    #recall
    ids_sb_pred_r = np.where(np_total_tag_nopad != noNER_idx)
    np_total_pred_r = np_total_pred_nopad[ids_sb_pred_r]
    np_total_tag_r = np_total_tag_nopad[ids_sb_pred_r]
    
    #print("ids_sb_pred_r", ids_sb_pred_r)
    #print("np_total_pred_r", np_total_pred_r)
    #print("np_total_tag_r", np_total_tag_r)
    
    count_active_tokens_r = len(np_total_pred_r)
    count_correct_r = np.count_nonzero((np_total_pred_r==np_total_tag_r) == True)
    
    print("count_active_tokens_r", count_active_tokens_r)
    print("count_correct_r", count_correct_r)
    
    r = count_correct_r/count_active_tokens_r
    print("recall:", r)
    
    
    #F1
    #f1 = 2*(p*r) / (p+r)
    print("F1:", f1)
    
    #count_active_tokens_recall = np.count_nonzero(np.array(total_targ) > -1)
    #print("count_active_tokens_recall", count_active_tokens_recall)
    #count_active_tokens_precision = np.count_nonzero(np.array(total_targ) > -1)
    
    #count_correct = np.count_nonzero((np.array(total_pred)==np.array(total_targ)) == True)
    #print("count_correct",count_correct)
    #print("ACCURACY:", count_correct/count_active_tokens)
    

In [17]:
class XLMRobertaEncoder(nn.Module):
    def __init__(self, num_upos, num_feat):
        super(XLMRobertaEncoder, self).__init__()
        self.xlm_roberta = transformers.AutoModel.from_pretrained(BERT_MODEL) #transformers.XLMRobertaModel.from_pretrained('xlm-roberta-base')
        self.dropout = nn.Dropout(0.33)
        self.linear = nn.Linear(768, num_upos)
        
        self.f_dropout = nn.Dropout(0.33)
        self.f_linear = nn.Linear(768, num_feat)
            
    def forward(self, ids, mask):
        o1, o2 = self.xlm_roberta(ids, mask, return_dict=False)
        
        #apool = torch.mean(o1, 1)
        #mpool, _ = torch.max(o1, 1)
        #cat = torch.cat((apool, mpool), 1)
        #bo = self.dropout(cat)
        p_logits = self.linear(o1)        
        f_logits = self.f_linear(o1)   
        
        return p_logits, f_logits
        


In [18]:
if VALID_FILE is not None:
    train_graphs = read_conll(TRAIN_FILE, 'ko')
    valid_graphs = read_conll(VALID_FILE, 'ko')
else:
    graphs = read_conll(TRAIN_FILE, 'ko')
    if DATASET.startswith("KLUE"):
        valid_graphs = graphs[9001:18000]
        train_graphs = graphs[18001:]
    elif DATASET.startswith("MODU21"):
        valid_graphs = graphs[69485:]
        train_graphs = graphs[:68400]
    elif DATASET.startswith("MODU19"):
        valid_graphs = graphs[round(len(graphs)*0.8):]
        train_graphs = graphs[:round(len(graphs)*0.8)]
    elif DATASET.startswith("ETRI"):
        valid_graphs = graphs[round(len(graphs)*0.8):]
        train_graphs = graphs[:round(len(graphs)*0.8)]
    else:
        print("Please set the dataset among [KLUE, MODU21, MODU19, ETRI, NAVER]")

train_dataset = CoNLLDataset(graphs=train_graphs, tokenizer=TOKENIZER, max_len=MAX_LEN)
train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
valid_dataset = CoNLLDataset(graphs=valid_graphs, tokenizer=TOKENIZER, max_len=MAX_LEN, fullvocab=train_dataset._fullvocab)
valid_loader = torch.utils.data.DataLoader(valid_dataset, num_workers=4, batch_size=VALID_BATCH_SIZE, shuffle=False)

read_conll with ko
Vocab containing 1504 words
Charset containing 1539 chars
UPOS containing 1 tags Counter({'_': 388005})
Rels containing 1 tags Counter({'_': 388005})
Feats containing 11 tags Counter({'o': 306949, 'i-og': 20055, 'i-ps': 17752, 'i-dt': 11734, 'b-ps': 7602, 'i-lc': 6638, 'b-og': 6347, 'b-dt': 4032, 'b-lc': 3931, 'i-ti': 2239, 'b-ti': 726})
lang containing 1 tags Counter({'ko': 388005})


In [19]:
num_upos = len(train_dataset._upos)
num_feat = len(train_dataset._feats)
model = XLMRobertaEncoder(num_upos, num_feat)
model = nn.DataParallel(model)
model = model.cuda()

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
lr = 0.000005
optimizer = AdamW(model.parameters(), lr=lr)

In [None]:
def train_loop_fn(train_loader, model, optimizer, DEVICE, scheduler=None):
    model.train()
    
    p_total_pred = []
    p_total_targ = []
    p_total_loss = []
    
    f_total_pred = []
    f_total_targ = []
    f_total_loss = []
    
    for idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
        optimizer.zero_grad()
        
        p_logits, f_logits = model(batch['ids'].cuda(), batch['mask'].cuda())
        
        #UPOS
        b,s,l = p_logits.size()
        #print(p_logits.view(b*s,l), p_logits.view(b*s,l).size())
        #print(batch['upos_ids'].cuda().view(b*s), batch['upos_ids'].cuda().view(b*s).size())
        p_loss = loss_fn(p_logits.view(b*s,l), batch['upos_ids'].cuda().view(b*s))
        p_total_loss.append(p_loss.item())
        p_total_pred.extend(torch.argmax(p_logits.view(b*s,l), 1).cpu().tolist())
        p_total_targ.extend(batch['upos_ids'].cuda().view(b*s).cpu().tolist())
        
        #FEAT
        b,s,l = f_logits.size()
        f_loss = loss_fn(f_logits.view(b*s,l), batch['feat_ids'].cuda().view(b*s))
        f_total_loss.append(f_loss.item())
        f_total_pred.extend(torch.argmax(f_logits.view(b*s,l), 1).cpu().tolist())
        f_total_targ.extend(batch['feat_ids'].cuda().view(b*s).cpu().tolist())
        
        #loss = p_loss+f_loss
        loss = f_loss
        loss.backward()
        optimizer.step()
        
    count_active_tokens = np.count_nonzero(np.array(p_total_targ) > -1)
    count_correct = np.count_nonzero((np.array(p_total_pred)==np.array(p_total_targ)) == True)
    print("TRAINING POS ACCURACY:", count_correct/count_active_tokens)
    
    count_active_tokens = np.count_nonzero(np.array(f_total_targ) > -1)
    count_correct = np.count_nonzero((np.array(f_total_pred)==np.array(f_total_targ)) == True)
    f1_score(f_total_pred, f_total_targ, train_dataset._feats.get('o', 2))
    print("TRAINING FEAT ACCURACY:", count_correct/count_active_tokens)


In [None]:
def valid_loop_fn(dev_loader, model, DEVICE):
    model.eval()
    
    p_total_pred = []
    p_total_targ = []
    p_total_loss = []
    
    f_total_pred = []
    f_total_targ = []
    f_total_loss = []
    with torch.no_grad():
        for idx, batch in tqdm(enumerate(dev_loader), total=len(dev_loader)):

            p_logits, f_logits = model(batch['ids'].cuda(), batch['mask'].cuda())

            #UPOS
            b,s,l = p_logits.size()
            p_loss = loss_fn(p_logits.view(b*s,l), batch['upos_ids'].cuda().view(b*s))
            p_total_loss.append(p_loss.item())
            p_total_pred.extend(torch.argmax(p_logits.view(b*s,l), 1).cpu().tolist())
            p_total_targ.extend(batch['upos_ids'].cuda().view(b*s).cpu().tolist())

            #FEAT
            b,s,l = f_logits.size()
            f_loss = loss_fn(f_logits.view(b*s,l), batch['feat_ids'].cuda().view(b*s))
            f_total_loss.append(f_loss.item())
            f_total_pred.extend(torch.argmax(f_logits.view(b*s,l), 1).cpu().tolist())
            f_total_targ.extend(batch['feat_ids'].cuda().view(b*s).cpu().tolist())

            loss = p_loss+f_loss
        
    count_active_tokens = np.count_nonzero(np.array(p_total_targ) > -1)
    count_correct = np.count_nonzero((np.array(p_total_pred)==np.array(p_total_targ)) == True)
    print("VALIDATION POS ACCURACY:", count_correct/count_active_tokens)
    
    count_active_tokens = np.count_nonzero(np.array(f_total_targ) > -1)
    count_correct = np.count_nonzero((np.array(f_total_pred)==np.array(f_total_targ)) == True)
    f1_score(f_total_pred, f_total_targ, train_dataset._feats.get('o', 2))
    print("VALIDATION FEAT ACCURACY:", count_correct/count_active_tokens)


In [None]:
print(train_dataset._feats.get('o'))

0


In [29]:
for idx in range(EPOCHS):
    train_loop_fn(train_loader, model, optimizer, DEVICE)
    valid_loop_fn(valid_loader, model, DEVICE)

  0%|          | 1/351 [00:03<23:05,  3.96s/it]  1%|          | 2/351 [00:07<23:09,  3.98s/it]  1%|          | 3/351 [00:11<21:33,  3.72s/it]  1%|          | 4/351 [00:14<20:46,  3.59s/it]  1%|▏         | 5/351 [00:18<20:22,  3.53s/it]  2%|▏         | 6/351 [00:21<20:04,  3.49s/it]  2%|▏         | 7/351 [00:25<19:57,  3.48s/it]  2%|▏         | 8/351 [00:28<19:45,  3.46s/it]  3%|▎         | 9/351 [00:31<19:35,  3.44s/it]  3%|▎         | 10/351 [00:35<19:30,  3.43s/it]  3%|▎         | 11/351 [00:38<19:23,  3.42s/it]  3%|▎         | 12/351 [00:42<19:18,  3.42s/it]  4%|▎         | 13/351 [00:45<19:13,  3.41s/it]  4%|▍         | 14/351 [00:48<19:12,  3.42s/it]  4%|▍         | 15/351 [00:52<19:07,  3.41s/it]  5%|▍         | 16/351 [00:55<19:02,  3.41s/it]  5%|▍         | 17/351 [00:59<18:58,  3.41s/it]  5%|▌         | 18/351 [01:02<18:54,  3.41s/it]  5%|▌         | 19/351 [01:05<18:52,  3.41s/it]  6%|▌         | 20/351 [01:09<18:54,  3.43s/it]  6%|▌         | 21/351 [01:

TRAINING POS ACCURACY: 0.8209902486085848
count_correct_p 33713
count_active_tokens_p 51028
precision: 0.6606764913380889
count_active_tokens_r 84556
count_correct_r 33713
recall: 0.39870618288471543
F1: 0
TRAINING FEAT ACCURACY: 0.9052323335789189


  1%|          | 1/84 [00:01<01:47,  1.30s/it]  2%|▏         | 2/84 [00:02<01:40,  1.23s/it]  4%|▎         | 3/84 [00:03<01:38,  1.21s/it]  5%|▍         | 4/84 [00:04<01:35,  1.20s/it]  6%|▌         | 5/84 [00:06<01:34,  1.20s/it]  7%|▋         | 6/84 [00:07<01:32,  1.19s/it]  8%|▊         | 7/84 [00:08<01:31,  1.19s/it] 10%|▉         | 8/84 [00:09<01:30,  1.19s/it] 11%|█         | 9/84 [00:10<01:29,  1.19s/it] 12%|█▏        | 10/84 [00:11<01:27,  1.19s/it] 13%|█▎        | 11/84 [00:13<01:26,  1.19s/it] 14%|█▍        | 12/84 [00:14<01:25,  1.19s/it] 15%|█▌        | 13/84 [00:15<01:24,  1.19s/it] 17%|█▋        | 14/84 [00:16<01:23,  1.19s/it] 18%|█▊        | 15/84 [00:17<01:22,  1.19s/it] 19%|█▉        | 16/84 [00:19<01:20,  1.19s/it] 20%|██        | 17/84 [00:20<01:19,  1.19s/it] 21%|██▏       | 18/84 [00:21<01:18,  1.19s/it] 23%|██▎       | 19/84 [00:22<01:17,  1.19s/it] 24%|██▍       | 20/84 [00:23<01:16,  1.19s/it] 25%|██▌       | 21/84 [00:25<01:14,  1.19s/it]

VALIDATION POS ACCURACY: 0.943528792839689
count_correct_p 17500
count_active_tokens_p 23377
precision: 0.7485990503486333
count_active_tokens_r 25013
count_correct_r 17500
recall: 0.6996361891816255
F1: 0
VALIDATION FEAT ACCURACY: 0.9408360618289804


  0%|          | 1/351 [00:03<20:50,  3.57s/it]  1%|          | 2/351 [00:06<20:11,  3.47s/it]  1%|          | 3/351 [00:10<20:16,  3.50s/it]  1%|          | 4/351 [00:13<19:58,  3.45s/it]  1%|▏         | 5/351 [00:17<19:49,  3.44s/it]  2%|▏         | 6/351 [00:20<19:47,  3.44s/it]  2%|▏         | 7/351 [00:24<19:38,  3.43s/it]  2%|▏         | 8/351 [00:27<19:32,  3.42s/it]  3%|▎         | 9/351 [00:30<19:32,  3.43s/it]  3%|▎         | 10/351 [00:34<19:26,  3.42s/it]  3%|▎         | 11/351 [00:37<19:20,  3.41s/it]  3%|▎         | 12/351 [00:41<19:16,  3.41s/it]  4%|▎         | 13/351 [00:44<19:11,  3.41s/it]  4%|▍         | 14/351 [00:48<19:09,  3.41s/it]  4%|▍         | 15/351 [00:51<19:20,  3.45s/it]  5%|▍         | 16/351 [00:54<19:11,  3.44s/it]  5%|▍         | 17/351 [00:58<19:04,  3.43s/it]  5%|▌         | 18/351 [01:01<18:59,  3.42s/it]  5%|▌         | 19/351 [01:05<18:54,  3.42s/it]  6%|▌         | 20/351 [01:08<18:49,  3.41s/it]  6%|▌         | 21/351 [01:

TRAINING POS ACCURACY: 0.9562676213308923
count_correct_p 65291
count_active_tokens_p 78673
precision: 0.8299035247162305
count_active_tokens_r 84556
count_correct_r 65291
recall: 0.7721628270022234
F1: 0
TRAINING FEAT ACCURACY: 0.9615882636193639


  1%|          | 1/84 [00:01<01:48,  1.31s/it]  2%|▏         | 2/84 [00:02<01:41,  1.23s/it]  4%|▎         | 3/84 [00:03<01:38,  1.21s/it]  5%|▍         | 4/84 [00:04<01:36,  1.20s/it]  6%|▌         | 5/84 [00:06<01:34,  1.20s/it]  7%|▋         | 6/84 [00:07<01:33,  1.19s/it]  8%|▊         | 7/84 [00:08<01:31,  1.19s/it] 10%|▉         | 8/84 [00:09<01:30,  1.19s/it] 11%|█         | 9/84 [00:10<01:29,  1.19s/it] 12%|█▏        | 10/84 [00:11<01:28,  1.19s/it] 13%|█▎        | 11/84 [00:13<01:26,  1.19s/it] 14%|█▍        | 12/84 [00:14<01:25,  1.19s/it] 15%|█▌        | 13/84 [00:15<01:24,  1.19s/it] 17%|█▋        | 14/84 [00:16<01:23,  1.19s/it] 18%|█▊        | 15/84 [00:17<01:21,  1.19s/it] 19%|█▉        | 16/84 [00:19<01:20,  1.19s/it] 20%|██        | 17/84 [00:20<01:19,  1.19s/it] 21%|██▏       | 18/84 [00:21<01:18,  1.19s/it] 23%|██▎       | 19/84 [00:22<01:17,  1.19s/it] 24%|██▍       | 20/84 [00:23<01:16,  1.19s/it] 25%|██▌       | 21/84 [00:25<01:14,  1.19s/it]

VALIDATION POS ACCURACY: 0.9631947280618567
count_correct_p 21151
count_active_tokens_p 24297
precision: 0.8705189941144997
count_active_tokens_r 25013
count_correct_r 21151
recall: 0.8456002878503178
F1: 0
VALIDATION FEAT ACCURACY: 0.9660397472017056


  0%|          | 1/351 [00:03<20:27,  3.51s/it]  1%|          | 2/351 [00:06<20:01,  3.44s/it]  1%|          | 3/351 [00:10<19:54,  3.43s/it]  1%|          | 4/351 [00:13<19:50,  3.43s/it]  1%|▏         | 5/351 [00:17<19:44,  3.42s/it]  2%|▏         | 6/351 [00:20<19:39,  3.42s/it]  2%|▏         | 7/351 [00:23<19:35,  3.42s/it]  2%|▏         | 8/351 [00:27<19:30,  3.41s/it]  3%|▎         | 9/351 [00:30<19:26,  3.41s/it]  3%|▎         | 10/351 [00:34<19:22,  3.41s/it]  3%|▎         | 11/351 [00:37<19:18,  3.41s/it]  3%|▎         | 12/351 [00:41<19:14,  3.41s/it]  4%|▎         | 13/351 [00:44<19:10,  3.40s/it]  4%|▍         | 14/351 [00:47<19:07,  3.41s/it]  4%|▍         | 15/351 [00:51<19:03,  3.40s/it]  5%|▍         | 16/351 [00:54<19:06,  3.42s/it]  5%|▍         | 17/351 [00:58<19:01,  3.42s/it]  5%|▌         | 18/351 [01:01<18:55,  3.41s/it]  5%|▌         | 19/351 [01:04<19:00,  3.44s/it]  6%|▌         | 20/351 [01:08<18:53,  3.43s/it]  6%|▌         | 21/351 [01:

In [None]:
valid_loop_fn(valid_loader, model, DEVICE)

In [None]:
valid_dataset = CoNLLDataset(graphs=valid_graphs, tokenizer=TOKENIZER, max_len=MAX_LEN, fullvocab=train_dataset._fullvocab)
valid_loader = torch.utils.data.DataLoader(valid_dataset, num_workers=4, batch_size=1, shuffle=False)

In [None]:
with torch.no_grad():
    for idx, batch in tqdm(enumerate(valid_loader), total=len(valid_loader)):

        p_logits, f_logits = model(batch['ids'].cuda(), batch['mask'].cuda())

        #UPOS
        b,s,l = p_logits.size()
        p_loss = loss_fn(p_logits.view(b*s,l), batch['upos_ids'].cuda().view(b*s))
        p_total_loss.append(p_loss.item())
        p_total_pred.extend(torch.argmax(p_logits.view(b*s,l), 1).cpu().tolist())
        p_total_targ.extend(batch['upos_ids'].cuda().view(b*s).cpu().tolist())

        #FEAT
        b,s,l = f_logits.size()
        f_loss = loss_fn(f_logits.view(b*s,l), batch['feat_ids'].cuda().view(b*s))
        f_total_loss.append(f_loss.item())
        f_total_pred.extend(torch.argmax(f_logits.view(b*s,l), 1).cpu().tolist())
        f_total_targ.extend(batch['feat_ids'].cuda().view(b*s).cpu().tolist())