In [2]:
%cd ~/protein-transfer

/home/francesca/protein-transfer


In [3]:
%load_ext blackcellmagic

In [13]:
"""
A script for receptive field calc, following
https://distill.pub/2019/computing-receptive-fields/
"""

from __future__ import annotations

from collections import defaultdict

import os
import pandas as pd

from sequence_models.pretrained import load_model_and_alphabet

import matplotlib.pyplot as plt

from scr.params.emb import CARP_INFO
from scr.utils import checkNgen_folder


class ReceptiveField:

    """
    Calculate receptive field
    """

    def __init__(self, encoder_name: str):

        self._encoder_name = encoder_name

        self._model, _ = load_model_and_alphabet(self._encoder_name)

        # total number of layers
        self._layer_numb = CARP_INFO[self._encoder_name][1]

        self._conv_stat_dict = defaultdict(dict)

        for layer_name, _ in self._model.model.embedder.state_dict().items():
            # take out bias and weight in the name
            conv_layer_name = layer_name.replace(".weight", "").replace(".bias", "")

            if "conv" in conv_layer_name:

                conv_layer = self._model.model.embedder

                for sub_obj in conv_layer_name.split("."):
                    conv_layer = getattr(conv_layer, sub_obj)

                kernel_size = getattr(conv_layer, "kernel_size")
                stride = getattr(conv_layer, "stride")
                dilation = getattr(conv_layer, "dilation")

                for conv_param in [kernel_size, stride, dilation]:
                    assert len(conv_param) == 1, f"{conv_param=} not 1D"

                common_dict = {
                    "kernel_size": kernel_size[0],
                    "stride": stride[0],
                    "dilation": dilation[0],
                    "adjusted_kernel_size": self._adjust_kl(
                        kl=kernel_size[0], dl=dilation[0]
                    ),
                }

                if "up_embedder" in conv_layer_name:
                    self._conv_stat_dict["up_embedder"][conv_layer_name] = common_dict
                elif "layers" in conv_layer_name:
                    sequence_numb = ""
                    if "sequence" in conv_layer_name:
                        sequence_numb = conv_layer_name.split(".")[2]

                    self._conv_stat_dict["layers"][conv_layer_name] = {
                        **common_dict,
                        "layer_numb": conv_layer_name.split(".")[1],
                        "sequence_numb": sequence_numb,
                    }

    def _adjust_kl(self, kl: int, dl: int) -> int:

        """
        adjust the kernel_size for layer l with dilation of layer l

        dl * (kl - 1) + 1

        Args:
        - kl: int, kernel size of layer l
        - dl: int, dilation factor of layer l
        """

        if kl > 0:
            return dl * (kl - 1) + 1
        else:
            return kl

    def _get_rl(self, rl_prev: int, sl: int, kl: int, dl: int) -> int:

        """
        Calculate rl following equation (1) in
        https://distill.pub/2019/computing-receptive-fields/

        Given rl_prev = sl * rl + kl - sl
        rl = (rl_prev + sl - kl)/sl

        Args:
        - rl_prev: int, receptive field of l-1
        - sl: int, stride of layer l
        - kl: int, kernel size of layer l
        """

        rl = (rl_prev + sl - self._adjust_kl(kl=kl, dl=dl)) / sl

        assert rl.is_integer(), f"rl = {rl} should be int"

        return int(rl)

    def _get_r0(self, L: int) -> int:

        """
        Calculate rl for the first layer following equation (2) in
        https://distill.pub/2019/computing-receptive-fields/

        Note that all CARP stride = 1
        Each layer, a NyteNetBlock, is composed of:
            - conv, MaskedConv1d, kernel_size = 5
            - sequence1, with a PositionFeedForward conv layer, kernel_size = 1
            - sequence2, with a PositionFeedForward conv layer, kernel_size = 1

        Thus, the product of all sl will be 1
        kl - 1 will all be 4

        Here, we chop off the arch and ask for r0

        Args:
        - L: int, layer number for the chop off arch
        """

        r0 = 1

        for l in range(L):
            r0 += (
                self.conv_layers_df.loc[f"layers.{str(l)}.conv", "adjusted_kernel_size"]
                * self.conv_unique_stride
            )

        return r0

    @property
    def conv_stat_dict(self) -> dict:
        """Get the dict with all layer kernel and stride info"""
        return self._conv_stat_dict

    @property
    def conv_layers_df(self) -> pd.DataFrame:
        """Get the dict with all layer kernel and stride in df"""
        return pd.DataFrame.from_dict(self.conv_stat_dict["layers"]).T

    @property
    def conv_unique_stride(self) -> int:

        """Get the unique stride in df, should be all 1s for CARP"""

        conv_unique_strides = self.conv_layers_df.stride.unique()

        if self._encoder_name in CARP_INFO.keys():
            assert len(conv_unique_strides) == 1, "CARP stride should only be 1"
            assert conv_unique_strides[0] == 1, "CARP stride should only be 1"

        return conv_unique_strides[0]

    @property
    def conv_unique_kernelsize(self) -> int:

        """Get the unique kernel in df, should be 1 and 5 for CARP"""

        conv_unique_kernelsize = self.conv_layers_df.kernel_size.unique()

        if self._encoder_name in CARP_INFO.keys():
            assert (
                len(conv_unique_kernelsize) == 2
            ), "CARP stride should only have 2 kernel sizes"

        return conv_unique_kernelsize

    @property
    def conv_unique_nonone_kernelsize(self) -> int:

        """Get the unique kernel in df that is not 1"""
        conv_unique_nonone_kernelsize = self.conv_unique_kernelsize[
            self.conv_unique_kernelsize != 1
        ]

        if self._encoder_name in CARP_INFO.keys():
            assert (
                len(conv_unique_nonone_kernelsize) == 1
            ), "CARP should only have one not 1 kernel sizes"

        return conv_unique_nonone_kernelsize[0]

    @property
    def rf_dict(self) -> dict:
        """
        A dict with rf for each ByteNetBlock layer

        Note that all CARP stride = 1
        Each layer, a NyteNetBlock, is composed of:
            - conv, MaskedConv1d, kernel_size = 5
            - sequence1, with a PositionFeedForward conv layer, kernel_size = 1
            - sequence2, with a PositionFeedForward conv layer, kernel_size = 1
            - the sequence1 and sequence2 does not impact the rf of each layer
        """

        return {layer: self._get_r0(L=layer) for layer in range(self._layer_numb + 1)}

    @property
    def rf_df(self) -> pd.DataFrame:
        """Convert layer rf dict to dataframe with column names"""

        df = pd.Series(self.rf_dict).to_frame(name="rf_size")
        df.index.name = "layers"

        return df.reset_index()


def run_carp_rf(rf_folder: str = "results/rf"):

    """Save carp rf calc output as csv and plot"""

    rf_df_folder = checkNgen_folder(os.path.join(rf_folder, "dfs"))
    rf_plot_folder = checkNgen_folder(os.path.join(rf_folder, "plots"))
    encoder_names = list(CARP_INFO.keys())

    # init fig
    fig, axs = plt.subplots(
        1,
        len(encoder_names),
        sharey=True,
        figsize=(10, 2.5),
        squeeze=False,  # not get rid off the extra dim if 1D
    )

    for i, carp in enumerate(encoder_names):

        print(f"Calculating and plotting rf for {carp}...")

        rf_df = ReceptiveField(carp).rf_df

        # save df
        rf_df.to_csv(os.path.join(rf_df_folder, carp + ".csv"), index=False)

        # plot individual
        plt.figure()
        plt.plot("layers", "rf_size", data=rf_df)
        plt.xlabel("layers")
        plt.ylabel("rf_size")
        plt.title(carp)

        plt.savefig(
            os.path.join(rf_plot_folder, carp + ".png"),
            bbox_inches="tight",
        )

        plt.close()

        # add to collage
        axs[0, i].plot("layers", "rf_size", data=rf_df)

    # add xlabels
    for ax in axs.flatten():
        ax.set_xlabel("layers", fontsize=12)
        ax.tick_params(axis="x", labelsize=12)

    # add column names
    for ax, col in zip(axs[0], encoder_names):
        ax.set_title(col, fontsize=12)

    axs[0, 0].set_ylabel("rf_size", fontsize=12)

    # add whole plot level title
    fig.suptitle(
        "carp receptive field size",
        y=0.925,
        fontsize=12,
        fontweight="bold",
    )
    fig.align_labels()
    fig.tight_layout()

    plt.savefig(
        os.path.join(rf_plot_folder, "carp_all" + ".png"),
        bbox_inches="tight",
    )

    plt.close()

In [56]:
run_carp_rf()

Calculating and plotting rf for carp_600k...
Calculating and plotting rf for carp_38M...
Calculating and plotting rf for carp_76M...
Calculating and plotting rf for carp_640M...


In [5]:
ReceptiveField("carp_600k").rf_df        

Unnamed: 0_level_0,rf_size
layers,Unnamed: 1_level_1
0,65
1,61
2,57
3,53
4,49
5,45
6,41
7,37
8,33
9,29


In [104]:
rf = ReceptiveField("carp_600k")

In [86]:
rf.conv_unique_stride, rf.conv_unique_nonone_kernelsize

(1, 5)

In [112]:
pd.Series(rf.rf_dict).to_frame(name="rf")

Unnamed: 0,rf
0,65
1,61
2,57
3,53
4,49
5,45
6,41
7,37
8,33
9,29


In [100]:
ReceptiveField("carp_38M").rf_dict

{0: 65,
 1: 61,
 2: 57,
 3: 53,
 4: 49,
 5: 45,
 6: 41,
 7: 37,
 8: 33,
 9: 29,
 10: 25,
 11: 21,
 12: 17,
 13: 13,
 14: 9,
 15: 5,
 16: 1}

In [101]:
ReceptiveField("carp_76M").rf_dict

{0: 129,
 1: 125,
 2: 121,
 3: 117,
 4: 113,
 5: 109,
 6: 105,
 7: 101,
 8: 97,
 9: 93,
 10: 89,
 11: 85,
 12: 81,
 13: 77,
 14: 73,
 15: 69,
 16: 65,
 17: 61,
 18: 57,
 19: 53,
 20: 49,
 21: 45,
 22: 41,
 23: 37,
 24: 33,
 25: 29,
 26: 25,
 27: 21,
 28: 17,
 29: 13,
 30: 9,
 31: 5,
 32: 1}

In [102]:
ReceptiveField("carp_640M").rf_dict

{0: 225,
 1: 221,
 2: 217,
 3: 213,
 4: 209,
 5: 205,
 6: 201,
 7: 197,
 8: 193,
 9: 189,
 10: 185,
 11: 181,
 12: 177,
 13: 173,
 14: 169,
 15: 165,
 16: 161,
 17: 157,
 18: 153,
 19: 149,
 20: 145,
 21: 141,
 22: 137,
 23: 133,
 24: 129,
 25: 125,
 26: 121,
 27: 117,
 28: 113,
 29: 109,
 30: 105,
 31: 101,
 32: 97,
 33: 93,
 34: 89,
 35: 85,
 36: 81,
 37: 77,
 38: 73,
 39: 69,
 40: 65,
 41: 61,
 42: 57,
 43: 53,
 44: 49,
 45: 45,
 46: 41,
 47: 37,
 48: 33,
 49: 29,
 50: 25,
 51: 21,
 52: 17,
 53: 13,
 54: 9,
 55: 5,
 56: 1}

In [50]:
import pandas as pd

In [51]:
pd.DataFrame.from_dict(rf.conv_stat["layers"])

Unnamed: 0,layers.0.conv,layers.0.sequence1.2.conv,layers.0.sequence2.2.conv,layers.1.conv,layers.1.sequence1.2.conv,layers.1.sequence2.2.conv,layers.2.conv,layers.2.sequence1.2.conv,layers.2.sequence2.2.conv,layers.3.conv,...,layers.12.sequence2.2.conv,layers.13.conv,layers.13.sequence1.2.conv,layers.13.sequence2.2.conv,layers.14.conv,layers.14.sequence1.2.conv,layers.14.sequence2.2.conv,layers.15.conv,layers.15.sequence1.2.conv,layers.15.sequence2.2.conv
kernel_size,5.0,1,1,5.0,1,1,5.0,1,1,5.0,...,1,5.0,1,1,5.0,1,1,5.0,1,1
stride,1.0,1,1,1.0,1,1,1.0,1,1,1.0,...,1,1.0,1,1,1.0,1,1,1.0,1,1
layer_numb,0.0,0,0,1.0,1,1,2.0,2,2,3.0,...,12,13.0,13,13,14.0,14,14,15.0,15,15
sequence_numb,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2,,...,sequence2,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2


In [56]:
pd.DataFrame.from_dict(rf.conv_stat["layers"]).T.stride.unique()

array([1], dtype=object)

In [64]:
pd.DataFrame.from_dict(rf.conv_stat["layers"]).T.kernel_size.unique() > 1

array([ True, False])

In [16]:
carp_600k_df = ReceptiveField("carp_600k").conv_layers_df
carp_600k_df

Unnamed: 0,kernel_size,stride,dilation,adjusted_kernel_size,layer_numb,sequence_numb
layers.0.conv,5,1,1,5,0,
layers.0.sequence1.2.conv,1,1,1,1,0,sequence1
layers.0.sequence2.2.conv,1,1,1,1,0,sequence2
layers.1.conv,5,1,2,9,1,
layers.1.sequence1.2.conv,1,1,1,1,1,sequence1
layers.1.sequence2.2.conv,1,1,1,1,1,sequence2
layers.2.conv,5,1,4,17,2,
layers.2.sequence1.2.conv,1,1,1,1,2,sequence1
layers.2.sequence2.2.conv,1,1,1,1,2,sequence2
layers.3.conv,5,1,8,33,3,


In [15]:
carp_600k_df.adjusted_kernel_size.unique()

array([5, 1, 9, 17, 33, 65, 129, 257, 513], dtype=object)

In [11]:
layer_numb = 1
carp_600k_df.loc[f"layers.{str(layer_numb)}.conv", "kernel_size"]

5

In [5]:
pd.DataFrame.from_dict(ReceptiveField("carp_600k").conv_stat_dict["layers"])

Unnamed: 0,layers.0.conv,layers.0.sequence1.2.conv,layers.0.sequence2.2.conv,layers.1.conv,layers.1.sequence1.2.conv,layers.1.sequence2.2.conv,layers.2.conv,layers.2.sequence1.2.conv,layers.2.sequence2.2.conv,layers.3.conv,...,layers.12.sequence2.2.conv,layers.13.conv,layers.13.sequence1.2.conv,layers.13.sequence2.2.conv,layers.14.conv,layers.14.sequence1.2.conv,layers.14.sequence2.2.conv,layers.15.conv,layers.15.sequence1.2.conv,layers.15.sequence2.2.conv
kernel_size,5.0,1,1,5.0,1,1,5.0,1,1,5.0,...,1,5.0,1,1,5.0,1,1,5.0,1,1
stride,1.0,1,1,1.0,1,1,1.0,1,1,1.0,...,1,1.0,1,1,1.0,1,1,1.0,1,1
dilation,1.0,1,1,2.0,1,1,4.0,1,1,8.0,...,1,32.0,1,1,64.0,1,1,128.0,1,1
layer_numb,0.0,0,0,1.0,1,1,2.0,2,2,3.0,...,12,13.0,13,13,14.0,14,14,15.0,15,15
sequence_numb,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2,,...,sequence2,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2


In [14]:
pd.DataFrame.from_dict(ReceptiveField("carp_38M").conv_stat_dict["layers"])

Unnamed: 0,layers.0.conv,layers.0.sequence1.2.conv,layers.0.sequence2.2.conv,layers.1.conv,layers.1.sequence1.2.conv,layers.1.sequence2.2.conv,layers.2.conv,layers.2.sequence1.2.conv,layers.2.sequence2.2.conv,layers.3.conv,...,layers.12.sequence2.2.conv,layers.13.conv,layers.13.sequence1.2.conv,layers.13.sequence2.2.conv,layers.14.conv,layers.14.sequence1.2.conv,layers.14.sequence2.2.conv,layers.15.conv,layers.15.sequence1.2.conv,layers.15.sequence2.2.conv
kernel_size,5.0,1,1,5.0,1,1,5.0,1,1,5.0,...,1,5.0,1,1,5.0,1,1,5.0,1,1
stride,1.0,1,1,1.0,1,1,1.0,1,1,1.0,...,1,1.0,1,1,1.0,1,1,1.0,1,1
dilation,1.0,1,1,2.0,1,1,4.0,1,1,8.0,...,1,32.0,1,1,64.0,1,1,128.0,1,1
layer_numb,0.0,0,0,1.0,1,1,2.0,2,2,3.0,...,12,13.0,13,13,14.0,14,14,15.0,15,15
sequence_numb,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2,,...,sequence2,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2


In [58]:
pd.DataFrame.from_dict(ReceptiveField("carp_38M").conv_stat["layers"]).T.stride.unique()

array([1], dtype=object)

In [59]:
pd.DataFrame.from_dict(ReceptiveField("carp_38M").conv_stat["layers"]).T.kernel_size.unique()

array([5, 1], dtype=object)

In [7]:
pd.DataFrame.from_dict(ReceptiveField("carp_76M").conv_stat_dict["layers"])

Unnamed: 0,layers.0.conv,layers.0.sequence1.2.conv,layers.0.sequence2.2.conv,layers.1.conv,layers.1.sequence1.2.conv,layers.1.sequence2.2.conv,layers.2.conv,layers.2.sequence1.2.conv,layers.2.sequence2.2.conv,layers.3.conv,...,layers.28.sequence2.2.conv,layers.29.conv,layers.29.sequence1.2.conv,layers.29.sequence2.2.conv,layers.30.conv,layers.30.sequence1.2.conv,layers.30.sequence2.2.conv,layers.31.conv,layers.31.sequence1.2.conv,layers.31.sequence2.2.conv
kernel_size,5.0,1,1,5.0,1,1,5.0,1,1,5.0,...,1,5.0,1,1,5.0,1,1,5.0,1,1
stride,1.0,1,1,1.0,1,1,1.0,1,1,1.0,...,1,1.0,1,1,1.0,1,1,1.0,1,1
dilation,1.0,1,1,2.0,1,1,4.0,1,1,8.0,...,1,32.0,1,1,64.0,1,1,128.0,1,1
layer_numb,0.0,0,0,1.0,1,1,2.0,2,2,3.0,...,28,29.0,29,29,30.0,30,30,31.0,31,31
sequence_numb,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2,,...,sequence2,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2


In [60]:
pd.DataFrame.from_dict(ReceptiveField("carp_76M").conv_stat["layers"]).T.stride.unique()

array([1], dtype=object)

In [61]:
pd.DataFrame.from_dict(ReceptiveField("carp_76M").conv_stat["layers"]).T.kernel_size.unique()

array([5, 1], dtype=object)

In [8]:
pd.DataFrame.from_dict(ReceptiveField("carp_640M").conv_stat_dict["layers"])

Unnamed: 0,layers.0.conv,layers.0.sequence1.2.conv,layers.0.sequence2.2.conv,layers.1.conv,layers.1.sequence1.2.conv,layers.1.sequence2.2.conv,layers.2.conv,layers.2.sequence1.2.conv,layers.2.sequence2.2.conv,layers.3.conv,...,layers.52.sequence2.2.conv,layers.53.conv,layers.53.sequence1.2.conv,layers.53.sequence2.2.conv,layers.54.conv,layers.54.sequence1.2.conv,layers.54.sequence2.2.conv,layers.55.conv,layers.55.sequence1.2.conv,layers.55.sequence2.2.conv
kernel_size,5.0,1,1,5.0,1,1,5.0,1,1,5.0,...,1,5.0,1,1,5.0,1,1,5.0,1,1
stride,1.0,1,1,1.0,1,1,1.0,1,1,1.0,...,1,1.0,1,1,1.0,1,1,1.0,1,1
dilation,1.0,1,1,2.0,1,1,4.0,1,1,8.0,...,1,32.0,1,1,64.0,1,1,128.0,1,1
layer_numb,0.0,0,0,1.0,1,1,2.0,2,2,3.0,...,52,53.0,53,53,54.0,54,54,55.0,55,55
sequence_numb,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2,,...,sequence2,,sequence1,sequence2,,sequence1,sequence2,,sequence1,sequence2


In [62]:
pd.DataFrame.from_dict(ReceptiveField("carp_640M").conv_stat["layers"]).T.stride.unique()

array([1], dtype=object)

In [63]:
pd.DataFrame.from_dict(ReceptiveField("carp_640M").conv_stat["layers"]).T.kernel_size.unique()

array([5, 1], dtype=object)

In [6]:
from sequence_models.pretrained import load_model_and_alphabet

model, collater = load_model_and_alphabet('carp_600k')

In [7]:
model.model.embedder.embedder

Embedding(30, 8, padding_idx=28)

In [8]:
model.model.embedder.up_embedder

PositionFeedForward(
  (conv): Conv1d(8, 128, kernel_size=(1,), stride=(1,))
)

In [9]:
getattr(getattr(model.model.embedder, "up_embedder"), "conv")

Conv1d(8, 128, kernel_size=(1,), stride=(1,))

In [10]:
model.model.embedder.up_embedder.conv

Conv1d(8, 128, kernel_size=(1,), stride=(1,))

In [11]:
model.model.embedder.up_embedder.conv.dilation

(1,)

In [42]:
model.model.embedder.up_embedder.conv.kernel_size, model.model.embedder.up_embedder.conv.stride

((1,), (1,))

In [47]:
model.model.embedder.up_embedder.conv.kernel_size[0]

1

In [9]:
model

CARP(
  (model): ByteNetLM(
    (embedder): ByteNet(
      (embedder): Embedding(30, 8, padding_idx=28)
      (up_embedder): PositionFeedForward(
        (conv): Conv1d(8, 128, kernel_size=(1,), stride=(1,))
      )
      (layers): ModuleList(
        (0): ByteNetBlock(
          (conv): MaskedConv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,))
          (sequence1): Sequential(
            (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (1): GELU(approximate=none)
            (2): PositionFeedForward(
              (conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
            )
            (3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (4): GELU(approximate=none)
          )
          (sequence2): Sequential(
            (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (1): GELU(approximate=none)
            (2): PositionFeedForward(
              (conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    

In [7]:
model.model.embedder.up_embedder

PositionFeedForward(
  (conv): Conv1d(8, 128, kernel_size=(1,), stride=(1,))
)

In [8]:
model.model.embedder.layers[0]

ByteNetBlock(
  (conv): MaskedConv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,))
  (sequence1): Sequential(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (1): GELU(approximate=none)
    (2): PositionFeedForward(
      (conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
    )
    (3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (4): GELU(approximate=none)
  )
  (sequence2): Sequential(
    (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (1): GELU(approximate=none)
    (2): PositionFeedForward(
      (conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    )
  )
)

In [5]:
# from scr.model.receptive_field import receptive_field, receptive_field_for_unit

In [8]:
seqs = [['TIM'], ['FRANCESCA']]
x = collater(seqs)[0]  # (n, max_len)
rep = model(x)  # (n, max_len, d_model)

In [37]:
x

tensor([[16,  7, 10, 27, 27, 27, 27, 27, 27],
        [ 4, 14,  0, 11,  1,  3, 15,  1,  0]])

In [15]:
x.shape, rep["representations"][16].shape

(torch.Size([2, 9]), torch.Size([2, 9, 128]))

In [33]:
import torch
import torch.nn as nn

In [38]:
embedder = nn.Embedding(num_embeddings=30, embedding_dim=8, padding_idx=28)

In [39]:
post_embed = embedder(x)
post_embed.shape

torch.Size([2, 9, 8])

In [40]:
up_embedder = nn.Conv1d(
        in_channels=8,
        out_channels=128,
        kernel_size=(1,),
        stride=(1,),
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        padding_mode="zeros",
        device=None,
        dtype=None,
    )

In [41]:
post_upemb = up_embedder(post_embed.transpose(1, 2)).transpose(1, 2)
post_upemb.shape

torch.Size([2, 9, 128])

In [43]:
model.model.embedder.layers[0]

ByteNetBlock(
  (conv): MaskedConv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,))
  (sequence1): Sequential(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (1): GELU(approximate=none)
    (2): PositionFeedForward(
      (conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
    )
    (3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (4): GELU(approximate=none)
  )
  (sequence2): Sequential(
    (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (1): GELU(approximate=none)
    (2): PositionFeedForward(
      (conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
    )
  )
)

In [18]:

receptive_field_dict = receptive_field(model.model.embedder.layers[0].to("cuda"), input_size=(1, 8, 1280))
# receptive_field_for_unit(receptive_field_dict, "1", (1,1))

m_key: 1, p_key: 0
receptive_field: OrderedDict([('0', OrderedDict([('j', 1.0), ('r', 1.0), ('start', 0.5), ('conv_stage', True), ('output_shape', [-1, 1, 8, 1280])])), ('1', OrderedDict())])
receptive_field[p_key]: OrderedDict([('j', 1.0), ('r', 1.0), ('start', 0.5), ('conv_stage', True), ('output_shape', [-1, 1, 8, 1280])])
m_key: 2, p_key: 1
receptive_field: OrderedDict([('0', OrderedDict([('j', 1.0), ('r', 1.0), ('start', 0.5), ('conv_stage', True), ('output_shape', [-1, 1, 8, 1280])])), ('1', OrderedDict([('input_shape', [-1, 1, 8, 1280]), ('output_shape', [-1, 1, 8, 1280])])), ('2', OrderedDict())])
receptive_field[p_key]: OrderedDict([('input_shape', [-1, 1, 8, 1280]), ('output_shape', [-1, 1, 8, 1280])])


KeyError: 'j'