In [155]:
import sys
sys.path
sys.path.append("/home/haoqi.whq/llm-inference/LoRA")

from src.model import GPT2Config, GPT2LMModel
import torch
from loralib import PruneLayer, LoRALayer
import loralib as lora
import copy
from torch import nn
from src.model import LayerNorm, Conv1D
import time
import math
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from torchsummary import summary


In [156]:
config = GPT2Config(
    n_embd=1024,
    n_layer=24,
    n_head=16,
    lora_attn_dim=4,
    lora_attn_alpha=32,
    lora_dropout=0.1,
    enable_mlp=True,
    enable_wo=True,
    enable_wq=True,
    enable_wk=True,
    enable_wv=True,
)

B = 8
SEQ_LEN = 512

NON_LINEAR_TIME, IO_TIME, MASK_TIME, LINER_TIME_GPU, LORA_TIME_TEE = 0, 0, 0, 0, 0

In [157]:
def mask(x):
    r = torch.zeros_like(x)
    return x + r

def unmask(x):
    r = torch.zeros_like(x)
    return x - r

## Transformer block runtime breakdown

In [158]:
class GPTConv1D(nn.Module, LoRALayer):
    # LoRA implemented in a Conv1D layer
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        merge_weights: bool = True,
        **kwargs
    ):
        super(GPTConv1D, self).__init__()
        LoRALayer.__init__(
            self,
            r=r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            merge_weights=merge_weights,
        )

        self.in_features = in_features
        self.out_features = out_features
        self.fan_in_fan_out = fan_in_fan_out
        w = torch.empty(in_features, out_features)
        nn.init.normal_(w, std=0.02)
        self.weight = Parameter(w)
        self.bias = Parameter(torch.zeros(out_features))

        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            # self.lora_scaling = self.lora_alpha / self.r
            self.lora_scaling = nn.Parameter(torch.tensor(self.lora_alpha / self.r))

            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight, std=0.02)
        nn.init.zeros_(self.bias)
        if hasattr(self, "lora_A"):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w

        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.lora_scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += T(self.lora_B @ self.lora_A) * self.lora_scaling
                self.merged = True

    def forward(self, x: torch.Tensor):
        size_out = x.size()[:-1] + (self.out_features,)

        start = time.time()
        x = mask(x)
        end = time.time()
        print(f"\t=====[MLP]====== mask x {(end - start) * 1000} ms")
        global MASK_TIME
        MASK_TIME += (end - start) * 1000

        start = time.time()
        x = x.cuda()
        end = time.time()
        print(f"\t=====[MLP]====== cpu->gpu IO {(end - start) * 1000} ms")
        global IO_TIME
        IO_TIME += (end - start) * 1000
        
        self.weight = Parameter(self.weight.cuda())
        self.bias = Parameter(self.bias.cuda())

        start = time.time()
        result = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        result = result.view(*size_out)
        end = time.time()
        print(f"\t=====[MLP]====== gpu linear {(end - start) * 1000} ms")
        global LINER_TIME_GPU
        LINER_TIME_GPU += (end - start) * 1000

        start = time.time()
        result = result.cpu()
        end = time.time()
        print(f"\t=====[MLP]====== gpu->cpu IO {(end - start) * 1000} ms")
        IO_TIME += (end - start) * 1000

        start = time.time()
        result = unmask(result)
        end = time.time()
        print(f"\t=====[MLP]====== unmask xA {(end - start) * 1000} ms")
        MASK_TIME += (end - start) * 1000

        x = x.cpu()
        start = time.time()
        lora_res = (
            self.lora_dropout(x)
            @ self.lora_A.transpose(0, 1)
            @ self.lora_B.transpose(0, 1)
        ) * self.lora_scaling
        end = time.time()
        print(f"\t=====[MLP]====== LoRA linear {(end - start) * 1000} ms")
        global LORA_TIME_TEE
        LORA_TIME_TEE += (end - start) * 1000

        if self.r > 0 and not self.merged:
            result += lora_res
        return result
    
class Linear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(
            self,
            r=r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            merge_weights=merge_weights,
        )

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            # self.scaling = self.lora_alpha / self.r
            self.lora_scaling = nn.Parameter(torch.tensor(self.lora_alpha / self.r))

            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, "lora_A"):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w

        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.lora_scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += T(self.lora_B @ self.lora_A) * self.lora_scaling
                self.merged = True

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w

        if self.r > 0 and not self.merged:

            start = time.time()
            x = mask(x)
            end = time.time()
            print(f"\t=====[Attention]====== mask x {(end - start) * 1000} ms")
            global MASK_TIME
            MASK_TIME += (end - start) * 1000

            start = time.time()
            x = x.cuda()
            end = time.time()
            print(f"\t=====[Attention]====== cpu->gpu IO {(end - start) * 1000} ms")
            global IO_TIME
            IO_TIME += (end - start) * 1000

            self.weight = Parameter(self.weight.cuda())
            self.bias = Parameter(self.bias.cuda())
            
            start = time.time()
            result = F.linear(x, T(self.weight), bias=self.bias)
            end = time.time()
            print(f"\t=====[Attention]====== GPU linear {(end - start) * 1000} ms")
            global LINER_TIME_GPU
            LINER_TIME_GPU += (end - start) * 1000

            start = time.time()
            result = result.cpu()
            end = time.time()
            print(f"\t=====[Attention]====== gpu->cpu IO {(end - start) * 1000} ms")
            IO_TIME += (end - start) * 1000

            start = time.time()
            result = unmask(result)
            end = time.time()
            print(f"\t=====[Attention]====== unmask xA {(end - start) * 1000} ms")
            MASK_TIME += (end - start) * 1000

            x = x.cpu()
            start = time.time()
            lora_res = (
                self.lora_dropout(x)
                @ self.lora_A.transpose(0, 1)
                @ self.lora_B.transpose(0, 1)
            ) * self.lora_scaling
            end = time.time()
            print(f"\t=====[Attention]====== LoRA linear {(end - start) * 1000} ms")
            global LORA_TIME_TEE
            LORA_TIME_TEE += (end - start) * 1000

            result += lora_res
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)

In [159]:
class PruneLinear(PruneLayer, Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0,
        fan_in_fan_out: bool = False,
        merge_weights: bool = True,
        keep_flag: bool = True,
        **kwargs
    ):
        PruneLayer.__init__(self, keep_flag)
        Linear.__init__(
            self,
            in_features,
            out_features,
            r,
            lora_alpha,
            lora_dropout,
            fan_in_fan_out,
            merge_weights,
            **kwargs
        )

        # update scaling as thr
        # self.scaling = nn.Parameter(torch.tensor(self.lora_alpha / self.r))

    def forward(self, x: torch.Tensor):
        if not self.keep_flag:
            self.merged = True  # set merged to True to escape computing lora module
        return Linear.forward(self, x)

    def complexity(self):
        return self.lora_scaling * (self.r * self.in_features + self.out_features * self.r)

    def empirical_consumption(self, hardwares):
        return self.complexity()

class PruneGPTConv1D(PruneLayer, GPTConv1D):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0,
        fan_in_fan_out: bool = False,
        merge_weights: bool = True,
        keep_flag: bool = True,
        **kwargs
    ):
        PruneLayer.__init__(self, keep_flag)
        GPTConv1D.__init__(
            self,
            in_features,
            out_features,
            r,
            lora_alpha,
            lora_dropout,
            fan_in_fan_out,
            merge_weights,
            **kwargs
        )

        # update scaling as thr
        # self.scaling = nn.Parameter(torch.tensor(self.lora_alpha / self.r))

    def forward(self, x: torch.Tensor):
        if not self.keep_flag:
            self.merged = True  # set merged to True to escape computing lora module
        return GPTConv1D.forward(self, x)

    def complexity(self):
        return self.lora_scaling * (self.r * self.in_features + self.out_features * self.r)

    def empirical_consumption(self, hardwares):
        return self.complexity()

In [160]:
class Attention(nn.Module):
    def __init__(self, nx, n_ctx, config, scale=False):
        super(Attention, self).__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]

        assert n_state % config.n_head == 0
        self.register_buffer(
            "bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)
        )
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale

        if config.enable_wq:
            self.q_attn = PruneLinear(
                nx,
                n_state,
                r=config.lora_attn_dim,
                lora_alpha=config.lora_attn_alpha,
                lora_dropout=config.lora_dropout,
                fan_in_fan_out=True,
                merge_weights=False,
            )
        else:
            self.q_attn = nn.Linear(nx, n_state)

        if config.enable_wk:
            self.k_attn = PruneLinear(
                nx,
                n_state,
                r=config.lora_attn_dim,
                lora_alpha=config.lora_attn_alpha,
                lora_dropout=config.lora_dropout,
                fan_in_fan_out=True,
                merge_weights=False,
            )
        else:
            self.k_attn = nn.Linear(nx, n_state)

        if config.enable_wv:
            self.v_attn = PruneLinear(
                nx,
                n_state,
                r=config.lora_attn_dim,
                lora_alpha=config.lora_attn_alpha,
                lora_dropout=config.lora_dropout,
                fan_in_fan_out=True,
                merge_weights=False,
            )
        else:
            self.v_attn = nn.Linear(nx, n_state)
        # self.c_attn = lora.MergedLinear(
        #     nx,
        #     n_state * 3,
        #     r=config.lora_attn_dim,
        #     lora_alpha=config.lora_attn_alpha,
        #     lora_dropout=config.lora_dropout,
        #     enable_lora=[config.enable_wq, config.enable_wk, config.enable_wv],
        #     fan_in_fan_out=True,
        #     merge_weights=False,
        # )
        # print(
        #     f"QKV Attention LoRA ({[config.enable_wq, config.enable_wk, config.enable_wv]}): {self.c_attn}"
        # )

        if not config.enable_wo:
            self.c_proj = Conv1D(n_state, nx)
            print(f"O Attention not use LoRA: {self.c_proj}")
        else:
            self.c_proj = PruneGPTConv1D(
                in_features=nx,
                out_features=n_state,
                r=config.lora_attn_dim,
                lora_alpha=config.lora_attn_alpha,
                lora_dropout=config.lora_dropout,
                fan_in_fan_out=False,
                merge_weights=False,
                keep_flag=True,
            )
            print(f"O Attention using LoRA: {self.c_proj}")

        self.config = config

    def _attn(self, q, k, v, len_kv=None):
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
        nd, ns = w.size(-2), w.size(-1)
        b = self.bias[:, :, ns - nd : ns, :ns]
        w = w * b - 1e10 * (1 - b)

        # q : (batch, head, q_seq_length, head_features)
        # k : (batch, head, head_features, kv_seq_length)
        # w : (batch, head, q_seq_length, kv_seq_length)
        # v : (batch, head, kv_seq_length, head_features)
        if len_kv is not None:
            _len = torch.arange(k.size(-1), device=k.device)
            _input_msk = _len[None, :] >= (len_kv)[:, None]
            w = w.masked_fill(_input_msk.unsqueeze(1).unsqueeze(2), -1.0e10)

        w = nn.Softmax(dim=-1)(w)
        return torch.matmul(w, v)

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(
                0, 2, 3, 1
            ).contiguous()  # (batch, head, head_features, seq_length)
        else:
            return x.permute(
                0, 2, 1, 3
            ).contiguous()  # (batch, head, seq_length, head_features)

    def forward(self, x, history=None, layer_past=None, len_past=None):
        hidden_states = x

        
        start = time.time()
        query, key, value = self.q_attn(x), self.k_attn(x), self.v_attn(x)
        print(f"qkv linear {(time.time() - start) * 1000} ms")

        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)

        # _input_msk = None

        len_kv = None

        if layer_past is not None:
            # key : (batch, head, head_features, seq_length)
            # value : (batch, head, seq_length, head_features)
            # layer_past, key : (batch, head, seq_length, head_features)
            if len_past is None:
                past_key, past_value = (
                    layer_past[0].transpose(-2, -1),
                    layer_past[1],
                )  # transpose back cf below
                key = torch.cat((past_key, key), dim=-1)
                value = torch.cat((past_value, value), dim=-2)
            else:
                key_seq = key.shape[-1]
                assert key_seq == 1

                _batch = torch.arange(
                    0, key.shape[0], dtype=torch.long, device=key.device
                )

                past_key, past_value = layer_past[0], layer_past[1]

                past_key[_batch, :, len_past, :] = key.squeeze(-1)
                past_value[_batch, :, len_past, :] = value.squeeze(-2)

                key = past_key.transpose(-2, -1)
                value = past_value

                len_kv = len_past + 1

        present = torch.stack(
            (key.transpose(-2, -1), value)
        )  # transpose to have same shapes for stacking

        start = time.time()
        a = self._attn(query, key, value, len_kv=len_kv)
        print(f"bmm+softmax+bmm {(time.time() - start) * 1000} ms")

        start = time.time()
        a = self.merge_heads(a)
        print(f"merge heads {(time.time() - start) * 1000} ms")

        start = time.time()
        a = self.c_proj(a)
        print(f"o linear {(time.time() - start) * 1000} ms")

        return a, present
    

def gelu(x):
    return (
        0.5
        * x
        * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    )

def gelu_fast(x):
    return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))

def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return (
        0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
    )

def gelu_quad(x):
    return 0.125 * torch.square(x + 0.25 * x + 0.5)

class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
        super(MLP, self).__init__()
        nx = config.n_embd
        # raw
        if not config.enable_mlp:
            self.c_fc = Conv1D(n_state, nx)
            self.c_proj = Conv1D(nx, n_state)
            print(f"MLP not use LoRA: {self.c_fc}, {self.c_proj}")
        else:  # modified
            self.c_fc = PruneGPTConv1D(
                in_features=nx,
                out_features=n_state,
                r=config.lora_attn_dim,
                lora_alpha=config.lora_attn_alpha,
                lora_dropout=config.lora_dropout,
                fan_in_fan_out=False,
                merge_weights=False,
                keep_flag=True,
            )
            self.c_proj = PruneGPTConv1D(
                in_features=n_state,
                out_features=nx,
                r=config.lora_attn_dim,
                lora_alpha=config.lora_attn_alpha,
                lora_dropout=config.lora_dropout,
                fan_in_fan_out=False,
                merge_weights=False,
                keep_flag=True,
            )
            print(f"MLP using LoRA: {self.c_fc}, {self.c_proj}")

        self.act = gelu_fast

    def forward(self, x):
        h = self.c_fc(x)

        start = time.time()
        h1 = self.act(h)
        end = time.time()
        # print(h.shape)
        print(f"GeLU nonlinear TEE time {(end - start) * 1000} ms")

        # h = h.cuda()
        # start = time.time()
        # h1 = self.act(h)
        # end = time.time()
        # print(f"GeLU nonlinear GPU time {(end - start) * 1000} ms")
        
        global NON_LINEAR_TIME
        NON_LINEAR_TIME += (end - start) * 1000

        # h1 = h1.cpu()
        h2 = self.c_proj(h1)
        return h2

In [161]:
class Block(nn.Module):
    def __init__(self, n_ctx, config, scale=False):
        super(Block, self).__init__()
        nx = config.n_embd
        self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.attn = Attention(nx, n_ctx, config, scale)
        self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = MLP(4 * nx, config)

    def forward(self, x, layer_past=None, len_past=None):
        start = time.time()
        ln_res = self.ln_1(x)
        end = time.time()
        print(f"ln_1 {(end - start) * 1000} ms")
        global NON_LINEAR_TIME
        NON_LINEAR_TIME += (end - start) * 1000


        start = time.time()
        a, present = self.attn(ln_res, layer_past=layer_past, len_past=len_past)
        print(f"attention {(time.time() - start) * 1000} ms")
        
        x = x + a

        start = time.time()
        ln_res = self.ln_2(x)
        end = time.time()
        print(f"ln_2 {(end - start) * 1000} ms")
        NON_LINEAR_TIME += (end - start) * 1000

        start = time.time()
        m = self.mlp(ln_res)
        print(f"mlp {(time.time() - start) * 1000} ms")

        x = x + m
        return x, present



block = Block(config.n_ctx, config, scale=True)
block.eval()

hidden_states = torch.rand(size=(B, SEQ_LEN, config.n_embd))

outputs = block(hidden_states)


times = [NON_LINEAR_TIME, IO_TIME, MASK_TIME, LINER_TIME_GPU, LORA_TIME_TEE]
total_time = sum(times)

print("NON_LINEAR_TIME\tIO_TIME\tMASK_TIME\tLINER_TIME_GPU\tLORA_TIME_TEE")
abs_time = ""
rel_time = ""
for time_ in times:
    abs_time += f"{time_}\t"
    rel_time += f"{time_ / total_time * 100:.2f}%\t"
print(abs_time)
print(rel_time)

O Attention using LoRA: PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
)
MLP using LoRA: PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
), PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
)
ln_1 3.737926483154297 ms
qkv linear 46.0052490234375 ms
bmm+softmax+bmm 63.253164291381836 ms
merge heads 0.9317398071289062 ms
o linear 15.317916870117188 ms
attention 131.25371932983398 ms
ln_2 5.640745162963867 ms
GeLU nonlinear TEE time 57.7847957611084 ms
mlp 285.003662109375 ms
NON_LINEAR_TIME	IO_TIME	MASK_TIME	LINER_TIME_GPU	LORA_TIME_TEE
67.16346740722656	118.33047866821289	45.03273963928223	1.2764930725097656	22.70030975341797	
26.39%	46.49%	17.69%	0.50%	8.92%	


In [162]:
from src.model import Block as RawBlock
block = RawBlock(config.n_ctx, config, scale=True)
block.eval()

hidden_states = torch.rand(size=(B, SEQ_LEN, config.n_embd))
summary(block, hidden_states, depth=3)

O Attention using LoRA: PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
)
MLP using LoRA: PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
), PruneGPTConv1D(
  (lora_dropout): Dropout(p=0.1, inplace=False)
)
Layer (type:depth-idx)                   Output Shape              Param #
├─LayerNorm: 1-1                         [-1, 512, 1024]           2,048
├─Attention: 1-2                         [-1, 512, 1024]           --
|    └─PruneLinear: 2-1                  [-1, 512, 1024]           --
|    |    └─Dropout: 3-1                 [-1, 512, 1024]           --
|    └─PruneLinear: 2-2                  [-1, 512, 1024]           --
|    |    └─Dropout: 3-2                 [-1, 512, 1024]           --
|    └─PruneLinear: 2-3                  [-1, 512, 1024]           --
|    |    └─Dropout: 3-3                 [-1, 512, 1024]           --
|    └─PruneGPTConv1D: 2-4               [-1, 512, 1024]           --
|    |    └─Dropout: 3-4                 [-1, 512, 10

Layer (type:depth-idx)                   Output Shape              Param #
├─LayerNorm: 1-1                         [-1, 512, 1024]           2,048
├─Attention: 1-2                         [-1, 512, 1024]           --
|    └─PruneLinear: 2-1                  [-1, 512, 1024]           --
|    |    └─Dropout: 3-1                 [-1, 512, 1024]           --
|    └─PruneLinear: 2-2                  [-1, 512, 1024]           --
|    |    └─Dropout: 3-2                 [-1, 512, 1024]           --
|    └─PruneLinear: 2-3                  [-1, 512, 1024]           --
|    |    └─Dropout: 3-3                 [-1, 512, 1024]           --
|    └─PruneGPTConv1D: 2-4               [-1, 512, 1024]           --
|    |    └─Dropout: 3-4                 [-1, 512, 1024]           --
├─LayerNorm: 1-3                         [-1, 512, 1024]           2,048
├─MLP: 1-4                               [-1, 512, 1024]           --
|    └─PruneGPTConv1D: 2-5               [-1, 512, 4096]           --
|    |   