In [1]:
import tqdm
import numpy as np
import pandas as pd
import torch
import random
from collections import defaultdict
from transformers import BertTokenizer, EncoderDecoderModel, AutoModelForSequenceClassification


from readers import tg_reader, lenta_reader, ria_reader
from custom_datasets.agency_title_dataset import AgencyTitleDatasetGeneration

In [2]:
tokenizer_path = '/home/aobuhtijarov/models/rubert_cased_L-12_H-768_A-12_pt/'

model_path = '/home/aobuhtijarov/models/style_gen_title_from_pretrained/checkpoint-6000/'
discr_model_path = '/home/aobuhtijarov/models/agency_discriminator/checkpoint-4000/'

test_data = '/home/aobuhtijarov/datasets/telegram_news/ru_tg_0511_0517.jsonl'
lenta_path = '/home/aobuhtijarov/datasets/lenta/lenta-ru-news.test.csv'
ria_path = '/home/aobuhtijarov/datasets/ria/ria.shuffled.test.json'

In [3]:
tokenizer = BertTokenizer.from_pretrained(tokenizer_path, do_lower_case=False, do_basic_tokenize=False)

In [4]:
model = EncoderDecoderModel.from_pretrained(model_path)
model.cuda();

In [5]:
discriminator = AutoModelForSequenceClassification.from_pretrained(discr_model_path)
discriminator.cuda();

In [6]:
test_records = [r for r in tqdm.tqdm(tg_reader(test_data))]
lenta_records = [r for r in tqdm.tqdm(lenta_reader(lenta_path))]
ria_records = [r for r in tqdm.tqdm(ria_reader(ria_path))]

120050it [01:17, 1541.12it/s]
75925it [00:02, 30571.63it/s]
47440it [00:29, 1601.41it/s]


In [7]:
agency_list = ["РИА Новости", "lenta.ru"]
agency_to_special_token_id = {a: tokenizer.vocab[f'[unused{i+1}]'] for i, a in enumerate(agency_list)}
agency_to_discr_target = {a: i for i, a in enumerate(sorted(agency_list))}

In [8]:
ria_data = AgencyTitleDatasetGeneration(ria_records, tokenizer, filter_agencies=None,
                                        agency_to_special_token_id=agency_to_special_token_id)

lenta_data = AgencyTitleDatasetGeneration(lenta_records, tokenizer, filter_agencies=None, 
                                          agency_to_special_token_id=agency_to_special_token_id)

other_data = AgencyTitleDatasetGeneration(test_records, tokenizer, 
                                          filter_agencies=['Невские Новости', 'ФедералПресс', 'Dynamomania.com'], 
                                          agency_to_special_token_id=agency_to_special_token_id)

In [9]:
len(ria_data), len(lenta_data), len(other_data)

(47440, 75925, 1946)

In [10]:
@torch.no_grad()
def eval_on_dataset(dataset, target_agency, n=1000, max_tokens_title=48):
    y_pred = []


    for i in tqdm.tqdm(np.random.choice(len(dataset), n, replace=False), total=n):
        x = dataset[i]
        x['input_ids'][1] = agency_to_special_token_id[target_agency]

        gen_ids = model.generate(
            input_ids=x['input_ids'].cuda().unsqueeze(0),
            attention_mask=x['attention_mask'].cuda().unsqueeze(0),
            decoder_start_token_id=model.config.decoder.pad_token_id,
            min_length=7,
            max_length=20,
            num_beams=6
        )
        
        gen_title = [tokenizer.decode(x, skip_special_tokens=True) for x in gen_ids][0]
        
        inp = tokenizer(gen_title, 
            add_special_tokens=True, max_length=max_tokens_title,
            padding='max_length', truncation=True
        )

        logits = discriminator(input_ids=torch.LongTensor(inp['input_ids']).cuda().unsqueeze(0), 
                               attention_mask=torch.LongTensor(inp['attention_mask']).cuda().unsqueeze(0))[0]
        y_pred.append(torch.argmax(logits).item())
    
    return y_pred.count(agency_to_discr_target[target_agency]) / n

In [11]:
def gen_title(input_ids, attention_mask):
    gen_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_start_token_id=model.config.decoder.pad_token_id,
        min_length=7,
        max_length=22,
        num_beams=6
    )
    
    print(gen_ids)

    gen_title = [tokenizer.decode(x, skip_special_tokens=True) for x in gen_ids][0]
    return gen_title

In [12]:
max_tokens_title = 48

In [13]:
agency_to_discr_target

{'lenta.ru': 0, 'РИА Новости': 1}

In [14]:
n = random.randint(0, 43000)
n = 11306
# n=1186
x = lenta_data[n]
a = gen_title(x['input_ids'].cuda().unsqueeze(0), x['attention_mask'].cuda().unsqueeze(0))
b = lenta_records[n]['title']
print('Gen:', a)
print('Ref:', b)

inp = tokenizer(a, 
    add_special_tokens=True, max_length=max_tokens_title,
    padding='max_length', truncation=True
)

logits = discriminator(input_ids=torch.LongTensor(inp['input_ids']).cuda().unsqueeze(0), 
                       attention_mask=torch.LongTensor(inp['attention_mask']).cuda().unsqueeze(0))[0]
a_pred = torch.argmax(logits).item()

inp = tokenizer(b, 
    add_special_tokens=True, max_length=max_tokens_title,
    padding='max_length', truncation=True
)

logits = discriminator(input_ids=torch.LongTensor(inp['input_ids']).cuda().unsqueeze(0), 
                       attention_mask=torch.LongTensor(inp['attention_mask']).cuda().unsqueeze(0))[0]
b_pred = torch.argmax(logits).item()

print('Gen:', a_pred)
print('Ref:', b_pred)

tensor([[    0, 20622, 42873, 96388,  6188, 45206, 34124, 16327, 12938,  1650,
         27453,   131, 17407, 29079, 17609, 24461,   102,   102,   102,   102,
           102,   102]], device='cuda:0')
Gen: адвокат емельяненко допустил оспаривание решения об условно-досрочном освобождении
Ref: адвокат потерпевшей рассказал о возможном оспаривании удо александра емельяненко
Gen: 0
Ref: 0


In [12]:
%%time

result = defaultdict(list)

for i, data in enumerate((ria_data, lenta_data, other_data)):
    for target_a in agency_list:
        acc = eval_on_dataset(data, target_a)
        result[i].append(acc)

100%|██████████| 1000/1000 [06:09<00:00,  2.71it/s]
100%|██████████| 1000/1000 [06:15<00:00,  2.67it/s]
100%|██████████| 1000/1000 [06:13<00:00,  2.68it/s]
100%|██████████| 1000/1000 [06:27<00:00,  2.58it/s]
100%|██████████| 1000/1000 [06:24<00:00,  2.60it/s]
100%|██████████| 1000/1000 [06:27<00:00,  2.58it/s]

CPU times: user 37min 58s, sys: 8.07 s, total: 38min 6s
Wall time: 37min 57s





In [13]:
df = pd.DataFrame(columns=['Data'] + agency_list)

for i, dataset_name in enumerate(('RIA', 'Lenta', 'Other')):
    row = {'Data': dataset_name}
    for j, a in enumerate(agency_list):
        row[a] = result[i][j]      
    df = df.append(row, ignore_index=True)
df

Unnamed: 0,Data,РИА Новости,lenta.ru
0,RIA,0.935,0.054
1,Lenta,0.426,0.609
2,Other,0.808,0.207


In [14]:
df.index = df.Data

In [15]:
df.drop('Data', axis=1, inplace=True)

In [16]:
df.round(2)

Unnamed: 0_level_0,РИА Новости,lenta.ru
Data,Unnamed: 1_level_1,Unnamed: 2_level_1
RIA,0.94,0.05
Lenta,0.43,0.61
Other,0.81,0.21
