In [1]:
import tqdm
import numpy as np
import pandas as pd
import torch
from collections import defaultdict
from transformers import AutoTokenizer, EncoderDecoderModel, AutoModelForSequenceClassification


from readers.tg_reader import tg_reader
from custom_datasets.agency_title_dataset import AgencyTitleDatasetGeneration

In [2]:
tokenizer_path = '/home/aobuhtijarov/models/rubert_cased_L-12_H-768_A-12_pt/'

model_path = '/home/aobuhtijarov/models/style_gen_title_from_pretrained/checkpoint-5000/'
discr_model_path = '/home/aobuhtijarov/models/agency_discriminator/checkpoint-3000/'

test_data = '/home/aobuhtijarov/datasets/telegram_news/ru_tg_0511_0517.jsonl'

In [3]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, do_lower_case=False)

In [4]:
model = EncoderDecoderModel.from_pretrained(model_path)
model.cuda();

In [5]:
discriminator = AutoModelForSequenceClassification.from_pretrained(discr_model_path)
discriminator.cuda();

In [6]:
test_records = [r for r in tqdm.tqdm(tg_reader(test_data))]

120050it [01:26, 1392.31it/s]


In [7]:
len(test_records)

120050

In [8]:
agency_list = ["ТАСС", "RT на русском", "lenta.ru"]
agency_to_special_token_id = {a: tokenizer.vocab[f'[unused{i+1}]'] for i, a in enumerate(agency_list)}
agency_to_discr_target = {a: i for i, a in enumerate(sorted(agency_list))}

In [9]:
tass_data = AgencyTitleDatasetGeneration(test_records, tokenizer, filter_agencies=["ТАСС"], 
                                         agency_to_special_token_id=agency_to_special_token_id)

rt_data = AgencyTitleDatasetGeneration(test_records, tokenizer, filter_agencies=["RT на русском"], 
                                       agency_to_special_token_id=agency_to_special_token_id)

lenta_data = AgencyTitleDatasetGeneration(test_records, tokenizer, filter_agencies=["lenta.ru"], 
                                          agency_to_special_token_id=agency_to_special_token_id)

other_data = AgencyTitleDatasetGeneration(test_records, tokenizer, 
                                          filter_agencies=['Невские Новости', 'ФедералПресс', 'РИАМО', 'Известия'], 
                                          agency_to_special_token_id=agency_to_special_token_id)

In [10]:
len(tass_data), len(rt_data), len(lenta_data), len(other_data)

(3989, 1263, 1613, 2977)

In [11]:
dataset = rt_data
target_agency = 'ТАСС'
max_tokens_title=48

In [12]:
@torch.no_grad()
def eval_on_dataset(dataset, target_agency, n=1200, max_tokens_title=48):
    y_pred = []


    for i in tqdm.tqdm(np.random.choice(len(dataset), n, replace=False), total=n):
        x = dataset[i]
        x['input_ids'][1] = agency_to_special_token_id[target_agency]

        gen_ids = model.generate(
            input_ids=x['input_ids'].cuda().unsqueeze(0),
            attention_mask=x['attention_mask'].cuda().unsqueeze(0),
            decoder_start_token_id=model.config.decoder.pad_token_id,
            min_length=7,
            max_length=20,
            num_beams=6
        )
        
        gen_title = [tokenizer.decode(x, skip_special_tokens=True) for x in gen_ids][0]
        
        inp = tokenizer(gen_title, 
            add_special_tokens=True, max_length=max_tokens_title,
            padding='max_length', truncation=True
        )

        logits = discriminator(input_ids=torch.LongTensor(inp['input_ids']).cuda().unsqueeze(0), 
                               attention_mask=torch.LongTensor(inp['attention_mask']).cuda().unsqueeze(0))[0]
        y_pred.append(torch.argmax(logits).item())
    
    return y_pred.count(agency_to_discr_target[target_agency]) / n

In [None]:
%%time

result = defaultdict(list)

for i, data in enumerate((tass_data, rt_data, lenta_data, other_data)):
    for target_a in agency_list:
        acc = eval_on_dataset(data, target_a)
        result[i].append(acc)

 89%|████████▉ | 1069/1200 [04:59<00:36,  3.59it/s]

In [None]:
df = pd.DataFrame(columns=['Data'] + agency_list)


for i, dataset_name in enumerate(('TASS', 'RT', 'Lenta', 'Other')):
    row = {'Data': dataset_name}
    for j, a in enumerate(agency_list):
        row[a] = result[i][j]      
    df = df.append(row, ignore_index=True)
df

In [19]:
df.index = df.Data

In [21]:
df.drop('Data', axis=1, inplace=True)

In [23]:
df.round(2)

Unnamed: 0_level_0,ТАСС,RT на русском,lenta.ru
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
TASS,0.88,0.08,0.02
RT,0.29,0.64,0.07
Lenta,0.35,0.23,0.43
Other,0.72,0.17,0.09
