In [1]:
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report

from transformers.trainer_callback import EarlyStoppingCallback

import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
df = pd.read_csv('/kaggle/input/unsafe2/train.csv')
df_val = pd.read_csv('/kaggle/input/unsafe2/val.csv')
df_test = pd.read_csv('/kaggle/input/unsafe2/test.csv')

In [3]:
df.head()

Unnamed: 0,text,crime_real,crime_web,drugs,gambling,pornography,prostitution,slavery,suicide,terrorism,weapons,body_shaming,halth_shaming,politics,racism,religion,sex_minorities,sexism,social
0,Убийства и мы все знаем что убийца там ☝️,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,...а потом граждане возмущаются что ктото кое ...,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,"Да преступление не тяжкое, могут под домашний ...",1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,Не льсти себе: Вот моя бывшая вообще мило пост...,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
4,Стать правителем и посадить их всех в тюрьму. ...,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [4]:
columns = list(df.columns)[1:] 
train_labels =[]
for i, el in df.iterrows():
    current_sample_labels = []
    any_class = False
    for clm in columns:
        if el[clm] == 1:
            any_class = True
            current_sample_labels.append(clm)
    if any_class == False:
        current_sample_labels.append("none")
    current_sample_labels = ','.join(current_sample_labels)
    train_labels.append(current_sample_labels)
    
val_labels =[]
for i, el in df_val.iterrows():
    current_sample_labels = []
    any_class = False
    for clm in columns:
        if el[clm] == 1:
            any_class = True
            current_sample_labels.append(clm)
    if any_class == False:
        current_sample_labels.append("none")
    current_sample_labels = ','.join(current_sample_labels)
    val_labels.append(current_sample_labels)


test_labels =[]
for i, el in df_test.iterrows():
    current_sample_labels = []
    any_class = False
    for clm in columns:
        if el[clm] == 1:
            any_class = True
            current_sample_labels.append(clm)
    if any_class == False:
        current_sample_labels.append("none")
    current_sample_labels = ','.join(current_sample_labels)
    test_labels.append(current_sample_labels)

In [5]:
len(set(train_labels))        

392

In [6]:
df_train_adjusted = pd.DataFrame({'text':list(df['text']), 'labels':train_labels})
df_val_adjusted = pd.DataFrame({'text':list(df_val['text']), 'labels':val_labels})
df_test_adjusted = pd.DataFrame({'text':list(df_test['text']), 'labels':test_labels})

In [7]:
mapping = dict()
mapping['none'] = 0

for label in train_labels:
    if label not in mapping:
        mapping[label] = len(mapping)

for label in test_labels:
    if label not in mapping:
        mapping[label] = len(mapping)
        
for label in val_labels:
    if label not in mapping:
        mapping[label] = len(mapping)

In [8]:
len(mapping)

393

In [9]:
df_train_adjusted['class'] = df_train_adjusted['labels'].apply(lambda x: mapping[x])
df_test_adjusted['class'] = df_test_adjusted['labels'].apply(lambda x: mapping[x])
df_val_adjusted['class'] = df_val_adjusted['labels'].apply(lambda x: mapping[x])

In [10]:
df_val_adjusted

Unnamed: 0,text,labels,class
0,Требуют забрать заявление об увольнении с угрозой,crime_real,1
1,"да расстреляют на месте, за сопротивление",crime_real,1
2,Белорусы в форме омона бьют белорусов в штатском,crime_real,1
3,"Одного нужно захватить, чтобы узнать мотивы со...",crime_real,1
4,сел в тюрьму за кражу глазурованного сырка,crime_real,1
...,...,...,...
1012,"ХУЁГЕ БЛЯДЬ, НИКУДА ОН НЕ ПОЕДЕТ, ИДИОТЫ",none,0
1013,Какие таблетки? Просто жри нормально и всё. По...,none,0
1014,"Александр, а если уравновесить право обязаннос...",none,0
1015,Наличие у пользователя полномочий Глобального ...,none,0


In [11]:
labels_val = df_val_adjusted['labels'].tolist()

In [12]:
x_train = df_train_adjusted['text'].tolist()
y_train = df_train_adjusted['class'].tolist()
x_test = df_test_adjusted['text'].tolist()
y_test = df_test_adjusted['class'].tolist()
x_val = df_val_adjusted['text'].tolist()
y_val = df_val_adjusted['class'].tolist()

In [14]:
class UnsafeData(Dataset):

    def __init__(self, texts, targets, tokenizer, max_len):
        
        super().__init__()
        
        self.texts = texts
        self.targets = targets        
        self.max_len = max_len
        self.tokenizer = tokenizer

    def __len__(self):
        
        return len(self.texts)

    
    def __getitem__(self, index):
        x = self.texts[index]

        enc_dict = self.tokenizer(x, truncation=True, max_length=self.max_len, padding='max_length')
      
        item = {key: torch.tensor(val).long() for key, val in enc_dict.items()}
        item['labels'] = torch.tensor(self.targets[index]).long()

        return item 

In [106]:
model_name = 'DeepPavlov/rubert-base-cased-conversational'

In [107]:
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels = 393)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at DeepPavlov/rubert-base-cased-conversational and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [108]:
train_dataset = UnsafeData(x_train, y_train, tokenizer, max_len = 60)
test_dataset = UnsafeData(x_test, y_test, tokenizer, max_len = 60)
val_dataset = UnsafeData(x_val, y_val, tokenizer, max_len = 60)

In [109]:
len(train_dataset), len(test_dataset), len(val_dataset)

(31130, 1156, 1017)

In [110]:
val_dataset[10]

{'input_ids': tensor([  101,  2270, 47970,   994,   846,  2181,   132,   458,  2396,  7370,
          1536,  1967,   838,  3005,   132,  1235,   322, 19121,   322, 28114,
           846,  2181,   132, 75832,   371,   801,  5827,   130,  1064,   802,
          7134,   322, 37442,   846,  2181,  1981,  6080,   132,  1190,  4302,
           340, 11728,  1143,  2838,  1088, 11757,   132,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'labels': tensor(1)}

In [111]:
topics_list = ['none']
topics_list.extend(columns)

In [112]:
target_vaiables_id2topic_dict = {val:key for key, val in mapping.items()}

In [113]:
def adjust_multilabel(y, is_pred = False):
    y_adjusted = []
    for y_c in y:
        y_test_curr = [0]*19
        if is_pred == True:
            y_c = target_vaiables_id2topic_dict[np.argmax(y_c)]
        else:
            y_c = target_vaiables_id2topic_dict[y_c]
        for tag in y_c.split(","):
            topic_index = topics_list.index(tag)
            y_test_curr[topic_index] = 1
        y_adjusted.append(y_test_curr)
    return y_adjusted

In [114]:
def compute_metrics(pred):
    labels = pred.label_ids
    labels = adjust_multilabel(labels, is_pred = False)
    preds = pred.predictions
    preds = adjust_multilabel(preds, is_pred = True)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division = 0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [115]:
training_args = TrainingArguments(
    output_dir='/kaggle/working/bert1',
    num_train_epochs=10,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    logging_steps = 600,
    evaluation_strategy = 'steps',
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='/kaggle/working/bert1/logs',
    save_steps = 5000,
    metric_for_best_model  = 'f1',
    greater_is_better = True,
    load_best_model_at_end = True    
)

In [116]:
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


In [117]:
trainer.add_callback(EarlyStoppingCallback(4))

In [118]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall,Runtime,Samples Per Second
600,2.6379,2.709408,0.46706,0.552528,0.744309,0.493662,1.9557,520.007
1200,0.8131,2.307045,0.505408,0.643172,0.745199,0.583099,1.9577,519.484
1800,0.5316,2.200252,0.528024,0.667087,0.727995,0.626056,1.9537,520.551
2400,0.3723,2.37711,0.514258,0.664307,0.732823,0.621831,1.9866,511.931
3000,0.2676,2.510347,0.516224,0.676894,0.723119,0.64507,1.9523,520.918
3600,0.2103,2.601495,0.52704,0.691697,0.730733,0.665493,1.9546,520.3
4200,0.1654,2.5956,0.519174,0.697182,0.720685,0.678873,1.9863,511.996
4800,0.1395,2.692417,0.534907,0.695313,0.734905,0.666901,1.9537,520.538


TrainOutput(global_step=4870, training_loss=0.6350461391942457, metrics={'train_runtime': 2204.6277, 'train_samples_per_second': 2.209, 'total_flos': 19965548168676000, 'epoch': 10.0})

In [128]:
trainer.evaluate()

{'eval_loss': 2.595600128173828,
 'eval_accuracy': 0.5191740412979351,
 'eval_f1': 0.6971824976816428,
 'eval_precision': 0.7206845731822252,
 'eval_recall': 0.6788732394366197,
 'eval_runtime': 1.9869,
 'eval_samples_per_second': 511.849,
 'epoch': 10.0}

In [127]:
#!rm -r ./bert1

In [126]:
trainer.save_model('multi-class')

Оценка на val_dataset

In [120]:
pred = trainer.predict(val_dataset)

In [121]:
pr = pred.predictions

In [122]:
print(classification_report(adjust_multilabel(y_val, is_pred = False), adjust_multilabel(pr, is_pred = True),
                           target_names=topics_list, zero_division = 0))

                precision    recall  f1-score   support

          none       0.65      0.59      0.62       188
    crime_real       0.62      0.58      0.60        76
     crime_web       0.50      0.48      0.49        23
         drugs       0.79      0.86      0.83        58
      gambling       0.40      1.00      0.57         2
   pornography       0.71      0.73      0.72       128
  prostitution       0.82      0.76      0.79        55
       slavery       0.80      0.73      0.76        22
       suicide       0.75      0.75      0.75         4
     terrorism       0.50      0.42      0.46        26
       weapons       0.87      0.95      0.91        96
  body_shaming       0.85      0.74      0.79        68
 halth_shaming       0.79      0.72      0.75        72
      politics       0.69      0.56      0.62       159
        racism       0.79      0.72      0.75       127
      religion       0.90      0.84      0.87        63
sex_minorities       0.73      0.55      0.63  

In [None]:
#t_names = [key for key, val in mapping.items() if val in y_val]

Оценка на test_dataset

In [123]:
pred2 = trainer.predict(test_dataset)

In [124]:
pr2 = pred2.predictions

In [None]:
#t_names2 = [key for key, val in mapping.items() if val in y_test]

In [125]:
print(classification_report(adjust_multilabel(y_test, is_pred = False), adjust_multilabel(pr2, is_pred = True),
                           target_names=topics_list, zero_division = 0))

                precision    recall  f1-score   support

          none       0.61      0.59      0.60       188
    crime_real       0.63      0.49      0.55       108
     crime_web       0.46      0.46      0.46        28
         drugs       0.84      0.87      0.85        70
      gambling       0.42      0.83      0.56         6
   pornography       0.68      0.59      0.63       163
  prostitution       0.69      0.65      0.67        77
       slavery       0.78      0.64      0.70        33
       suicide       0.50      0.33      0.40         6
     terrorism       0.66      0.49      0.56        39
       weapons       0.83      0.86      0.84       107
  body_shaming       0.82      0.75      0.78        89
 halth_shaming       0.82      0.60      0.69        85
      politics       0.67      0.61      0.64       191
        racism       0.85      0.66      0.74       163
      religion       0.91      0.70      0.79        83
sex_minorities       0.68      0.51      0.59  