In [1]:
import argparse
import json
import random
import tqdm
import torch
import wandb
import numpy as np
import os

from transformers import BertTokenizer, AutoModelForSequenceClassification, logging
from custom_datasets import LentaRiaDataset

In [2]:
tokenizer_path = '/home/aobuhtijarov/models/rubert_cased_L-12_H-768_A-12_pt/'
discriminator_path = '/home/aobuhtijarov/models/discriminator_on_clusters_from_rubert/'
paired_discr_path = '/home/aobuhtijarov/models/paired_clf_from_rubert/checkpoint-5000/'

In [3]:
tokenizer = BertTokenizer.from_pretrained(tokenizer_path, do_lower_case=False, do_basic_tokenize=False)

discriminator = AutoModelForSequenceClassification.from_pretrained(discriminator_path,
                                                                   num_labels=2, output_attentions=True)
discriminator.eval()
discriminator.cuda()

paired_discriminator = AutoModelForSequenceClassification.from_pretrained(paired_discr_path,
                                                                          num_labels=2, output_attentions=True)
paired_discriminator.eval()
paired_discriminator.cuda();

In [4]:
agency_list = ["РИА Новости", "lenta.ru"]
agency_to_special_token_id = {a: tokenizer.vocab[f'[unused{i+1}]'] for i, a in enumerate(agency_list)}
agency_to_discr_target = {a: i for i, a in enumerate(sorted(agency_list))}

In [5]:
def reader(path):
    with open(path, 'r') as f:
        for line in f:
            yield json.loads(line.strip())
            
            
records = [r for r in tqdm.tqdm(reader('../../datasets/full_lenta_ria.test.jsonl'))]

2000it [00:00, 5565.62it/s]


In [6]:
ria_style_headlines = [r['ria_title'] for r in records]
lenta_style_headlines = [r['lenta_title'] for r in records]

In [11]:
mode = 'HARD'
mode = 'SOFT'

device = 'cuda'

In [18]:
with torch.no_grad():
    paired_discr_ok = 0
    disc_ok = 0
    
    for ria_title, lenta_title in tqdm.tqdm(zip(ria_style_headlines, lenta_style_headlines), 
                                            total=len(ria_style_headlines)):            
        inputs_0 = tokenizer(
            ' [SEP] '.join([lenta_title, ria_title]),
            add_special_tokens=True,
            max_length=100,
            padding="max_length",
            truncation=True
        )
        
        inputs_1 = tokenizer(
            ' [SEP] '.join([ria_title, lenta_title]),
            add_special_tokens=True,
            max_length=100,
            padding="max_length",
            truncation=True
        )
        
        # Paired discr
        logits_0 = paired_discriminator(input_ids=torch.LongTensor(inputs_0['input_ids']).to(device).unsqueeze(0), 
                               attention_mask=torch.LongTensor(inputs_0['attention_mask']).to(device).unsqueeze(0))[0]
        pred_0 = torch.argmax(logits_0).item()
        
        logits_1 = paired_discriminator(input_ids=torch.LongTensor(inputs_1['input_ids']).to(device).unsqueeze(0), 
                               attention_mask=torch.LongTensor(inputs_1['attention_mask']).to(device).unsqueeze(0))[0]
        pred_1 = torch.argmax(logits_1).item()
        
        paired_discr_ok += int(pred_0 == 0 and pred_1 == 1)
        
        # Vanilla discr
        inputs_ria = tokenizer(
            ria_title,
            add_special_tokens=True,
            max_length=48,
            padding="max_length",
            truncation=True
        )
        
        inputs_lenta = tokenizer(
            lenta_title,
            add_special_tokens=True,
            max_length=48,
            padding="max_length",
            truncation=True
        )
        
        logits_ria = discriminator(input_ids=torch.LongTensor(inputs_ria['input_ids']).to(device).unsqueeze(0), 
                               attention_mask=torch.LongTensor(inputs_ria['attention_mask']).to(device).unsqueeze(0))[0][0]
        pred_ria = torch.argmax(logits_ria).item()
        
        logits_lenta = discriminator(input_ids=torch.LongTensor(inputs_lenta['input_ids']).to(device).unsqueeze(0), 
                               attention_mask=torch.LongTensor(inputs_lenta['attention_mask']).to(device).unsqueeze(0))[0][0]
        pred_lenta = torch.argmax(logits_lenta).item()
        
        if mode == 'HARD':
            disc_ok += int(pred_ria == agency_to_discr_target['РИА Новости'] and \
                           pred_lenta == agency_to_discr_target['lenta.ru'])
        elif mode == 'SOFT':
            disc_ok += int(logits_ria[agency_to_discr_target['lenta.ru']] < \
                           logits_lenta[agency_to_discr_target['lenta.ru']])
        

100%|██████████| 2000/2000 [01:04<00:00, 31.01it/s]


In [19]:
round(paired_discr_ok / 20, 2)

91.55

In [20]:
round(disc_ok / 20, 2)

92.9