In [4]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import (AutoTokenizer, AutoModel, AutoConfig,
                          DataCollatorWithPadding)
from transformers.modeling_outputs import SequenceClassifierOutput
from datasets import load_dataset
import numpy as np
import json
import sklearn

In [5]:
TOKEN_PATH = "?"
MODEL_PATH = "?"
DATA_PATH = "?"
WEIGHTS_PATH = "?"
CLASS_JSON = "?"
NUM_CLASSES = 30

In [6]:
dataset = load_dataset(
    "csv",
    data_files = DATA_PATH
)

In [7]:
def one_hotting(example):
    if isinstance(example['labels'], str):
        sep_labels = example['labels'].split(", ")
        one_hot = np.zeros(NUM_CLASSES)
        
        for lbl in sep_labels:
            one_hot[int(lbl)] = 1
        
        example['labels'] = one_hot.tolist()
    return example

mod_dataset = dataset.map(one_hotting)

In [8]:
tokenizer = AutoTokenizer.from_pretrained(TOKEN_PATH)

def preprocess_function(examples):
    return tokenizer(examples["texts"], truncation=True, padding=True)

encoded_dataset = mod_dataset.map(preprocess_function, batched=True)

In [9]:
encoded_dataset.set_format("torch",columns=["input_ids",
                                          "attention_mask", 
                                          "labels"])

In [10]:
collater = DataCollatorWithPadding(
    tokenizer=tokenizer,
    pad_to_multiple_of=True
)

In [11]:
eval_dataloader = DataLoader(
    encoded_dataset["train"], batch_size=10, collate_fn=collater
)

In [12]:
rubert_model = AutoModel.from_pretrained(MODEL_PATH,
                                         trust_remote_code=True,
                                        config=AutoConfig.from_pretrained(MODEL_PATH,
                                                                          return_unused_kwargs=True,
                                                                          output_hidden_states=True))

In [13]:
class CustomModel(nn.Module):
    def __init__(self, model, num_labels):
        super(CustomModel,self).__init__()
        self.num_labels = num_labels
        self.model = model
        self.dropout = nn.Dropout(0.2)
        self.classifier = nn.Linear(model.config.hidden_size,num_labels)
        
    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
        outputs = self.model(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)
        pooled_output = outputs.last_hidden_state
        x = self.dropout(pooled_output)
        logits = self.classifier(x[:, 0, :].view(-1,self.model.config.hidden_size))
        
        loss=None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
        
        return SequenceClassifierOutput(loss=loss, logits=logits)

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CustomModel(rubert_model,30).to(device)

In [15]:
stats = torch.load(WEIGHTS_PATH)

model.load_state_dict(stats["state_dict"])

  stats = torch.load(WEIGHTS_PATH)


<All keys matched successfully>

In [16]:
preds = []
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    activation = torch.sigmoid(logits)
    predictions = (activation >= 0.5).float()
    preds.append(predictions)
    
preds = torch.cat(preds, dim=0)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [17]:
with open(CLASS_JSON) as f:
    d = json.load(f)

labels = []

for _,values in d.items():
    labels.append(values)

In [18]:
print(sklearn.metrics.classification_report(
    encoded_dataset["train"]["labels"].cpu().numpy(),
    preds.cpu().numpy(),
    target_names = labels
))

                                                                                                                  precision    recall  f1-score   support

                                                                                                           Other       0.97      0.97      0.97       332
                                                                           Наименование юридического лица или ИП       1.00      0.91      0.95        34
                                                                                               Юридический адрес       1.00      1.00      1.00        19
                                                                                             Контактные телефоны       1.00      1.00      1.00         9
                                                                                                          E-mail       1.00      0.96      0.98        23
                                                                           

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
