In [1]:
from mda.data.bow_dataset import BagOfWordsSingleBatchDataset
from mda.data.data_collection import DataCollection
from mda.logreg import lexicon_predict, train_lexicon
from mda.util import load_json
from repo_root import get_full_path
import torch

In [2]:
IN_DOMAINS = [
    "deathpenalty",
    "guncontrol",
    "immigration",
    "samesex",
    "tobacco",
]
OUT_DOMAINS = ["climate"]

### Build a vocabulary

In [3]:
train_collection = DataCollection.parse_obj(
    load_json(get_full_path("data/mfc.train.json"))
)
train_dataset = BagOfWordsSingleBatchDataset(
    batch_size=-1,
    num_workers=-1,
    collection=train_collection,
    use_domain_strs=IN_DOMAINS,
    vocab_size=5000,
)
vocab = train_dataset.vocab

In [4]:
# vocab size
len(vocab)

5000

In [5]:
# most common vocab words
vocab[:10]

['said',
 'gun',
 'would',
 'state',
 'new',
 'court',
 'law',
 'marriage',
 'states',
 'people']

### Train a lexicon

In [6]:
lexicon_df = train_lexicon(dataset=train_dataset)

100%|██████████| 5000/5000 [00:13<00:00, 375.66it/s]


In [7]:
lexicon_df

Unnamed: 0,word,Economic,Capacity and Resources,Morality,Fairness and Equality,"Legality, Constitutionality, Jurisdiction",Policy Prescription and Evaluation,Crime and Punishment,Security and Defense,Health and Safety,Quality of Life,Cultural Identity,Public Sentiment,Political,External Regulation and Reputation,Other
0,said,0.067012,0.179501,-0.144857,-5.015747e-03,-0.003872,-0.344471,0.090793,-0.010305,-0.034390,0.076832,-0.156841,-0.000014,-0.005524,0.239455,2.972137e-05
1,gun,0.000019,0.000013,0.112756,1.091244e-04,-0.123926,-0.209876,-0.073423,-0.000035,0.187286,-0.082601,-0.000007,0.176232,-0.026121,0.000043,-7.774811e-06
2,would,-0.000088,0.060649,-0.000031,-1.287555e-02,-0.002946,0.514685,0.000102,-0.091849,0.008815,-0.008114,-0.232657,-0.139944,0.134798,0.077551,-5.086828e-05
3,state,-0.199858,0.142883,-0.127448,-1.527587e-02,0.080521,0.045492,-0.074152,0.000056,0.000041,-0.000861,-0.021786,0.028198,-0.124993,0.240542,-1.417165e-05
4,new,0.144268,-0.076395,0.000052,-1.417813e-05,0.000038,0.153550,-0.174919,0.085598,-0.055154,-0.000029,0.055116,0.069544,-0.220090,0.000067,-3.173600e-05
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,sworn,0.000054,-0.000019,0.000008,8.142626e-02,0.176276,-0.000097,-0.170792,0.000002,0.000008,0.000008,-0.000010,-0.000019,-0.000065,-0.000008,2.918449e-05
4996,commercials,-0.000060,0.000019,0.000015,6.154087e-05,-0.000066,0.000003,-0.000079,0.000013,0.000010,0.000084,-0.049537,0.000050,0.000008,0.000033,-3.333278e-06
4997,profiling,-0.000007,0.000016,-0.000033,4.399460e-03,0.000137,-0.000019,-0.020404,0.000060,-0.000071,-0.000019,-0.000065,0.000061,0.000027,-0.000033,1.055479e-06
4998,auto,-0.000066,-0.000009,-0.000013,3.308873e-07,0.000060,-0.000113,-0.000004,-0.000022,-0.000031,0.000097,0.000109,0.000016,-0.000057,-0.000002,-6.898679e-08


In [8]:
# words with highest weights in frame "Economics"
sorted(
    list(zip(lexicon_df["word"].to_list(), lexicon_df["Economic"].to_list())),
    reverse=True,
    key=lambda x: x[1],
)[:10]


[('economic', 0.9426248669624329),
 ('business', 0.9138430953025818),
 ('economy', 0.8519132137298584),
 ('financial', 0.8444198966026306),
 ('costs', 0.7670855522155762),
 ('jobs', 0.764188289642334),
 ('budget', 0.7124324440956116),
 ('income', 0.7074215412139893),
 ('tax', 0.6934945583343506),
 ('sales', 0.6821240782737732)]

### Evaluate the lexicon on labeled, in domain data

In [9]:
test_collection = DataCollection.parse_obj(
    load_json(get_full_path("data/mfc.test.json"))
)
in_domain_test_dataset = BagOfWordsSingleBatchDataset(
    batch_size=-1,
    num_workers=-1,
    collection=test_collection,
    use_domain_strs=IN_DOMAINS,
    vocab_override=lexicon_df["word"].to_list(),
)


In [10]:
probs = lexicon_predict(
    lexicon_df=lexicon_df,
    dataset=in_domain_test_dataset,
)
preds = torch.argmax(probs, dim=-1)
acc = (
    (
        preds
        == torch.cat(
            [batch["class_idx"] for batch in in_domain_test_dataset.get_loader()], dim=0
        )
    )
    * 1.0
).mean()
acc.item()


0.597000002861023

### Evaluate the lexicon on labeled, out of domain data


In [11]:
out_domain_test_dataset = BagOfWordsSingleBatchDataset(
    batch_size=-1,
    num_workers=-1,
    collection=test_collection,
    use_domain_strs=OUT_DOMAINS,
    vocab_override=lexicon_df["word"].to_list(),
)

In [12]:
probs = lexicon_predict(lexicon_df=lexicon_df, dataset=out_domain_test_dataset)
preds = torch.argmax(probs, dim=-1)
acc = (
    (
        preds
        == torch.cat(
            [batch["class_idx"] for batch in out_domain_test_dataset.get_loader()],
            dim=0,
        )
    )
    * 1.0
).mean()
acc.item()


0.5249999761581421

### Predict with the lexicon on partially labeled, out of domain data

Use the subset of samples that are labeled to estimate a class distribution of this unseen domain, then use it for domain-specific bias when predicting.

In [13]:
# build a partially labeled out-of-domain data collection
partially_labeled_collection = DataCollection(
    class_strs=train_collection.class_strs,
    domain_strs=["climate"],
)
samples = [s for s in train_collection.samples.values() if s.domain_str == "climate"]
for sample in samples[:250]:  # only first 250 samples are labeled
    partially_labeled_collection.add_sample(sample)
for sample in samples[250:]:
    sample.class_idx = sample.class_str = None
    partially_labeled_collection.add_sample(sample)
partially_labeled_collection.populate_class_distribution()


In [14]:
# class distribution estimated from the subset of samples that are labeled
partially_labeled_collection.class_dist

{'climate': [0.06799999999919999,
  0.32399999984559996,
  0.020000000027999995,
  0.008000000035199998,
  0.020000000027999995,
  0.11999999996799997,
  0.004000000037599999,
  0.016000000030399995,
  0.008000000035199998,
  0.036000000018399994,
  0.04800000001119999,
  0.020000000027999995,
  0.17999999993199997,
  0.12799999996319997,
  3.999999997599999e-11]}

In [15]:
partially_labeled_dataset = BagOfWordsSingleBatchDataset(
    batch_size=-1,
    num_workers=-1,
    collection=partially_labeled_collection,
    vocab_override=lexicon_df["word"].to_list(),
)
probs = lexicon_predict(
    lexicon_df=lexicon_df,
    dataset=partially_labeled_dataset,
)
preds = torch.argmax(probs, dim=-1)


In [16]:
# sample text
list(partially_labeled_collection.samples.values())[1000].text

'Heat eases, but more waves may come\nHeat eases, but more waves may come\nThe heat wave that affected at least 200 million people in the United States during the past week and a half has finally subsided, after shattering or tying thousands of records.\nNationally, 1,966 daily high maximum temperature records were broken or tied this month (through July 23). More impressive, however, are the figures for nighttime lows. A whopping 4,376 highest minimum temperature records were broken or tied through July 23.\nIn the Mid-Atlantic, where the the heat peaked July 22-23:\nl Washington Dulles International Airport broke its all-time record July 22, reaching 105 degrees.\nl Baltimore Washington International Thurgood Matshall Airport reached its second-highest all-time temperature July 22, 106 degrees.\nl Washington Reagan National Airport tied its all-time record high minimum temperature of 84 on July 23 and 24.\nIn many ways, this heat wave exemplified the type of extreme heat events that 

In [17]:
# predicted class
preds[251].item(), partially_labeled_collection.class_strs[preds[1000]]

(1, 'Capacity and Resources')