In [None]:
import os
import random
import sys
import numpy as np
import argparse
from collections import defaultdict

sys.path.append("../..")

import torch
from torch.utils.data import random_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all
from utils.toolbox import same_seeds, show_settings, record_settings, get_preprocess_document, get_preprocess_document_embs, get_preprocess_document_labels, get_word_embs

torch.set_num_threads(8)

In [None]:
config = {'dataset': '20news', 'target': 'tf-idf', 'seed': 123, 'ratio': 0.8, 'preprocess_config_dir': 'parameters_baseline2', 'encoder': 'mpnet'}

if config['dataset'] == '20news':
    config["min_df"], config['max_df'], config['min_doc_word'] = 62, 1.0, 15
elif config['dataset'] == 'agnews':
    config["min_df"], config['max_df'], config['min_doc_word'] = 425, 1.0, 15
elif config['dataset'] == 'IMDB':
    config["min_df"], config['max_df'], config['min_doc_word'] = 166, 1.0, 15
elif config['dataset'] == 'wiki':
    config["min_df"], config['max_df'], config['min_doc_word'] = 2872, 1.0, 15

In [None]:
show_settings(config)
same_seeds(config["seed"])

# data preprocessing
unpreprocessed_corpus, preprocessed_corpus = get_preprocess_document(**config)

# generating document embedding
doc_embs, doc_model, device = get_preprocess_document_embs(preprocessed_corpus, config['encoder'])
print("Get doc embedding done.")

In [None]:
random.shuffle(preprocessed_corpus)

train_size = int(len(preprocessed_corpus) * config['ratio'])

_, voc_train = get_preprocess_document_labels(preprocessed_corpus[:train_size])
_, voc_test = get_preprocess_document_labels(preprocessed_corpus[train_size:])

In [None]:
voc_train = voc_train['tf-idf']
voc_test = voc_test['tf-idf']

In [None]:
wordNotInTrain = [w for w in voc_test if w not in voc_train]
print("Training voc size:{}".format(len(voc_train)))
print("Testing voc size:{}".format(len(voc_test)))
print("Word in test but not in train:{}".format(len(wordNotInTrain)))
print("Word missing percentage:{}".format(len(wordNotInTrain) / len(voc_train)))