In [1]:
from edge_probing_utils import (
    JiantDatasetSingleSpan,
    JiantDatasetTwoSpan
    )

import edge_probing as ep
import torch
import torch.nn as nn
import torch.utils.data as data
from transformers import AutoModel, AutoTokenizer


**Setup:**

In [2]:
tasks = [
    #"ner", 
    #"semeval",
    "coref",
    #"sup-squad",
    #"ques",
    #"sup-babi",
    #"sup-hotpot",
    ]

task_type = {
    "ner": "single_span", 
    "semeval": "single_span",
    "coref": "two_span",
    "sup-squad": "two_span",
    "ques": "single_span",
    "sup-babi": "two_span",
    "sup-hotpot": "two_span",
    }

models = [
    "bert-base-uncased", 
    "csarron/bert-base-uncased-squad-v1"
    ]

task_label_to_id = {
    "coref": {"0": 0, "1": 1},
    }

In [4]:
import torch_xla_py.xla_model as xm
import os

loss_function = nn.BCELoss()
batch_size = 32
num_layers = range(1,13,2)
num_workers = 8

#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = xm.xla_device()
device = torch.device("cpu")

# Disable warnings.
os.environ["TOKENIZERS_PARALLELISM"] = "false"

for model in models:
    tokenizer = AutoTokenizer.from_pretrained(model)
    for task in tasks:
        
        label_to_id = task_label_to_id[task]
        train_data = ep.tokenize_jiant_dataset(
            tokenizer,
            *(ep.read_jiant_dataset(f"../data/{task}/small/train.jsonl")),
            label_to_id,
            device=device
            )
        val_data = ep.tokenize_jiant_dataset(
            tokenizer,
            *(ep.read_jiant_dataset(f"../data/{task}/small/val.jsonl")),
            label_to_id,
            device=device
            )
        if task_type[task] == "single_span":
            train_data = JiantDatasetSingleSpan(train_data)
            val_data = JiantDatasetSingleSpan(val_data)
        elif task_type[task] == "two_span":
            train_data = JiantDatasetTwoSpan(train_data)
            val_data = JiantDatasetTwoSpan(val_data)
        train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
        val_loader = data.DataLoader(val_data, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
        ep.probing(
            train_loader,
            val_loader,
            model,
            num_layers,
            loss_function,
            task_type[task],
            device=device
            )


100%|██████████| 260/260 [00:00<00:00, 102588.81it/s]
100%|██████████| 500/500 [00:00<00:00, 6286.20it/s]
100%|██████████| 31/31 [00:00<00:00, 50299.20it/s]
100%|██████████| 51/51 [00:00<00:00, 4315.13it/s]Reading ../data/coref/small/train.jsonl
Tokenizing
Reading ../data/coref/small/val.jsonl
Tokenizing
Probing layer 1 of 11

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertEdgeProbingTwoSpan: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'bert.encoder.layer.1.attention.self.query.weight', 'bert.encoder.layer.1.attention.self.query.bias', 'bert.encoder.layer.1.attention.self.key.weight', 'bert.encoder.layer.1.attention.self.key.bias', 'bert.encode