In [1]:
# # %load_ext autoreload
# %autoreload 2
import torch
from transformer_vq.nn.attn_vq import VQAttentionQK
from transformer_vq.nn.config_spec import TransformerConfig
from transformer_vq.nn.emb import TransformerEmbedding as Emb
import yaml
from einops import rearrange
with open('conf.yml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
n_vocab = config['n_vocab']
sequence_length = config['sequence_len']
batch_size = config['global_batch_size']
d_model = config['d_model']
config['d_type'] = torch.float32
config['param_dtype'] = torch.float32
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
data = torch.randint(low=0, high=n_vocab, size=(4, sequence_length))
d_model = config['d_model']
emb = Emb(n_vocab, d_model)
model_config = TransformerConfig(**config)
model = VQAttentionQK(model_config)
data = emb(data)
block_len = config['block_len']
data = rearrange(data, 'b (l s) d -> l b s d', l=block_len)
# data = rearrange(data, 'b s d -> b (s d)')

In [2]:
data.shape

torch.Size([32, 4, 32, 1024])

In [3]:
block_len

32

In [8]:
q, k, v, g = model.compute_k_q_v_g(data)

In [9]:
q.shape

torch.Size([32, 4, 32, 384])

In [11]:
k.shape

torch.Size([32, 4, 3, 32, 128])

In [12]:
v.shape

torch.Size([32, 4, 3, 32, 128])

In [13]:
g.shape

torch.Size([32, 4, 32, 384])

In [5]:
from transformer_vq.nn.norm import LayerNorm

In [6]:
norm = LayerNorm(d_model, 42)

In [7]:
norm

LayerNorm()

In [14]:
import einops; import torch

In [15]:
x = torch.randn(3, 1, 4)
x_squeezed = rearrange(x, 'h 1 w -> h w')

print(x.shape)        # Output: torch.Size([3, 1, 4])
print(x_squeezed.shape)  

torch.Size([3, 1, 4])
torch.Size([3, 4])


In [16]:
import copy
import math

import torch
from torch import nn
import torch.nn.functional as F

# Copy a module N times
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# Implement attention (Scaled Dot Product)
def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    attention_result = torch.matmul(p_attn, value)
    return attention_result, p_attn

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 
                             for l, x in zip(self.linears, (query, key, value))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [19]:
import numpy as np
batch_size = 4
sequence_length = 2048
hidden_size = 128
attention_heads = 1
mha = MultiHeadedAttention(h=attention_heads, d_model=hidden_size)
print("With as many attention queries as there are values:\n")
query = torch.tensor(np.ones([batch_size, 1, hidden_size])).float()
value = torch.tensor(np.ones([batch_size, sequence_length, hidden_size])).float()
result = mha.forward(query, value, value)
print("query:", query.size())
print("value:", value.size())
print("result:", result.size())
print("\n")

With as many attention queries as there are values:

query: torch.Size([4, 1, 128])
value: torch.Size([4, 2048, 128])
result: torch.Size([4, 1, 128])




In [20]:
def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    attention_result = torch.matmul(p_attn, value)
    return attention_result, p_attn

In [None]:
queries = torch.randn(4, 2048, 1, 128)
keys = torch.randn(4, 2048, 1, 128)
values = torch.randn(4, 2048, 1, 128)
# shape(batch, sequence_length, num_heads, hidden_size)
attention(queries, keys, values)[0].shape

torch.Size([4, 2048, 1, 128])

In [26]:
pip install flops-profiler

Collecting flops-profiler
  Obtaining dependency information for flops-profiler from https://files.pythonhosted.org/packages/43/26/5732be586af7ab6cf8a518c91e7a5a44a839aa00d014173eee0398d357c7/flops_profiler-0.1.2-py3-none-any.whl.metadata
  Downloading flops_profiler-0.1.2-py3-none-any.whl.metadata (33 kB)
Downloading flops_profiler-0.1.2-py3-none-any.whl (20 kB)
Installing collected packages: flops-profiler
Successfully installed flops-profiler-0.1.2
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torchvision.models as models
import torch
from flops_profiler import get_model_profile

with torch.cuda.device(0):
    model = models.alexnet()
    batch_size = 256
    flops, macs, params = get_model_profile(model=model, # model
                                    input_shape=(batch_size, 3, 224, 224), # input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
                                    args=None, # list of positional arguments to the model.
                                    kwargs=None, # dictionary of keyword arguments to the model.
                                    print_profile=True, # prints the model graph with the measured profile attached to each module
                                    detailed=True, # print the detailed profile
                                    module_depth=-1, # depth into the nested modules, with -1 being the inner most modules
                                    top_modules=1, # the number of top modules to print aggregated profile
                                    warm_up=10, # the number of warm-ups before measuring the time of each module
                                    as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
                                    output_file=None, # path to the output file. If None, the profiler prints to stdout.
                                    ignore_modules=None, # the list of modules to ignore in the profiling
                                    func_name='forward') # the function name to profile, "forward" by default, for huggingface generative models, `generate` is used

ImportError: cannot import name 'get_model_profile' from 'flops_profiler' (/Users/jmuneton/miniconda3/envs/torch/lib/python3.9/site-packages/flops_profiler/__init__.py)