In [3]:
import torch
from torch import nn

tel = nn.TransformerEncoderLayer(d_model=12,nhead=3,batch_first=True)

In [16]:
for name,param in tel.self_attn.named_parameters():
    print(f'{name}: {param.shape}')    

in_proj_weight: torch.Size([36, 12])
in_proj_bias: torch.Size([36])
out_proj.weight: torch.Size([12, 12])
out_proj.bias: torch.Size([12])


In [28]:
nn.MultiheadAttention(embed_dim=12,num_heads=3).out_proj

NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)

In [None]:
nn.MultiheadAttention(embed_dim=12,num_heads=3).out_proj

In [None]:
nn.MultiheadAttention(embed_dim=12,num_heads=3).out_proj

In [10]:
tel

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
  )
  (linear1): Linear(in_features=12, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=12, bias=True)
  (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)

In [9]:
from torchinfo import summary as sm
sm(tel,(5,3,12))

Layer (type:depth-idx)                   Output Shape              Param #
TransformerEncoderLayer                  --                        --
├─MultiheadAttention: 1-1                [5, 3, 12]                --
├─Dropout: 1-2                           [5, 3, 12]                --
├─LayerNorm: 1-3                         [5, 3, 12]                24
├─Linear: 1-4                            [5, 3, 2048]              26,624
├─Dropout: 1-5                           [5, 3, 2048]              --
├─Linear: 1-6                            [5, 3, 12]                24,588
├─Dropout: 1-7                           [5, 3, 12]                --
├─LayerNorm: 1-8                         [5, 3, 12]                24
Total params: 51,260
Trainable params: 51,260
Non-trainable params: 0
Total mult-adds (M): 0.26
Input size (MB): 0.00
Forward/backward pass size (MB): 0.25
Params size (MB): 0.21
Estimated Total Size (MB): 0.46

In [6]:
tel(torch.rand(5,3,12)).shape

torch.Size([5, 3, 12])

### external attentions  

In [36]:
import os
ext_path = os.path.abspath('./ext_attns/')

import sys
sys.path.append(ext_path)

In [37]:
# 1.external attn: "Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks"
from model.attention.ExternalAttention import ExternalAttention
import torch

input=torch.randn(50,49,512)
ea = ExternalAttention(d_model=512,S=8)
output=ea(input)
print(output.shape)

torch.Size([50, 49, 512])


In [39]:
# 2. self attn: attention is all you need 
from model.attention.SelfAttention import ScaledDotProductAttention
import torch

input=torch.randn(50,49,512)
sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)

torch.Size([50, 49, 512])


In [40]:
# 3. simplified self attn: None
from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention
import torch

input=torch.randn(50,49,512)
ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
output=ssa(input,input,input)
print(output.shape)

torch.Size([50, 49, 512])


In [47]:
# 10 Efficient Multi-Head Self-Attention: "ResT: An Efficient Transformer for Visual Recognition"
from model.attention.EMSA import EMSA
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,64,512)
emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)
output=emsa(input,input,input)
print(output.shape)

torch.Size([50, 64, 512])


In [42]:
# 13 muse attn: "MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning"
from model.attention.MUSEAttention import MUSEAttention
import torch
from torch import nn
from torch.nn import functional as F


input=torch.randn(50,49,512)
sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)

torch.Size([50, 49, 512])


In [43]:
# 16 aft: An Attention Free Transformer
from model.attention.AFT import AFT_FULL
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,49,512)
aft_full = AFT_FULL(d_model=512, n=49)
output=aft_full(input)
print(output.shape)

torch.Size([50, 49, 512])


In [45]:
# 30 UFO: UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29
from model.attention.UFOAttention import *
import torch
from torch import nn
from torch.nn import functional as F


input=torch.randn(50,49,512)
ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)
output=ufo(input,input,input)
print(output.shape) #[50, 49, 512]

torch.Size([50, 49, 512])


### importing test 

In [89]:
# assertion test
x = torch.rand(3, 4, 768)
y = torch.rand(3, 4, 768)
# for attn in attn_li:
# pre = attn_li[0](x,x,x)
# pre = attn_li[1](x)     not solved
# pre = attn_li[2](x,x,x)
# pre = attn_li[3](x,x,x) not solved
# pre = attn_li[4](x,x,x)
# pre = attn_li[5](x,x,x)
# pre = attn_li[6](x)
assert y.shape == pre.shape
print(f'y:{y.shape} pre:{pre.shape}')

y:torch.Size([3, 4, 768]) pre:torch.Size([3, 4, 768])


In [1]:
import os
ext_path = os.path.abspath('./ext_attns/')
import sys
sys.path.append(ext_path)
# list
from model.attention.UFOAttention import *
from model.attention.AFT import AFT_FULL
from model.attention.MUSEAttention import MUSEAttention
from model.attention.EMSA import EMSA
from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention
from model.attention.SelfAttention import ScaledDotProductAttention
from model.attention.ExternalAttention import ExternalAttention
d_model = 768
n_head  = 12
ufo = UFOAttention(d_model=d_model, d_k=d_model//n_head, d_v=d_model//n_head, h=n_head)
aft_full = AFT_FULL(d_model=d_model, n=n_head)
ma = MUSEAttention(d_model=d_model, d_k=d_model//n_head, d_v=d_model//n_head, h=n_head)
emsa = EMSA(d_model=d_model, d_k=d_model//n_head, d_v=d_model//n_head, h=n_head,H=n_head,W=n_head,ratio=2,apply_transform=True)
ssa = SimplifiedScaledDotProductAttention(d_model=d_model, h=n_head)
sa = ScaledDotProductAttention(d_model=d_model, d_k=d_model//n_head, d_v=d_model//n_head, h=n_head)
ea = ExternalAttention(d_model=d_model,S=8) #v

attn_li = [ufo,aft_full,ma,emsa,ssa,sa,ea]

import transformers

# print(transformers.__version__)

model_checkpoint = "t5-small"
model_checkpoint = 'facebook/bart-base'

from datasets import load_dataset, load_metric

raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")

# raw_datasets["train"][0]

import datasets

from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# ?
# with tokenizer.as_target_tokenizer():
#     print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))

if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

batch_size = 3 #16
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
#     push_to_hub=True,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

Using custom data configuration default
Reusing dataset xsum (/home/yp/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/yp/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499/cache-5d732b1c86657ea0.arrow
Loading cached processed dataset at /home/yp/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499/cache-38966a651c72a103.arrow
Loading cached processed dataset at /home/yp/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499/cache-d19c279816291ade.arrow


In [3]:
from torch import nn
for i,attn in enumerate(attn_li[:]):
    not_allowed = ('AFT_FULL','EMSA')
    print(attn)
    if attn.__str__().startswith(not_allowed):
        continue
        
    class tel2lin1(nn.Module):
        def __init__(self):
            super(tel2lin1,self).__init__()
            # self.tel = nn.TransformerEncoderLayer(d_model=768,nhead=3,batch_first=True)
            self.tel = attn
            self.lin = nn.Linear(in_features=768,out_features=50265, bias=False)
            
        def forward(self,x):
            if self.tel.__str__().startswith('ExternalAttention'):
                x = self.tel(x)
            else:
                x = self.tel(x,x,x)
            x = self.lin(x)
            return x
    model.lm_head = tel2lin1()

    trainer = Seq2SeqTrainer(
        model,
        args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()