In [None]:
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
from pathlib import Path

In [None]:
import tika
tika.initVM()
from tika import parser

In [None]:
import pandas as pd

In [None]:
from tqdm.auto import tqdm
import json

In [None]:
tqdm.pandas()

In [None]:
with Path('hacka-aka-embedika/classes.json').open('r') as f:
    tgt = json.load(f)

In [None]:
base_tokenizer = AutoTokenizer.from_pretrained('sberbank-ai/ruRoberta-large')

In [None]:
from pathlib import Path
doc_files = list(Path('hacka-aka-embedika/docs').rglob('*.*'))

In [None]:
df = pd.DataFrame({'file': doc_files})

In [None]:
df['id'] = df['file'].apply(lambda x: x.name)

In [None]:
df['file'] = df['file'].apply(str)

In [None]:
df['target'] = df['id'].apply(lambda x: tgt[x])

In [None]:
df['file'] = df['file'].apply(str)

In [None]:
df['parsed'] = df['file'].progress_apply(lambda x: parser.from_file(x)['content'])

In [None]:
tgt_items = list(df['target'].unique())

In [None]:
df['target_idx'] = df['target'].apply(lambda x: tgt_items.index(x))

In [None]:
import re
_re_space = re.compile('\s+')
_re_under = re.compile('_+')
def preprocess_str(s: str) -> str:
    s = _re_space.sub(' ', s)
    s = _re_under.sub('__', s)
    return s.strip()

In [None]:
df['parsed_preproc'] = df['parsed'].progress_apply(preprocess_str)

In [None]:
df['tokenization_results'] = df['parsed_preproc'].progress_apply(lambda x: base_tokenizer(x, add_special_tokens=False))

In [None]:
import plotly.express as px

In [None]:
px.histogram(df['tokenization_results'].apply(lambda x: len(x['input_ids'])), nbins=100)

In [None]:
class ModelClassify(nn.Module):
    def __init__(self):
        super().__init__()
        self._bert = AutoModel.from_pretrained('sberbank-ai/ruRoberta-large')
        self._n_cls = 5
        self._cls = nn.Linear(1024, 5)


    def forward(self, input_ids, attention_mask):
        x = self._bert(input_ids=input_ids, attention_mask=attention_mask, ).pooler_output
        x = self._cls(x)
        return x

    def restate_text_enc(self, freeze: bool):
        for param in self._bert.parameters():
            param.requires_grad = not freeze

In [None]:
from transformers import DataCollatorWithPadding

collator = DataCollatorWithPadding(base_tokenizer, return_tensors='pt')

In [None]:
from torch.utils.data import Dataset
import random
import copy


MIN_TOKENS = 300
MAX_TOKENS = 510


class MyDataset(Dataset):
    def __init__(self, df: pd.DataFrame, train_mode: bool):
        self._df = df
        self._train_mode = train_mode

    def __getitem__(self, item):
        row = self._df.iloc[item]
        dct = copy.deepcopy(row['tokenization_results'])
        if self._train_mode:
            len_can_take = len(dct['input_ids'])
            take_n = random.randint(min(MIN_TOKENS, len_can_take), min(MAX_TOKENS, len_can_take))
            take_at = random.randint(0, len_can_take - take_n - 1)
            dct['input_ids'] = dct['input_ids'][1 + take_at:1 + take_at + take_n]
            dct['attention_mask'] = dct['attention_mask'][1 + take_at:1 + take_at + take_n]
        else:
            dct['input_ids'] = dct['input_ids'][:510]
            dct['attention_mask'] = dct['attention_mask'][:510]
        dct['input_ids'] = [base_tokenizer.bos_token_id] + dct['input_ids'] + [base_tokenizer.eos_token_id]
        dct['attention_mask'] = [1, 1] + dct['attention_mask']
        dct['target'] = row['target_idx']
        return dct

    def __len__(self):
        return len(self._df)

In [None]:
from sklearn.model_selection import StratifiedKFold
import catboost as cb

folder = StratifiedKFold(n_splits=5, random_state=0xDEADBEEF, shuffle=True)

df_idx = df.index

for i, (train_index, test_index) in enumerate(folder.split(df_idx, df['target_idx'])):
    print('Fold', i)
    train_df = df.loc[df_idx[train_index]]
    val_df = df.loc[df_idx[test_index]]
    cls = cb.CatBoostClassifier(loss_function='MultiClass', eval_metric='TotalF1:average=Macro', depth=7, n_estimators=100, random_state=0xDEADBEEF)
    cls.fit(train_df[['parsed_preproc']], train_df['target_idx'], eval_set=(val_df[['parsed_preproc']], val_df['target_idx']), text_features=['parsed_preproc'],
           plot=False, verbose=True)

In [None]:
from collections import defaultdict
from xztrainer.logger.tensorboard import TensorboardLoggingEngineConfig
from transformers import get_linear_schedule_with_warmup, AdamW
from xztrainer.engine.standard import StandardEngineConfig
from xztrainer import XZTrainable, TrainContext, BaseContext, BaseTrainContext, XZTrainer, XZTrainerConfig, \
    SchedulerType, SavePolicy
import torch.nn.functional as F
import numpy as np
import sklearn.metrics as skm

# mdl.restate_text_enc(freeze=True)


class Trainer(XZTrainable):
    def __init__(self, unfreeze_at: int):
        super().__init__()
        self.unfreeze_at = unfreeze_at
        self.loss_fn = nn.CrossEntropyLoss()

    def on_update(self, context: TrainContext, step):
        unfreeze_at = self.unfreeze_at
        if step == unfreeze_at:
            print('Unfreezing model')
            context.model_unwrapped.restate_text_enc(freeze=False)


    def step(self, context: BaseContext, data):
        model_out = context.model(data['input_ids'], data['attention_mask'])
        preds_proba = F.softmax(model_out, dim=1)
        preds = torch.argmax(preds_proba, dim=1)
        loss = self.loss_fn(model_out, data['target'])
        return loss, {
            'target': data['target'],
            'predict': preds,
            'predict_proba': preds_proba
        }
        # if isinstance(context, BaseTrainContext):
        #     loss = sigmoid_focal_loss(model_out[0], batch.target, batch.weight, reduction='mean', alpha=-1, gamma=1.5)
        #     return loss, {
        #         'target': batch.target,
        #         'predict': preds,
        #         'predict_proba': preds_proba
        #     }
        # else:
        #     d = {
        #         'target': batch.target,
        #         'predict': preds,
        #         'predict_proba': preds_proba,
        #         'logits': model_out[1]
        #     }
        #     return None, d

    def calculate_metrics(self, context: BaseContext, model_outputs):
        return {
            'loss': np.mean(model_outputs['loss']),
            'accuracy': skm.accuracy_score(model_outputs['target'], model_outputs['predict']),
            'f1_score': skm.f1_score(model_outputs['target'], model_outputs['predict'], average='macro')
        }

In [None]:
for i, (train_index, test_index) in enumerate(folder.split(df_idx, df['target_idx'])):
    print('Fold', i)
    trainer = XZTrainer(XZTrainerConfig(
        engine=StandardEngineConfig(),
        batch_size=2,
        accumulation_batches=1,
        batch_size_eval=4,
        epochs=20,
        gradient_clipping=1, # don't clip
        collate_fn=collator,
        scheduler=lambda optim, stps: get_linear_schedule_with_warmup(optim, stps*0.1, stps),
        optimizer=lambda m: AdamW(m.parameters(), lr=1e-5, weight_decay=0.00001),
        scheduler_type=SchedulerType.STEP,
        shuffle_train_dataset=True,
        dataloader_num_workers=0,
        dataloader_persistent_workers=False,
        save_policy=SavePolicy.EVERY_EPOCH,
        print_steps=0,
        # logger=TensorboardLoggingEngineConfig()
    ),  ModelClassify(), Trainer(unfreeze_at=100), device=torch.device('cuda:0'))
    train_df = df.loc[df_idx[train_index]]
    val_df = df.loc[df_idx[test_index]]

    trainer.train(MyDataset(train_df, True), MyDataset(val_df, False))