In [None]:
%%capture
#export
from fastcore.all import *
from fastai2.basics import *

from transformers import AutoModelForSequenceClassification, AutoModelWithLMHead

In [None]:
# default_exp model_splits
# all_skip

# Model Splits
> For learn.freeze_to(). Print the model, look at its architecture, then write down the split

## bert_SeqClassification_split

In [None]:
#export
''' Print the model, look at its architecture, then write down the split '''
def bert_SeqClassification_split(m:nn.Module): 
    # 12 layers, 110M params
    return L(m.bert.embeddings, *m.bert.encoder.layer, m.bert.pooler, m.classifier).map(params)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

## roberta_SeqClassification_split

In [None]:
#export
def roberta_SeqClassification_split(m:nn.Module): 
    return L(m.roberta.embeddings, *m.roberta.encoder.layer, m.roberta.pooler, m.classifier).map(params)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained('roberta-base')
model

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm

## gpt2_lmhead_split

In [None]:
#export
def gpt2_lmhead_split(m:nn.Module): 
    # 12-layer, 768-hidden, 12-heads, 117M parameters.
    return L(m.transformer.wte, m.transformer.wpe, *m.transformer.h, m.lm_head).map(params)

In [None]:
model = AutoModelWithLMHead.from_pretrained('gpt2')
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): Laye

## Others (Not Tested)

In [None]:
#export
def distilbert_SeqClassification_split(m:nn.Module): 
    # 6 layers, 66M params
    return L(m.distilbert.embeddings, *m.distilbert.transformer.layer, m.pre_classifier, m.classifier).map(params)
def albert_SeqClassification_split(m: nn.Module):
    return L(m.albert.embeddings, *m.albert.encoder.albert_layer_groups, m.albert.pooler, m.classifier).map(params)

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_general.ipynb.
Converted 01a_transforms.ipynb.
Converted 01b_model_splits.ipynb.
Converted 01c_callbacks.ipynb.
Converted 99a_example_roberta_classification.ipynb.
Converted 99b_example_gpt2_lm.ipynb.
Converted index.ipynb.
