In [1]:
import wandb
wandb.init(project="EntitiesAsExperts")

wandb.config.device = "cpu"

[34m[1mwandb[0m: Currently logged in as: [33merolm_a[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.17 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [2]:
import datasets

squad_metric, squad_v2_metric = datasets.load_metric('squad'), datasets.load_metric('squad_v2')

In [3]:
import datasets
from tools.dataloaders import SQuADDataloader
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import numpy as np
np.random.seed(42)

squad_dataset = SQuADDataloader()


def squad_collate_fn(rows):
    keys = rows[0].keys()
    return {key: [row[key] for row in rows] for key in keys}

squad_train_dataset = squad_dataset.train_dataset

FULL_FINETUNING=False
if not FULL_FINETUNING:
    squad_dev_size = int(0.1*len(squad_dataset.train_dataset))
    squad_dev_indices = np.random.choice(len(squad_dataset.train_dataset), size=squad_dev_size)
    squad_train_sampler = SubsetRandomSampler(squad_dev_indices,
                                              generator=torch.Generator().manual_seed(42))
    squad_train_dataloader = DataLoader(squad_train_dataset,
                                        sampler=squad_train_sampler,
                                        batch_size=wandb.config.squad_batch_size,
                                        collate_fn=squad_collate_fn)

else:
    squad_train_dataloader = DataLoader(squad_train_dataset,
                                        batch_size=wandb.config.squad_batch_size,
                                        collate_fn=squad_collate_fn)

squad_validation_dataset = squad_dataset.validation_dataset
squad_validation_dataloader = DataLoader(squad_validation_dataset,
                                         batch_size=wandb.config.squad_batch_size,
                                         collate_fn=squad_collate_fn)

Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7)
Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-091847320c309abd.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7/cache-68d7f72cecbe1b9d.arrow


In [4]:
from models import EaEForQuestionAnswering, EntitiesAsExperts
from models.training import train_model, get_optimizer, get_schedule, MetricWrapper, load_model
from tools.dataloaders import WikipediaCBOR
from models import EntitiesAsExperts, EaEForQuestionAnswering

from transformers import BertForMaskedLM, BertForTokenClassification
model_masked_lm = BertForMaskedLM.from_pretrained('bert-base-uncased')

l0 = 4
l1 = 8

entity_embedding_size = 256 # TODO: move this to the config zone

wikipedia_cbor = WikipediaCBOR("wikipedia/car-wiki2020-01-01/enwiki2020.cbor", "wikipedia/car-wiki2020-01-01/partitions",
                                       # top 2% most frequent items,  roughly at least 100 occurrences, with a total of  ~ 20000 entities
                                       # cutoff_frequency=0.02, recount=True 
                                       # TODO: is this representative enough?
)


pretraining_model = load_model(EntitiesAsExperts, "pretrained_eae_minimal", model_masked_lm, l0, l1,
                               wikipedia_cbor.max_entity_num, entity_embedding_size)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Loaded from cache




In [5]:
from models.device import get_available_device

DEVICE = get_available_device()
# TODO: make sure that while training a model gets moved to the DEVICE
model_qa = EaEForQuestionAnswering(pretraining_model).to(DEVICE)

wandb.watch(model_qa)

[<wandb.wandb_torch.TorchGraph at 0x7efe04c44668>]

In [7]:
from models.training import train_model, get_optimizer, get_schedule, MetricWrapper
import torch

squad_epochs = wandb.config.squad_epochs

def parse_batch(batch):
    input_ids = torch.tensor(batch['input_ids'])
    attention_mask = torch.FloatTensor(batch['attention_mask'])
    token_type_ids = torch.tensor(batch['token_type_ids'])
    start = torch.tensor(batch['answer_start'])
    end = torch.tensor(batch['answer_end'])
    
    return (input_ids, attention_mask, token_type_ids, start, end), (batch,)

def parse_batch_2(batch):
    input_ids = torch.tensor(batch['input_ids'])
    attention_mask = torch.FloatTensor(batch['attention_mask'])
    token_type_ids = torch.tensor(batch['token_type_ids'])
    start = torch.tensor(batch['answer_start'])
    end = torch.tensor(batch['answer_end'])
    
    return (input_ids, attention_mask, token_type_ids), (batch,)



class SQuADMetric(MetricWrapper):
    def __init__(self, squad_dataset: SQuADDataloader):
        self.squad_dataset = squad_dataset
        self.reset()
    
    def reset(self):
        self.squad_metric = datasets.load_metric('squad')
        self.loss = 0.0
    
    def add_batch(self, inputs, outputs, loss):
        self.loss += float(loss)
        
        batch_input = inputs[-1]

        # outputs = total_loss, answer_start_logits, answer_end_logits
        answer_start_logits = outputs[1].detach().cpu()
        answer_end_logits = outputs[1].detach().cpu()

        answer_starts = torch.argmax(answer_start_logits, 1).tolist()
        answer_ends = torch.argmax(answer_end_logits, 1).tolist()

        input_ids = inputs[0].detach().cpu().tolist()

        prediction_texts = self.squad_dataset.reconstruct_sentences(input_ids, answer_starts, answer_ends)

        predictions = [{
            "id": id,
            "prediction_text": prediction_text,
        } for id, prediction_text in zip(batch_input["id"], prediction_texts)]


        references = [{
            "id": id,
            "answers": answers
        } for id, answers in zip(batch_input["id"], batch_input['answers'])]

        self.squad_metric.add_batch(predictions=predictions, references=references)

    # return validation loss
    def compute(self, epoch: int) -> float:
        metric_loss = self.squad_metric.compute()
        wandb.log({'exact_match': metric_loss['exact_match'],
                     'epoch': epoch,
                     'f1': metric_loss['f1'],
                     'val_loss': self.loss})

my_metric = SQuADMetric(squad_dataset)

optimizer = get_optimizer(model_qa)
scheduler = get_schedule(squad_epochs, optimizer, squad_train_dataloader)

train_model(model_qa.eae, squad_train_dataloader, squad_validation_dataloader,
                parse_batch_2, optimizer, scheduler, squad_epochs, my_metric, gradient_accumulation_factor=1)

  0%|          | 0/1095 [00:00<?, ?it/s]

[tensor([[  101,  4895, 22540,  ...,     0,     0,     0],
        [  101,  1999,  2254,  ...,     0,     0,     0],
        [  101,  2047,  6768,  ...,     0,     0,     0],
        ...,
        [  101,  2039,  2127,  ...,     0,     0,     0],
        [  101,  2144,  1996,  ...,     0,     0,     0],
        [  101, 18847, 10760,  ...,     0,     0,     0]]), tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.]]), tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])]


  0%|          | 0/1095 [00:01<?, ?it/s]


TypeError: 'NoneType' object is not subscriptable

In [12]:
for b in squad_train_dataloader:
    b_input, _ = parse_batch_2(b)
    model_qa.eae(b_input[0], b_input[1])
    

TypeError: 'NoneType' object is not subscriptable

In [10]:
print(model_qa.eae)

EntitiesAsExperts(
  (first_block): TruncatedModel(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): TruncatedEncoder(
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
                  (den

In [8]:
for b in squad_train_dataloader:
    print(b)
    break

{'answer_end': [58, 88, 10, 69, 3, 25, 31, 123], 'answer_start': [56, 87, 10, 68, 3, 23, 27, 123], 'answers': [{'answer_start': [246], 'text': ['women seeking assistance']}, {'answer_start': [448], 'text': ['Miscellaneous objections']}, {'answer_start': [42], 'text': ['hundred']}, {'answer_start': [359], 'text': ['Starr Pass']}, {'answer_start': [7], 'text': ['eight']}, {'answer_start': [118], 'text': ['the Black Death']}, {'answer_start': [129], 'text': ["President's Private Secretary"]}, {'answer_start': [552], 'text': ['Quran']}], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [9]:
for b in squad_validation_dataloader:
    print(b)
    break

{'answer_end': [35, 46, 83, 35, 98, 100, 66, 29], 'answer_start': [34, 45, 80, 34, 98, 97, 63, 27], 'answers': [{'answer_start': [177, 177, 177], 'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']}, {'answer_start': [249, 249, 249], 'text': ['Carolina Panthers', 'Carolina Panthers', 'Carolina Panthers']}, {'answer_start': [403, 355, 355], 'text': ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."]}, {'answer_start': [177, 177, 177], 'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']}, {'answer_start': [488, 488, 521], 'text': ['gold', 'gold', 'gold']}, {'answer_start': [487, 521, 487], 'text': ['"golden anniversary"', 'gold-themed', '"golden anniversary']}, {'answer_start': [334, 334, 334], 'text': ['February 7, 2016', 'February 7', 'February 7, 2016']}, {'answer_start': [133, 133, 133], 'text': ['American Football Conference', 'American Football Conference', 'American Football Conference']}],

In [None]:
for b in squad_validation_dataloader:
    model_input, _ = parse_batch(b)
    print(model_input)
    model_qa(*model_input)
    break

In [16]:
import torch

weights = torch.FloatTensor([1, 2, 3, 4, 5, 6, 7, 8])
topk = torch.topk(weights, k=2)
alpha = torch.nn.functional.softmax(topk.values, dim=0)
N = 1000
d_ent = 3
E = torch.zeros(N, d_ent)

E[topk.indices].T.matmul(alpha).size()

torch.Size([3])

tensor([0.7311, 0.2689])