# Try `torch.compile` on A100 machine

Use a machine with a `torch.compile`-compatible/-recommended GPU

In [5]:
!nvidia-smi

Thu Sep  7 22:15:26 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  Off  | 00000000:61:00.0 Off |                    0 |
| N/A   27C    P0    32W / 250W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Get some data to pass through model

In [6]:
from pathlib import Path
from lhotse import CutSet

from lhotse.dataset import DynamicBucketingSampler
from lhotse.dataset.collation import collate_custom_field
from torch.utils.data import Dataset, DataLoader

train_data_dir = Path('data/mini-librispeech/train-clean-5')

train_cuts = CutSet.from_shar(fields={
    'cuts': sorted(list(train_data_dir.glob("cuts.*.jsonl.gz"))),
    'fbank': sorted(list(train_data_dir.glob("fbank.*.tar"))),
    'ptlabel': sorted(list(train_data_dir.glob("ptlabel.*.tar")))
})

train_sampler = DynamicBucketingSampler(
    train_cuts,
    # Dynamically sample items that
    # altogther add up to 60 seconds
    max_duration=60
  )

class HuBERTPretrainingDataset(Dataset):
    def __getitem__(self, cuts: CutSet) -> dict:
        cuts = cuts.sort_by_duration()

        # Collate and pad
        feats_padded, feat_lens = collate_custom_field(cuts, 'fbank', pad_value=0)
        # Note we'll use a negative integer for padding the labels since 0 is a valid label
        ptlabels_padded, ptlabels_lengths = collate_custom_field(cuts, 'ptlabel', pad_value=-100)

        return {"feats_padded": feats_padded, "feat_lens": feat_lens, "ptlabels_padded": ptlabels_padded, "ptlabels_lengths": ptlabels_lengths}

train_loader = DataLoader(
    HuBERTPretrainingDataset(),
    sampler=train_sampler,
    batch_size=None,
    num_workers=1
)

batch = next(iter(train_loader))

batch

{'feats_padded': tensor([[[ 8.0312,  6.3081,  5.6085,  ...,  9.9287,  9.7825,  9.2193],
          [ 6.0623,  5.5936,  6.3717,  ...,  8.7658,  9.4892,  9.5833],
          [ 6.5774,  5.4523,  7.3200,  ...,  9.1725,  9.9851,  9.9644],
          ...,
          [ 6.8111,  6.4719,  7.3717,  ..., 11.5423, 11.6784, 11.2297],
          [ 6.3828,  7.6254,  8.7734,  ..., 15.8933, 16.1830, 16.0805],
          [ 7.8642,  7.3241,  8.9053,  ..., 11.0786, 10.8585, 10.8489]],
 
         [[ 8.3438,  7.9204,  7.6191,  ...,  9.3661,  9.0802,  8.9283],
          [ 8.0286,  6.6956,  8.4073,  ...,  9.4909,  9.6967,  8.9417],
          [ 8.9658,  9.1295,  8.5936,  ...,  9.2658,  9.3963,  9.5405],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[ 8.1562,  6.2256,  5.9716,  ...,  9.9072,  9.6092,  9.2015],
          [ 

## Pre-process data for forward pass

In [31]:
from components.verbatim_torchaudio import _get_padding_mask, _compute_mask_indices

padding_mask = _get_padding_mask(batch['feats_padded'], batch['feat_lens'])

B, T, C = batch['feats_padded'].shape

masks_for_modeling = _compute_mask_indices(
    (B, T),
    padding_mask,
    # Use HuBERT defaults
    mask_prob=0.8,
    mask_length=10,
    mask_type='static',
    mask_other=0.0,
    min_masks=2,
    no_overlap=False,
    min_space=1,
)

# Zero-out random frames
batch['feats_padded'][masks_for_modeling] = 0

## Construct and compile model

In [29]:
import torch

from components.model import get_torchaudio_hubert_pretrain_base_encoder

model = get_torchaudio_hubert_pretrain_base_encoder()

opt_model = torch.compile(model)

opt_model.to('cuda')

OptimizedModule(
  (_orig_mod): Encoder(
    (feature_projection): FeatureProjection(
      (layer_norm): LayerNorm((80,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=80, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (pos_conv_embed): ConvolutionalPositionalEmbedding(
        (conv): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
      )
      (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (layers): ModuleList(
        (0-11): 12 x EncoderLayer(
          (attention): SelfAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
  

It compiled! Will it work right out of the box?

In [30]:
transformer_outputs = opt_model(
    batch['feats_padded'].to('cuda'),
    batch['feat_lens'].to('cuda')
)

[2023-09-07 22:29:11,055] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 333, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 225, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 2020, in index_put
    return index_put_(clone(x), indices, values, accumulate)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 225, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 2044, in index_put_
    return index_put_as_masked_fill(self, indices, values, accumulate)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 2028, in index_put_as_masked_fill
    return muta

BackendCompilerFailed: debug_wrapper raised LoweringException: AssertionError: 
  target: aten.index_put.default
  args[0]: TensorBox(StorageBox(
    Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(seed_cuda_0, 0)
      tmp1 = index_expr(i2 + 768 * i1 + 580608 * i0, torch.int32)
      tmp2 = rand(tmp0, tmp1, torch.float32)
      tmp3 = constant(0.1, torch.float32)
      tmp4 = tmp2 > tmp3
      tmp5 = to_dtype(tmp4, torch.float32)
      tmp6 = load(buf6, i2 + 768 * i1 + 580608 * i0)
      tmp7 = tmp5 * tmp6
      tmp8 = constant(1.1111111111111112, torch.float32)
      tmp9 = tmp7 * tmp8
      return tmp9
      ,
      ranges=[3, 756, 768],
      origins={convert_element_type, mul_2, primals_3, mul_3, primals_4, view, gt, philox_rand_like, permute, view_1, philox_seed_like, addmm}
    )
  ))
  args[1]: [TensorBox(StorageBox(
    ComputedBuffer(name='buf7', layout=FixedLayout('cuda', torch.bool, size=[3, 756], stride=[756, 1]), data=Pointwise(
      'cuda',
      torch.bool,
      tmp0 = index_expr(i1, dtype=torch.int64)
      tmp1 = load(primals_6, i0)
      tmp2 = tmp0 >= tmp1
      return tmp2
      ,
      ranges=[3, 756],
      origins={ge}
    ))
  ))]
  args[2]: TensorBox(StorageBox(
    Pointwise(
      'cpu',
      torch.float32,
      tmp0 = constant(0.0, torch.float32)
      return tmp0
      ,
      ranges=[],
      origins={lift_fresh_copy, _tensor_constant0}
    )
  ))

While executing %index_put : [#users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%mul_3, [%ge], %lift_fresh_copy), kwargs = {})
Original traceback:
  File "/home/scriptable_hubert_encoder/components/verbatim_torchaudio.py", line 254, in _preprocess
    x[mask] = 0.0
 |   File "/home/scriptable_hubert_encoder/components/verbatim_torchaudio.py", line 265, in forward
    x, mask = self._preprocess(features, lengths)


Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


Alright, I guess while `torch.compile()` didn't complain, still too good to be true to work right out of the gate. Seems the error is related to `self._preprocess(features, lengths)`

```
Original traceback:
  File "/home/scriptable_hubert_encoder/components/verbatim_torchaudio.py", line 254, in _preprocess
    x[mask] = 0.0
 |   File "/home/scriptable_hubert_encoder/components/verbatim_torchaudio.py", line 265, in forward
    x, mask = self._preprocess(features, lengths)


Set torch._dynamo.config.verbose=True for more information
```

Seems it does not like the attention mask calculation being inside the `_preprocess` function:

```
def _preprocess(
    self,
    features: Tensor,
    lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    x = self.feature_projection(features)

    mask: Optional[Tensor] = None
    if lengths is not None:
        batch_size, max_len, _ = x.shape
        # create mask for padded elements and zero-out them
        mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
        x[mask] = 0.0
        # extend the mask to attention shape and set weight
        mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
        mask = mask.expand(batch_size, 1, max_len, max_len)
    return x, mask
```

## Calculate post-projection attention mask outside of compiled model

In [34]:
opt_feat_proj = torch.compile(model.feature_projection).to('cuda')

opt_feat_proj

OptimizedModule(
  (_orig_mod): FeatureProjection(
    (layer_norm): LayerNorm((80,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=80, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [35]:
x = opt_feat_proj(batch['feats_padded'].to('cuda'))

In [40]:
lengths = batch['feat_lens']

batch_size, max_len, _ = x.shape
# create mask for padded elements and zero-out them
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
x[mask] = 0.0
# extend the mask to attention shape and set weight
mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
mask = mask.expand(batch_size, 1, max_len, max_len)

RuntimeError: Output 0 of CompiledFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

Hm why does `x[mask] = 0.0` need to be set? Isn't the point of the `mask`/`attention_mask` to prevent the transformer from paying attention to these positions?

Anyway comment this out and continue with a not-zeroed-out `x` to see if the rest of the forward pass will work with `torch.compile`

In [54]:
lengths = batch['feat_lens']

batch_size, max_len, _ = x.shape
# create mask for padded elements and zero-out them
mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None]
# x[mask] = 0.0
# extend the mask to attention shape and set weight
mask = -10000.0 * mask[:, None, None, :].to(dtype=batch["feats_padded"].dtype)
mask = mask.expand(batch_size, 1, max_len, max_len)

mask

tensor([[[[    -0.,     -0.,     -0.,  ...,     -0.,     -0.,     -0.],
          [    -0.,     -0.,     -0.,  ...,     -0.,     -0.,     -0.],
          [    -0.,     -0.,     -0.,  ...,     -0.,     -0.,     -0.],
          ...,
          [    -0.,     -0.,     -0.,  ...,     -0.,     -0.,     -0.],
          [    -0.,     -0.,     -0.,  ...,     -0.,     -0.,     -0.],
          [    -0.,     -0.,     -0.,  ...,     -0.,     -0.,     -0.]]],


        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.],
          [    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.],
          [    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.],
          ...,
          [    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.],
          [    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.],
          [    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],


        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.],
          [    -0.,     -0

## Pass (projected) features and attention mask to transformer

In [48]:
opt_transformer = torch.compile(model.transformer).to('cuda')

opt_transformer

OptimizedModule(
  (_orig_mod): Transformer(
    (pos_conv_embed): ConvolutionalPositionalEmbedding(
      (conv): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
    )
    (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x EncoderLayer(
        (attention): SelfAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (feed_forward): FeedForward(
          (intermediate_dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_dropout): Dropo

In [55]:
y = opt_transformer(x, attention_mask=mask.to('cuda'))

In [56]:
y

tensor([[[-7.5501e-01,  9.0716e-01,  4.5373e-01,  ..., -1.1112e-01,
          -3.5967e-01,  2.7984e-01],
         [-1.6776e+00, -8.4338e-01,  5.5379e-01,  ...,  1.3777e+00,
          -6.9165e-01, -4.6057e-01],
         [-1.5607e-01, -3.0129e-01,  4.0043e-01,  ...,  1.5154e+00,
           6.7790e-02, -3.2070e-01],
         ...,
         [ 2.8277e-01, -3.5982e-01,  2.8334e-01,  ...,  3.9271e-01,
          -3.2856e-01, -1.6475e+00],
         [-2.1042e-01, -7.6817e-01,  8.8457e-01,  ..., -2.5041e-01,
           3.4465e-01, -1.1508e+00],
         [-1.2911e-01, -4.2044e-01,  8.6431e-01,  ...,  1.9125e+00,
           7.7329e-02, -9.6082e-01]],

        [[-6.7865e-01, -1.0606e+00,  2.3368e-01,  ...,  9.5067e-01,
           4.8681e-01, -1.8798e-01],
         [ 6.8016e-01, -1.0107e+00, -8.2733e-01,  ...,  9.6873e-02,
           5.3239e-01, -2.2120e+00],
         [-7.2662e-01, -8.6442e-02,  9.7579e-01,  ...,  7.6153e-01,
          -2.8874e-01,  8.8322e-02],
         ...,
         [ 3.2440e-01, -5