In [3]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForCausalLM
from transformers import DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, GPT2Config, GPT2ForQuestionAnswering
from transformers import TrainingArguments, Trainer
from torch.utils.data import Dataset
import copy
import pandas as pd
import re
import transformers
import torch

transformers.logging.set_verbosity_error()

# Medium has 24 layers/GPT2Blocks
med_tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
med_model = AutoModelForCausalLM.from_pretrained("gpt2-medium", pad_token_id = med_tokenizer.eos_token_id)

# Large has 36 layers/GPT2Blocks
large_tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
large_model = AutoModelForCausalLM.from_pretrained("gpt2-large", pad_token_id = large_tokenizer.eos_token_id)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [11]:
def make_pruned(model: GPT2Model, layers: list):
    if model.config.n_layer < len(layers):
        print(f"List of layers too long")
        return
    if any([l for l in layers if l >= model.config.n_layer or l < 0]):
        print(f"All layers specified must be indexes _less_ than number of layers available")
        return
    
    layers.sort()
    print(f"Pruning {len(layers)} layer(s)...")
    pruned_config = copy.deepcopy(model.config)
    pruned_config.n_layer -= len(layers)
    pruned_model = GPT2LMHeadModel(pruned_config)

    pruned_states = []
    for layer in layers:
        pruned_states += list(filter(
            lambda s: re.search(f'transformer.h\.{layer}\.',s) is not None,
            model.state_dict().keys()))
    print(f"Dropping these states: {pruned_states[:3]}+...")

    base = dict(model.named_parameters())
    pruned = dict(pruned_model.named_parameters())

    prev_base_idx = -1
    pruned_idx = 0
    prev_skipped = False
    copied_states = []
    
    for k, v in model.named_parameters():
        base_idx = re.search(r".h.([0-9]+).", k)
        if base_idx:
            base_idx = int(base_idx.group(1))
            if base_idx in layers:
                # the next base layer to copy should go into the current pruned layer
                if prev_base_idx != base_idx and not prev_skipped and pruned_idx > 0:
                    pruned_idx += 1
                prev_skipped = True
                continue                
            if prev_base_idx != base_idx and not prev_skipped and base_idx > 0:
                pruned_idx += 1
            prev_skipped = False
            copied_states.append(k)
            k = re.sub(f".h.{base_idx}.", f".h.{pruned_idx}.", k)
            pruned[k].data = copy.deepcopy(v.data)
            prev_base_idx = base_idx
        else:
            copied_states.append(k)
            pruned[k].data = copy.deepcopy(v.data)
            
    print(f"Copied these states into the pruned model: {copied_states[:3]}+...")
    # print(f"Pruned model architecture: {pruned_model}")
    return pruned_model

In [13]:
pruned_med_s23 = make_pruned(med_model, [23])
pruned_med_s22  = make_pruned(med_model, [22])
pruned_med_s21e22  = make_pruned(med_model, [21, 22])
pruned_med_s19e22  = make_pruned(med_model, [19, 20, 21, 22])
pruned_med_s15e22  = make_pruned(med_model, [15, 16, 17, 18, 19, 20, 21, 22])

pruned_large_s35 = make_pruned(large_model, [35])
pruned_large_s34  = make_pruned(large_model, [34])
pruned_large_s33e34  = make_pruned(large_model, [33, 34])
pruned_large_s31e34  = make_pruned(large_model, [31, 32, 33, 34])
pruned_large_s27e34  = make_pruned(large_model, [27, 28, 29, 30, 31, 32, 33, 34])
pruned_large_s19e34  = make_pruned(large_model, [19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34])

Pruning 1 layer(s)...
Dropping these states: ['transformer.h.23.ln_1.weight', 'transformer.h.23.ln_1.bias', 'transformer.h.23.attn.c_attn.weight']+...
Copied these states into the pruned model: ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight']+...
Pruning 1 layer(s)...
Dropping these states: ['transformer.h.22.ln_1.weight', 'transformer.h.22.ln_1.bias', 'transformer.h.22.attn.c_attn.weight']+...
Copied these states into the pruned model: ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight']+...
Pruning 2 layer(s)...
Dropping these states: ['transformer.h.21.ln_1.weight', 'transformer.h.21.ln_1.bias', 'transformer.h.21.attn.c_attn.weight']+...
Copied these states into the pruned model: ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight']+...
Pruning 4 layer(s)...
Dropping these states: ['transformer.h.19.ln_1.weight', 'transformer.h.19.ln_1.bias', 'transformer.h.19.attn.c_attn.weight']+...
Cop

In [15]:
models = [pruned_med_s23,
          pruned_med_s22,
          pruned_med_s21e22,
          pruned_med_s19e22,
          pruned_med_s15e22,
          pruned_large_s35,
          pruned_large_s34,
          pruned_large_s33e34,
          pruned_large_s31e34,
          pruned_large_s27e34,
          pruned_large_s19e34]
names=["gpt2_med_s23",
       "gpt2_med_s22",
       "gpt2_med_s21e22", 
       "gpt2_med_s19e22", 
       "gpt2_med_s15e22", 
       "gpt2_large_s35", 
       "gpt2_large_s34", 
       "gpt2_large_s33e34", 
       "gpt2_large_s31e34", 
       "gpt2_large_s27e34",
       "gpt2_large_s19e34"]

In [None]:
from huggingface_hub import login
login()

In [20]:
for model, name in zip(models, names):
    model.push_to_hub(name)

README.md:   0%|          | 0.00/5.18k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.37G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.32G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.02G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.02G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.94G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.78G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.84G [00:00<?, ?B/s]