In [1]:
from sklearn.cluster import AgglomerativeClustering
from _jsonnet import evaluate_file as jsonnet_evaluate_file
from transformers import AutoTokenizer, EncoderDecoderModel, logging, AutoModel, BertTokenizer
from models.bottleneck_encoder_decoder import BottleneckEncoderDecoderModel

In [2]:
from evaluation.clustering_utils import get_gold_markup, get_data_to_cluster, doc2vec_bert, calc_clustering_metrics

In [3]:
def get_text_to_vector_func(text_to_vec_func, model, tokenizer):
    if text_to_vec_func == 'bert-MeanSum':
        return lambda doc: doc2vec_bert(doc, model, tokenizer, 'MeanSum')
    elif text_to_vec_func == 'bert-FirstCLS':
        return lambda doc: doc2vec_bert(doc, model, tokenizer, 'FirstCLS')
    else:
        raise NotImplementedError

In [4]:
logging.set_verbosity_info()

In [5]:
model = BottleneckEncoderDecoderModel.from_pretrained('/data/aobuhtijarov/models/gen_title_bottleneck_v2/checkpoint-10000/')

loading configuration file /data/aobuhtijarov/models/gen_title_bottleneck_v2/checkpoint-10000/config.json
Model config EncoderDecoderConfig {
  "architectures": [
    "BottleneckEncoderDecoderModel"
  ],
  "decoder": {
    "_name_or_path": "/data/aobuhtijarov/models/pretrained_dec_6_layers",
    "add_cross_attention": true,
    "architectures": [
      "BertModel"
    ],
    "attention_probs_dropout_prob": 0.2,
    "bad_words_ids": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "decoder_start_token_id": null,
    "directionality": "bidi",
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "eos_token_id": null,
    "finetuning_task": null,
    "gradient_checkpointing": false,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.2,
    "hidden_size": 768,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "is_decoder": true,
    "is_encoder_decod

In [6]:
gold_markup_file = '/data/aobuhtijarov/datasets/telegram_news/ru_pairs_raw_markup.tsv'
clustering_data_file = '/data/aobuhtijarov/datasets/telegram_news/ru_clustering_data.jsonl'

In [7]:
### BEWARE

tokenizer_model_path = '/data/aobuhtijarov/models/rubert_cased_L-12_H-768_A-12_pt'

tokenizer = BertTokenizer.from_pretrained(tokenizer_model_path, do_lower_case=False, do_basic_tokenize=False)

Model name '/data/aobuhtijarov/models/rubert_cased_L-12_H-768_A-12_pt' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, TurkuNLP/bert-base-finnish-cased-v1, TurkuNLP/bert-base-finnish-uncased-v1, wietsedv/bert-base-dutch-cased). Assuming '/data/aobuhtijarov/models/rubert_cased_L-12_H-768_A-12_pt' is a path, a model identifier, or url to a directory containing tokenizer files.
Didn't find file /data/aobuhtijarov/models/rubert_cased_L-12_H-768_A-12_pt/added_tokens.json. We won't load it.
Didn't find file /data/aobuhtijarov/models/rubert_cased_L-12_H

In [8]:
max_tokens_text=250
max_tokens_title=48
text_to_vec_func='bert-FirstCLS'

### Train dataset

In [10]:
from custom_datasets.agency_title_dataset import AgencyTitleDataset
from readers.tg_reader import tg_reader
import tqdm
import random
import numpy as np

In [11]:
agency_list = ["ТАСС", "РИАМО", "RT на русском", "Новости Мойка78"]

In [12]:
train_file = '/data/aobuhtijarov/datasets/telegram_news/ru_tg_1101_0510.jsonl'

In [13]:
train_sample_rate = 1.0

train_records = [r for r in tqdm.tqdm(tg_reader(train_file, agency_list)) if random.random() <= train_sample_rate]

18032it [05:40, 52.93it/s] 


In [15]:
from evaluation.clustering_utils import get_data_to_cluster

In [16]:
url2record, filename2url = get_data_to_cluster(clustering_data_file)
setattr(tokenizer, 'max_tokens_text', max_tokens_text)
text_to_vector_func = get_text_to_vector_func(text_to_vec_func, model, tokenizer)


In [18]:
train_dataset = AgencyTitleDataset(
    train_records,
    tokenizer,
    agency_list,
    max_tokens_text=max_tokens_text,
    max_tokens_title=max_tokens_title
)

In [19]:
len(train_dataset)

18032

In [22]:
total_articles = len(train_dataset)

train_embeds = np.zeros((total_articles, 768))

In [24]:
for i in tqdm.trange(total_articles):
    record = train_dataset.get_strings(i)
    text = record["title"] + ' ' + record["text"]
    text = text.lower().replace('\xa0', ' ')
    train_embeds[i] = text_to_vector_func(text).detach().numpy().ravel()

100%|██████████| 18032/18032 [48:26<00:00,  6.20it/s] 


In [25]:
train_embeds

array([[-0.67644489, -0.63231045,  0.00695466, ...,  1.8090775 ,
         0.84440309, -0.50852436],
       [ 0.1376366 , -0.21560137, -1.94320238, ...,  0.68795353,
         1.11395466, -0.95550954],
       [-0.59580934,  0.22169754, -0.86293459, ...,  0.10177908,
        -0.1572016 , -0.29767632],
       ...,
       [-0.31069651,  0.70517147, -0.94575864, ...,  1.59404671,
        -0.26357779, -0.19272354],
       [-0.38505411,  0.04302118,  0.53257304, ...,  0.49253193,
        -0.77190143, -0.01345761],
       [ 1.25519109, -0.28259528, -0.28828409, ...,  0.35324875,
         0.18214597,  0.1251189 ]])

### Fitting Aggl Clust

In [26]:
clustering_model = AgglomerativeClustering(
    n_clusters=None,
    distance_threshold=0.18,
    linkage="single",
    affinity="cosine"
)

clustering_model.fit(train_embeds)
train_labels = clustering_model.labels_

In [27]:
train_labels.shape

(18032,)

In [28]:
len(set(train_labels))

14793

In [29]:
from collections import defaultdict
cluster_to_inds = defaultdict(list)

In [30]:
for i, label in enumerate(train_labels):
    cluster_to_inds[label].append(i)

In [31]:
def print_articles(inds, dataset):
    for i in inds:
        print(f'Agency: {dataset.get_strings(i)["agency"]}\nTitle:\n{dataset.get_strings(i)["title"]}\nText:\n{dataset.get_strings(i)["text"]}\n\n')

In [37]:
with open('clusters.txt', 'w', encoding='utf-8') as f:
    for cl_id in cluster_to_inds:
        f.write(f'\tCluster #{cl_id}:\n\n')
        
        for i in cluster_to_inds[cl_id]:
            f.write(f'Agency: {train_dataset.get_strings(i)["agency"]}\nTitle:\n{train_dataset.get_strings(i)["title"]}\nText:\n{train_dataset.get_strings(i)["text"]}\n\n')
            
        
        f.write('\n\n' + ('-'*50+'\n') * 2 + '\n\n')

In [39]:
cnt = 0
for cl_id in cluster_to_inds:
    if len(cluster_to_inds[cl_id]) != 4:
        continue
        
    a = {train_dataset.get_strings(i)["agency"] for i in cluster_to_inds[cl_id]}
    if a == set(agency_list):
        cnt += 1
        print_articles(cluster_to_inds[cl_id], train_dataset)
        print('\n\n' + ('-'*50+'\n') * 2 + '\n\n')
        
print(cnt)

Agency: ТАСС
Title:
в шереметьево совершил экстренную посадку ssj-100 из-за отказа двигателя
Text:
москва, 5 мая. /тасс/. пассажирский самолет sukhoi superjet 100 совершил экстренную посадку в аэропорту шереметьево из-за отказа двигателя. об этом тасс сообщили в экстренных службах. "после вылета в саратов, по предварительным данным, у самолета отказал левый двигатель. экипаж принял решение о возвращении. в 07:16 он совершил экстренную посадку в штатном режиме", - сказал собеседник агентства. пострадавших нет. в экстренных службах уточнили, что ssj-100 вылетел в саратов около 06:45. в воздухе он пробыл около получаса, сделав несколько кругов перед посадкой. самолет будет осмотрен технической службой. причины произошедшего устанавливаются. в новость внесена правка (07:27 мск) - передается с уточнением пункта вылета, верно - саратов.


Agency: Новости Мойка78
Title:
в шереметьево приземлился самолет после отказа левого двигателя
Text:
самолет sukhoi superjet 100, который летел из саратова

In [48]:
cluster_sizes = [len(x) for x in cluster_to_inds.values()]

In [54]:
from collections import Counter

In [55]:
Counter(cluster_sizes)

Counter({1: 12607,
         2: 1186,
         4: 108,
         9: 12,
         5: 55,
         106: 1,
         3: 287,
         8: 14,
         12: 5,
         24: 2,
         6: 35,
         7: 22,
         15: 3,
         11: 6,
         16: 2,
         10: 5,
         13: 3,
         18: 2,
         38: 1,
         25: 1,
         45: 1,
         14: 4,
         37: 2,
         17: 2,
         28: 1,
         59: 1,
         21: 1,
         20: 1,
         19: 1})