In [20]:
import torch
import pandas as pd
from copy import deepcopy
pd.set_option('display.max_colwidth', None)

### Transform Function OpenAI States to OPS States

In [30]:
def load_openai_weight(ops_state_dict, openai_state_dict, args):
    state = deepcopy(ops_state_dict)
    openai_weight_patterns = ['h.{n}.attn.c_attn.weight',
                              'h.{n}.attn.c_proj.weight',
                              'h.{n}.mlp.c_fc.weight',
                              'h.{n}.mlp.c_proj.weight']
    
    # loading weight for attention and mlp layers for each block
    # mismatch was due to use of Conv1D in openai implementation
    for i in range(args.N):
        for key in openai_weight_patterns:
            ops_weight_key = 'transformer.decoder.decoder_blocks'+key.format(n=i)[1:]
            state[ops_weight_key] = openai_state_dict[key.format(n=i)].transpose(-1,-2)
            
    state['transformer.embedding.embedding.weight'] = openai_state_dict['wte.weight']
    state['transformer.pos_embedding.pos_embedding.weight'] = openai_state_dict['wpe.weight']
    return state

### Loading OPS GPT-2 Model State Dict

In [22]:
from gpt2.model import get_gpt2
from gpt2.utils import load_config

args = load_config('config.yml')
model = get_gpt2(args)
ops_state_dic = model.state_dict()

In [23]:
ops_state_info = []
for key in ops_state_dic.keys():
    ops_state_info.append([key, list(ops_state_dic[key].size())])

ops_df = pd.DataFrame(ops_state_info, columns=['State', 'Shape'])
ops_df.head(14), ops_df.tail(3)

(                                                      State         Shape
 0                    transformer.embedding.embedding.weight  [50257, 768]
 1            transformer.pos_embedding.pos_embedding.weight   [1024, 768]
 2   transformer.decoder.decoder_blocks.0.attn.c_attn.weight   [2304, 768]
 3     transformer.decoder.decoder_blocks.0.attn.c_attn.bias        [2304]
 4   transformer.decoder.decoder_blocks.0.attn.c_proj.weight    [768, 768]
 5     transformer.decoder.decoder_blocks.0.attn.c_proj.bias         [768]
 6      transformer.decoder.decoder_blocks.0.mlp.c_fc.weight   [3072, 768]
 7        transformer.decoder.decoder_blocks.0.mlp.c_fc.bias        [3072]
 8    transformer.decoder.decoder_blocks.0.mlp.c_proj.weight   [768, 3072]
 9      transformer.decoder.decoder_blocks.0.mlp.c_proj.bias         [768]
 10         transformer.decoder.decoder_blocks.0.ln_1.weight         [768]
 11           transformer.decoder.decoder_blocks.0.ln_1.bias         [768]
 12         transformer.d

In [24]:
ops_df.shape

(149, 2)

### Loading OpenAI GPT-2 Model State Dict

In [8]:
path = './assets/gpt2-pytorch_model.bin'
openai_state_dic = torch.load(f=path, 
                               map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

In [9]:
openai_state_info = []
for key in openai_state_dic.keys():
    openai_state_info.append([key, list(openai_state_dic[key].size())])

openai_df = pd.DataFrame(openai_state_info, columns=['State', 'Shape'])
openai_df.head(15), openai_df.tail(2)

(                     State               Shape
 0               wte.weight        [50257, 768]
 1               wpe.weight         [1024, 768]
 2          h.0.ln_1.weight               [768]
 3            h.0.ln_1.bias               [768]
 4            h.0.attn.bias  [1, 1, 1024, 1024]
 5   h.0.attn.c_attn.weight         [768, 2304]
 6     h.0.attn.c_attn.bias              [2304]
 7   h.0.attn.c_proj.weight          [768, 768]
 8     h.0.attn.c_proj.bias               [768]
 9          h.0.ln_2.weight               [768]
 10           h.0.ln_2.bias               [768]
 11     h.0.mlp.c_fc.weight         [768, 3072]
 12       h.0.mlp.c_fc.bias              [3072]
 13   h.0.mlp.c_proj.weight         [3072, 768]
 14     h.0.mlp.c_proj.bias               [768],
            State  Shape
 158  ln_f.weight  [768]
 159    ln_f.bias  [768])

### Transforming OpenAI Pretrained States into OPS States

In [31]:
ops_pretrained_state_dic = load_openai_weight(ops_state_dic, openai_state_dic, args)

In [32]:
ops_pretrained_state_info = []
for key in ops_pretrained_state_dic.keys():
    ops_pretrained_state_info.append([key, list(ops_pretrained_state_dic[key].size())])

ops_pretrained_df = pd.DataFrame(ops_pretrained_state_info, columns=['State', 'Shape'])
ops_pretrained_df.head(14), ops_df.tail(3)

(                                                      State         Shape
 0                    transformer.embedding.embedding.weight  [50257, 768]
 1            transformer.pos_embedding.pos_embedding.weight   [1024, 768]
 2   transformer.decoder.decoder_blocks.0.attn.c_attn.weight   [2304, 768]
 3     transformer.decoder.decoder_blocks.0.attn.c_attn.bias        [2304]
 4   transformer.decoder.decoder_blocks.0.attn.c_proj.weight    [768, 768]
 5     transformer.decoder.decoder_blocks.0.attn.c_proj.bias         [768]
 6      transformer.decoder.decoder_blocks.0.mlp.c_fc.weight   [3072, 768]
 7        transformer.decoder.decoder_blocks.0.mlp.c_fc.bias        [3072]
 8    transformer.decoder.decoder_blocks.0.mlp.c_proj.weight   [768, 3072]
 9      transformer.decoder.decoder_blocks.0.mlp.c_proj.bias         [768]
 10         transformer.decoder.decoder_blocks.0.ln_1.weight         [768]
 11           transformer.decoder.decoder_blocks.0.ln_1.bias         [768]
 12         transformer.d

In [33]:
ops_pretrained_df.shape

(149, 2)

### Loading OpenAI Pretrained States into OPS States

In [35]:
model.load_state_dict(ops_pretrained_state_dic)

In [None]:
torch.save(model.state_dict(), "./assets/ops_gpt2_pretrained_states.pth")