In [1]:
from u import *
from ut import *
from data import *
import quantized_model
from quantized_model import evaluate, get_net
quantized_model.distiller_vs_explicit = 'explicit' # switch off using distiller, which we use during quantization aware training

os.environ['CUDA_VISIBLE_DEVICES'] = '2'
num_bits = 9

In [3]:
# load in hyperparameters and paths, etc
c = Config(Wiki / 'quant_aware,1', device='cuda:0', distributed=False).load()

# load in data
data_val = SequentialIterator(c, c.eval_batch, split='valid')
data_test = SequentialIterator(c, c.eval_batch, split='test')

# create network
net = get_net(c)
net.load_state_dict(
    torch.load(c.res / 'models/model-1-processed.pth')
)
net = net.to(c.device)
net

Transformer(
  (embed): AdaptiveEmbedding(
    (layers): ModuleList(
      (0): Embedding(3500, 256)
      (1): Embedding(21500, 64)
      (2): Embedding(242735, 4)
    )
    (projections): ModuleList(
      (0): Linear(in_features=64, out_features=256, bias=False)
      (1): Linear(in_features=4, out_features=256, bias=False)
    )
  )
  (dropout1): Dropout(p=0)
  (layers): ModuleList(
    (0): Decoder(
      (ln1): LayerNorm(torch.Size([256]), eps=1e-05, elementwise_affine=True)
      (quant_ln1): ExplicitQuantize()
      (qkv): Linear(in_features=256, out_features=576, bias=True)
      (quant_qkv): ExplicitQuantize()
      (quant_attn): ExplicitQuantize()
      (quant_attnv): ExplicitQuantize()
      (out): Linear(in_features=192, out_features=256, bias=False)
      (dropout): Dropout(p=0)
      (ln2): LayerNorm(torch.Size([256]), eps=1e-05, elementwise_affine=True)
      (quant_ln2): ExplicitQuantize()
      (fc): Sequential(
        (0): Linear(in_features=256, out_features=768, b

In [66]:
# validation
net.cache_keys = net.cache_values = None
print('validation:', evaluate(c, data_val, net)) 

# test
net.cache_keys = net.cache_values = None
print('test:', evaluate(c, data_test, net))

validation: {'loss': 3.536329746246338, 'perplexity': 34.34064871540981, 'time': 2.0}
test: {'loss': 3.5540616512298584, 'perplexity': 34.95500458840018, 'time': 2.0}


In [4]:
nonzero_params = 0
total_params = 0

mask_size = 0
quantization_size = 0

for k, p in net.state_dict().items():
    if k in ['loss.layers.0.weight', 'loss.layers.1.weight', 'loss.layers.2.weight']: # shared with input embedding
        continue
    param_type = k.split('.')[-1]
    if param_type == 'max_abs':
        quantization_size += 1
    elif param_type == 'inv_scale':
        quantization_size += (32 - num_bits) / 32
    elif param_type in ['weight', 'bias', 'pos_emb']: # masked params
        if '.ln1.' in k or '.ln2.' in k: # ignore layernorm beta, gamma, can be fused into fc
            continue
        nz = from_torch((p != 0).sum())
        total = p.numel()
        nonzero_params += nz
        total_params += total
        mask_size += total / 32
        print(k, 'sparsity %.5g' % (1 - nz / total))
    elif param_type in ['cache_theta_inv_softplus', 'cache_lambda_inv_sigmoid']:
        nonzero_params += p.numel()
        total_params += p.numel()
    else:
        raise RuntimeError('Should not happen')
print()
print('nonzero params', nonzero_params)
print('total params', total_params)
print('total sparsity %.5g' % (1 - nonzero_params / total_params))

print()
param_size = nonzero_params * num_bits / 32
print('total param size', param_size)
print('total mask size', mask_size)
print('total quantization size', quantization_size)
total_size = param_size + mask_size + quantization_size
print('total size', total_size)

embed.layers.0.weight sparsity 0.1896
embed.layers.1.weight sparsity 0.34571
embed.layers.2.weight sparsity 0.40006
embed.projections.0.weight sparsity 0.52264
embed.projections.1.weight sparsity 0.00097656
layers.0.pos_emb sparsity 0
layers.0.qkv.weight sparsity 0.40005
layers.0.qkv.bias sparsity 0.050347
layers.0.out.weight sparsity 0.53571
layers.0.fc.0.weight sparsity 0.36091
layers.0.fc.0.bias sparsity 0
layers.0.fc.3.weight sparsity 0.29874
layers.0.fc.3.bias sparsity 0.035156
layers.1.pos_emb sparsity 0
layers.1.qkv.weight sparsity 0.50439
layers.1.qkv.bias sparsity 0.032986
layers.1.out.weight sparsity 0.724
layers.1.fc.0.weight sparsity 0.45222
layers.1.fc.0.bias sparsity 0
layers.1.fc.3.weight sparsity 0.4392
layers.1.fc.3.bias sparsity 0.03125
layers.2.pos_emb sparsity 0
layers.2.qkv.weight sparsity 0.41308
layers.2.qkv.bias sparsity 0.059028
layers.2.out.weight sparsity 0.64268
layers.2.fc.0.weight sparsity 0.47834
layers.2.fc.0.bias sparsity 0
layers.2.fc.3.weight sparsity

In [6]:
params = net.state_dict()
densities = {}

matmuls = 0
adds = 0
others = 0

# collect densities
for k, p in params.items():
    param_type = k.split('.')[-1]
    if param_type == 'max_abs': # TODO
        pass
    elif param_type == 'inv_scale':
        pass
    elif param_type in ['weight', 'bias']:
        nz = from_torch((p != 0).sum())
        total = p.numel()
        densities[k] = nz / total
# densities = defaultdict(lambda:1)

# input embedding
token_bin_counts = np.array([198232, 35479, 11858])
token_bin_fracs = token_bin_counts / token_bin_counts.sum()
for i, p in enumerate(token_bin_fracs):
    if i == 0:
        continue
    embed_weight = params['embed.layers.%s.weight' % i] 
    proj_weight = params['embed.projections.%s.weight' % (i - 1)]
    proj_density = densities['embed.projections.%s.weight' % (i - 1)]
    
    h2, h1 = proj_weight.shape
    matmuls += h1 * h2 * p * proj_density
    adds += (h1 - 1) * h2 * p * proj_density

n_layers = 8
n_hidden = 256
n_attn = 97

for i in range(n_layers):
    layer_matmuls = 0
    layer_adds = 0
    
    # layer norm 1
    ln1 = n_hidden + n_hidden * 2 + (n_hidden - 1) + 1 + n_hidden + n_hidden
    
    # qkv fully connected layer
    w_density = densities['layers.%s.qkv.weight' % i]
    b_density = densities['layers.%s.qkv.bias' % i]
    layer_matmuls += n_hidden * 192 * 3 * w_density
    layer_adds += (n_hidden - 1) * 192 * 3 * w_density + 192 * 3 * b_density
    
    # q * k
    layer_matmuls += 24 * 8 * n_attn
    layer_adds += (24 - 1) * 8 * n_attn
    
    # positional embedding
    layer_matmuls += 24 * n_attn
    layer_adds += (24 - 1) * n_attn + n_attn
    
    # softmax
    sm = n_attn + n_attn - 1 + n_attn
    
    # attn * v
    layer_matmuls += n_attn * 24 * 8
    layer_adds += (n_attn - 1) * 24 * 8
    
    # out fully connected layer
    w_density = densities['layers.%s.out.weight' % i]
    layer_matmuls += 192 * n_hidden * w_density
    layer_adds += (192 - 1) * n_hidden * w_density
    
    # residual
    layer_adds += n_hidden
    
    # layer norm 2
    ln2 = ln1
    
    # FFN 1
    w_density = densities['layers.%s.fc.0.weight' % i]
    b_density = densities['layers.%s.fc.0.bias' % i]
    layer_matmuls += n_hidden * 768 * w_density
    layer_adds += (n_hidden - 1) * 768 * w_density + 768 * b_density
    
    # ReLU
    relu = 768
    
    w_density = densities['layers.%s.fc.3.weight' % i]
    b_density = densities['layers.%s.fc.3.bias' % i]
    layer_matmuls += 768 * n_hidden * w_density
    layer_adds += (768 - 1) * n_hidden * w_density + n_hidden * b_density
    
    # residual
    layer_adds += n_hidden
    
    matmuls += layer_matmuls
    adds += layer_adds
    others += ln1 + sm + ln2 + relu

# clusters
w_density = densities['loss.clusters.weight']
b_density = densities['loss.clusters.bias']
matmuls += n_hidden * 2 * w_density
adds += (n_hidden - 1) * 2 * w_density + 2 * b_density

# projections for bin 1 and 2
for i in range(1, 3):
    w_density = densities['loss.projections.%s.weight' % (i - 1)]
    matmuls += {0: n_hidden * 64, 1: n_hidden * 4}[i - 1] * w_density
    adds += {0: (n_hidden - 1) * 64, 1: (n_hidden - 1) * 4}[i - 1] * w_density

for i in range(3):
    w = params['loss.layers.%s.weight' % i]
    w_density = densities['loss.layers.%s.weight' % i]
    b_density = densities['loss.layers.%s.bias' % i]
    
    v, h = w.shape
    matmuls += v * h * w_density
    adds += (h - 1) * v * w_density + v * b_density
sm = sum([x * 3 - 1 for x in (3502, 21500, 242735)])
others += sm

matmuls += n_hidden * 2000
adds += (n_hidden - 1) * 2000
cache_sm = 2000 * 3 - 1
others += cache_sm

mm_quantized = matmuls * num_bits / 32

print('matrix multiplications', matmuls)
print('matrix multiplications after quantization', mm_quantized)
print('additions', adds)
print('others fp32', others)
total_flops = mm_quantized + adds + others
print('total', total_flops)

matrix multiplications 5546286.350785319
matrix multiplications after quantization 1559893.0361583708
additions 5428300.290974323
others fp32 842247
total 7830440.3271326935


In [7]:
print('final score', total_size / 159e6 + total_flops / 318e6)

final score 0.03474921270324746
