In [2]:
from mda.data.bow_dataset import BagOfWordsSingleBatchDataset
from mda.data.data_collection import DataCollection
from mda.api import lexicon_predict, train_lexicon
from mda.util import load_json
from repo_root import get_full_path
import torch

In [3]:
IN_DOMAINS = [
    "cs.CL",  # computation and language
    "cs.CV",
    "cs.LG",  # machine learning
    "cs.NE",  # neural
    "cs.SI",  # social and information network
]
OUT_DOMAINS = [
    "cs.AI",
]

### Build a vocabulary

In [4]:
train_collection = DataCollection.parse_obj(
    load_json(get_full_path("data/arxiv.train.json"))
)
train_dataset = BagOfWordsSingleBatchDataset(
    batch_size=-1,
    num_workers=-1,
    collection=train_collection,
    use_domain_strs=IN_DOMAINS,
    vocab_size=5000,
)
vocab = train_dataset.vocab

In [5]:
# vocab size
len(vocab)

5000

In [6]:
# most common vocab words
vocab[:10]

['learning',
 'data',
 'model',
 'models',
 'network',
 'method',
 'methods',
 'using',
 'paper',
 'performance']

### Train a lexicon

In [7]:
lexicon_df = train_lexicon(dataset=train_dataset)

100%|██████████| 5000/5000 [01:13<00:00, 67.90it/s]


In [8]:
lexicon_df

Unnamed: 0,word,upto2008,2009-2014,2015-2018,2019after
0,learning,-5.128001e-02,1.680402e-02,0.000004,0.067732
1,data,-8.412749e-02,1.026811e-02,0.073600,0.041268
2,model,-3.845137e-02,-1.000125e-03,0.038513,0.076289
3,models,-5.908057e-02,8.870528e-03,0.008726,0.121469
4,network,-2.848798e-02,1.328365e-07,-0.000007,0.054982
...,...,...,...,...,...
4995,6d,2.025091e-06,-1.087420e-05,-0.000012,0.000078
4996,equivariance,8.801538e-06,-6.828160e-06,-0.027495,0.088229
4997,timedependent,-3.742034e-07,-6.719351e-07,-0.000011,0.000018
4998,mot,2.620446e-06,-7.150712e-06,-0.000032,0.000089


In [9]:
# words with highest weights in papers of 2019 and after
sorted(
    list(zip(lexicon_df["word"].to_list(), lexicon_df["2019after"].to_list())),
    reverse=True,
    key=lambda x: x[1],
)[:10]


[('covid19', 1.4918851852416992),
 ('bert', 1.1331787109375),
 ('federated', 0.7731612920761108),
 ('transformer', 0.7572790384292603),
 ('selfsupervised', 0.7344726920127869),
 ('transformerbased', 0.6686000823974609),
 ('pandemic', 0.6529555916786194),
 ('fewshot', 0.6176268458366394),
 ('transformers', 0.5596126914024353),
 ('sota', 0.5254027247428894)]

### Evaluate the lexicon on labeled, in domain data

In [17]:
test_collection = DataCollection.parse_obj(
    load_json(get_full_path("data/arxiv.test.json"))
)
in_domain_test_dataset = BagOfWordsSingleBatchDataset(
    batch_size=-1,
    num_workers=-1,
    collection=test_collection,
    use_domain_strs=IN_DOMAINS,
    vocab_override=lexicon_df["word"].to_list(),
)


In [18]:
probs = lexicon_predict(
    lexicon_df=lexicon_df,
    dataset=in_domain_test_dataset,
)
preds = torch.argmax(probs, dim=-1)
acc = (
    (
        preds
        == torch.cat(
            [batch["class_idx"] for batch in in_domain_test_dataset.get_loader()], dim=0
        )
    )
    * 1.0
).mean()
acc.item()


1it [00:00, 118.16it/s]


0.7234599590301514

### Evaluate the lexicon on labeled, out of domain data


In [19]:
out_domain_test_dataset = BagOfWordsSingleBatchDataset(
    batch_size=-1,
    num_workers=-1,
    collection=test_collection,
    use_domain_strs=OUT_DOMAINS,
    vocab_override=lexicon_df["word"].to_list(),
)

In [20]:
probs = lexicon_predict(lexicon_df=lexicon_df, dataset=out_domain_test_dataset)
preds = torch.argmax(probs, dim=-1)
acc = (
    (
        preds
        == torch.cat(
            [batch["class_idx"] for batch in out_domain_test_dataset.get_loader()],
            dim=0,
        )
    )
    * 1.0
).mean()
acc.item()


1it [00:00, 1356.94it/s]


0.5878968238830566

### Predict with the lexicon on partially labeled, out of domain data

Use the subset of samples that are labeled to estimate a class distribution of this unseen domain, then use it for domain-specific bias when predicting.

In [26]:
# build a partially labeled out-of-domain data collection
partially_labeled_collection = DataCollection(
    class_strs=train_collection.class_strs,
    domain_strs=OUT_DOMAINS,
)
samples = [s for s in train_collection.samples.values() if s.domain_str == OUT_DOMAINS[0]]
for sample in samples[:250]:  # only first 250 samples are labeled
    partially_labeled_collection.add_sample(sample)
for sample in samples[250:]:
    sample.class_idx = sample.class_str = None
    partially_labeled_collection.add_sample(sample)
partially_labeled_collection.populate_class_distribution()


In [27]:
# class distribution estimated from the subset of samples that are labeled
partially_labeled_collection.class_dist

{'cs.AI': [0.036000000034240004,
  0.19600000000864,
  0.20400000000736,
  0.56399999994976]}

In [28]:
partially_labeled_dataset = BagOfWordsSingleBatchDataset(
    batch_size=-1,
    num_workers=-1,
    collection=partially_labeled_collection,
    vocab_override=lexicon_df["word"].to_list(),
)
probs = lexicon_predict(
    lexicon_df=lexicon_df,
    dataset=partially_labeled_dataset,
)
preds = torch.argmax(probs, dim=-1)


1it [00:00, 1842.84it/s]


In [56]:
# sample text
list(partially_labeled_collection.samples.values())[5].text

'  The problem of learning Markov equivalence classes of Bayesian network\nstructures may be solved by searching for the maximum of a scoring metric in a\nspace of these classes. This paper deals with the definition and analysis of\none such search space. We use a theoretically motivated neighbourhood, the\ninclusion boundary, and represent equivalence classes by essential graphs. We\nshow that this search space is connected and that the score of the neighbours\ncan be evaluated incrementally. We devise a practical way of building this\nneighbourhood for an essential graph that is purely graphical and does not\nexplicitely refer to the underlying independences. We find that its size can be\nintractable, depending on the complexity of the essential graph of the\nequivalence class. The emphasis is put on the potential use of this space with\ngreedy hill -climbing search\n'

In [57]:
# predicted class
preds[5].item(), partially_labeled_collection.class_strs[preds[5]]

(1, '2009-2014')