In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import cloudpickle as pickle
import numpy as np

from pytorch_pretrained_bert import GPT2Model, GPT2Tokenizer, GPT2LMHeadModel

from gpt_model import GPT2SimpleLM

In [2]:
# we use the 117M gpt2
model = GPT2LMHeadModel.from_pretrained("gpt2")

In [3]:
# we use the 117M gpt2
model_medium = GPT2LMHeadModel.from_pretrained("gpt2-medium")

## Modify the tokenizer

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [5]:
speical_tokens = [
    "[PAD]",
    "[SEP]",
    "[EOS]",
]

# add them to encoder
for i in range(len(speical_tokens)):
    tokenizer.encoder[speical_tokens[i]] = 50257 + i

# add them to decoder
for i in range(len(speical_tokens)):
    tokenizer.decoder[50257 + i] = speical_tokens[i]

setattr(tokenizer, "__special_tokens__", speical_tokens)

In [6]:
tokenizer.encode("URL")

[21886]

In [7]:
torch.save(tokenizer, "special3_gpt2_tokenizer.pkl")

## Modify Model

In [8]:
def random_interpolate(x):
    return x[np.random.randint(50257, size=20), :].mean(0) + \
            torch.randn(x.shape[-1]) * 0.01

## Small Model

In [9]:
num_special_tokens = len(speical_tokens)

# copy the original embedding
new_embedding = nn.Embedding(model.config.vocab_size + num_special_tokens, model.config.n_embd)
new_embedding.weight.data[:model.config.vocab_size, :] = model.transformer.wte.weight.data

In [10]:
# for the first three, use random interpolate
for i in range(3):
    new_embedding.weight.data[model.config.vocab_size+i, :] = random_interpolate(new_embedding.weight.data)

In [11]:
model.transformer.wte = new_embedding
model.lm_head.decoder.weight = model.transformer.wte.weight

In [12]:
class GPT2SmallConfig:
    vocab_size = 50257 + len(speical_tokens)
    n_special = len(speical_tokens)
    n_positions = 1024
    n_ctx = 1024
    n_embd = 768
    n_layer = 12
    n_head = 12
    resid_pdrop = 0.1
    embd_pdrop = 0.1
    attn_pdrop = 0.1
    layer_norm_epsilon = 1e-5
    initializer_range = 0.02
    gradient_checkpointing = False

In [13]:
model_states = model.state_dict()
model_states = {k: v for k, v in model_states.items() if '.attn.bias' not in k}

new_model_small = GPT2SimpleLM(GPT2SmallConfig)
new_model_small.load_state_dict(model_states)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

## Medium Model

In [14]:
num_special_tokens = len(speical_tokens)

# copy the original embedding
new_embedding = nn.Embedding(model_medium.config.vocab_size + num_special_tokens, model_medium.config.n_embd)
new_embedding.weight.data[:model_medium.config.vocab_size, :] = model_medium.transformer.wte.weight.data

In [15]:
# for the first three, use random interpolate
for i in range(3):
    new_embedding.weight.data[model_medium.config.vocab_size+i, :] = random_interpolate(new_embedding.weight.data)

In [16]:
model_medium.transformer.wte = new_embedding
model_medium.lm_head.decoder.weight = model_medium.transformer.wte.weight

In [17]:
class GPT2MediumConfig:
    vocab_size = 50257 + len(speical_tokens)
    n_special = len(speical_tokens)
    n_positions = 1024
    n_ctx = 1024
    n_embd = 1024
    n_layer = 24
    n_head = 16
    resid_pdrop = 0.1
    embd_pdrop = 0.1
    attn_pdrop = 0.1
    layer_norm_epsilon = 1e-5
    initializer_range = 0.02
    gradient_checkpointing = True

In [None]:
model_medium_states = model_medium.state_dict()
model_medium_states = {k: v for k, v in model_medium_states.items() if '.attn.bias' not in k}

new_model_medium = GPT2SimpleLM(GPT2MediumConfig)
new_model_medium.load_state_dict(model_states)