In [1]:

import sys
sys.path.append("..")

In [2]:
import argparse
import glob
import json
import logging
import os

import torch

from src.data.data_reader import load_and_cache_examples
from src.eval.model_eval import evaluate
from src.model.generation import generate
from src.model.predictions_aggregator import aggregate_predictions
from src.training.trainer import train, _sorted_checkpoints, set_seed
from src.data.synthesis.influence_graph import Rels
from transformers import (
    WEIGHTS_NAME,
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
)

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")

In [4]:
print(tokenizer.encode("RELATION-HURT-BY 1-HOP are you??"))

[16448, 6234, 12, 39, 4261, 51, 12, 17513, 352, 12, 39, 3185, 389, 345, 3548]


In [5]:
tokenizer.add_special_tokens({"additional_special_tokens": list(Rels.special_tokens.values())})

9

In [6]:
print(tokenizer.encode("RELATION-HURT-BY 1-HOP are you??"))

[50260, 50261, 389, 345, 3548]


In [7]:
tokenizer.added_tokens_decoder

{50257: 'RELATION-HELPS',
 50258: 'RELATION-HURTS',
 50259: 'RELATION-HELPED-BY',
 50260: 'RELATION-HURT-BY',
 50261: '1-HOP',
 50262: '2-HOP',
 50263: '3-HOP',
 50264: '<PARA>',
 50265: '<NODE>'}

In [22]:
len(tokenizer.decoder)

50257

In [16]:
list(tokenizer.decoder.values())[-10:]

['Ġ(/',
 'âĢ¦."',
 'Compar',
 'Ġamplification',
 'ominated',
 'Ġregress',
 'ĠCollider',
 'Ġinformants',
 'Ġgazed',
 '<|endoftext|>']

In [45]:
model = GPT2LMHeadModel.from_pretrained("../output/vanilla_reversed_multihop/checkpoint-24.87569-1000/")

In [46]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50264, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2):