# Libraries

In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from tqdm import tqdm
import glob
import json
from os.path import join, exists
from os import mkdir

import pandas as pd
import numpy as np

from torch.utils.data import TensorDataset, DataLoader

#  Competition Score Script RuREBus

In [2]:
# Competition Score Script RuREBus
# https://github.com/dialogue-evaluation/RuREBus/blob/master/eval_scripts/evaluate_ners.py

def cacl_ner_tp_fp_fn(true_ners, pred_ners):

    true_positive = 0
    false_positive = 0
    false_negative = 0

    i = 0
    j = 0
    while i != len(true_ners) and j != len(pred_ners):
        if true_ners[i] == pred_ners[j]:
            true_positive += 1
            i += 1
            j += 1
            continue
        if true_ners[i][1] >= pred_ners[j][2]:
            false_positive += 1
            j += 1
            continue
        if true_ners[i][2] <= pred_ners[j][1]:
            false_negative += 1
            i += 1
            continue
        if true_ners[i][1] < pred_ners[j][1]:
            false_negative += 1
            i += 1
            continue
        if true_ners[i][1] > pred_ners[j][1]:
            false_positive += 1
            j += 1
            continue

        false_positive += 1
        false_negative += 1
        j += 1
        i += 1

    false_negative += len(true_ners) - i
    false_positive += len(pred_ners) - j

    return true_positive, false_positive, false_negative

def compute_precision_and_recall(true_positive, false_positive, false_negative):

    if false_positive + true_positive > 0:
        precision = float(true_positive) / (true_positive + false_positive)
    else:
        precision = 0
    if false_negative + true_positive > 0:
        recall = float(true_positive) / (true_positive + false_negative)
    else:
        recall = 0
    return recall, precision

In [3]:
# converting from mine to evaluation format with Script RuREBus
def for_m(tags):
    if tags == []:
        return []
    start = 0
    end = 1
    ent = tags[0]
    count = 1
    l = []
    while count != len(tags):
        if tags[count] == ent:
            end += 1
        else:
            l+= [[ent,start,end]]
            start = end
            end = start + 1
            ent = tags[count]
        count += 1
    l+= [[ent,start,end]]
    return l

# Models

In [4]:
class CRF(nn.Module):

    def __init__(self, in_features, num_tags):
        super(CRF, self).__init__()

        self.num_tags = num_tags + 2
        self.start_idx = self.num_tags - 2
        self.stop_idx = self.num_tags - 1
        self.fc = nn.Linear(in_features, self.num_tags)
        self.transitions = nn.Parameter(torch.randn(self.num_tags, self.num_tags), requires_grad=True)
        self.transitions.data[self.start_idx, :] = -1e4
        self.transitions.data[:, self.stop_idx] = -1e4

    def forward(self, features, masks):
        features = self.fc(features)
        return self.__viterbi_decode(features, masks[:, :features.size(1)].float())

    def loss(self, features, ys, masks):
        features = self.fc(features)
        L = features.size(1)
        masks_ = masks[:, :L].float()
        forward_score = self.__forward_algorithm(features, masks_)
        gold_score = self.__score_sentence(features, ys[:, :L].long(), masks_)
        loss = (forward_score - gold_score).mean()
        return loss

    def __score_sentence(self, features, tags, masks):
        B, L, C = features.shape
        emit_scores = features.gather(dim=2, index=tags.unsqueeze(-1)).squeeze(-1)
        start_tag = torch.full((B, 1), self.start_idx, dtype=torch.long, device=tags.device)
        tags = torch.cat([start_tag, tags], dim=1)  # [B, L+1]
        trans_scores = self.transitions[tags[:, 1:], tags[:, :-1]]
        last_tag = tags.gather(dim=1, index=masks.sum(1).long().unsqueeze(1)).squeeze(1)  # [B]
        last_score = self.transitions[self.stop_idx, last_tag]
        score = ((trans_scores + emit_scores) * masks).sum(1) + last_score
        return score

    def __viterbi_decode(self, features, masks):

        B, L, C = features.shape
        bps = torch.zeros(B, L, C, dtype=torch.long, device=features.device) 
        max_score = torch.full((B, C), -1e4, device=features.device) 
        max_score[:, self.start_idx] = 0
        for t in range(L):
            mask_t = masks[:, t].unsqueeze(1)
            emit_score_t = features[:, t]  
            acc_score_t = max_score.unsqueeze(1) + self.transitions
            acc_score_t, bps[:, t, :] = acc_score_t.max(dim=-1)
            acc_score_t += emit_score_t
            max_score = acc_score_t * mask_t + max_score * (1 - mask_t) 
        max_score += self.transitions[self.stop_idx]
        best_score, best_tag = max_score.max(dim=-1)
        best_paths = []
        bps = bps.cpu().numpy()
        for b in range(B):
            best_tag_b = best_tag[b].item()
            seq_len = int(masks[b, :].sum().item())
            best_path = [best_tag_b]
            for bps_t in reversed(bps[b, :seq_len]):
                best_tag_b = bps_t[best_tag_b]
                best_path.append(best_tag_b)
            best_paths.append(best_path[-2::-1])
        return best_score, best_paths

    def __forward_algorithm(self, features, masks):

        B, L, C = features.shape
        scores = torch.full((B, C), -1e4, device=features.device)
        scores[:, self.start_idx] = 0.
        trans = self.transitions.unsqueeze(0) 
        for t in range(L):
            emit_score_t = features[:, t].unsqueeze(2) 
            score_t = scores.unsqueeze(1) + trans + emit_score_t  
            score_t = score_t.max(-1)[0] + (score_t - score_t.max(-1)[0].unsqueeze(-1)).exp().sum(-1).log()
            mask_t = masks[:, t].unsqueeze(1)  
            scores = score_t * mask_t + scores * (1 - mask_t)
        scores = scores + self.transitions[self.stop_idx]
        scores =  scores.max(-1)[0] + (scores - scores.max(-1)[0].unsqueeze(-1)).exp().sum(-1).log()
        return scores

In [5]:
class BiLSTMCRF(nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim, hidden_dim):
        super(BiLSTMCRF, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tagset_size = tagset_size
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        self.crf = CRF(hidden_dim, self.tagset_size)

    def __build_features(self, sentences):
        masks = sentences.gt(0)
        embeds = self.embedding(sentences.long())
        seq_length = masks.sum(1)
        sorted_seq_length, perm_idx = seq_length.sort(descending=True)
        embeds = embeds[perm_idx, :]
        pack_sequence = pack_padded_sequence(embeds, lengths=sorted_seq_length.to('cpu'), batch_first=True)
        packed_output, _ = self.lstm(pack_sequence)
        lstm_out, _ = pad_packed_sequence(packed_output, batch_first=True)
        _, unperm_idx = perm_idx.sort()
        lstm_out = lstm_out[unperm_idx, :]
        return lstm_out, masks

    def loss(self, xs, tags):
        features, masks = self.__build_features(xs)
        loss = self.crf.loss(features, tags, masks=masks)
        return loss

    def forward(self, xs):
        features, masks = self.__build_features(xs)
        scores, tag_seq = self.crf(features, masks)
        return scores, tag_seq

# Read data

In [6]:
def read_data(path):
    data = []
    for i in tqdm(glob.glob(path)):
        try:
            with open(i) as f:
                f = f.read()
            i[:-3] + "ann"
            df = pd.read_csv(i[:-3] + "ann",  sep="\t", header=None)
            t = np.array([i.split() for i in df[1]])
            df["entity"] = t[:,0]
            df["start"] = t[:,1]
            df["end"] = t[:,2]
            del df[1], df[0]
            df =  df[df[2].notnull()]
            df =  df.reset_index(drop=True)
            starts = []
            entity = []
            if "0" not in df["start"]:
                starts =  [0] + [int(i) for i in df["start"]] + [int(i) for i in df["end"]]
                starts = sorted(starts)
            else:
                assert 1==0 , "FUCK"
            text = []
            enteties = []
            all_en_starts = [int(i) for i in df["start"]]
            for n,start in enumerate(starts):
                if len(starts)-1==n:
                    break
                new_part = f[start: starts[n+1]].split()
                text += new_part
                if start in all_en_starts:
                    enteties += [df["entity"][all_en_starts.index(start)] for j in range(len(new_part))]
                else:
                    enteties += ["O" for j in range(len(new_part))]

            data += [(text, enteties)]
        except:
            print(i)
    return data

## train

In [7]:
training_data = read_data("/kaggle/input/train-data-ner/train_data/*.txt")

 21%|██        | 39/188 [00:01<00:03, 43.82it/s]

/kaggle/input/train-data-ner/train_data/31339061024501948020025_29_part_0.txt


 70%|██████▉   | 131/188 [00:03<00:00, 58.78it/s]

/kaggle/input/train-data-ner/train_data/31339011023301254426027_6_part_0.txt


100%|██████████| 188/188 [00:04<00:00, 39.50it/s]


## validation

In [8]:
validation_data = read_data("/kaggle/input/validdataner/*.txt")

100%|██████████| 30/30 [00:00<00:00, 52.93it/s]


# Prepare data

In [9]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"
list_of_tags = ["O", "OUT", "ACT", "BIN", "CMP", "ECO", "INST", "MET", "SOC", "QUA", "N", START_TAG, STOP_TAG]
dict_of_tags = {"O":0, "OUT":1, "ACT":2, "BIN":3, "CMP":4, "ECO":5, "INST":6, "MET":7, "SOC":8, "QUA":9, "N":10, START_TAG:11, STOP_TAG:12}


training_data = training_data

word_to_ix = {}
word_to_ix["<PAD>"] = len(word_to_ix)

for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
            
for sentence, tags in validation_data:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)          
            
word_to_ix["<OOV>"] = len(word_to_ix)

In [10]:
l = []
for i in training_data:
    l+= [len(i[0])]
for i in validation_data:
    l+= [len(i[0])]
max(l)

8025

In [11]:
max_seq_len = max(l)

x_train, y_train, x_val, y_val, x_test, y_test = [], [], [], [], [], []

for i in training_data:
    xi_train = [int(word_to_ix[n]) for n in i[0]][:max_seq_len] + [0]*(max_seq_len-len(i[1]))
    yi_train = [dict_of_tags[n] for n in i[1]][:max_seq_len] + [0]*(max_seq_len-len(i[1]))
    x_train += [xi_train]
    y_train += [yi_train]
x_train = torch.Tensor(x_train).type(torch.LongTensor)
y_train = torch.Tensor(y_train).type(torch.LongTensor)

for i in validation_data:
    xi_val = [int(word_to_ix[n]) for n in i[0]][:max_seq_len] + [0]*(max_seq_len-len(i[1]))
    yi_val = [dict_of_tags[n] for n in i[1]][:max_seq_len] + [0]*(max_seq_len-len(i[1]))
    x_val += [xi_val]
    y_val += [yi_val]
x_val = torch.Tensor(x_val).type(torch.LongTensor)
y_val = torch.Tensor(y_val).type(torch.LongTensor)


print("Train samples: ", len(x_train))
print("Val samples: ",len(y_val))

train_dl = DataLoader(TensorDataset(x_train, y_train), batch_size=1000, shuffle=True)
valid_dl = DataLoader(TensorDataset(x_val, y_val), batch_size=1000 * 2)

Train samples:  186
Val samples:  30


# Train

In [12]:
model = BiLSTMCRF(len(word_to_ix), len(list_of_tags), embedding_dim=100, hidden_dim=128)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
all_valid_f1 = []
all_val_loss = []
all_train_loss = []
for epoch in range(1000):
    model.train()
    loss_ep = 0
    for bi, (xb, yb) in enumerate(train_dl):
        model.zero_grad()
        loss = model.loss(xb.to(device), yb.to(device))
        loss.backward()
        optimizer.step()
        loss_ep += loss
    all_train_loss += [float(loss_ep)/len(train_dl)]

    model.eval()
    with torch.no_grad():
        f1_ep = []
        loss_ep = 0
        for xb, yb in valid_dl:
            loss = model.loss(xb.to(device), yb.to(device))
            loss_ep += loss
            xxx = (len(xb), )
            loss= (loss.cpu(), )
            preds = model(xb.to(device))[1]
            for x1,x2 in zip(preds, yb):
                x2_l = [int(i) for i in x2]
                pred = for_m(list(filter(("uuu").__ne__, ["uuu" if x2_l[i]==0 else x1[i] for i in range(len(x1))])))
                true = for_m(list(filter((0).__ne__, x2_l)))
                precision, recall = compute_precision_and_recall(*cacl_ner_tp_fp_fn(true,pred))
                if precision==0 and recall==0:
                    f_measure = 0
                else:
                    f_measure = 2 * precision * recall / (precision + recall)
                f1_ep += [f_measure]
    print(f"epoch { epoch}/f1_valid : ",sum(f1_ep)/len(f1_ep))
    all_val_loss += [float(loss_ep)/len(valid_dl)]
    all_valid_f1 += [sum(f1_ep)/len(f1_ep)]
    with open("all_train_loss.txt", "w") as w:
        w.write(json.dumps(all_train_loss))
    with open("all_valid_f1.txt", "w") as w:
        w.write(json.dumps(all_valid_f1))
    with open("all_val_loss.txt", "w") as w:
        w.write(json.dumps(all_val_loss))
            

epoch 0/f1_valid :  0.0017599220909822431
epoch 1/f1_valid :  0.0015122977701838647
epoch 2/f1_valid :  0.0009347768457742882
epoch 3/f1_valid :  0.0007156576818409642
epoch 4/f1_valid :  0.001127917325522116
epoch 5/f1_valid :  0.0004096028486272389
epoch 6/f1_valid :  0.0004252511620932674
epoch 7/f1_valid :  0.0004406511774932828
epoch 8/f1_valid :  0.0004518328218559134
epoch 9/f1_valid :  0.0001455604075691412
epoch 10/f1_valid :  0.0009493352497305067
epoch 11/f1_valid :  0.002846994423356216
epoch 12/f1_valid :  0.005721027861775446
epoch 13/f1_valid :  0.008721442674491589
epoch 14/f1_valid :  0.009616822982645993
epoch 15/f1_valid :  0.014224638265105661
epoch 16/f1_valid :  0.015979068502141633
epoch 17/f1_valid :  0.017645856657720876
epoch 18/f1_valid :  0.019022973990561912
epoch 19/f1_valid :  0.017305506714146986
epoch 20/f1_valid :  0.021879528232204175
epoch 21/f1_valid :  0.021580634759761807
epoch 22/f1_valid :  0.02276521528637789
epoch 23/f1_valid :  0.023374765256

KeyboardInterrupt: 

In [74]:
def viz(text, yb, preds):
    output_text = ""
    for i in range(len(text[0])):
        if yb[i]==0 and preds[i] == 0:
            output_text += '\033[1m' + text[0][i] + '\033[0m'
        elif yb[i]==0 and preds[i] != 0:
            output_text +=  text[0][i] + "[" + '\u0336' + '\u0336'.join(list_of_tags[preds[i]]) + list_of_tags[yb[i]] + "]"
        elif yb[i] == yb[i]:
            output_text += '\033[1m' + text[0][i] + "[" + list_of_tags[preds[i]] + "]"+ '\033[0m'
        else:
            output_text +=  text[0][i] + "[" + '\u0336' + '\u0336'.join(list_of_tags[preds[i]]) + list_of_tags[yyb[i]] + "]"
        output_text += " "
    print(output_text)
    print(f1_ep)

In [78]:
viz(validation_data[20],yb[20], preds[20])
print("____")
print(f1_ep[20])

[1mГОСУДАРСТВЕННАЯ[0m [1mПРОГРАММА[0m [1mРЕСПУБЛИКИ[0m [1mКОМИ[0m [1m"[0m [1mЮСТИЦИЯ[SOC][0m [1mИ[0m [1mОБЕСПЕЧЕНИЕ[BIN][0m [1mПРАВОПОРЯДКА[SOC][0m [1mВ[0m [1mРЕСПУБЛИКЕ[0m [1mКОМИ"[0m [1mПАСПОРТ[0m [1mГосударственной[0m [1mпрограммы[0m [1mРеспублики[0m [1mКоми[0m [1m"[0m [1mЮстиция[SOC][0m [1mи[0m [1mобеспечение[BIN][0m [1mправопорядка[SOC][0m [1mв[0m [1mРеспублике[0m [1mКоми"[0m [1m(далее[0m [1m-[0m [1mПрограмма)[0m [1mПаспорт[0m [1mподпрограммы[0m [1m"[0m [1mПравовая[SOC][0m [1mзащищенность[SOC][0m [1mнаселения[0m [1mРеспублики[0m [1mКоми"[0m [1m(далее[0m [1m-[0m [1mПодпрограмма[0m [1m1)[0m [1mПаспорт[0m [1mподпрограммы[0m [1m"[0m [1mПравопорядок[O][0m [1m"[0m [1m(далее[0m [1m-[0m [1mПодпрограмма[0m [1m2)[0m [1mПаспорт[0m [1mподпрограммы[0m [1m"Государственная[0m [1mрегистрация[0m [1mактов[0m [1mгражданского[0m [1mсостояния[0m [1mв[0m [1mРеспублике[0m [1mКоми"[0m [1

In [80]:
viz(validation_data[0],yb[0], preds[0])
print("____")
print(f1_ep[0])

АДМИНИСТРАЦИЯ[̶I̶N̶S̶TO] РЕПЬЕВСКОГО[̶I̶N̶S̶TO] МУНИЦИПАЛЬНОГО[̶I̶N̶S̶TO] [1mРАЙОНА[0m [1mВОРОНЕЖСКОЙ[0m [1mОБЛАСТИ[0m [1mПОСТАНОВЛЕНИЕ[0m [1m«[0m [1m08[0m [1m»[0m [1mноября[0m [1m2016[0m [1mг[0m [1m.[0m [1m№[0m [1m226а[0m [1mс.[0m [1mРепьевка[0m [1mОб[0m [1mутверждении[0m [1mмуниципальной[0m [1mпрограммы[0m Репьевского[̶Q̶U̶AO] [1mмуниципального[0m [1mрайона[0m [1m«[0m [1mРазвитие[BIN][0m [1mтранспортной[ECO][0m [1mсистемы[ECO][0m [1m»[0m [1m([0m [1mв[0m [1mред.[0m [1mпост.[0m [1mот[0m [1m12.05.2017[0m [1m№[0m [1m157[0m [1m,[0m [1mот[0m 14.02.2018[̶E̶C̶OO] [1m№[0m [1m63[0m [1m,[0m [1mот[0m [1m17.04.2018[0m [1m№[0m [1m147[0m [1m)[0m [1mВ[0m [1mсоответствии[0m [1mс[0m [1mФедеральным[0m [1mзаконом[0m [1mот[0m [1m06.10.2003[0m [1m№[0m [1m131[0m [1m-[0m [1mФЗ[0m [1m«[0m [1mОб[0m [1mобщих[0m [1mпринципах[0m [1mорганизации[0m [1mместного[0m [1mсамоуправления[0m [1mв[0

In [81]:
viz(validation_data[1],yb[1], preds[1])
print("____")
print(f1_ep[1])

[1mмуниципальной[0m [1mпрограммы[0m Реализация[̶B̶I̶NO] мероприятия[̶A̶C̶TO] Программы[̶A̶C̶TO] осуществляется[̶B̶I̶NO] [1mза[0m [1mсчет[0m [1mсредств[ECO][0m [1mобластного[ECO][0m [1mбюджета[ECO][0m [1m,[0m [1mрайонного[ECO][0m [1mбюджета[ECO][0m и[̶E̶C̶OO] [1mвнебюджетных[ECO][0m [1mисточников[ECO][0m [1m.[0m [1mОбъемы[MET][0m [1mфинансирования[MET][0m [1mПрограммы[0m [1mподлежат[0m [1mежегодному[0m [1mуточнению[0m [1mв[0m [1mрамках[0m бюджетного[̶E̶C̶OO] [1mцикла[0m [1m.[0m Ресурсное[̶E̶C̶OO] обеспечение[̶E̶C̶OO] [1mи[0m [1mпрогнозная[0m [1m([0m [1mсправочная[0m [1m)[0m [1mоценка[0m [1mрасходов[ECO][0m [1mобластного[ECO][0m [1mбюджета[ECO][0m [1m,[0m [1mмуниципального[ECO][0m [1mбюджета[ECO][0m [1mи[0m [1mвнебюджетных[ECO][0m [1mисточников[ECO][0m [1mна[0m [1mреализацию[0m [1mПрограммы[0m [1mна[0m [1mпериод[0m [1m2017[0m [1m-[0m [1m2022[0m [1mгодов[0m [1mпредставлены[0m [1mв[0m [1mприл