# General testing notebook for qtransform and quantization
## Import stuff

In [2]:
import torch
import numpy as np
from typing import List, Tuple
from torch.utils.data import Dataset, DataLoader
from logging import getLogger
import os
from omegaconf import DictConfig
from brevitas import nn as qnn

  return torch._C._cuda_getDeviceCount() > 0
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


## Experiments with dataclasses and python classes

In [48]:
from abc import abstractclassmethod, ABC
from dataclasses import dataclass, replace

@dataclass
class Metadata():
    encoding: str

@dataclass
class BarMetadata(Metadata):
    other: str = ""

class Foo(ABC):
    def __init__(self, encoding: str):
        self.metadata: Metadata = Metadata(encoding)


    def load_metadata():
        pass

    @abstractclassmethod
    def test(self, file: str):
        file += "   padding"

class Bar(Foo):
    def __init__(self, encoding: str):
        super().__init__()
        self.metadata: BarMetadata

    def test(self, file: str):
        super().test(file)
        print(file)

In [49]:
from dataclasses import dataclass
@dataclass
class Metadata():
    encoding: str

@dataclass
class BarMetadata(Metadata):
    other: str = ""

In [50]:
test = BarMetadata(encoding="gpt2", other="ok")
import dataclasses
dataclasses.replace(test, **{"other": "Bruh"})

BarMetadata(encoding='gpt2', other='Bruh')

In [51]:
test: Metadata = Metadata("gpt2")
test: BarMetadata = BarMetadata(**test, other="other")

TypeError: __main__.BarMetadata() argument after ** must be a mapping, not Metadata

In [None]:
obj = test
params = set(inspect.signature(Metadata.__init__).parameters.keys()) - set(['self'])
{x:getattr(obj, x) for x in params}

{'encoding': 'gpt2'}

In [None]:
from dataclasses import asdict, 
asdict(test)

{'encoding': 'gpt2'}

In [None]:
#test if inner functions can access member attributes
class Foo():
    def __init__(self):
        self.a = 10
    def function(self):
        def other():
            print(self.a)
        other()

Foo().function()

10


In [None]:
# padding does not get appended to the parameter as it is a seperate function
Bar().test("test")

test


## Tests with torch framework to gain familiarity

In [None]:
b,c,e = 4, 5,6
tensor_3d = torch.arange(b*c*e).reshape(b,c,e)
tensor_3d

In [None]:
#batch has 5 rows, only want 3 
index = torch.tile(torch.arange(3).reshape(3,1), (b,1,e))
#you only consider the first batch
torch.gather(tensor_3d, dim=1, index=index)

tensor([[[  0,   1,   2,   3,   4,   5],
         [  6,   7,   8,   9,  10,  11],
         [ 12,  13,  14,  15,  16,  17]],

        [[ 30,  31,  32,  33,  34,  35],
         [ 36,  37,  38,  39,  40,  41],
         [ 42,  43,  44,  45,  46,  47]],

        [[ 60,  61,  62,  63,  64,  65],
         [ 66,  67,  68,  69,  70,  71],
         [ 72,  73,  74,  75,  76,  77]],

        [[ 90,  91,  92,  93,  94,  95],
         [ 96,  97,  98,  99, 100, 101],
         [102, 103, 104, 105, 106, 107]]])

In [None]:
#objective: retrieve first rows of tensor_3d -> if we specify dim=1, we collapse along the rows (we perform indexing for each row)
#b,c,e = 4,5,6
#i always want the first row -> specify by row, dim=1
#how do i reduce the amount of rows if the index tensor has to be of the same dimension?
#dimension has to be the same but not the shape
#torch.zeros(4,1,6) gets the first row of the tensor, but it is problematic if i want multiple rows as i 
#then use the same index (0) while having the output shape that i want
#solution: arange
#index=torch.zeros(4,1,6) -> if we use 5 instead of 6, each row has 5 columns
#meaning: we need a row containing the same index 
tensor_3d.gather(dim=1, index=torch.zeros(4,2,6, dtype=torch.int64))

tensor([[[ 0,  1,  2,  3,  4],
         [ 0,  1,  2,  3,  4]],

        [[30, 31, 32, 33, 34],
         [30, 31, 32, 33, 34]],

        [[60, 61, 62, 63, 64],
         [60, 61, 62, 63, 64]],

        [[90, 91, 92, 93, 94],
         [90, 91, 92, 93, 94]]])

In [None]:
torch.arange(2).reshape(2,1)

tensor([[0],
        [1]])

In [None]:
y = torch.tensor([
     [
       [1, 2, 3],
       [4, 5, 6],
       [0, 0, 0],
       [0, 0, 0]
     ],
     [
       [1, 2, 3],
       [4, 5, 6],
       [0, 0, 0],
       [0, 0, 0]
     ],
     [
       [1, 2, 3],
       [4, 5, 6],
       [0, 0, 0],
       [0, 0, 0]
     ]
   ])
#size is: 3, 4, 3. if you collapse in the first dimension (dim=0), the result tensor becomes of size 4,3. if you collapse it in the second dimension, you get a tensor of size 3,3

In [None]:
y.sum(dim=1)
#in transformers, we usually have tensors of shape b,c,e (batch_size, context, embedding_dimension).
#if we specify dim=0, we perform the operation along the entire batch, in dim=1 along the context and in dim=2 along the embedding dimension.
#if we were to sum the tensors together, sum(dim=1) will yield the sum of the embeddings of each word.
#think of it as squishing a dimension together so that it is of size 1, meaning that we have to squeeze in that dimension.

tensor([[5, 7, 9],
        [5, 7, 9],
        [5, 7, 9]])

In [None]:
#test if torch.tile and tensor.repeat are the same
c = 2 #simulate two words
a = torch.arange(c).reshape((c,1)).repeat((3,1,4))
b = torch.tile(torch.arange(c).reshape((c,1)), (3,1,4))
print(a.equal(b))
print(a)

True
tensor([[[0, 0, 0, 0],
         [1, 1, 1, 1]],

        [[0, 0, 0, 0],
         [1, 1, 1, 1]],

        [[0, 0, 0, 0],
         [1, 1, 1, 1]]])


In [57]:
"experiments with torch.gather"
M = torch.tensor([[1,2,3], [4,7,18], [19,9,23]])
#if there is more than one value inside of the last dimension, continue along current index
#meaning at dim=1:
#[1,1,1] -> 2,7,9
#[0,0,0] -> 1,4,19
#increments along the current dimension
#at new row, reset counter ->
#[1] -> 2
#[1] -> 2
indexes = torch.tensor([1,1,2]).view(-1,1) 

dimension = 0 #2d, meaning dim=0 along rows, dim=1 along columns
out = M.gather(dimension ,indexes) #dim=0: , dim=1: tensor([[ 2],[ 7],[23]])
M.gather(1, torch.Tensor([[1],[1],[2]]).to(dtype=torch.long)) #counter along the current dimension for the dimension of index
#M.gather(1, torch.tensor([[0,0,0],[0,1,0]]))

tensor([[ 2],
        [ 7],
        [23]])

## Test BatchNorm with Padding

In [None]:
from qtransform.model.modules import BatchNorm as BatchNormWithPadding
"test if padding does not lower values"
#first word of each batch -> gather by column
#result tensor: (3, 1, 64)
#retrieving an index from the dimension increases the counter along index of said dimension by one
#e.g. indexing 0 twice will retrieve two different values
FEATURES = 16
EMBEDDINGS = 64
BATCH_SIZE = 3
bn = torch.nn.BatchNorm1d(FEATURES)
#get first word embeddings of three batches
embedding_layer = torch.nn.Embedding(FEATURES, EMBEDDINGS)
batch = embedding_layer(torch.randint(16, (BATCH_SIZE, FEATURES)))
index = torch.arange(1).repeat(BATCH_SIZE,1,EMBEDDINGS).to(dtype=torch.long)
embd_first_word = torch.gather(batch, index=index, dim=1)
padding_bn = BatchNormWithPadding(FEATURES,bias=True)
norm_padding = padding_bn(embd_first_word)
norm = bn(batch)
#check if values are the same
print(f'Values are: {"same" if torch.gather(norm, index=index, dim=1).equal(norm_padding) else "different"}')

Values are: same


## Test huggingface dataset processing

In [None]:
import os
#test if huggingface datasets can be created from text files
import datasets

BASEDIR = '/home/mabot004/.qtransform/datasets/files/shakespeare/untokenized/'
#number of rows depends on the amount of files
files = [os.path.join(BASEDIR, 'shakespeare.txt'), os.path.join(BASEDIR, 'shakespeare_2.txt')]
#does the same as huggingface mapping but now with files
def gen_text():
    for filename in files:
        with open(filename, 'r') as file:
            yield {"text": file.read()}

#chunk size from config, default 100
def chunk_examples(examples):
                #splits the text of each row into chunks of length chunk_length. currently it is only used
                #for character tokenization to avoid feeding large samples to the tokenizer
    chunk_length = 100
                #perform tokenization on a handful of characters at a time
                #from: https://huggingface.co/docs/datasets/process#split-long-examples            
    chunks = []
    
    for sentence in examples["text"]:
        new_chunks = [sentence[i:i + chunk_length] for i in range(0, len(sentence), chunk_length)]
        chunks.extend(new_chunks)
    return {"chunks": chunks}
from tiktoken import get_encoding
tokenizer = get_encoding("gpt2")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
shakespeare = datasets.Dataset.from_generator(gen_text)
chunks = shakespeare.map(chunk_examples, batched=True, remove_columns = "text")
rotten_tomatoes = datasets.load_dataset('rotten_tomatoes')
rotten_tomatoes["train"].shard(num_shards=1000, index=0, contiguous = True)

Generating train split: 2 examples [00:00, 32.58 examples/s]


In [None]:
# status bar like huggingface dataset map process
from tqdm import tqdm
msg = 'ok'
for i in tqdm(range(100), desc=f'{msg}'):
    msg = str(i)
from tqdm import tqdm
import time
for i, data in tqdm(enumerate(range(10)), desc='test progress bar and other stdout stuff'):
    print(data)
    time.sleep(0.5)

ok: 100%|██████████| 100/100 [00:00<00:00, 842229.72it/s]


In [None]:
#error occurs because the splits have more than one feature and this function changes the amount of samples in each split of one feature without changint the other
#so: 5 samples, 2 features. after mapping: text has 10 samples, other feature still has 5 features
#from: https://github.com/huggingface/datasets/issues/1817#issuecomment-774066254
rt_chunks = datasets.concatenate_datasets(rotten_tomatoes.select_columns("text").map(chunk_examples, batched=True, remove_columns = "text").values())
print(rt_chunks)
#tokenize
rt_chunks = rt_chunks.map(
    #map function expects dictionary or dataset object, tokenize function returns list of tokens (integers)
    lambda batch: {"input_ids": [tokenizer.encode(x) for x in batch["chunks"]]}, 
    batched=True, 
    remove_columns = "chunks",
    #num_proc=os.cpu_count()//2 if cfg.encoding != 'character' else 1 
    desc="tokenizing the dataset from chunks")
rt_chunks.save_to_disk('/home/mabot004/custom_hf_datasets/')
"test if tokenizing is correct"
tokenizer.decode(rt_chunks["train"]["input_ids"][0])

In [None]:

#https://huggingface.co/docs/datasets/create_dataset#from-local-files
shakespeare = datasets.Dataset.from_generator(gen_text)
shakespeare = shakespeare.map(chunk_examples, batched=True, remove_columns = "text")
shakespeare = shakespeare.map(
    #map function expects dictionary or dataset object, tokenize function returns list of tokens (integers)
    lambda batch: {"input_ids": [tokenizer.encode(x) for x in batch["chunks"]]}, 
    batched=True, 
    remove_columns = "chunks",
    #num_proc=os.cpu_count()//2 if cfg.encoding != 'character' else 1 
    desc="tokenizing the dataset from chunks")

In [None]:
tokenizer.decode(np.concatenate(shakespeare[:3]["input_ids"]))


"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us"

In [None]:
def write_memmap(memmap, start, end, data):
    memmap[start:end] = data

### test generating huggingface datasets from files

In [None]:
def gen_text():
    for i in range(163):
        yield {"text": i}

test_threading = datasets.Dataset.from_generator(gen_text)

Generating train split: 163 examples [00:00, 31912.97 examples/s]


In [None]:
test_threading.rename_column("text", "chunks")

Dataset({
    features: ['chunks'],
    num_rows: 163
})

In [None]:
test_threading.shard(num_shards=30, index=17)

Dataset({
    features: ['text'],
    num_rows: 5
})

In [None]:
import threading
num_threads = 3 #os.cpu_count // 2
batch_size = 30
num_samples = len(test_threading)
# 163 // 30 shards
# -> 3 threads, each having a batch size of 30 samples
# dataset has 163 samples -> each thread should have around 50-60 samples max
# -> divide samples of dataset with num_threads
# -> each thread should have the entire dataset as an arg, but split differently
# range of splitting should be specified as an arg in thread -> index arg in parameter

In [None]:
#why should you use multithreading? the writing process is I/O based
#if anything, the amount of write requests increases with the amount of threads
memmap = np.memmap('test', mode='w+', shape=(163,), dtype=np.int64)

In [None]:
#playing around with error messages
try:
    int("abcd")
except Exception as e:
    print(str(e))

invalid literal for int() with base 10: 'abcd'


In [None]:
"""
test memory usage in worst case scenarios
"""

#no high memory usage as memmap values are lazily loaded, only overhead is the pages (around 5MB per memmap )
memmap = np.memmap('/home/mabot004/.qtransform/datasets/huggingface/openwebtext/tokenized/gpt2/openwebtext-float32.bin', dtype=np.float32, mode='r')
memmap2 = np.memmap('/home/mabot004/.qtransform/datasets/huggingface/openwebtext/tokenized/gpt2/openwebtext-float32.bin', dtype=np.float32, mode='r')
memmap3 = np.memmap('/home/mabot004/.qtransform/datasets/huggingface/openwebtext/tokenized/gpt2/openwebtext-float32.bin', dtype=np.float32, mode='r')
memmap4 = np.memmap('/home/mabot004/.qtransform/datasets/huggingface/openwebtext/tokenized/gpt2/openwebtext-float32.bin', dtype=np.float32, mode='r')
import psutil
# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

In [None]:
from qtransform.dataset import MemmapDataset
#token_file: str, dtype: np.dtype, block_size: int, start: float=0.0, end: float = 1.0
memmap_ds = MemmapDataset(
    token_file='/home/mabot004/.qtransform/datasets/huggingface/openwebtext/tokenized/gpt2/openwebtext-float32.bin',
    dtype=np.float32,
    block_size=64,
    start=0.0,
    end=0.3
)
len(memmap_ds)

2709600997

### test torch Dataloader

In [None]:
dataloader = DataLoader(memmap_ds, batch_size=12, num_workers=8)
next(iter(dataloader))

In [None]:
for i, data in enumerate(dataloader):
    if i == 10:
        break
    input, labels = data
    print(f'{input.size()}, {labels.size()}')

torch.Size([12, 64]), torch.Size([12, 64])
torch.Size([12, 64]), torch.Size([12, 64])
torch.Size([12, 64]), torch.Size([12, 64])
torch.Size([12, 64]), torch.Size([12, 64])
torch.Size([12, 64]), torch.Size([12, 64])
torch.Size([12, 64]), torch.Size([12, 64])
torch.Size([12, 64]), torch.Size([12, 64])
torch.Size([12, 64]), torch.Size([12, 64])
torch.Size([12, 64]), torch.Size([12, 64])
torch.Size([12, 64]), torch.Size([12, 64])


## Testing quantization

In [5]:
#testing batchnorm quant
#https://github.com/Xilinx/brevitas/issues/542
#https://github.com/Xilinx/brevitas/issues/363
#test merge_bn from qtransform
from qtransform.model.modules import CausalSelfAttention
from qtransform.model.modules import BatchNorm as BatchNormWithPadding, MLP
from qtransform.model.gpt import GPTConfig
import brevitas.nn as qnn
from brevitas.nn import utils as qutils
import torch
import torch.nn as nn
from brevitas.quant import scaled_int
#simulate values from embedding, skip positional encoding
wte = torch.nn.Embedding(16,64)
tokens = torch.randint(16, (3,16))
embeddings = wte(tokens)
embeddings.size()

torch.Size([3, 16, 64])

In [None]:
#test if quantized layers having return_quant_tensor set to True are compatible with torch operations 
quant_tensor_linear = qnn.QuantLinear(1,1,True,return_quant_tensor=True)
quant_tensor_linear(torch.Tensor(8,1)) #works

In [None]:
#debug loading quantized checkpoint
CHECKPOINT = '/home/mabot004/eki-transformer-dev/qtransform/outputs/models/GPT_2024-01-17_08:30:49__epoch:1'
#doesnt work since qtransform.dataset cannot be found
#but module info about tokenizers is not saved in checkpoint, only their names
checkpoint = torch.load(CHECKPOINT)
checkpoint.keys()

dict_keys(['model_state_dict', 'optimizer_state_dict', 'epoch', 'model_cfg', 'tokenizer_cfg', 'metrics'])

In [None]:
#check if info about quant params are even saved within checkpoint
import re
keys = checkpoint["model_state_dict"].keys()
#quant param that exists within checkpoint: 
#transformer.layer.0.mlp.active.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value 
weights_and_biases = list(filter(lambda x: re.search(r'.+\.(weight|bias)$', x), keys))
def find(x):
    if not re.search(r'.+\.(weight|bias)$', x):
        return x
other_keys = list(filter(find, keys))
len(keys) == len(weights_and_biases) # not only weights and biases in state dict
#only scaling_impl is saved in state dict
#no multiheadattention though
#in gpt quant config, every single layer has a quantizer (most commonly Int8WeightPerTensorFloat)
#that quantizer has ScalingImplType STATS
#the layers with scaling_impl had an activation quantizer named Int8ActPerTensorFloat
#it had the ScalingImplType PARAMETER_FROM_STATS
other_keys

['transformer.layer.0.attn.attn_mask',
 'transformer.layer.0.mlp.active.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value',
 'transformer.layer.0.mlp.active.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value',
 'transformer.layer.1.attn.attn_mask',
 'transformer.layer.1.mlp.active.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value',
 'transformer.layer.1.mlp.active.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value']

In [None]:
#check if qparam is not one 
checkpoint["model_state_dict"]["transformer.layer.0.mlp.active.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value"]

tensor(2.6414)

In [53]:
#test if scaling_impl params exist within model
test_mha = qnn.QuantMultiheadAttention(num_heads=2, embed_dim=256)
#simulate some learning steps for param
print(test_mha.v_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value)
test_mha.v_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value = torch.nn.Parameter(torch.tensor(3.1415))
test_mha.v_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value

Parameter containing:
tensor(1., requires_grad=True)


Parameter containing:
tensor(3.1415, requires_grad=True)

In [None]:
torch.save(test_mha.state_dict(), 'mha.chpt')
#v_quant etc. not appearing within state_dict
test_mha.state_dict().keys()

In [None]:
#test if brevitas layers relevant for Transformers return qparams in state_dict
print(qnn.QuantLinear(1,1,True,input_quant=scaled_int.Int8ActPerTensorFloat).state_dict())
print(qnn.QuantIdentity(act_quant=scaled_int.Int8ActPerTensorFloat).state_dict())
print(qnn.QuantReLU(act_quant=scaled_int.Int8ActPerTensorFloat).state_dict())

OrderedDict([('weight', tensor([[0.9874]])), ('bias', tensor([-0.8623]))])

In [None]:
re.search(r'(?!hallo|welt).*$', "hallo")

<re.Match object; span=(1, 5), match='allo'>

In [None]:
#check if storing checkpoints of quantized models even is working
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.network = torch.nn.ModuleDict(dict(
            wte = qnn.QuantEmbedding(32, 128),
            pos = qnn.QuantEmbedding(16, 128),
            logic = nn.ModuleDict(dict(
                layer1 = qnn.QuantLinear(128, 16, True),
                layer2 = qnn.QuantLinear(16,1, True))
            )
        ))
    def forward(self, x):
        embd = self.network.wte(x)
        b,t = x.size()
        pos = torch.arange(0, t, dtype=torch.long).unsqueeze(0) # shape (1, t)
        pos = self.network.pos(pos)
        output = embd + pos
        for name, layer in self.network.logic.items():
            output = layer(output)
        return output

In [None]:
model = Model()
model(torch.randint(32, (1,16)))

  return super().rename(names)


tensor([[[-0.3944],
         [ 0.2287],
         [-0.5937],
         [-0.8445],
         [ 0.4049],
         [-0.1961],
         [ 0.2558],
         [ 0.5325],
         [-0.2270],
         [ 0.0485],
         [-0.5637],
         [ 0.1862],
         [ 0.7595],
         [-0.2511],
         [ 0.1841],
         [-0.3207]]], grad_fn=<ViewBackward0>)

In [None]:
#doesnt work, Quantizer cannot be found in brevitas.inject
#why are they being searched for in inject if they are in brevitas.quant.scaled_int
torch.save(model, 'quantized_test') 

PicklingError: Can't pickle <class 'brevitas.inject.Int8WeightPerTensorFloat'>: attribute lookup Int8WeightPerTensorFloat on brevitas.inject failed

In [56]:
from qtransform import DeviceSingleton
#check if value from class is set in object
DeviceSingleton.device = 'cuda'
singleton = DeviceSingleton()
singleton.device

'cuda'

### Testing Batchnorm and Conv merging

In [None]:
#from: 
def fuse_conv_and_bn(conv, bn):
	#
	# init
	fusedconv = torch.nn.Conv1d(
		conv.in_channels,
		conv.out_channels,
		kernel_size=conv.kernel_size,
		stride=conv.stride,
		padding=conv.padding,
		bias=True
	)
	#
	# prepare filters
	w_conv = conv.weight.clone().view(conv.out_channels, -1)
	w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
	fusedconv.weight.copy_( torch.mm(w_bn, w_conv).view(fusedconv.weight.size()) )
	#
	# prepare spatial bias
	if conv.bias is not None:
		b_conv = conv.bias
	else:
		b_conv = torch.zeros( conv.weight.size(0) )
	b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
	fusedconv.bias.copy_( torch.matmul(w_bn, b_conv) + b_bn )
	#
	# we're done
	return fusedconv


torch.set_grad_enabled(False)
batch_size = (16, 64, 256)
x = torch.randn(16, 64, 256)

net = torch.nn.Sequential(
    torch.nn.Conv1d(64, 64, kernel_size=(256,256)),
    torch.nn.BatchNorm1d(64)
)
y1 = net.forward(x)
fusedconv = fuse_conv_and_bn(net[0], net[1])
y2 = fusedconv.forward(x)
d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)

error: 0.49767026


In [None]:
cv1 = qnn.QuantLinear(5,5,bias=True)
cv1_copy = qnn.QuantLinear(5,5,bias=True)
cv1_copy.load_state_dict(cv1.state_dict())
bn1 = torch.nn.BatchNorm1d(5)
qnn.utils.merge_bn(cv1, bn1)
input = torch.Tensor(2,5)

In [None]:
cv1 is cv1_copy

False

In [None]:
cv1(input)

tensor([[ 3.0883e-01, -9.2026e-03,  3.9793e-01,  3.7391e-01,  4.2723e-01],
        [-3.9785e+20,  2.8513e+20, -5.6362e+19,  2.4866e+20,  3.8127e+20]],
       grad_fn=<AddmmBackward0>)

In [None]:
#output is the same without batchnorm, why?
output = cv1_copy(input)
output

tensor([[ 3.0883e-01, -9.2026e-03,  3.9794e-01,  3.7391e-01,  4.2723e-01],
        [-3.9785e+20,  2.8513e+20, -5.6362e+19,  2.4866e+20,  3.8127e+20]],
       grad_fn=<AddmmBackward0>)

In [None]:
output = bn1(input)
cv1_copy(output)

tensor([[ 0.3088, -0.0092,  0.3979,  0.3739,  0.4272],
        [ 0.3088, -0.0092,  0.3979,  0.3739,  0.4272]],
       grad_fn=<AddmmBackward0>)

In [None]:
tensor = torch.randint(30, (3,5,20)).to(dtype=torch.float32) / 10

In [None]:
torch.torch.nn.BatchNorm1d(5)(qnn.QuantConv1d(5,5,kernel_size=3)(tensor))

tensor([[[-1.5470e+00,  1.5270e-01, -2.5464e-01,  2.9082e-01,  1.2001e+00,
           6.7265e-01,  7.3291e-01, -1.8250e-01, -1.2105e+00, -1.0728e+00,
           3.1590e-01,  2.5334e+00, -1.3411e-01,  2.5901e+00,  6.9120e-02,
          -5.1725e-01,  1.9493e-01,  2.2703e+00],
         [-3.4492e-01, -9.6099e-01, -9.2788e-01, -5.6099e-01, -2.0823e+00,
           8.8297e-01,  4.6034e-01,  9.3609e-01,  1.8312e+00, -8.3214e-01,
          -1.0253e+00, -1.3361e+00, -1.3721e+00,  4.9575e-01, -6.1378e-01,
           3.7313e-01, -1.6607e+00, -7.7247e-01],
         [ 1.8919e+00,  8.4047e-01,  5.7258e-01,  3.1601e-01,  1.4367e-01,
          -4.8590e-01,  7.8809e-01,  5.6231e-01, -9.9302e-01, -5.6623e-01,
          -2.1978e-01, -8.2209e-01, -2.8324e-02,  9.5371e-01,  9.6952e-02,
          -6.7169e-01, -8.7423e-02, -4.8495e-02],
         [-2.3595e+00, -4.5535e-01,  1.1663e+00,  1.6639e+00,  5.8315e-01,
          -4.0086e-01, -4.8103e-01, -2.0247e+00,  4.6665e-01,  2.1141e-01,
           2.7884e-02,  1

In [None]:
cv1(tensor)

tensor([[[ 1.5805,  1.6874,  1.4232,  0.4953,  0.3392,  1.6921,  0.4289,
           0.6056,  1.4143, -0.3349,  0.7054,  1.2688,  0.7755,  0.8200,
           0.8781,  1.1979,  0.5057,  0.3852],
         [-0.6557,  0.1111,  0.1168,  0.1569, -0.1044, -0.7017, -0.4960,
          -0.0382, -0.2186,  0.1632,  0.3991, -0.2319, -0.1412, -1.0671,
          -0.2227, -0.4349,  0.0598, -0.0862],
         [-1.8093, -1.3581, -1.4218, -1.9729, -2.4638, -1.5891, -2.7596,
          -1.0707, -0.9489, -1.2964, -0.4576, -2.2601, -2.3817, -2.9702,
          -2.1696, -1.4183, -1.6867, -1.5476],
         [ 0.0087,  0.9669,  0.2861, -0.0195,  1.7805, -0.1940,  1.5801,
           0.4090,  0.1231, -0.3474,  0.4677,  0.5724,  0.8208,  1.0850,
           1.3689,  0.1611,  0.5700, -0.0545],
         [-1.7327, -0.0812, -0.3896, -1.2272, -0.7673, -1.1279, -1.3454,
          -0.1006,  0.1854, -0.1570, -0.0988, -0.2189, -0.5590, -1.7284,
          -0.0418, -0.0416,  0.0663, -0.0783]],

        [[ 1.6639,  0.7003,  1.20

## Test if QuantMHA is learnable

### Debug QuantMultiheadAttention and merge_bn

In [4]:
def merge_bn_mha(layer, bn, output_channel_dim=0):
    
    #retrieve learnable parameters from batchnorm (scale + bias)
    mul_factor, add_factor = qutils.mul_add_from_bn(
        bn_mean=bn.running_mean,
        bn_var=bn.running_var,
        bn_eps=bn.eps,
        bn_weight=bn.weight.data.clone(),
        bn_bias=bn.bias.data.clone())
    #out_proj is QuantLinear(in_features=embd_dim, out_features=embd_dim)
    out_ch_weight_shape = qutils.compute_channel_view_shape(layer.weight, output_channel_dim)
    #apply batchnorm during after forward pass of layer, before returning result
    
    #!!
    layer.weight.data.mul_(mul_factor.view(out_ch_weight_shape))
    #!!
    
    if layer.out_proj.bias is not None:
        out_ch_bias_shape = qutils.compute_channel_view_shape(layer.out_proj.bias, channel_dim=0)
        layer.out_proj.bias.data.mul_(mul_factor.view(out_ch_bias_shape))
        layer.out_proj.bias.data.add_(add_factor.view(out_ch_bias_shape))
    else:
        layer.out_proj.bias = nn.Parameter(add_factor)
    if (hasattr(layer, 'out_proj_weight_quant') and
            isinstance(layer.out_proj_weight_quant, WeightQuantProxyFromInjector)):
        layer.out_proj_weight_quant.init_tensor_quant()
    if (hasattr(layer, 'out_proj_bias_quant') and isinstance(layer.out_proj_bias_quant, BiasQuantProxyFromInjector)):
        layer.out_proj_bias_quant.init_tensor_quant()

In [43]:
test_linear = torch.nn.Linear(embed_dim, embed_dim, bias=False)
test_linear.weight = torch.nn.parameter.Parameter(torch.ones((embed_dim, embed_dim)))

In [45]:
test_linear(embeddings)

tensor([[[ -0.8011,  -0.8011,  -0.8011,  ...,  -0.8011,  -0.8011,  -0.8011],
         [ 10.6637,  10.6637,  10.6637,  ...,  10.6637,  10.6637,  10.6637],
         [  4.0272,   4.0272,   4.0272,  ...,   4.0272,   4.0272,   4.0272],
         ...,
         [  4.0272,   4.0272,   4.0272,  ...,   4.0272,   4.0272,   4.0272],
         [ -8.1817,  -8.1817,  -8.1817,  ...,  -8.1817,  -8.1817,  -8.1817],
         [  1.8297,   1.8297,   1.8297,  ...,   1.8297,   1.8297,   1.8297]],

        [[-11.1670, -11.1670, -11.1670,  ..., -11.1670, -11.1670, -11.1670],
         [ -3.0146,  -3.0146,  -3.0146,  ...,  -3.0146,  -3.0146,  -3.0146],
         [  2.3900,   2.3900,   2.3900,  ...,   2.3900,   2.3900,   2.3900],
         ...,
         [  9.5870,   9.5870,   9.5870,  ...,   9.5870,   9.5870,   9.5870],
         [ -6.9218,  -6.9218,  -6.9218,  ...,  -6.9218,  -6.9218,  -6.9218],
         [ -5.1709,  -5.1709,  -5.1709,  ...,  -5.1709,  -5.1709,  -5.1709]],

        [[ 10.6637,  10.6637,  10.6637,  ...

In [34]:
heads = 2
embed_dim = 64
context = 16
quant_mha = qnn.QuantMultiheadAttention(num_heads=heads, embed_dim=embed_dim)
#pass the same input tensor to merged and unmerged mha + batchnorm and compare results
quant_mha_merged = qnn.QuantMultiheadAttention(num_heads=heads, embed_dim=embed_dim)
from brevitas import config
#qparams not imported from state dict
config.IGNORE_MISSING_KEYS = True
quant_mha_merged.load_state_dict(quant_mha.state_dict())
#feature length not critical
bn = torch.nn.BatchNorm1d(context)
#test if quantmha works
assert quant_mha(embeddings, embeddings, embeddings)[0].size() == embeddings.size()

torch.Size([3, 16, 64])


In [14]:
quant_mha.out_proj

QuantLinear(
  in_features=64, out_features=64, bias=True
  (input_quant): ActQuantProxyFromInjector(
    (_zero_hw_sentinel): StatelessBuffer()
    (fused_activation_quant_proxy): FusedActivationQuantProxy(
      (activation_impl): Identity()
      (tensor_quant): RescalingIntQuant(
        (int_quant): IntQuant(
          (float_to_int_impl): RoundSte()
          (tensor_clamp_impl): TensorClamp()
          (delay_wrapper): DelayWrapper(
            (delay_impl): _NoDelay()
          )
        )
        (scaling_impl): ParameterFromRuntimeStatsScaling(
          (stats_input_view_shape_impl): OverTensorView()
          (stats): _Stats(
            (stats_impl): AbsPercentile()
          )
          (restrict_scaling): _RestrictValue(
            (restrict_value_impl): FloatRestrictValue()
          )
          (clamp_scaling): _ClampValue(
            (clamp_min_ste): ScalarClampMinSte()
          )
          (restrict_inplace_preprocess): Identity()
          (restrict_preprocess): 

In [4]:
#added a print statement after out_proj to see if the shape is changed afterwards
quant_mha(embeddings,embeddings,embeddings)

torch.Size([3, 16, 64])


(tensor([[[-0.0926,  0.1662,  0.0606,  ...,  0.0211,  0.0027,  0.6270],
          [-0.4505,  0.2013,  0.0885,  ..., -0.1597, -0.3558,  0.4523],
          [-0.3425, -0.0519,  0.0775,  ..., -0.1996, -0.2760, -0.2787],
          ...,
          [-0.0925,  0.2803, -0.1011,  ...,  0.2038, -0.0426,  0.4598],
          [ 0.1152,  0.0099, -0.0458,  ...,  0.0279, -0.2678, -0.3036],
          [ 0.0462,  0.4055, -0.4288,  ...,  1.0150,  0.1507, -0.0658]],
 
         [[-0.1684,  0.2740, -0.0567,  ...,  0.1945, -0.1586,  0.4784],
          [-0.2414,  0.2263,  0.0706,  ..., -0.2693, -0.4626,  0.4440],
          [-0.3649,  0.0976, -0.0294,  ..., -0.0971, -0.1957, -0.1036],
          ...,
          [-0.0461,  0.1246, -0.1975,  ...,  0.0336,  0.2553,  0.3853],
          [ 0.0460,  0.0541, -0.0979,  ..., -0.0303, -0.3825, -0.3351],
          [-0.0157,  0.2822, -0.4208,  ...,  0.5638,  0.3149,  0.0858]],
 
         [[-0.0181,  0.1725, -0.0047,  ...,  0.1487,  0.1976,  0.5633],
          [-0.1927,  0.1515,

In [6]:
#test if qparams are the same without copying them. they should have value 1 
MISSING_QPARAMS =  ["in_proj.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "out_proj.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "attn_output_weights_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "q_scaled_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "k_transposed_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value", "v_quant.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value"]
for missing_qparam in MISSING_QPARAMS:
    #omit ".value" to get submodule 
    submodule = missing_qparam[:-len(".value")]
    assert quant_mha.get_submodule(submodule).value == quant_mha_merged.get_submodule(submodule).value, f'qparam {missing_qparam} is different' 

In [110]:
import brevitas
test = brevitas.core.quant.IntQuant(True, True)
tensor = torch.Tensor(2,3)
print(tensor)
print(test(tensor))

tensor([[1.0486e+05, 4.5555e-41, 1.2967e-05],
        [3.0866e-41, 4.4842e-44, 0.0000e+00]])


TypeError: IntQuant.forward() missing 3 required positional arguments: 'zero_point', 'bit_width', and 'x'

## Test if creating a custom activation can simulate batchnorm
#### Alternative: Implement custom Quantizer which performs normalization based on a scale and add factor which can be passed in its constructor

In [None]:
from brevitas.quant_tensor import QuantTensor
from torch import Tensor
def QuantIdentityWithWeights(qnn.QuantIdentity):
    def __init__(self,
            act_quant: Optional[ActQuantType] = Int8ActPerTensorFloat,
            return_quant_tensor: bool = False,
            **kwargs):
        super().__init__(self,
            act_quant = act_quant
            return_quant_tensor = return_quant_tensor,
            **kwargs):
        
        
    def forward(self, input: Union[Tensor, QuantTensor]):
        return super().forward(input)

In [20]:
from typing import Optional, Union
from brevitas.quant_tensor import QuantTensor
from brevitas.nn.quant_layer import ActQuantType
from torch import Tensor
from brevitas.quant.scaled_int import *
from brevitas.nn.quant_layer import QuantNonLinearActLayer as QuantNLAL

class QuantCustom(QuantNLAL):

    def __init__(
            self,
            mul_factor: Tensor,
            add_factor: Tensor,
            act_quant: Optional[ActQuantType] = Uint8ActPerTensorFloat,
            input_quant: Optional[ActQuantType] = None,
            return_quant_tensor: bool = False,
            **kwargs):
        QuantNLAL.__init__(
            self,
            act_impl=None,
            passthrough_act=True,
            input_quant=input_quant,
            act_quant=act_quant,
            return_quant_tensor=return_quant_tensor,
            **kwargs)
        self.mul_factor = mul_factor
        self.add_factor = add_factor
    

    def forward(self, input: Union[Tensor, QuantTensor]):
        input = self.unpack_input(input)
        quant_input = self.input_quant(input)
        # shortcut execution through the export impl during export
        if self.export_mode:
            out = self.export_handler(quant_input.value)
            self._set_global_is_quant_layer(False)
            return out
        out = self.act_quant(quant_input)
        
        out = self.pack_output(out)
        return out

In [6]:
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from torch.nn import Module as TorchModule
from brevitas.nn.mixin import * #WeightQuantType, BiasQuantType
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from typing import Optional, Union
from brevitas.quant_tensor import QuantTensor
from torch import Tensor
import torch
#test if a quantized layer can be implemented which basically scales the values along a tensor and adds a bias, thereby simulating batch normalization
class QuantBatchnorm1d(QuantWBIOL, TorchModule):
    def __init__(
            self,
            num_features: int,
            weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
            bias_quant: Optional[BiasQuantType] = None,
            return_quant_tensor: bool = False,
            **kwargs) -> None:
        TorchModule.__init__(self)
        if not isinstance(num_features, int) or num_features <= 0:
            raise AttributeError()
        #do the same as quantidentity
        self.weight = torch.ones(num_features)
        self.bias = torch.zeros(num_features)
        QuantWBIOL.__init__(
            self,
            weight_quant=weight_quant,
            bias_quant=bias_quant,
            input_quant=None,
            output_quant=None,
            return_quant_tensor=return_quant_tensor,
            **kwargs)
    
    def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
        return self.forward_impl(input)
    
    def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
        #inner_forward_impl is apparently the actual forward pass of the layer 
        return x * quant_weight[:,None] + quant_bias[:,None]

#### Test functionality of custom Quantized layer

In [7]:
#mul_factor = torch.randint(15, (context,)).to(dtype=torch.float) / 6.3
#add_factor = torch.randint(15, (context,)).to(dtype=torch.float) / 6.3
batch, context, embeds = 2,4,8
#should do nothing as weights are 1 and biases are 0
custom_quant_layer = QuantBatchnorm1d(num_features=context)
print(f'{custom_quant_layer.weight}, {custom_quant_layer.bias}\n')
input = torch.randint(30, (batch, context, embeds))
output = custom_quant_layer(input)
assert input.equal(output)

tensor([1., 1., 1., 1.]), tensor([0., 0., 0., 0.])



  return super().rename(names)


### compare results of custom Batchnorm with regular batchnorm

In [8]:
batch, sentence_length, embedding_dim = 20, 5, 10
embeddings = torch.randn(batch, sentence_length, embedding_dim)
custom_quant_bn = QuantBatchnorm1d(sentence_length)
regular_bn = torch.nn.BatchNorm1d(sentence_length)
#simulate some forward passes to change the mean and standard deviation
for _ in range(10):
    regular_bn(embeddings)
#test if merge_bn works
mul_factor, add_factor = qutils.mul_add_from_bn(
    bn_mean=regular_bn.running_mean,
    bn_var=regular_bn.running_var,
    bn_eps=regular_bn.eps,
    bn_weight=regular_bn.weight.data.clone(),
    bn_bias=regular_bn.bias.data.clone())
assert custom_quant_bn.weight.size() == mul_factor.size()
assert custom_quant_bn.bias.size() == add_factor.size()
#change weight and bias of quantized batchnorm
custom_quant_bn.weight = mul_factor
custom_quant_bn.bias = add_factor
#test if results are at least somewhat similiar
out_regular_bn = regular_bn(embeddings)
out_quant_bn = custom_quant_bn(embeddings)
#loss is around 0.05 for each normalized embedding
print((out_regular_bn - out_quant_bn) < 0.05)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True, False, False,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False,  True, False,  True,  True, False,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True, 

In [150]:
regular_bn.num_features

5

In [10]:
from qtransform.quantization.quant_bn import *
qtransform_quant_bn = merge_quant_bn(regular_bn)
out_qtransform_quant_bn = qtransform_quant_bn(embeddings)
print((out_regular_bn - out_qtransform_quant_bn) < 0.05)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True, 

In [149]:
from brevitas.nn.utils import merge_bn
custom_merged_quant_bn = QuantBatchnorm1d(sentence_length)
merge_bn(layer=custom_merged_quant_bn, bn = regular_bn)
out_regular_bn = regular_bn(embeddings)
out_quant_bn = custom_quant_bn(embeddings)
#loss is around 0.05 for each normalized embedding
print((out_regular_bn - out_quant_bn) < 0.05)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True, False,  True,  True, False,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True, False,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True, 

In [120]:
#copy pasted from brevitas.nn.utils
def compute_channel_view_shape(tensor: torch.Tensor, channel_dim: int):
    #create a list containing ones with length of the tensor dimension (for mha: always length of 3)
    broadcast_shape = [1] * len(tensor.size())
    #why is that important
    broadcast_shape[channel_dim] = -1
    return tuple(broadcast_shape)

def mul_add_from_bn(bn_mean, bn_var, bn_eps, bn_weight, bn_bias):
    denom = torch.sqrt(bn_var + bn_eps)
    mul_factor = bn_weight / denom
    add_factor = -bn_mean * mul_factor + bn_bias
    return mul_factor, add_factor

In [29]:
list(layernorm.parameters())

[Parameter containing:
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        requires_grad=True)]

In [30]:
tensor = embeddings.rename(None)
for i in range(10):
    layernorm(tensor)
    bn(tensor)

In [24]:
# test if mul_add_from_bn works for layernorm
layernorm = torch.nn.LayerNorm(embed_dim, bias=True)
ln_mul, ln_add = qutils.mul_add_from_bn(
    bn_mean=layernorm.running_mean,
    bn_var=layernorm.running_var,
    bn_eps=layernorm.eps,
    bn_weight=layernorm.weight.data.clone(),
    bn_bias=layernorm.bias.data.clone())

AttributeError: 'LayerNorm' object has no attribute 'running_mean'

In [60]:
#basically a list of ones with a -1 at index channel_dim
compute_channel_view_shape(torch.Tensor(3,3,3,3,3), 4)

(1, 1, 1, 1, -1)

In [55]:
len(quant_mha_merged.out_proj.weight.size())

2

In [59]:
#disassemble functionality of merge_bn_mha
#one dimensional tensor to scale all values of a batch with a corresponding scalar
mul_factor, add_factor = qutils.mul_add_from_bn(
    bn_mean=bn.running_mean,
    bn_var=bn.running_var,
    bn_eps=bn.eps,
    bn_weight=bn.weight.data.clone(),
    bn_bias=bn.bias.data.clone())
assert mul_factor.size() == add_factor.size()
assert mul_factor.size()[0] == add_factor.size()[0] == context
output_channel_dim = 0
#out_proj is QuantLinear -> 2d Tensor, [-1, 1]
#meaning: reverse shape of mul_factor tensor
#currently: [1,context] now: [context, 1]
out_ch_weight_shape = qutils.compute_channel_view_shape(quant_mha_merged.out_proj.weight, output_channel_dim)
#out_proj is a linear layer applying a scaling factor to quantize outputs -> inputs: n_embd, outputs: n_embd
#batchnorm applied normalization along second dimension, linear layer along third
#-> could work if batchnorm normalizes along embeddings
assert mul_factor.view(out_ch_weight_shape).size() != quant_mha_merged.out_proj.weight.data.size()
print(f'shape: {out_ch_weight_shape}, mul_factor view: {mul_factor.view(out_ch_weight_shape).size()}')
print(f'quant_mha_merged.out_proj.weight.data shape: {quant_mha_merged.out_proj.weight.data.size()}')
#quant_mha_merged.out_proj.weight.data.mul_(mul_factor.view(out_ch_weight_shape))
#merge params of bn in quant_mha_merged
#merge_bn_mha(quant_mha_merged, bn)

shape: (-1, 1), mul_factor view: torch.Size([16, 1])
quant_mha_merged.out_proj.weight.data shape: torch.Size([64, 64])


In [62]:
#each value of mul_factor is the scale with which the embedding of one word is multiplied with
#in total: context amount of words
#linear layer calculates the sum of each embedding of a word with a weight
#problem: weights are the same for every word
print(f'mul_factor: {mul_factor.size()}, add_factor: {add_factor.size()}')

mul_factor: torch.Size([16]), add_factor: torch.Size([16])


In [95]:
qutils.merge_bn(qnn.QuantLinear(context, context, True), bn)

In [72]:
from IPython.core.display import HTML
HTML(
    """<body>
    <p>Rows: Context, Columns: Embeddings</p>
    <table>
        <tr>
            <th>Row</th>
            <th>Embedding 1</th>
            <th>Embedding 2</th>
            <th>Embedding 3</th>
            <th>Embedding 4</th>
        </tr>
        <tr>
            <td>1</td>
            <td>-0.6182</td>
            <td>0.6397</td>
            <td>-0.6141</td>
            <td>0.8668</td>
        </tr>
        <tr>
            <td>2</td>
            <td>0.4140</td>
            <td>0.1806</td>
            <td>-1.1200</td>
            <td>-0.3160</td>
        </tr>
    </table>
</body>"""
)

Row,Embedding 1,Embedding 2,Embedding 3,Embedding 4
1,-0.6182,0.6397,-0.6141,0.8668
2,0.414,0.1806,-1.12,-0.316


each row of a prompt needs to be multiplied with the same scalar from mul_factor, the position of the row determines the index of mul_factor\
basically: row 1 * mul_factor[0], row 2 * mul_factor[2] etc.\
remember the bit_width value from quant config

In [78]:
#perform normalization like so:
print(f'{embeddings[0][0]}, \n{mul_factor[0] * embeddings[0][0] + add_factor[0]}')

tensor([-0.6182,  0.6397, -0.6141,  0.1633,  0.1303,  0.9678, -0.0135, -0.3592,
         1.0851,  1.2883, -1.4943, -1.3554, -1.2857,  0.5534, -0.9053, -0.2538,
         2.0112,  1.5106, -0.5143,  0.0181, -1.1853, -0.1291,  1.1889, -0.2304,
        -0.1677, -1.0456, -0.1630,  0.8798,  0.4793,  1.3267,  0.9272,  0.4181,
        -1.6796, -0.2393,  0.7780, -1.7058, -1.1486, -1.5907,  0.9055, -0.6892,
        -0.3182,  1.7268,  1.3576, -0.0698,  0.5315, -0.9513,  0.0850, -0.0770,
         0.6108, -0.9660, -0.2021, -0.8171, -0.4134, -0.9940, -1.1997,  1.1844,
         2.1950, -0.9969, -0.6415,  2.3817, -0.1636,  0.8668, -1.2963,  0.2591],
       grad_fn=<SelectBackward0>, names=('E',)), 
tensor([-0.6476,  0.6723, -0.6434,  0.1723,  0.1378,  1.0165, -0.0132, -0.3759,
         1.1397,  1.3528, -1.5670, -1.4212, -1.3480,  0.5817, -0.9490, -0.2653,
         2.1113,  1.5861, -0.5386,  0.0200, -1.2427, -0.1344,  1.2486, -0.2408,
        -0.1750, -1.0961, -0.1700,  0.9242,  0.5039,  1.3931,  0.9740

In [44]:
q,k,v = [embeddings for _ in range(3)]
output_no_merge = quant_mha(q,k,v)[0]
output_no_merge = bn(output_no_merge)
output_merge = quant_mha_merged(q,k,v)[0]
assert output_no_merge.equal(output_merge)

torch.Size([3, 16, 64])
torch.Size([3, 16, 64])


AssertionError: 

### Test what happens if values are transposed after merging QuantLinear with BatchNorm
quant_linear = qnn.QuantLinear()

### test if transposing inputs changes values for batchnorm with two different feature lengths

In [25]:
bn_context = torch.nn.BatchNorm1d(context)
bn_embedding = torch.nn.BatchNorm1d(embed_dim)
input = embeddings
input = input.rename(None)
out_context = bn_context(input)
out_embedding = bn_embedding(input.transpose(-1,-2))
#after transposing, normalize along embeddings instead of along words
assert out_context.equal(out_embedding), "transposing does not work"

AssertionError: transposing does not work

### test if linear transformation can simulate the functionality of batchnorm

In [33]:
test = torch.randint(30, (3,10)).to(dtype=torch.float) * 0.7
linear = torch.nn.Linear(3, 3)
linear(test.transpose(-1,0)).transpose(-1,0)

tensor([[ 5.9493,  2.7542,  9.5925,  7.8767,  1.7580,  3.1383,  6.2312,  5.3196,
          7.8936,  1.0332],
        [12.1436, 11.1988, 15.5672, 13.1937,  5.2797,  7.7320, 12.6711,  8.1652,
         13.6740,  6.0436],
        [ 4.6590,  9.2941,  0.4616, -0.8165,  1.4461,  7.0240, -2.8188,  0.2568,
          0.8149,  6.8677]], grad_fn=<TransposeBackward0>)

In [65]:
#before attention calculation, q,k,v are quantized
#shape: n_embd * 3 as q,k,v are the same and of shape n_embd
quant_mha_merged.in_proj.weight.size()

torch.Size([192, 64])

In [None]:
#code copied from forward pass of QuantMultiheadAttention
#function is called if in_proj quantization has been set
#TODO: find out what it does
def chunk(x, num=3, dim=-1):
    _len, _bsz, _dim = x.shape
    x = x.reshape(_len, _bsz, num, dim)
    return x[:, :, 0, :], x[:, :, 1, :], x[:, :, 2, :]
assert attn_no_batchfirst.in_proj is not None
from brevitas.nn.utils import check_tensors_same_ptr
#no idea what it does, it has to be True or else an Exception will be thrown
assert check_tensors_same_ptr([embeddings, embeddings, embeddings]) == True
torch._C._get_tracing_state()

query = embeddings
query.rename_('L', 'N', 'E')
#no idea why q,k,v are infered from the query and params key and value are still used
#this is an issue if no in_proj is specified i think
q,k,v = chunk(attn_no_batchfirst.in_proj(query))
print(f'{q.size()}, {k.size()}, {v.size()}')
#issue with wrong shapes could be that batch size is transposed instead of embedding dimension
#q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)

torch.Size([3, 16, 64]), torch.Size([3, 16, 64]), torch.Size([3, 16, 64])


In [None]:
tensor = torch.arange(9).reshape(3,3)
#columns become rows, rows become columns
print(f'{tensor}, \n{tensor.transpose(1,0)}')

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]), 
tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])


In [None]:
small_attn_cpy = CausalSelfAttention(GPTConfig(block_size=16, n_embd=64, n_head=2))
#if batchnorm and mha are merged together, padding should not be necessary for inference
small_attn_cpy.mha = qnn.QuantMultiheadAttention(num_heads=2, embed_dim=64)
from brevitas import config
config.IGNORE_MISSING_KEYS = True #copy state dict does not return brevitas qparams
small_attn_cpy.load_state_dict(small_attn.state_dict())
#qparams from state dict are set to 1 at first
print(small_attn.mha.in_proj.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value)
print(small_attn_cpy.mha.in_proj.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value)



Parameter containing:
tensor(1., requires_grad=True)
Parameter containing:
tensor(1., requires_grad=True)


In [None]:
"""
idea from: https://github.com/Xilinx/brevitas/issues/542#issuecomment-1446338490
merge_bn does not delete current batchnorm, meaning that one model has to be initialiized without bn and the parameters from the trained model
have to be copied to the model without bn
TODO: find more ressource efficient ways
"""
#at one step in merge_bn_mha, layer.out_proj.weight.data.mul_(mul_factor.view(out_ch_weight_shape)) is performed
#weight is of shape (embd_dim, embd_dim), mul_factor is of (shape features, 1)
#meaning that batchnorm probably normalizes along the embeddings instead of each sentence
"bn_alt feature length is 64 (embedding dimension)"


In [None]:
#merge_bn_mha appends batchnorm to mha, TODO: prepend it (maybe use input_quant_tensor or something)
#problem: merged and unmerged outputs are not the same, possibly since feature length is different
no_merge_attn_output = small_attn(embeddings)
no_merge_bn_output = bn(no_merge_attn_output)
try:
    merge_bn_mha(small_attn.mha, bn, output_channel_dim=0)
except Exception:
    merge_bn_mha(small_attn.mha, bn, output_channel_dim=1)
except Exception:
    merge_bn_mha(small_attn.mha, bn, output_channel_dim=2)
merge_attn_output = small_attn(embeddings)
assert torch.equal(no_merge_bn_output, merge_attn_output) == True

RuntimeError: The size of tensor a (64) must match the size of tensor b (16) at non-singleton dimension 1

In [None]:
help(layer.out_proj.weight.data.mul_)

Help on built-in function mul_:

mul_(...) method of torch.Tensor instance
    mul_(value) -> Tensor
    
    In-place version of :meth:`~Tensor.mul`.



In [None]:
#a.mul_(tensor) basically is a = a * tensor
a = torch.Tensor([1,2,3])
a.mul_(3)
a

tensor([3., 6., 9.])

In [None]:
small_attn(torch.Tensor(3,16,64)).size()

torch.Size([3, 16, 64])

In [None]:
m = torch.nn.Conv1d(16, 33, 3, stride=2)
input = torch.randn(20, 16, 50)
output = m(input)
output.size()

torch.Size([20, 33, 24])

In [None]:
#conv1d and batchnorm1d merge

qnn.quant_layer.merge_bn

<function brevitas.nn.utils.merge_bn(layer, bn, output_channel_dim=0)>

In [None]:
tensor = torch.rand((3,6,9))
tensor

tensor([[[0.9917, 0.4984, 0.6176, 0.5039, 0.8158, 0.8521, 0.0155, 0.1858,
          0.8048],
         [0.1621, 0.4298, 0.3947, 0.5427, 0.8238, 0.9419, 0.7478, 0.4333,
          0.0647],
         [0.0897, 0.2927, 0.9780, 0.6710, 0.0377, 0.8199, 0.1301, 0.8592,
          0.8216],
         [0.2074, 0.6790, 0.2042, 0.7838, 0.5414, 0.5088, 0.8481, 0.2490,
          0.1760],
         [0.0197, 0.6737, 0.1897, 0.2794, 0.4024, 0.3306, 0.8610, 0.8641,
          0.6871],
         [0.7651, 0.4413, 0.9831, 0.4328, 0.2344, 0.0799, 0.4901, 0.1151,
          0.9380]],

        [[0.4503, 0.5180, 0.3012, 0.7354, 0.2637, 0.9073, 0.9226, 0.7925,
          0.0674],
         [0.9067, 0.1654, 0.9186, 0.1072, 0.0438, 0.4049, 0.1374, 0.3990,
          0.6381],
         [0.3767, 0.8549, 0.5588, 0.2489, 0.2599, 0.6461, 0.5800, 0.1559,
          0.0832],
         [0.9381, 0.2192, 0.7259, 0.7615, 0.1411, 0.1472, 0.9268, 0.6733,
          0.9049],
         [0.1468, 0.8668, 0.3151, 0.5401, 0.4347, 0.5541, 0.0995, 0.

In [None]:
#normalized values along second dimension, meaning: along sentences
#are 
torch.nn.BatchNorm1d(6)(tensor)

tensor([[[ 1.5546, -0.3177,  0.1346, -0.2969,  0.8870,  1.0246, -2.1506,
          -1.5043,  0.8454],
         [-0.9784, -0.1143, -0.2279,  0.2499,  1.1573,  1.5384,  0.9118,
          -0.1030, -1.2930],
         [-1.5225, -0.7464,  1.8741,  0.7002, -1.7215,  1.2696, -1.3680,
           1.4198,  1.2762],
         [-1.0571,  0.6570, -1.0687,  1.0382,  0.1571,  0.0385,  1.2718,
          -0.9059, -1.1712],
         [-1.5085,  0.9969, -0.8575, -0.5138, -0.0426, -0.3174,  1.7148,
           1.7266,  1.0484],
         [ 1.0773, -0.1349,  1.8936, -0.1669, -0.9097, -1.4882,  0.0477,
          -1.3565,  1.7249]],

        [[-0.5003, -0.2434, -1.0663,  0.5820, -1.2088,  1.2342,  1.2923,
           0.7987, -1.9539],
         [ 1.4249, -0.9680,  1.4631, -1.1558, -1.3605, -0.1948, -1.0582,
          -0.2139,  0.5578],
         [-0.4253,  1.4034,  0.2712, -0.9140, -0.8717,  0.6049,  0.3522,
          -1.2693, -1.5475],
         [ 1.5987, -1.0141,  0.8276,  0.9569, -1.2980, -1.2758,  1.5576,
       

In [None]:
tensor[0][0].mean()

tensor(0.5873)

In [None]:
#tensor retains size, batchnorm essentially is a linear transformation to shift values to have a mean of 0 and a standard deviation of 1
torch.nn.BatchNorm1d(10)(torch.Tensor(3,10,16)).size()

torch.Size([3, 10, 16])

In [None]:
identity = qnn.QuantIdentity()
tensor = torch.Tensor(12,64,256)

In [None]:
tensor[0][0][0]

tensor(2.8026e-45)

In [None]:
output[0][0][0]

tensor(0.)

In [None]:
#test if quantidentity is a simple wrapper around a tensor that does nothing
#if so, it could be useful for merging with batchnorm
tensor = torch.Tensor(2,3,4)
print(tensor)
print("\n" + 30* "-" + "\n")
print(qnn.QuantIdentity()(tensor).isclose(tensor).all().item())

tensor([[[-1.1617e+35,  3.0907e-41, -1.5597e+37,  3.0907e-41],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  1.4013e-45,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  1.1351e-43,  0.0000e+00],
         [-1.5597e+37,  3.0907e-41, -3.0176e+34,  3.0907e-41],
         [ 0.0000e+00,  0.0000e+00,  1.4013e-45,  0.0000e+00]]])

------------------------------

False


In [None]:
output = identity(tensor)
output.size == tensor.size
output == tensor

tensor([[[False,  True, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]],

        [[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [

## Test custom batchnorm and merging process

In [9]:
from qtransform.quantization.quant_bn import *
import torch
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
normal_bn = torch.nn.BatchNorm1d(sentence_length)
unquantized_custom_bn = CustomBatchNorm1d(sentence_length)
#simulate some forward passes, without updating learnable scale and bias
for _ in range(100):
    normal_bn(embedding)
#before merging, do nothing
assert embedding.equal(unquantized_custom_bn(embedding))
unquantized_custom_bn = replace_bn(normal_bn, unquantized_custom_bn)
out_normal_bn = normal_bn(embedding)
out_unquantized_custom_bn = unquantized_custom_bn(embedding)
print(out_normal_bn[0])
print(f'{30*"-"} unquantized: {out_unquantized_custom_bn[0]}')
quantized_custom_bn = QuantBatchnorm1d(sentence_length)
assert embedding.equal(quantized_custom_bn(embedding))
#replace
quantized_custom_bn = replace_bn(normal_bn, quantized_custom_bn)
out_quantized_custom_bn = quantized_custom_bn(embedding)
print(f'{30*"-"} quantized: {out_quantized_custom_bn[0]}')

tensor([[-1.2659,  0.6091,  1.4836,  0.9439,  1.2065, -0.9223, -0.5282, -0.8455,
         -0.6554, -1.1303],
        [-1.3342, -0.6483, -1.0222,  1.5021, -0.2206,  0.6714, -0.7772,  1.1717,
          1.8188, -0.7320],
        [-0.2430, -2.0209,  0.5107, -1.3762,  0.5122, -0.6174, -0.2816, -0.9368,
          0.0110, -0.6685],
        [ 1.0439,  0.5637,  0.0682,  0.5662, -0.7288, -0.8823,  1.3416,  0.7040,
         -0.4979,  1.0843],
        [-0.8410, -0.4107,  1.9964, -0.3908, -1.1501, -1.4927, -1.0785, -0.1600,
         -0.2757, -2.0370]], grad_fn=<SelectBackward0>)
------------------------------ unquantized: tensor([[-1.2628,  0.6076,  1.4799,  0.9416,  1.2035, -0.9200, -0.5269, -0.8434,
         -0.6538, -1.1275],
        [-1.3309, -0.6467, -1.0197,  1.4984, -0.2201,  0.6698, -0.7753,  1.1687,
          1.8142, -0.7302],
        [-0.2424, -2.0158,  0.5095, -1.3728,  0.5109, -0.6158, -0.2809, -0.9344,
          0.0109, -0.6668],
        [ 1.0413,  0.5623,  0.0680,  0.5647, -0.7270, -0

In [18]:
from qtransform.model.gpt import GPT, GPTConfig
gpt = GPT(GPTConfig())



In [21]:
for mn, module in gpt.named_modules():
    print(module)

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50304, 768)
    (wpe): Embedding(1024, 768)
    (dropout): Dropout(p=0.0, inplace=False)
    (layer): ModuleList(
      (0-11): 12 x TransformerBlock(
        (ln_1): BatchNorm(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (attn): CausalSelfAttention(
          (mha): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): BatchNorm(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_features=3072, out_fea

In [1]:
import torch
from qtransform.quantization import quant_bn
#one word with 32 embeddings
input = torch.randn(1,32)
#context is max. of 8 words
weight, bias = torch.randn(2, 8)
output = quant_bn.custom_bn1d(input, weight, bias)
#input and output are the same, why?
print(input)
print(30*"-")
print(output)

  return torch._C._cuda_getDeviceCount() > 0
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


tensor([[ 0.4834, -0.1027, -2.2402,  0.4352,  0.2077,  1.0181,  0.6446, -0.6509,
          0.9757,  0.8973, -0.7394,  0.0084,  1.1972, -1.3001,  0.8528,  0.6930,
          1.2027, -0.3247,  0.2913, -0.3543, -1.0235,  1.0952,  0.9958, -0.4298,
          0.6690, -0.7751,  1.2495,  1.1023, -0.6781,  0.4847, -1.4741, -0.9532]])
------------------------------
tensor([[-0.4549, -0.9629, -2.8159, -0.4967, -0.6939,  0.0086, -0.3151, -1.4382,
         -0.0281, -0.0961, -1.5149, -0.8666,  0.1639, -2.0010, -0.1346, -0.2732,
          0.1687, -1.1554, -0.6214, -1.1810, -1.7612,  0.0755, -0.0107, -1.2465,
         -0.2940, -1.5459,  0.2093,  0.0817, -1.4618, -0.4537, -2.1518, -1.7002]])


### Debug custom_bn1d

In [27]:
def check_shapes(tensor: torch.Tensor) -> torch.Tensor:
    """
    Checks if a tensor is of shape [C], [N,C] or [C,N] with N = 1 and C >= 1.
    If tensor is of a different shape, a ValueError will be thrown.
    The returning tensor will be of shape [C, 1].
    """
    shape_tensor = tensor.size()
    if len(shape_tensor) == 1:
        tensor = tensor[:,None]
    if len(shape_tensor) == 2:
        if shape_tensor[0] > 1 and shape_tensor[1] > 1:
            raise ValueError(f'Too many values to unpack for tensor {shape_tensor}.')
        elif shape_tensor[0] == 1 and shape_tensor[1] > 1:
            tensor = tensor.transpose(0,1)
    elif len(shape_tensor) > 2:
        raise ValueError(f'Too many values to unpack for tensor {shape_tensor}.')
    return tensor


def custom_bn1d(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of custom BatchNorm implementation. It expects a Tensor x of size [N,C] or [N,C,L]
    and both a weight and bias Tensor, each of size [C, 1] or of size [1,C] / [C].
    Each row/ embedding of a sentence (dimension C) will be multiplied with one value from the index of the corresponding
    weight tensor and added with the value of the bias tensor.

    Output: tensor of shape [N,C] or [N,C,L], basically of the same size as the input tensor.
    """
    if not isinstance(x, torch.Tensor) :
        raise TypeError('Input is not a tensor')
    elif not isinstance(weight, torch.Tensor):
        raise TypeError('Weight is not a tensor')
    elif not isinstance(bias, torch.Tensor):
        raise TypeError('Bias is not a Tensor')
    #make sure that weights and biases are of shape [C,1]
    print(weight)
    print(10*"#")
    weight = check_shapes(weight)
    bias = check_shapes(bias)
    C_x = x.size()[0] if len(x.size()) == 2 else x.size()[1]
    print(f'ok: {weight[:C_x]}')
    out = x * weight[:C_x] + bias[:C_x]
    #only return the first C_x rows of output tensor
    return out[:,None:C_x] if len(x.size()) == 3 else out[:C_x]

In [32]:
weight = torch.randint(20, (16,))
bias = torch.randint(20, (16,))
input = torch.randint(20, (3,8))
print(weight)
print(bias)
print(input)
print(30*"-")
out = custom_bn1d(input,weight,bias)
print(out)
print(out.size())

tensor([ 6, 10, 10, 13, 17, 18,  9,  7,  9, 11, 19,  0,  7, 16, 12,  3])
tensor([ 0,  8,  9, 14, 13,  3, 14,  9, 11,  8, 10,  6, 17, 15,  9,  3])
tensor([[ 0, 16,  3,  2,  9, 19, 13,  7],
        [10,  4,  1, 18, 13,  7,  8,  5],
        [17, 19,  3,  5, 12, 19, 18, 11]])
------------------------------
tensor([ 6, 10, 10, 13, 17, 18,  9,  7,  9, 11, 19,  0,  7, 16, 12,  3])
##########
ok: tensor([[ 6],
        [10],
        [10]])
tensor([[  0,  96,  18,  12,  54, 114,  78,  42],
        [108,  48,  18, 188, 138,  78,  88,  58],
        [179, 199,  39,  59, 129, 199, 189, 119]])
torch.Size([3, 8])


In [34]:
tensor = torch.arange(27).reshape(3,3,3)
tensor

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]])

In [39]:
tensor = torch.arange(27).reshape(9,3)
tensor

tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]])

In [47]:
tensor.unsqueeze(1).size()

torch.Size([9, 1, 3])

In [38]:
tensor[:,None:2].size()

torch.Size([3, 2, 3])

In [51]:
print(input.size())
output = custom_bn1d(input, weight, bias)
custom_bn1d(torch.randn(3,1,32), weight, bias)
output.size()
assert output.equal(input * weight[0] + bias[0])

torch.Size([1, 32])


In [15]:
torch.cat((tensor,torch.Tensor(1,1,32)), dim=0)

tensor([[[6.4805e-10, 6.3011e-10, 6.6376e-07, 6.7212e-04, 1.7340e-07,
          1.6594e-07, 6.4097e-10, 1.4580e-19, 1.1495e+24, 3.0956e-18,
          5.8981e-10, 3.2506e+21, 1.0528e-11, 2.7625e-06, 6.4103e-10,
          2.1744e+23, 1.2794e+22, 2.1574e-04, 3.3980e+21, 3.0818e-18,
          3.1360e+27, 7.0800e+31, 3.1095e-18, 4.7851e+22, 2.8826e+32,
          4.4248e+30, 7.2442e+22, 2.3086e-12, 7.1760e+22, 7.2250e+28,
          1.5766e-19, 2.7447e-06]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          9.8091e-45, 0.0000e+00, 1.4013e-45, 0.0000e+00, 9.1844e-41,
          1.1551e-40, 4.5919e-41, 8.2957e-43, 2.9147e-43, 0.0000e+00,
          6.7262e-44, 0.0000e+00]]])

# Test MergeBatchNorm class from brevitas.graph.quantize.preprocess_for_quantize

In [11]:
from brevitas.graph.quantize import preprocess_for_quantize
from qtransform.model.gpt import GPT, GPTConfig
import torch

gpt = GPT(GPTConfig())



In [79]:
#initialization of normalization layer dependent on whether batchnorm or layernorm is used
#
torch.fx.symbolic_trace(gpt, concrete_args= {'idx': torch.Tensor(1,1)})



TraceError: symbolically traced variables cannot be used as inputs to control flow

In [2]:
#preprocess_for_quantize needs access to a graph representation of the model
from logging import getLogger
log = getLogger(__name__)
other_model = preprocess_for_quantize(gpt)
from brevitas.graph.fixed_point import MergeBatchNorm
#model needs a graph attribute from torch.fx.symbolic_trace
#the purpose of that probably is the same as in https://github.com/pytorch/examples/blob/main/fx/replace_op.py
"""
it seems that control flow depending on arguments leads to this error
https://discuss.tvm.apache.org/t/torch-fx-symbolic-trace-fails-for-most-encoder-decoder-nlp-models/16004
"""
try:
    MergeBatchNorm().apply(gpt)
except Exception as e:
    log.error(e)

AttributeError: 'Tracer' object has no attribute 'unpack_arg'

In [4]:
from transformers.utils.fx import symbolic_trace # is being used with transformers
symbolic_trace(model)

  from .autonotebook import tqdm as notebook_tqdm


AttributeError: 'GPT' object has no attribute 'dummy_inputs'

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MinstConv(nn.Module):
    def __init__(self, param = 10):
        super(MinstConv, self).__init__()
        #each model needs nn.module for quantization to work
        self.model = nn.ModuleDict(dict(
            conv1 = nn.Conv2d(1, 32, 3, 1),
            relu1 = nn.ReLU(),
            conv2 = nn.Conv2d(32, 64, 3, 1),
            relu2 = nn.ReLU(),
            maxpool2d = nn.MaxPool2d(kernel_size=2),
            dropout1 = nn.Dropout(0.25),
            flatten = nn.Flatten(),
            fc1 = nn.Linear(9216, 128),
            relu3 = nn.ReLU(),
            dropout2 = nn.Dropout(0.5),
            fc2 = nn.Linear(128, 10)
        ))
        #check symbolic traceability
        self.param = param

    def forward(self, x):
        #no exception
        assert self.param > 0
        #exception, meaning param checking during forward pass not possible
        assert x.size()[-1] > 0
        output = x
        for layer_name, layer in self.model.items():
            output = layer(output)
        output = F.log_softmax(output, dim=1)
        return output

In [33]:
torch.fx.symbolic_trace(MinstConv())

MinstConv(
  (model): Module(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (relu1): ReLU()
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (relu2): ReLU()
    (maxpool2d): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (dropout1): Dropout(p=0.25, inplace=False)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (fc1): Linear(in_features=9216, out_features=128, bias=True)
    (relu3): ReLU()
    (dropout2): Dropout(p=0.5, inplace=False)
    (fc2): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [3]:
from brevitas.fx.brevitas_tracer import symbolic_trace
symbolic_trace(model)

TraceError: symbolically traced variables cannot be used as inputs to control flow

In [19]:
import torch.fx as fx
"""
using values from param leads to errors
"""
try:
    fx.symbolic_trace(model)
except:
    pass
import torch
from torch.fx import symbolic_trace
def test(x):
    l = x.size(1)
    return torch.arange(l, dtype=torch.long, device='cuda')
traced = symbolic_trace(test)

TypeError: arange() received an invalid combination of arguments - got (Proxy, device=str, dtype=torch.dtype), but expected one of:
 * (Number end, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Number start, Number end, *, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Number start, Number end, Number step, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [16]:
import torch 
class SimpleModel(torch.nn.Module):
    def __init__(self, size: int = 10):
        super().__init__()
        #attributes can be param checked during symbolic tracing
        self.size = size
        self.mul = MultiplyModule()
    def forward(self, x: torch.Tensor):
        #param checking fails for symbolic tracing, unless modules are leaf modules which are not traced
        #assert isinstance(x,torch.Tensor), 'x is not a tensor'
        #assert x.size()[-1] < 10, f'Size of input tensor {x.size()} not compatible with size {self.size}'
        #custom modules for operations only necessary when param checking is performed
        return self.mul.forward(x)
    
class MultiplyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x * 2
model = torch.fx.symbolic_trace(SimpleModel())
from brevitas.graph.quantize import preprocess_for_quantize
model = preprocess_for_quantize(model, trace_model = False)
#forward pass works for model after preprocess_model_for_quantize, but not for gpt
assert model(10) == 20

In [23]:
#submodules disappear form module list after preprocess_for_quantize
print(list(model.modules()))
print(list(SimpleModel().modules()))

[SimpleModel()]
[SimpleModel(
  (mul): MultiplyModule()
), MultiplyModule()]


In [2]:
#from: https://github.com/pytorch/pytorch/issues/51803#issuecomment-1104634592
#experiments to make fx graphing work with transformers
import torch

from torch.fx import Tracer
from torch.fx import symbolic_trace
from torch.fx.graph_module import GraphModule


class CustomedTracer(Tracer):
    """
    ``Tracer`` is the class that implements the symbolic tracing functionality
    of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
    to ``Tracer().trace(m)``.
    This Tracer override the ``is_leaf_module`` function to make symbolic trace
    right in some cases.
    """
    def __init__(self, *args, customed_leaf_module=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.customed_leaf_module = customed_leaf_module

    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
        """
        A method to specify whether a given ``nn.Module`` is a "leaf" module.
        Leaf modules are the atomic units that appear in
        the IR, referenced by ``call_module`` calls. By default,
        Modules in the PyTorch standard library namespace (torch.nn)
        are leaf modules. All other modules are traced through and
        their constituent ops are recorded, unless specified otherwise
        via this parameter.
        Args:
            m (Module): The module being queried about
            module_qualified_name (str): The path to root of this module. For example,
                if you have a module hierarchy where submodule ``foo`` contains
                submodule ``bar``, which contains submodule ``baz``, that module will
                appear with the qualified name ``foo.bar.baz`` here.
        """
        if self.customed_leaf_module and isinstance(m, self.customed_leaf_module):
            return True
        return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)



class ArangeForFx(torch.nn.Module):
    def forward(self, x):
        return torch.arange(x)



class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.arange = ArangeForFx()

    def forward(self, x):
        l = x.size(1)
        return self.arange(l).to(dtype=torch.long, device=x.device)


model = Net()
#all ops which need a param from a symbolically traced variable (from the function parameter) should be a leaf module
#no idea why it fixes it though
tracer = CustomedTracer(customed_leaf_module=(ArangeForFx,))
graph = tracer.trace(model)
#graph = symbolic_trace(model)
name = model.__class__.__name__ if isinstance(model, torch.nn.Module) else model.__name__
traced = GraphModule(tracer.root, graph, name)

print(traced.code)




def forward(self, x):
    size = x.size(1)
    arange = self.arange(size);  size = None
    getattr_1 = x.device;  x = None
    to = arange.to(dtype = torch.int64, device = getattr_1);  arange = getattr_1 = None
    return to
    


In [5]:
import torch
from qtransform.model.modules import BatchNorm, LayerNorm
#https://pytorch.org/docs/stable/fx.html#leaf-modules
#leaf modules are not being traced through
#could be problematic as the output of batch/layer norm are being passed into the attention and mlp layer
tracer_gpt = CustomedTracer(customed_leaf_module=(BatchNorm,LayerNorm))
from qtransform.model.gpt import GPT as qGPT, GPTConfig
gpt = qGPT(GPTConfig())
tokens = torch.randint(50304, (2, 1024))
#assert statements should be nested inside of a custom module
graph_gpt = tracer_gpt.trace(gpt, {"idx": tokens})
name = gpt.__class__.__name__ if isinstance(gpt, torch.nn.Module) else gpt.__name__
traced_gpt = GraphModule(tracer_gpt.root, graph_gpt, name)

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


In [7]:
#preceding layer is merged with batchnorm layer
#batchnorm, conv, linear and their quantized versions can be merged
#if merging occurs before quantization, how can the learnable parameters from batchnorm be merged?
#
from brevitas.graph.quantize import preprocess_for_quantize
processed_gpt = preprocess_for_quantize(traced_gpt, trace_model = False)

In [99]:
GPTConfig()

GPTConfig(block_size=1024, vocab_size=50304, n_layer=12, n_head=12, n_embd=768, dropout=0.0, bias=True, flash=False, transformer_active_func='ReLU', norm_layer='BatchNorm', single_output=False)

In [10]:
out_no_symbolic_tracing = gpt(tokens)

In [12]:
print(processed_gpt.graph)

graph():
    %idx_1 : [num_users=0] = placeholder[target=idx_1]
    %targets : [num_users=1] = placeholder[target=targets](default=None)
    %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0]
    %transformer_wte : [num_users=1] = call_module[target=transformer.wte](args = (%_tensor_constant0,), kwargs = {})
    %_tensor_constant1 : [num_users=1] = get_attr[target=_tensor_constant1]
    %transformer_wpe : [num_users=1] = call_module[target=transformer.wpe](args = (%_tensor_constant1,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%transformer_wte, %transformer_wpe), kwargs = {})
    %transformer_dropout : [num_users=2] = call_module[target=transformer.dropout](args = (%add,), kwargs = {})
    %transformer_layer_0_ln_1 : [num_users=1] = call_module[target=transformer.layer.0.ln_1](args = (%transformer_dropout,), kwargs = {})
    %transformer_layer_0_attn_attn_mask : [num_users=1] = get_attr[target=transformer.layer.0.attn.attn_m

In [9]:
#normalization function receives a 2d tensor, during training a 3d tensor is forwarded. 
processed_gpt(tokens)

Traceback (most recent call last):
  File "/home/mabot004/eki-transformer-dev/qtransform/eki/lib/python3.10/site-packages/torch/fx/graph_module.py", line 274, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/mabot004/eki-transformer-dev/qtransform/eki/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mabot004/eki-transformer-dev/qtransform/eki/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.21 from /home/mabot004/eki-transformer-dev/qtransform/qtransform/model/gpt.py:126 in forward", line 159, in forward
    view_1 = targets.view(-1);  targets = None
AttributeError: 'NoneType' object has no attribute 'view'

Call using an FX-traced Module, line 159 of the traced Module's generated forward function:
    view = linear_out.view(-1, size)

AttributeError: 'NoneType' object has no attribute 'view'

In [77]:
#mini transformer test
class Layer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = torch.nn.BatchNorm1d(64)
        self.linear = torch.nn.Linear(256,256)
        self.mha = torch.nn.MultiheadAttention(256, 2, batch_first = True)
        self.softmax = torch.nn.Softmax()
        
    def forward(self, x):
        x = self.linear(x)
        x = self.bn(x)
        x, = self.mha(x,x,x, need_weights = False)
        return self.softmax(x)
#maybe the order of layers being called in the forward pass
nodes = list(torch.fx.symbolic_trace(Layer()).graph.nodes)
#nodes[1].args.users

In [78]:
from brevitas.graph.fixed_point import MergeBatchNorm
from brevitas.fx.brevitas_tracer import symbolic_trace as brevitas_symbolic_trace
model = brevitas_symbolic_trace(Layer())
MergeBatchNorm().apply(model)

RuntimeError: The size of tensor a (256) must match the size of tensor b (64) at non-singleton dimension 0

In [71]:
from brevitas.graph.utils import matches_module_pattern
def is_converged(graph_model: GraphModule):
        named_modules = dict(graph_model.named_modules())
        for node in graph_model.graph.nodes:
            for pattern in MergeBatchNorm.DEFAULT_PATTERNS:
                if matches_module_pattern(pattern, node, named_modules):
                    #potential error since node.args is a list containing tuples
                    if len(node.args[0].users) > 1:
                        continue
                    layer = named_modules[node.args[0].target]
                    bn = named_modules[node.target]
                    #!!!! check if batchnorm is merged
                    print(f'{layer}\n{bn}')
                    return -1
                    #merging happens here
                    merge_bn(layer, bn, get_output_channel_dim(layer))
                    
                    
                    node.replace_all_uses_with(node.args[0])
                    graph_model.graph.erase_node(node)
                    del_module(graph_model, node.target)
        graph_model.recompile()
        graph_model.graph.lint()
        return graph_model


In [76]:
#not merged as it does not fit the patterns
#BatchNorm can be merged into Conv Layers, BatchNorm layers and linear layers
out = is_converged(model)
assert out == -1

AssertionError: 

In [3]:
bn = torch.nn.BatchNorm1d(64)
for i in range(200):
    bn(torch.randn(12,64,256))
bn.eval()


BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [6]:
mean = bn.running_mean
bn(torch.randn(12,64,256))
assert bn.running_mean.equal(mean)