In [1]:
import os
import re
import random
import pickle
import warnings
warnings.filterwarnings("ignore")

import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from transformers import DataCollatorWithPadding
%env TOKENIZERS_PARALLELISM=false

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

env: TOKENIZERS_PARALLELISM=false


In [2]:
class CFG:
    num_workers=8
    path_11="/kaggle/input/me5_11_cls/pytorch/default/1/kaggle/working/me5_instruct_11_classes"
    path_39="/kaggle/input/me5_39_cls/pytorch/default/1/kaggle/working/me5_instruct_39_classes"
    config_path_11='/kaggle/input/me5_11_cls/pytorch/default/1/kaggle/working/me5_instruct_11_classes/config.pth'
    config_path_39='/kaggle/input/me5_39_cls/pytorch/default/1/kaggle/working/me5_instruct_39_classes/config.pth'
    model="intfloat/multilingual-e5-large-instruct"
    gradient_checkpointing=False
    batch_size=32
    seed=42
    max_len=512


In [3]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)

In [4]:
# тут нужен препроцесс из трейна
def preprocess(text):
    processed_text = " ".join(re.findall(r"[а-яА-Я0-9 ёЁ\-\.,?!+a-zA-Z]+", text))
    return processed_text

def get_detailed_instruct(task_description: str, query: str) -> str:
    # функция преобразования промпта для instruct версий моделей
    return f'Instruct: {task_description}\nQuery: {query}'

In [5]:
def prepare_input(cfg, text):
    inputs = cfg.tokenizer.encode_plus(
        text, 
        return_tensors=None, 
        add_special_tokens=True, 
        max_length=512,
        pad_to_max_length=True,
        truncation=True
    )
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs


class TestDataset(Dataset):
    def __init__(self, cfg, texts):
        self.cfg = cfg
        self.texts = texts if isinstance(texts, list) else [texts]

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, self.texts[item])
        return inputs

In [6]:
def average_pool(last_hidden_states, attention_mask):
    last_hidden = last_hidden_states.masked_fill(
        ~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


class CustomModel(nn.Module):
    def __init__(self, cfg, num_classes, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(
                cfg.model, output_hidden_states=True)
            self.config.hidden_dropout = 0.
            self.config.hidden_dropout_prob = 0.
            self.config.attention_dropout = 0.
            self.config.attention_probs_dropout_prob = 0.
            LOGGER.info(self.config)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            self.model = AutoModel.from_pretrained(
                cfg.model, config=self.config)
        else:
            self.model = AutoModel.from_config(self.config)
        if self.cfg.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

        self.fc = nn.Linear(self.config.hidden_size, num_classes)
        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(
                mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(
                mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs):
        outputs = self.model(**inputs)
        feature = average_pool(outputs.last_hidden_state,
                               inputs['attention_mask'])
        return feature

    def forward(self, inputs):
        feature = self.feature(inputs)
        output = self.fc(feature)

        return output

In [7]:
def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            pred = model(inputs)
                
        preds.append(pred.to('cpu').numpy())
    
    predictions = np.concatenate(preds)
    return predictions

In [8]:
CFG.tokenizer = AutoTokenizer.from_pretrained(os.path.join(CFG.path_11, 'tokenizer'))

In [9]:
# сам пример запроса
user_query = "Здравствуйте! Где я могу узнать про монетизацию RUTUBE?"
user_query = preprocess(user_query)
print(user_query)

Здравствуйте! Где я могу узнать про монетизацию RUTUBE?


In [10]:
user_queries = []

if CFG.model in ['intfloat/multilingual-e5-large-instruct']:
    task = """Classify the detailed category of the given user request into one of {num_cats} categories"""
    for num_cats in ['eleven', 'thirty nine']:
        task_ = task.format(num_cats=num_cats)
        user_queries.append(get_detailed_instruct(task_, user_query))
    
print(user_queries[1])

Instruct: Classify the detailed category of the given user request into one of thirty nine categories
Query: Здравствуйте! Где я могу узнать про монетизацию RUTUBE?


In [11]:
test_dataset_11 = TestDataset(CFG, user_queries[0])
test_dataset_39 = TestDataset(CFG, user_queries[1])

In [12]:
test_loader_11 = DataLoader(
    test_dataset_11,
    batch_size=CFG.batch_size,
    shuffle=False,
    collate_fn=DataCollatorWithPadding(tokenizer=CFG.tokenizer, padding='longest'),
    num_workers=CFG.num_workers, pin_memory=True, drop_last=False
)

test_loader_39 = DataLoader(
    test_dataset_39,
    batch_size=CFG.batch_size,
    shuffle=False,
    collate_fn=DataCollatorWithPadding(tokenizer=CFG.tokenizer, padding='longest'),
    num_workers=CFG.num_workers, pin_memory=True, drop_last=False
)

In [13]:
model_11 = CustomModel(CFG, num_classes=11, config_path=CFG.config_path_11, pretrained=False)
state_11 = torch.load(os.path.join(CFG.path_11, f"{CFG.model.replace('/', '-')}_fold0_best.pth"),
                   map_location=torch.device('cpu'))
model_11.load_state_dict(state_11['model'])


model_39 = CustomModel(CFG, num_classes=39, config_path=CFG.config_path_39, pretrained=False)
state_39 = torch.load(os.path.join(CFG.path_39, f"{CFG.model.replace('/', '-')}_fold0_best.pth"),
                   map_location=torch.device('cpu'))
model_39.load_state_dict(state_39['model'])

<All keys matched successfully>

In [14]:
predictions_11 = inference_fn(test_loader_11, model_11, device)
predictions_39 = inference_fn(test_loader_39, model_39, device)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [15]:
final_labels_11 = [np.argmax(el) for el in predictions_11]
final_labels_39 = [np.argmax(el) for el in predictions_39]

In [16]:
import pickle

with open ("/kaggle/input/me5_11_cls/pytorch/default/1/kaggle/working/me5_instruct_11_classes/executor_le.pkl", "rb") as f:
    exec_le_11 = pickle.load(f)
    
with open ("/kaggle/input/me5_39_cls/pytorch/default/1/kaggle/working/me5_instruct_39_classes/executor_le.pkl", "rb") as f:
    exec_le_39 = pickle.load(f)

le_final_labels_11 = exec_le_11.inverse_transform(final_labels_11)
le_final_labels_39 = exec_le_39.inverse_transform(final_labels_39)

In [17]:
le_final_labels_11

array(['МОНЕТИЗАЦИЯ'], dtype='<U34')

In [18]:
le_final_labels_39

array(['Отключение/подключение монетизации'], dtype='<U38')