In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '6, 7'

In [2]:
import json
from os import path
from argparse import ArgumentParser
from ud_treebank_utils.reader import UDTreebankReader
from os import path
from config import cfg as default_cfg
from utils import *

cfg = default_cfg
device = cfg['device']

In [3]:
# Create Reader for English UD Treebank
treebank       = UDTreebankReader.get_treebank_file(cfg['language'], embedding=cfg['embedding'])
treebank_valid = UDTreebankReader.get_treebank_file(cfg['language'], embedding=cfg['embedding'], valid_file=True)
treebank_test  = UDTreebankReader.get_treebank_file(cfg['language'], embedding=cfg['embedding'], test_file=True)

print(f"Treebank: {treebank}")
print(f"Valid Treebank: {treebank_valid}")
print(f"Test Treebank: {treebank_test}")

words = UDTreebankReader.read([treebank])
# 一个词和它对应的属性的值 Shaikh({'Number': 'SG', 'Part of Speech': 'PROPN'})
words_valid = UDTreebankReader.read([treebank_valid])
words_test = UDTreebankReader.read([treebank_test])

counters = [
    UDTreebankReader.get_attribute_value_counter(words),
    UDTreebankReader.get_attribute_value_counter(words_valid),
    UDTreebankReader.get_attribute_value_counter(words_test)
]

attr_vals_dict = UDTreebankReader.get_attributes_to_values_dict_from_counters(counters, min_count=100)
# 字典: 要探测的属性的类别与值
# {'Number': ['SG', 'PL'], 'Part of Speech': ['PROPN', 'ADJ', 'N', 'V', 'V.PTCP', 'NUM', 'ADV'], 'Tense': ['PST', 'PRS']}

reader = UDTreebankReader(words, attr_vals_dict)
reader_valid = UDTreebankReader(words_valid, attr_vals_dict)
reader_test = UDTreebankReader(words_test, attr_vals_dict)

Treebank: /data2/zhihao/intrinsic-probing/data/ud/ud-treebanks-v2.1/UD_English/en-um-train-bert-base-multilingual-cased.pkl
Valid Treebank: /data2/zhihao/intrinsic-probing/data/ud/ud-treebanks-v2.1/UD_English/en-um-dev-bert-base-multilingual-cased.pkl
Test Treebank: /data2/zhihao/intrinsic-probing/data/ud/ud-treebanks-v2.1/UD_English/en-um-test-bert-base-multilingual-cased.pkl


In [4]:
from ud_treebank_utils.cache import AttributeValueGaussianCache
from ud_treebank_utils.trainer import MLETrainer, MAPTrainer

print("Building caches...")
if cfg['trainer'] == "mle":
    trainer = MLETrainer()
elif cfg['trainer'] == "map":
    trainer = MAPTrainer.from_data(device=device)

cache_attr_vals_dict = attr_vals_dict

cache = AttributeValueGaussianCache(
    reader.get_words(), trainer=trainer, attribute_values_dict=cache_attr_vals_dict, diagonal_only=cfg['diagonalize'])
cache_valid = AttributeValueGaussianCache(
    reader_valid.get_words(), trainer=trainer, attribute_values_dict=cache_attr_vals_dict, diagonal_only=cfg['diagonalize'])
cache_test = AttributeValueGaussianCache(
    reader_test.get_words(), trainer=trainer, attribute_values_dict=cache_attr_vals_dict, diagonal_only=cfg['diagonalize'])

if cfg['attribute'] is not None:
    attributes_queue = [cfg['attribute']]
else:
    attributes_queue = list(attr_vals_dict.keys())

    ignore_list = ["Part of Speech"]
    attributes_queue = [x for x in attributes_queue if x not in ignore_list]

print(f"Attributes queue: {attributes_queue}")

Building caches...


  embeddings_scatter_prior = (embeddings_mean - self.mu) @ (embeddings_mean - self.mu).T
Build Cache: 100%|██████████| 3/3 [00:12<00:00,  4.27s/it]
Build Cache: 100%|██████████| 3/3 [00:06<00:00,  2.07s/it]
Build Cache: 100%|██████████| 3/3 [00:06<00:00,  2.14s/it]

Attributes queue: ['Number', 'Tense']





In [6]:
from ud_treebank_utils.models import ValueModel
from ud_treebank_utils.runner import Runner

for attribute in attributes_queue:
    # 安全检查
    if not cache.has_attribute(attribute):
        print(f"Attribute '{attribute}' does not exist in this dataset/language combination.")
        exit()
    
    if len(cache.get_all_attribute_values(attribute)) < 2:
        print(f"Attribute '{attribute}' has less that 2 values in this dataset/language combination.")
        exit()

    if cfg['log_wandb'] is True:  
        import wandb
        tags = [cfg['language'], cfg['embedding'], attribute]
        cfg['attribute'] = attribute
        if cfg['tag'] is not None:
            tags.append(cfg['tag'])

        run = wandb.init(project="interp-bert", tags=tags, config=cfg, reinit=True)  # 原代码中使用的是config=args, 所以wandb可能不可用?
        run.name = f"{attribute} ({cfg['embedding']}-{cfg['language']}) ("
        if cfg['diagonalize']:
            run.name += f"{cfg['selection_criterion']}, diag)"
        else:
            run.name += f"{cfg['selection_criterion']})"

        run.name += f" [{wandb.run.id}]"
        run.save()

    print("Computing MI for '{}'. Possible values: {}".format(
            attribute, cache.get_all_attribute_values(attribute)))
    
    # 创建Value Model
    attribute_values = cache.get_all_attribute_values(attribute) # list: 对应属性的所有可能值 ['SG', 'PL']

    value_model = ValueModel.from_cache_entries(
        [cache.get_cache_entry(attribute, v) for v in attribute_values], device=device
    )

    value_model_valid = ValueModel.from_cache_entries(
        [cache_valid.get_cache_entry(attribute, v) for v in attribute_values], device=device
    )

    value_model_test = ValueModel.from_cache_entries(
        [cache_test.get_cache_entry(attribute, v) for v in attribute_values], device=device
    )

    runner_config = {
        "reader": reader,
        "reader_valid": reader_valid,
        "reader_test": reader_test,
        "device": device,
        "cache": cache,
        "cache_valid": cache_valid,
        "cache_test": cache_test,
        "value_model": value_model,
        "value_model_valid": value_model_valid,
        "value_model_test": value_model_test,
        "attribute": attribute,
        "selection_criterion": cfg['selection_criterion'],
    }

    if cfg['log_wandb'] is True:
        runner_config["wandb_run"] = run

    total_dims = reader.get_dimensionality() # 768
    runner = Runner(runner_config)
    selected_results = runner.main_loop(max_iter=cfg['max_iter'])
    
    # Draw graphs
    graphs = runner.draw_graphs(selected_results)
    mi_fig = graphs["mi"]
    normalized_mi_fig = graphs["normalized_mi"]
    accuracy_fig = graphs["accuracy"]
    scatter_fig = runner.plot_dims(
        selected_results[0]["candidate_dim"], selected_results[1]["candidate_dim"], test_data=True,
        log_prob_dim_pool=list(selected_results[-1]["candidate_dim_pool"])
    )

    
    if cfg['show_charts'] is True:
        mi_fig.show()
        normalized_mi_fig.show()
        accuracy_fig.show()
        scatter_fig.show()

Computing MI for 'Number'. Possible values: ['SG', 'PL']



Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:261.)

100%|██████████| 768/768 [00:19<00:00, 40.40it/s]


Selected '387'
	I(H_[387]; V_a): 0.04658160232393693
	Accuracy: 0.8424102067947388
	Confusion Matrix: 
[[7208, 58], [1302, 62]]



100%|██████████| 767/767 [00:17<00:00, 44.04it/s]


Selected '160'
	I(H_[160, 387]; V_a): 0.0941754299385682
	Accuracy: 0.8412514328956604
	Confusion Matrix: 
[[7039, 227], [1143, 221]]



100%|██████████| 766/766 [00:21<00:00, 35.03it/s]


Selected '453'
	I(H_[160, 387, 453]; V_a): 0.13157073437588895
	Accuracy: 0.8524913191795349
	Confusion Matrix: 
[[7000, 266], [1007, 357]]



100%|██████████| 765/765 [00:26<00:00, 28.97it/s]


Selected '450'
	I(H_[160, 450, 387, 453]; V_a): 0.17018872043860606
	Accuracy: 0.8626883029937744
	Confusion Matrix: 
[[6949, 317], [868, 496]]



100%|██████████| 764/764 [00:25<00:00, 29.70it/s]


Selected '223'
	I(H_[160, 450, 387, 453, 223]; V_a): 0.20026389471371775
	Accuracy: 0.874275803565979
	Confusion Matrix: 
[[6967, 299], [786, 578]]

Full Vector Accuracy: 0.9646582007408142



torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be removed in a future PyTorch release.
L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A).mH
This transform will produce equivalent results for all valid (symmetric positive definite) inputs. (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1692.)



Computing MI for 'Tense'. Possible values: ['PST', 'PRS']


100%|██████████| 768/768 [00:09<00:00, 80.74it/s]


Selected '477'
	I(H_[477]; V_a): 0.17616291681608764
	Accuracy: 0.7048412561416626
	Confusion Matrix: 
[[740, 241], [326, 614]]



100%|██████████| 767/767 [00:14<00:00, 52.15it/s]


Selected '179'
	I(H_[179, 477]; V_a): 0.3174778493891156
	Accuracy: 0.771473228931427
	Confusion Matrix: 
[[813, 168], [271, 669]]



100%|██████████| 766/766 [00:18<00:00, 40.45it/s]


Selected '753'
	I(H_[753, 179, 477]; V_a): 0.44033135860377326
	Accuracy: 0.8172826766967773
	Confusion Matrix: 
[[802, 179], [172, 768]]



100%|██████████| 765/765 [00:25<00:00, 30.20it/s]


Selected '464'
	I(H_[464, 753, 179, 477]; V_a): 0.4808665583377604
	Accuracy: 0.8375846147537231
	Confusion Matrix: 
[[826, 155], [157, 783]]



100%|██████████| 764/764 [00:25<00:00, 30.48it/s]


Selected '554'
	I(H_[464, 753, 179, 554, 477]; V_a): 0.542170257108153
	Accuracy: 0.8599687814712524
	Confusion Matrix: 
[[855, 126], [143, 797]]

Full Vector Accuracy: 0.985424280166626
