In [None]:
!pip install dvach.api
!pip install transformers
!pip install sentencepiece

In [None]:
import torch
from matplotlib import pyplot as plt
import numpy as np
import os
import copy
import random
import sentencepiece

from sklearn.metrics.pairwise import cosine_similarity

import transformers

In [None]:
OMIKRON = 1
MAX_OUTPUT_LEN = 512

classifier_name = 'SkolkovoInstitute/russian_toxicity_classifier'
classifier = transformers.BertForSequenceClassification.from_pretrained(classifier_name)
classifier_tokenizer = transformers.BertTokenizer.from_pretrained(classifier_name)

model_tokenizer = transformers.RobertaTokenizer.from_pretrained('blinoff/roberta-base-russian-v0')
model_config = transformers.RobertaConfig.from_pretrained("blinoff/roberta-base-russian-v0")
model_config.is_decoder = True
model = transformers.RobertaForCausalLM.from_pretrained('blinoff/roberta-base-russian-v0', config=model_config)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

if torch.cuda.is_available:
  print('Now on gpu!')
  device = torch.device('cpu')
  model.to(device)
  classifier.to(device)

In [None]:
X, Y = 0.1, 0.5

def get_reward(similarity, toxicity, omikron=OMIKRON):
  if similarity < 0.1:
    similarity = similarity * (1 - Y) / (1 + X) - (X + Y) / (1 + X)
  return omikron * similarity / (toxicity + 1)

def get_similarity(y, y_pred):
  a = y.clone().detach()
  b = y_pred.clone().detach()
  if a.shape != b.shape:
    if a.shape[1] > b.shape[1]:
      a, b = b, a
    zero = torch.zeros([1, b.shape[1] - a.shape[1]])
    a = torch.cat((a, zero), dim=1)
  return cosine_similarity(a, b)[0][0]

def get_toxicity(vector):
  vector = vector.to(device)
  result = classifier(vector)['logits'].argmax(1).float().tolist()[0]
  return result

def encode(sentence, tokenizer):
  return tokenizer.encode(sentence, return_tensors='pt')

def decode(y_pred, tokenizer):
  return tokenizer.decode(y_pred[0])


In [None]:
out_time = 10
timer = 9
def train_one_step(text, losses=[], rewards=[]):
  global timer
  global out_time

  timer += 1

  inputs = model_tokenizer(text, add_special_tokens=False, return_tensors="pt")['input_ids']
  input_len = len(model_tokenizer.decode(inputs[0]))
  inputs = inputs.to(device)
  outputs = model.generate(inputs, max_length=MAX_OUTPUT_LEN, do_sample=True, top_p=0.95, top_k=60)
  output_text = model_tokenizer.decode(outputs[0])[input_len+1:]
  
  encoded_output_text = encode(output_text, classifier_tokenizer)
  encoded_input_text = encode(text, classifier_tokenizer)

  similarity = get_similarity(encoded_output_text, encoded_input_text)
  toxicity = get_toxicity(encoded_output_text)

  reward = get_reward(similarity, toxicity)

  output_logits = model(**model_tokenizer(text, return_tensors='pt').to(device)).logits
  target = torch.full(output_logits.shape, reward).to(device)
  loss = torch.nn.functional.cross_entropy(output_logits, target)
  if timer % out_time == 0:
    print(f' predict: {output_text}\n',
          f'loss: {loss.item()}\n',
          f'reward: {reward}\n')
    losses.append(loss.mean().item())
    rewards.append(reward)

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()  

rwd = []
lss = []
train_one_step('Я ВПЛ который в прошлом году хуево сдал егэ. Не завалил, просто набрал не очень много баллов.', rwd, lss)

In [None]:
NUM_OF_TRAIN_ON_SENTENCE = 10

def train_on_corpus(filename, losses, rewards):
  with open(filename, 'r') as text_corpus:
    length = len(open('data.txt').readlines())
    print(f'{length} обучающих предложений')
    for sentence in text_corpus:
      global NUM_OF_TRAIN_ON_SENTENCE
      for _ in range(NUM_OF_TRAIN_ON_SENTENCE):
        train_one_step(sentence, losses, rewards)

  torch.save(model, 'weights')

In [None]:
def show_plot(data):
  plt.plot([i for i in range(len(data))], data)
  plt.show()

In [None]:
import api2ch

brds = ['b','d','sn','cc','soc','po','un','o','po']
ch = []
fieldnames = ['board','post','tags']

def clean_post(text):
  import re
  return re.sub(r'[a-z\/<>&"":;#=$^*]', '', text)
print(clean_post(r'<br>Ребята, вы самые лучшие! Вот бы никогда-никогда отсюда не уезжать <br/>!'))

def parse_2ch():
  print('Парсим двач')
  with open('data.txt', 'w', newline='\n') as f:
          
          for b in brds:

            api = api2ch.DvachApi(b)

            board = api.get_board()

            for thread in board:
                try:
                  thread = api.get_thread(thread)

                  for post in thread:
                    try:
                      comm = post.comment.split('</a><br>')[1]
                      tags = post.tags
                      f.write(f'{clean_post(comm)}\n')
                    except Exception as e:
                      print(e)
                except:
                    print('err')
  length = len(open('data.txt', 'r').readlines())
  print(f'{length} строк')


Ребята, вы самые лучшие! Вот бы никогда-никогда отсюда не уезжать !


In [None]:
parse_2ch()
system('cls')
losses, rewards = [], []

train_on_corpus('data.txt', losses, rewards)
show_plot(losses)
show_plot(rewards)