### Imports

In [1]:
%load_ext autoreload
%autoreload 2

import sys
from typing import List
from datetime import datetime

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import json
import torch
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split

from src.manipulation_helpers.models import ManipulationTargetLightningModel
from src.manipulation_helpers.data_preparation import markup_conll, encode_tags, Markup, \
                                                      read_markup, create_concat_data

OSError: [WinError 126] Не найден указанный модуль. Error loading "E:\anaconda\envs\manipulation\lib\site-packages\torch\lib\shm.dll" or one of its dependencies.

In [None]:
from IPython.display import clear_output

In [None]:
str(datetime.now())

## Data processing

In [None]:
PATH_TO_MARKUP = "data/markup_union_matched.json"
markup = pd.read_json(PATH_TO_MARKUP, lines=True)
print("Всего разметок:", len(markup))

In [None]:
texts = []
markups = []
for index, row in markup.iterrows():
    text, row_markup = markup_conll(row["input_input"], 
                                    row["output_result"], 
                                    row["input_entitiesdata"])

    texts.append(text)
    markups.append(row_markup)
clear_output()

In [None]:
unique_tags = {x.manipulation_class for m in markups for x in m}
tag2id = {tag: idx for idx, tag in enumerate(unique_tags)}
id2tag = {idx: tag for tag, idx in tag2id.items()}

In [None]:
train_texts, val_texts, train_markups, val_markups, train_ids, val_ids = \
train_test_split(texts, markups, range(len(texts)), test_size=.2, random_state=42)

### Dataset & tokenization

In [None]:
MODEL_NAME = "sberbank-ai/ruBert-base"

In [None]:
from transformers import BertTokenizerFast
from functools import partial

tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

encode = partial(
    tokenizer,
    add_special_tokens=False, 
    is_split_into_words=True, 
    padding=True, 
    truncation=True, 
    max_length=512)

encoded_sep_token = encode([tokenizer.sep_token])['input_ids'][0]

train_encodings = encode(train_texts)
val_encodings = encode(val_texts)

In [None]:
train_labels = encode_tags(train_markups, train_encodings, "manipulation_class", tag2id=tag2id)
val_labels = encode_tags(val_markups, val_encodings, "manipulation_class", tag2id=tag2id)

train_entities = encode_tags(train_markups, train_encodings, "entity_id")
val_entities = encode_tags(val_markups, val_encodings, "entity_id")

train_manipulation_targets = encode_tags(train_markups, train_encodings, "manipulation_target")
val_manipulation_targets = encode_tags(val_markups, val_encodings, "manipulation_target")

In [None]:
train_x, train_y = create_concat_data(train_encodings, train_entities, train_manipulation_targets, encoded_sep_token)
val_x, val_y = create_concat_data(val_encodings, val_entities, val_manipulation_targets, encoded_sep_token)

In [None]:
assert (len(train_x) == len(train_y)) & (len(val_x) == len(val_y)), "Чето не так"

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

train_x, train_y = torch.tensor(train_x), torch.tensor(train_y)
val_x, val_y = torch.tensor(val_x), torch.tensor(val_y)

train_dataset = TensorDataset(train_x, train_y)
val_dataset = TensorDataset(val_x, val_y)

train_loader = DataLoader(train_dataset, batch_size=8)
val_loader = DataLoader(val_dataset, batch_size=8)

## Model pipeline

### Pytorch-Lightning Model

In [None]:
config = {
    'bert_model_name': MODEL_NAME,
    'optimizer': torch.optim.AdamW, 
    'lr': 0.001, 
    'freeze_bert': True,
    'loss_function': torch.nn.CrossEntropyLoss()
}
model = ManipulationTargetLightningModel(**config)
trainer = pl.Trainer(max_epochs=25)
trainer.fit(model, train_loader, val_loader)

### Training

### Evaluation