In [1]:
from mda.data.roberta_dataset import RobertaTokenizeDataset
from mda.data.data_collection import DataCollection
from mda.api import predict
from mda.model.roberta import RobertaClassifier
from mda.util import load_json, AUTO_DEVICE
from repo_root import get_full_path
import torch


In [2]:
IN_DOMAINS = [
    "deathpenalty",
    "guncontrol",
    "immigration",
    "samesex",
    "tobacco",
]
OUT_DOMAINS = ["climate"]

### Evaluate the lexicon on labeled, in domain data

In [3]:
test_collection = DataCollection.parse_obj(
    load_json(get_full_path("data/mfc.test.json"))
)
in_domain_test_dataset = RobertaTokenizeDataset(
    batch_size=100,
    num_workers=8,
    collection=test_collection,
    use_domain_strs=IN_DOMAINS,
)


In [4]:
model = RobertaClassifier(
    n_classes=len(test_collection.class_strs),
    n_domains=len(test_collection.domain_strs),
    use_domain_specific_bias=True,
)
model.load_state_dict(
    torch.load(
        get_full_path("wkdir/holdout_domain/mfc/roberta_dsbias/climate/checkpoint.pth")
    )
)
model = model.to(AUTO_DEVICE)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
probs = predict(model, in_domain_test_dataset)

100%|██████████| 20/20 [00:10<00:00,  1.90it/s]


In [6]:
preds = torch.argmax(probs, dim=-1)
acc = (
    (
        preds.cpu()
        == torch.cat(
            [
                batch["class_idx"]
                for batch in in_domain_test_dataset.get_loader(shuffle=False)
            ],
            dim=0,
        )
    )
    * 1.0
).mean()
acc.item()


0.7049999833106995

### Evaluate the lexicon on labeled, out of domain data

In [11]:
train_collection = DataCollection.parse_obj(
    load_json(get_full_path("data/mfc.train.json"))
)
out_domain_test_dataset = RobertaTokenizeDataset(
    batch_size=100,
    num_workers=10,
    collection=train_collection,
    use_domain_strs=OUT_DOMAINS,
)
probs = predict(model, out_domain_test_dataset)
preds = torch.argmax(probs, dim=-1)
acc = (
    (
        preds.cpu()
        == torch.cat(
            [batch["class_idx"] for batch in out_domain_test_dataset.get_loader(False)],
            dim=0,
        )
    )
    * 1.0
).mean()
acc.item()


100%|██████████| 38/38 [00:20<00:00,  1.88it/s]


0.6511198878288269