In [1]:
import torch
import torch.nn as nn
from bolero.tl.generic.module_lora import LoRAConv, convert_to_lora_model
from enformer_pytorch.modeling_enformer import from_pretrained
from torchinfo import summary

In [2]:
def summary_on_cuda(model, input_size=None, row_settings=["var_names"], depth=4):
    model.to("cuda")
    out = summary(model, input_size, row_settings=["var_names"], depth=depth)
    model.to("cpu")
    torch.cuda.empty_cache()
    return out


def print_lora_conv(model):
    print("in_channels", model.conv.in_channels)
    print("out_channels", model.conv.out_channels)
    print("groups", model.conv.groups)
    print("kernel_size", model.conv.kernel_size)
    print()
    print("LoRA Module", type(model))
    count = 0
    for name, param in model.named_parameters():
        print(name, param.shape, param.requires_grad)
        if param.requires_grad:
            count += param.shape.numel()
    print("LoRA trainable parameters: ", count)


def print_lora_linear(model):
    print("in_features", model.in_features)
    print("out_features", model.out_features)
    print()
    print("LoRA Module", type(model))
    count = 0
    for name, param in model.named_parameters():
        print(name, param.shape, param.requires_grad)
        if param.requires_grad:
            count += param.shape.numel()
    print("LoRA trainable parameters: ", count)

## Load Pretrained Enformer

In [3]:
enformer = from_pretrained("EleutherAI/enformer-official-rough")

for param in enformer.parameters():
    param.requires_grad = False

In [4]:
lora_rank = 4

In [5]:
# Convert to a LoRA-enabled network
lora_enformer = convert_to_lora_model(
    enformer,
    convert_conv=True,
    convert_linear=True,
    rank=lora_rank,
    alpha=1.0,
    inplace=False,
)

In [6]:
count = 0
for param in lora_enformer.parameters():
    if param.requires_grad:
        count += param.numel()
print("Total trainable parameters:", count)

Total trainable parameters: 1848256


## Enformer stem

- Input:
  - shape: `(bs, channel, seq_len) or (16, 4, 196_608)`
  - Seq length: 196_608 = 3 * 2** (9 + 1 + 6)
- Output:
  - shape: `(bs, C/2, seq_len/2) or (16, 768, 98_304)`

In [7]:
print("\nEnformer model")
summary_on_cuda(
    enformer.stem, input_size=(16, 4, 196_608), row_settings=["var_names"], depth=4
)


Enformer model


Layer (type (var_name))                  Output Shape              Param #
Sequential (Sequential)                  [16, 768, 98304]          --
├─Conv1d (0)                             [16, 768, 196608]         (46,848)
├─Residual (1)                           [16, 768, 196608]         --
│    └─Sequential (fn)                   [16, 768, 196608]         --
│    │    └─BatchNorm1d (0)              [16, 768, 196608]         (1,536)
│    │    └─GELU (1)                     [16, 768, 196608]         --
│    │    └─Conv1d (2)                   [16, 768, 196608]         (590,592)
├─AttentionPool (2)                      [16, 768, 98304]          --
│    └─Rearrange (pool_fn)               [16, 768, 98304, 2]       --
│    └─Conv2d (to_attn_logits)           [16, 768, 98304, 2]       (589,824)
Total params: 1,228,800
Trainable params: 0
Non-trainable params: 1,228,800
Total mult-adds (T): 3.86
Input size (MB): 50.33
Forward/backward pass size (MB): 77309.41
Params size (MB): 4.92
Estimated 

In [8]:
print("\nLoRA-enabled Network:")
summary_on_cuda(
    lora_enformer.stem, input_size=(16, 4, 196_608), row_settings=["var_names"], depth=4
)


LoRA-enabled Network:


Layer (type (var_name))                  Output Shape              Param #
Sequential (Sequential)                  [16, 768, 98304]          --
├─LoRAConv (0)                           [16, 768, 196608]         49,680
│    └─Conv1d (conv)                     [16, 768, 196608]         (46,848)
├─Residual (1)                           [16, 768, 196608]         --
│    └─Sequential (fn)                   [16, 768, 196608]         --
│    │    └─BatchNorm1d (0)              [16, 768, 196608]         (1,536)
│    │    └─GELU (1)                     [16, 768, 196608]         --
│    │    └─LoRAConv (2)                 [16, 768, 196608]         6,144
│    │    │    └─Conv1d (conv)           [16, 768, 196608]         (590,592)
├─AttentionPool (2)                      [16, 768, 98304]          --
│    └─Rearrange (pool_fn)               [16, 768, 98304, 2]       --
│    └─LoRAConv (to_attn_logits)         [16, 768, 98304, 2]       6,144
│    │    └─Conv2d (conv)                [16, 768, 98304,

### LoRA modules

#### Stem DNA Conv1D

In [9]:
_module = lora_enformer.stem[0]
print_lora_conv(_module)

in_channels 4
out_channels 768
groups 1
kernel_size (15,)

LoRA Module <class 'bolero.tl.generic.module_lora.LoRAConv'>
lora_A torch.Size([60, 60]) True
lora_B torch.Size([768, 60]) True
conv.weight torch.Size([768, 4, 15]) False
conv.bias torch.Size([768]) False
LoRA trainable parameters:  49680


#### Stem Linear Conv1D

In [10]:
_module = lora_enformer.stem[1].fn[2]
print_lora_conv(_module)

in_channels 768
out_channels 768
groups 1
kernel_size (1,)

LoRA Module <class 'bolero.tl.generic.module_lora.LoRAConv'>
lora_A torch.Size([4, 768]) True
lora_B torch.Size([768, 4]) True
conv.weight torch.Size([768, 768, 1]) False
conv.bias torch.Size([768]) False
LoRA trainable parameters:  6144


#### Stem AttentionPool Conv2D

In [11]:
_module = lora_enformer.stem[2].to_attn_logits
print_lora_conv(_module)

in_channels 768
out_channels 768
groups 1
kernel_size (1, 1)

LoRA Module <class 'bolero.tl.generic.module_lora.LoRAConv'>
lora_A torch.Size([4, 768]) True
lora_B torch.Size([768, 4]) True
conv.weight torch.Size([768, 768, 1, 1]) False
LoRA trainable parameters:  6144


## Enformer Conv Towers

- Input:
  - shape: `(bs, C/2, seq_len/2) or (16, 768, 98_304)`
- Output:
  - Output seq length 1536 = 196_608 / 2 ** 6
  - shape: `(bs, C, seq_len/2**7) or (16, 1536, 1536)`

In [12]:
print("\nEnformer model")
summary_on_cuda(
    enformer.conv_tower,
    input_size=(16, 768, 98_304),
    row_settings=["var_names"],
    depth=4,
)


Enformer model


Layer (type (var_name))                  Output Shape              Param #
Sequential (Sequential)                  [16, 1536, 1536]          --
├─Sequential (0)                         [16, 768, 49152]          --
│    └─Sequential (0)                    [16, 768, 98304]          --
│    │    └─BatchNorm1d (0)              [16, 768, 98304]          (1,536)
│    │    └─GELU (1)                     [16, 768, 98304]          --
│    │    └─Conv1d (2)                   [16, 768, 98304]          (2,949,888)
│    └─Residual (1)                      [16, 768, 98304]          --
│    │    └─Sequential (fn)              [16, 768, 98304]          --
│    │    │    └─BatchNorm1d (0)         [16, 768, 98304]          (1,536)
│    │    │    └─GELU (1)                [16, 768, 98304]          --
│    │    │    └─Conv1d (2)              [16, 768, 98304]          (590,592)
│    └─AttentionPool (2)                 [16, 768, 49152]          --
│    │    └─Rearrange (pool_fn)          [16, 768, 49152, 2

In [13]:
print("\nLoRA-enabled Network:")
summary_on_cuda(
    lora_enformer.conv_tower,
    input_size=(16, 768, 98_304),
    row_settings=["var_names"],
    depth=8,
)


LoRA-enabled Network:


Layer (type (var_name))                       Output Shape              Param #
Sequential (Sequential)                       [16, 1536, 1536]          --
├─Sequential (0)                              [16, 768, 49152]          --
│    └─Sequential (0)                         [16, 768, 98304]          --
│    │    └─BatchNorm1d (0)                   [16, 768, 98304]          (1,536)
│    │    └─GELU (1)                          [16, 768, 98304]          --
│    │    └─LoRAConv (2)                      [16, 768, 98304]          92,160
│    │    │    └─Conv1d (conv)                [16, 768, 98304]          (2,949,888)
│    └─Residual (1)                           [16, 768, 98304]          --
│    │    └─Sequential (fn)                   [16, 768, 98304]          --
│    │    │    └─BatchNorm1d (0)              [16, 768, 98304]          (1,536)
│    │    │    └─GELU (1)                     [16, 768, 98304]          --
│    │    │    └─LoRAConv (2)                 [16, 768, 98304]          

## LoRA Modules

#### Conv Tower Conv1D

In [14]:
_module = lora_enformer.conv_tower[0][0][2]
print_lora_conv(_module)

in_channels 768
out_channels 768
groups 1
kernel_size (5,)

LoRA Module <class 'bolero.tl.generic.module_lora.LoRAConv'>
lora_A torch.Size([20, 3840]) True
lora_B torch.Size([768, 20]) True
conv.weight torch.Size([768, 768, 5]) False
conv.bias torch.Size([768]) False
LoRA trainable parameters:  92160


#### Conv Tower Linear Conv1D

In [15]:
_module = lora_enformer.conv_tower[0][1].fn[2]
print_lora_conv(_module)

in_channels 768
out_channels 768
groups 1
kernel_size (1,)

LoRA Module <class 'bolero.tl.generic.module_lora.LoRAConv'>
lora_A torch.Size([4, 768]) True
lora_B torch.Size([768, 4]) True
conv.weight torch.Size([768, 768, 1]) False
conv.bias torch.Size([768]) False
LoRA trainable parameters:  6144


#### Conv Tower AttentionPool Conv2D

In [16]:
_module = lora_enformer.conv_tower[0][2].to_attn_logits
print_lora_conv(_module)

in_channels 768
out_channels 768
groups 1
kernel_size (1, 1)

LoRA Module <class 'bolero.tl.generic.module_lora.LoRAConv'>
lora_A torch.Size([4, 768]) True
lora_B torch.Size([768, 4]) True
conv.weight torch.Size([768, 768, 1, 1]) False
LoRA trainable parameters:  6144


## Enformer Transformers


- Input:
  - shape: `(bs, C, seq_len/2**7) or (16, 1536, 1536)`
- Output:
  - shape: `(bs, C, seq_len/2**7) or (16, 1536, 1536)`


In [17]:
print("\nEnformer model")
summary_on_cuda(
    enformer.transformer,
    input_size=(16, 1536, 1536),
    row_settings=["var_names"],
    depth=4,
)


Enformer model


Layer (type (var_name))                            Output Shape              Param #
Sequential (Sequential)                            [16, 1536, 1536]          --
├─Sequential (0)                                   [16, 1536, 1536]          --
│    └─Residual (0)                                [16, 1536, 1536]          --
│    │    └─Sequential (fn)                        [16, 1536, 1536]          --
│    │    │    └─LayerNorm (0)                     [16, 1536, 1536]          (3,072)
│    │    │    └─Attention (1)                     [16, 1536, 1536]          (6,392,320)
│    │    │    └─Dropout (2)                       [16, 1536, 1536]          --
│    └─Residual (1)                                [16, 1536, 1536]          --
│    │    └─Sequential (fn)                        [16, 1536, 1536]          --
│    │    │    └─LayerNorm (0)                     [16, 1536, 1536]          (3,072)
│    │    │    └─Linear (1)                        [16, 1536, 3072]          (4,721,664)
│    │ 

In [18]:
print("\nLoRA-enabled Network:")
summary_on_cuda(
    lora_enformer.transformer,
    input_size=(16, 1536, 1536),
    row_settings=["var_names"],
    depth=8,
)


LoRA-enabled Network:


Layer (type (var_name))                            Output Shape              Param #
Sequential (Sequential)                            [16, 1536, 1536]          --
├─Sequential (0)                                   [16, 1536, 1536]          --
│    └─Residual (0)                                [16, 1536, 1536]          --
│    │    └─Sequential (fn)                        [16, 1536, 1536]          --
│    │    │    └─LayerNorm (0)                     [16, 1536, 1536]          (3,072)
│    │    │    └─Attention (1)                     [16, 1536, 1536]          1,024
│    │    │    │    └─LoRALinear (to_q)            [16, 1536, 512]           794,624
│    │    │    │    └─LoRALinear (to_k)            [16, 1536, 512]           794,624
│    │    │    │    └─LoRALinear (to_v)            [16, 1536, 1536]          2,371,584
│    │    │    │    └─Dropout (pos_dropout)        [3071, 192]               --
│    │    │    │    └─LoRALinear (to_rel_k)        [3071, 512]               101,120
│    

### LoRA Modules

#### Transformer Q Encoder

In [19]:
_module = lora_enformer.transformer[0][0].fn[1].to_q
print_lora_linear(_module)

in_features 1536
out_features 512

LoRA Module <class 'bolero.tl.generic.module_lora.LoRALinear'>
weight torch.Size([512, 1536]) False
lora_A torch.Size([4, 1536]) True
lora_B torch.Size([512, 4]) True
LoRA trainable parameters:  8192


#### Transformer K Encoder

In [20]:
_module = lora_enformer.transformer[0][0].fn[1].to_k
print_lora_linear(_module)

in_features 1536
out_features 512

LoRA Module <class 'bolero.tl.generic.module_lora.LoRALinear'>
weight torch.Size([512, 1536]) False
lora_A torch.Size([4, 1536]) True
lora_B torch.Size([512, 4]) True
LoRA trainable parameters:  8192


#### Transformer V Encoder

In [21]:
_module = lora_enformer.transformer[0][0].fn[1].to_v
print_lora_linear(_module)

in_features 1536
out_features 1536

LoRA Module <class 'bolero.tl.generic.module_lora.LoRALinear'>
weight torch.Size([1536, 1536]) False
lora_A torch.Size([4, 1536]) True
lora_B torch.Size([1536, 4]) True
LoRA trainable parameters:  12288


#### Transformer Rel K Encoder

In [22]:
_module = lora_enformer.transformer[0][0].fn[1].to_rel_k
print_lora_linear(_module)

in_features 192
out_features 512

LoRA Module <class 'bolero.tl.generic.module_lora.LoRALinear'>
weight torch.Size([512, 192]) False
lora_A torch.Size([4, 192]) True
lora_B torch.Size([512, 4]) True
LoRA trainable parameters:  2816


#### Transformer Output Encoder

In [23]:
_module = lora_enformer.transformer[0][0].fn[1].to_out
print_lora_linear(_module)

in_features 1536
out_features 1536

LoRA Module <class 'bolero.tl.generic.module_lora.LoRALinear'>
weight torch.Size([1536, 1536]) False
bias torch.Size([1536]) False
lora_A torch.Size([4, 1536]) True
lora_B torch.Size([1536, 4]) True
LoRA trainable parameters:  12288


#### FeedForward

In [24]:
_module = lora_enformer.transformer[0][1].fn[1]
print_lora_linear(_module)

in_features 1536
out_features 3072

LoRA Module <class 'bolero.tl.generic.module_lora.LoRALinear'>
weight torch.Size([3072, 1536]) False
bias torch.Size([3072]) False
lora_A torch.Size([4, 1536]) True
lora_B torch.Size([3072, 4]) True
LoRA trainable parameters:  18432


In [25]:
_module = lora_enformer.transformer[0][1].fn[4]
print_lora_linear(_module)

in_features 3072
out_features 1536

LoRA Module <class 'bolero.tl.generic.module_lora.LoRALinear'>
weight torch.Size([1536, 3072]) False
bias torch.Size([1536]) False
lora_A torch.Size([4, 3072]) True
lora_B torch.Size([1536, 4]) True
LoRA trainable parameters:  18432


## Enformer Final Point Wise

- Input:
  - shape: `(bs, C, seq_len/2**7) or (16, 1536, 1536)`
- Output:
  - shape: `(bs, seq_len/2**7, 2*C) or (16, 1536, 3072)`

In [26]:
print("\nEnformer model")
summary_on_cuda(
    enformer.final_pointwise,
    input_size=(16, 1536, 1536),
    row_settings=["var_names"],
    depth=6,
)


Enformer model


Layer (type (var_name))                  Output Shape              Param #
Sequential (Sequential)                  [16, 1536, 3072]          --
├─Rearrange (0)                          [16, 1536, 1536]          --
├─Sequential (1)                         [16, 3072, 1536]          --
│    └─BatchNorm1d (0)                   [16, 1536, 1536]          (3,072)
│    └─GELU (1)                          [16, 1536, 1536]          --
│    └─Conv1d (2)                        [16, 3072, 1536]          (4,721,664)
├─Rearrange (2)                          [16, 1536, 3072]          --
├─Dropout (3)                            [16, 1536, 3072]          --
├─GELU (4)                               [16, 1536, 3072]          --
Total params: 4,724,736
Trainable params: 0
Non-trainable params: 4,724,736
Total mult-adds (G): 116.04
Input size (MB): 151.00
Forward/backward pass size (MB): 905.97
Params size (MB): 18.90
Estimated Total Size (MB): 1075.86

In [27]:
print("\nLoRA-enabled Network:")
summary_on_cuda(
    lora_enformer.final_pointwise,
    input_size=(16, 1536, 1536),
    row_settings=["var_names"],
    depth=6,
)


LoRA-enabled Network:


Layer (type (var_name))                  Output Shape              Param #
Sequential (Sequential)                  [16, 1536, 3072]          --
├─Rearrange (0)                          [16, 1536, 1536]          --
├─Sequential (1)                         [16, 3072, 1536]          --
│    └─BatchNorm1d (0)                   [16, 1536, 1536]          (3,072)
│    └─GELU (1)                          [16, 1536, 1536]          --
│    └─LoRAConv (2)                      [16, 3072, 1536]          18,432
│    │    └─Conv1d (conv)                [16, 3072, 1536]          (4,721,664)
├─Rearrange (2)                          [16, 1536, 3072]          --
├─Dropout (3)                            [16, 1536, 3072]          --
├─GELU (4)                               [16, 1536, 3072]          --
Total params: 4,743,168
Trainable params: 18,432
Non-trainable params: 4,724,736
Total mult-adds (G): 116.04
Input size (MB): 151.00
Forward/backward pass size (MB): 905.97
Params size (MB): 18.90
Estimated 

### LoRA Modules

In [28]:
_module = lora_enformer.final_pointwise[1][2]
print_lora_conv(_module)

in_channels 1536
out_channels 3072
groups 1
kernel_size (1,)

LoRA Module <class 'bolero.tl.generic.module_lora.LoRAConv'>
lora_A torch.Size([4, 1536]) True
lora_B torch.Size([3072, 4]) True
conv.weight torch.Size([3072, 1536, 1]) False
conv.bias torch.Size([3072]) False
LoRA trainable parameters:  18432


## Enformer Output Heads - Human

- Input:
  - shape: `(bs, seq_len/2**7, 2*C) or (16, 1536, 3072)`
- Output:
  - shape: `(bs, seq_len/2**7, n_tracks) or (16, 1536, 5313)`

In [29]:
print("\nEnformer model")
summary_on_cuda(
    enformer.heads["human"],
    input_size=(16, 1536, 3072),
    row_settings=["var_names"],
    depth=6,
)


Enformer model


Layer (type (var_name))                  Output Shape              Param #
Sequential (Sequential)                  [16, 1536, 5313]          --
├─Linear (0)                             [16, 1536, 5313]          (16,326,849)
├─Softplus (1)                           [16, 1536, 5313]          --
Total params: 16,326,849
Trainable params: 0
Non-trainable params: 16,326,849
Total mult-adds (M): 261.23
Input size (MB): 301.99
Forward/backward pass size (MB): 1044.58
Params size (MB): 65.31
Estimated Total Size (MB): 1411.88

In [30]:
print("\nLoRA-enabled Network:")
summary_on_cuda(
    lora_enformer.heads["human"],
    input_size=(16, 1536, 3072),
    row_settings=["var_names"],
    depth=6,
)


LoRA-enabled Network:


Layer (type (var_name))                  Output Shape              Param #
Sequential (Sequential)                  [16, 1536, 5313]          --
├─LoRALinear (0)                         [16, 1536, 5313]          16,360,389
├─Softplus (1)                           [16, 1536, 5313]          --
Total params: 16,360,389
Trainable params: 33,540
Non-trainable params: 16,326,849
Total mult-adds (M): 261.23
Input size (MB): 301.99
Forward/backward pass size (MB): 1044.58
Params size (MB): 65.44
Estimated Total Size (MB): 1412.01

### LoRA Modules

#### Output Head Linear

In [31]:
_module = lora_enformer.heads["human"][0]
print_lora_linear(_module)

in_features 3072
out_features 5313

LoRA Module <class 'bolero.tl.generic.module_lora.LoRALinear'>
weight torch.Size([5313, 3072]) False
bias torch.Size([5313]) False
lora_A torch.Size([4, 3072]) True
lora_B torch.Size([5313, 4]) True
LoRA trainable parameters:  33540


## Enformer Output Heads - Mouse:
  - Input:
    - shape: `(bs, seq_len/2**7, 2*C) or (16, 1536, 3072)`
  - Output:
    - shape: `(bs, seq_len/2**7, n_tracks) or (16, 1536, 1643)`

In [32]:
print("\nEnformer model")
summary_on_cuda(
    enformer.heads["mouse"],
    input_size=(16, 1536, 3072),
    row_settings=["var_names"],
    depth=6,
)


Enformer model


Layer (type (var_name))                  Output Shape              Param #
Sequential (Sequential)                  [16, 1536, 1643]          --
├─Linear (0)                             [16, 1536, 1643]          (5,048,939)
├─Softplus (1)                           [16, 1536, 1643]          --
Total params: 5,048,939
Trainable params: 0
Non-trainable params: 5,048,939
Total mult-adds (M): 80.78
Input size (MB): 301.99
Forward/backward pass size (MB): 323.03
Params size (MB): 20.20
Estimated Total Size (MB): 645.21

In [33]:
print("\nLoRA-enabled Network:")
summary_on_cuda(
    lora_enformer.heads["mouse"],
    input_size=(16, 1536, 3072),
    row_settings=["var_names"],
    depth=6,
)


LoRA-enabled Network:


Layer (type (var_name))                  Output Shape              Param #
Sequential (Sequential)                  [16, 1536, 1643]          --
├─LoRALinear (0)                         [16, 1536, 1643]          5,067,799
├─Softplus (1)                           [16, 1536, 1643]          --
Total params: 5,067,799
Trainable params: 18,860
Non-trainable params: 5,048,939
Total mult-adds (M): 80.78
Input size (MB): 301.99
Forward/backward pass size (MB): 323.03
Params size (MB): 20.27
Estimated Total Size (MB): 645.29

### LoRA Modules

#### Output Head Linear

In [34]:
_module = lora_enformer.heads["mouse"][0]
print_lora_linear(_module)

in_features 3072
out_features 1643

LoRA Module <class 'bolero.tl.generic.module_lora.LoRALinear'>
weight torch.Size([1643, 3072]) False
bias torch.Size([1643]) False
lora_A torch.Size([4, 3072]) True
lora_B torch.Size([1643, 4]) True
LoRA trainable parameters:  18860
