## Russian Text Detoxification

In [None]:
!pip install transformers -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m47.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.4/182.4 KB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m81.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# VRAM <= 8Gb
# os.environ['PYTORCH_CUDA_ALLOC_CONF']  = 'max_split_size_mb:4096'

In [None]:
import pandas as pd
from sklearn.utils import shuffle
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import Trainer, TrainingArguments, T5ForConditionalGeneration, AutoTokenizer, AutoModel
from transformers.file_utils import cached_property
from typing import Tuple, List, Dict, Union
from sklearn.model_selection import train_test_split
import gc
from tqdm.auto import tqdm, trange

In [None]:
df = pd.read_csv('https://raw.githubusercontent.com/s-nlp/russe_detox_2022/main/data/input/train.tsv', index_col=0, sep='\t')
df = df.fillna('')

In [None]:
df_train_toxic = []
df_train_neutral = []

for index, row in df.iterrows():
    references = row[['neutral_comment1', 'neutral_comment2', 'neutral_comment3']].tolist()
    
    for reference in references:
        if len(reference) > 0:
            df_train_toxic.append(row['toxic_comment'])
            df_train_neutral.append(reference)
        else:
            break

In [None]:
df = pd.DataFrame({
    'toxic_comment': df_train_toxic,
    'neutral_comment': df_train_neutral
})

df = shuffle(df)

### TRAIN

In [None]:
class PairsDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        assert idx < len(self.x['input_ids'])
        item = {key: val[idx] for key, val in self.x.items()}
        item['decoder_attention_mask'] = self.y['attention_mask'][idx]
        item['labels'] = self.y['input_ids'][idx]
        return item
    
    @property
    def n(self):
        return len(self.x['input_ids'])

    def __len__(self):
        return self.n # * 2

In [None]:
class DataCollatorWithPadding:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        batch = self.tokenizer.pad(
            features,
            padding=True,
        )
        ybatch = self.tokenizer.pad(
            {'input_ids': batch['labels'], 'attention_mask': batch['decoder_attention_mask']},
            padding=True,
        ) 
        batch['labels'] = ybatch['input_ids']
        batch['decoder_attention_mask'] = ybatch['attention_mask']
        
        return {k: torch.tensor(v) for k, v in batch.items()}

In [None]:
def cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    
cleanup()

In [None]:
def evaluate_model(model, test_dataloader):
    num = 0
    den = 0

    for batch in test_dataloader:
        with torch.no_grad():
            loss = model(**{k: v.to(model.device) for k, v in batch.items()}).loss
            num += len(batch) * loss.item()
            den += len(batch)
    val_loss = num / den
    return val_loss

In [None]:
def train_loop(
    model, train_dataloader, val_dataloader, 
    max_epochs=10, 
    max_steps=5_00, 
    lr=2e-5,
    gradient_accumulation_steps=1, 
    cleanup_step=100,
    report_step=300,
    window=100,
):
    cleanup()
    optimizer = torch.optim.Adam(params = [p for p in model.parameters() if p.requires_grad], lr=lr)

    ewm_loss = 0
    step = 0
    model.train()

    for epoch in trange(max_epochs):
        print(step, max_steps)
        if step >= max_steps:
            break
        tq = tqdm(train_dataloader)
        for i, batch in enumerate(tq):
            try:
                batch['labels'][batch['labels']==0] = -100
                out = model(**{k: v.to(model.device) for k, v in batch.items()})
                loss = out.loss
                loss.backward()
            except Exception as e:
                print('error on step', i, e)
                loss = None
                cleanup()
                continue
            if i and i % gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                step += 1
                if step >= max_steps:
                    break

            if i % cleanup_step == 0:
                cleanup()

            w = 1 / min(i+1, window)
            ewm_loss = ewm_loss * (1-w) + loss.item() * w
            tq.set_description(f'loss: {ewm_loss:4.4f}')

            if (i and i % report_step == 0 or i == len(train_dataloader)-1)  and val_dataloader is not None:
                model.eval()   
                eval_loss = evaluate_model(model, val_dataloader)
                model.train()
                print(f'epoch {epoch}, step {i}/{step}: train loss: {ewm_loss:4.4f}  val loss: {eval_loss:4.4f}')
                
            if step % 1000 == 0:
                model.save_pretrained(f't5_base_{dname}_{steps}')
        
    cleanup()

In [None]:
def train_model(x, y, model_name, test_size=0.1, batch_size=32, **kwargs):
    model = T5ForConditionalGeneration.from_pretrained(model_name).cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    x1, x2, y1, y2 = train_test_split(x, y, test_size=test_size, random_state=42)
    train_dataset = PairsDataset(tokenizer(x1), tokenizer(y1))
    test_dataset = PairsDataset(tokenizer(x2), tokenizer(y2))
    
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, drop_last=False, shuffle=True, collate_fn=data_collator)
    val_dataloader = DataLoader(test_dataset, batch_size=batch_size, drop_last=False, shuffle=True, collate_fn=data_collator)

    train_loop(model, train_dataloader, val_dataloader, **kwargs)
    return model

In [None]:
model_name = 'sberbank-ai/ruT5-base'

In [None]:
cleanup()

In [None]:
datasets = {
    'train': df
}

In [None]:
parametrs = [[2e-5, 1, 100]]
steps = 20000
for lr, gr, w in parametrs:
    for dname, d in datasets.items():
        print(f'\n\n\n  {dname}  {steps} \n=====================\n\n')
        model = train_model(d['toxic_comment'].tolist(), d['neutral_comment'].tolist(), model_name=model_name, batch_size=8, max_epochs=1000, max_steps=steps, lr=lr, gradient_accumulation_steps=gr, window=w)
        model.save_pretrained(f't5_base_{dname}_params_{steps}_{lr}_{gr}_{w}')

### INFERENCE

In [None]:
df = pd.read_csv('https://raw.githubusercontent.com/s-nlp/russe_detox_2022/main/data/input/test.tsv', sep='\t')
toxic_inputs = df['toxic_comment'].tolist()

In [None]:
base_model_name = 'sberbank-ai/ruT5-base'
model_name = 't5_base_train_params_20000_2e-05_1_100'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

In [None]:
model.cuda();

In [None]:
def paraphrase(text, model, n=None, max_length='auto', temperature=0.0, beams=3):
    texts = [text] if isinstance(text, str) else text
    inputs = tokenizer(texts, return_tensors='pt', padding=True)['input_ids'].to(model.device)
    if max_length == 'auto':
        max_length = int(inputs.shape[1] * 1.2) + 10
    result = model.generate(
        inputs, 
        num_return_sequences=n or 1, 
        do_sample=False, 
        temperature=temperature, 
        repetition_penalty=3.0, 
        max_length=max_length,
        bad_words_ids=[[2]],  # unk
        num_beams=beams,
    )
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]
    if not n and isinstance(text, str):
        return texts[0]
    return texts

In [None]:
para_results = []
problematic_batch = [] 
batch_size = 8

for i in tqdm(range(0, len(toxic_inputs), batch_size)):
    batch = [sentence for sentence in toxic_inputs[i:i + batch_size]]
    try:
        para_results.extend(paraphrase(batch, model, temperature=0.0))
    except Exception as e:
        print(i)
        para_results.append(toxic_inputs[i:i + batch_size])

  0%|          | 0/110 [00:00<?, ?it/s]

In [None]:
with open('t5_base.txt', 'w', encoding="utf-8") as file:
    file.writelines([sentence+'\n' for sentence in para_results])