In [1]:
from u import *
from data import *
import model
import quantize

In [7]:
c = Config(Res / 'quantize_prune33.75_distill_8.3M_cache2000_hebbian_step175000_cache3000_bits9',
           device='cpu', distributed=False).load()
# Evaluate perplexities on validation and test sets
# ! cd {c.res} && CUDA_VISIBLE_DEVICES=0 python ../../main.py . valid test

Model at step 200000
Model has 8304213 parameters. Embedding has 3278270 parameters
valid {'loss': 3.676476001739502, 'perplexity': 39.50692613554216, 'time': 2.0}
test {'loss': 3.721527099609375, 'perplexity': 41.327457088205485, 'time': 2.0}


In [282]:
net = eval(c.get('model', 'model.Transformer'))(c)
net, step = c.init_model(net, step=c.get('step', 'max'), train=False)
print('Loaded step', step)

Loaded step 200000


In [289]:
target_sparsities = {}
if (c.res / 'distiller_prune.yaml').exists():
    for k, v in (c.res / 'distiller_prune.yaml').load()['pruners'].items():
        for w in v['weights']:
            target_sparsities[w] = v['final_sparsity']
        
nonzero_params = 0
target_nonzero_params = 0
total_params = 0

mask_size = 0
quantization_size = 0

for k, p in net.state_dict().items():
    if k.startswith('loss.layers.') and k.endswith('.weight'): # shared with input embedding
        continue
    param_type = k.split('.')[-1]
    if param_type == 'max_abs':
        quantization_size += 1
    elif param_type == 'inv_scale':
        quantization_size += (32 - c.bits) / 32
    elif param_type in ['weight', 'bias', 'pos_emb']: # masked params
        if '.ln1.' in k or '.ln2.' in k: # ignore layernorm beta, gamma, can be fused into fc
            continue
        nz = from_torch((p != 0).sum())
        total = p.numel()
        if total == 0: continue
        nonzero_params += nz
        total_params += total
        mask_size += total / 32
        target_nonzero_params += total * (1 - target_sparsities.get(k, 0))
#         print(k, 'sparsity %.5g' % (1 - nz / total))
    elif param_type in ['cache_theta_inv_softplus', 'cache_lambda_inv_sigmoid']:
        nonzero_params += p.numel()
        total_params += p.numel()
    else:
        raise RuntimeError('Should not happen')
# print()
print('nonzero params', nonzero_params)
print('total params', total_params)
if len(target_sparsities):
    print('target total sparsity %.5g' % (1 - target_nonzero_params / total_params))
print('total sparsity %.5g' % (1 - nonzero_params / total_params))

print()
param_size = nonzero_params * c.get('bits', 32) / 32
print('total param size', param_size)
print('total mask size', mask_size)
print('total quantization size', quantization_size)
total_size = param_size + mask_size + quantization_size
print('total size', total_size)

nonzero params 73588631
total params 73588631
target total sparsity 0
total sparsity 0

total param size 73588631.0
total mask size 2299644.71875
total quantization size 0
total size 75888275.71875


In [284]:
test_tokens = (Cache / 'wikitext-103/sorted_test.npy').load()
cutoffs = np.array([0] + c.cutoffs + [c.n_vocab])
token_bin_counts = np.zeros(len(c.cutoffs) + 1)
for i, (prev_cutoff, cutoff) in enumerate(zip(cutoffs, cutoffs[1:])):
    token_bin_counts[i] = ((prev_cutoff <= test_tokens) & (test_tokens < cutoff)).sum()
print(token_bin_counts)
token_bin_fracs = token_bin_counts / token_bin_counts.sum()

# cumulative operation counts
muls = 0
adds = 0
others = 0

def density(p):
    return from_torch((p != 0).sum()) / (p.numel() or 1)

def tally_fc(fc, multiplier=1):
    wd = density(fc.weight)
    bd = density(fc.bias) if fc.bias is not None else 0
    out, in_ = fc.weight.shape
    tally_matmul(in_, out, multiplier * wd, multiplier * bd)
    
def tally_matmul(in_, out, mmultiplier=1, bmultiplier=1):
    global muls, adds
    # multiplies a matrix of shape (out, in_) by a vector of shape (in_,) with optional bias
    muls += mmultiplier * in_ * out
    # the intuition is that in_ is the dimension that's being summed up
    adds += mmultiplier * (in_ - 1) * out + bmultiplier * out
    
def tally_layernorm(ln):
    global others
    dim = np.prod(l.ln1.weight.shape)
    others += dim + dim * 2 + (dim - 1) + 1 + dim + dim
    
def tally_softmax(dim):
    global others
    others += dim + (dim - 1) + dim
    
for f, p in zip(token_bin_fracs[1:], net.embed.projections):
    tally_fc(p, multiplier=f)
    
for l in net.layers:
    # layer norm 1
    tally_layernorm(l.ln1)

    # qkv fully connected layer
    tally_fc(l.qkv)
    
    # q * k
    context = c.n_seq + 1
    tally_matmul(c.n_k, context * c.n_head, bmultiplier=0)
    
    # q * pos_emb + qk
    tally_matmul(c.n_k, context)
    
    # softmax
    tally_softmax(context)
    
    # attn * v
    tally_matmul(context, c.n_head * c.n_v, bmultiplier=0)
    
    # out fully connected layer
    tally_fc(l.out)
    
    # residual
    adds += c.n_embed
    
    # layer norm 2
    tally_layernorm(l.ln2)
    
    # FFN 1
    tally_fc(l.fc[0])
    
    # ReLU
    others += c.n_inner
    
    # FFN 2
    tally_fc(l.fc[3])
    
    # residual
    adds += c.n_embed

# first bin
tally_fc(net.loss.clusters)

# projections for other bins
for layer in net.loss.layers:
    tally_fc(layer)
for p in net.loss.projections:
    tally_fc(p)

# softmax over different bins
bin_sizes = cutoffs[1:] - cutoffs[:-1]
bin_sizes[0] += len(bin_sizes[1:])

for s in bin_sizes:
    tally_softmax(s)

# cache
tally_matmul(c.n_embed, c.get('n_cache', 0), bmultiplier=0)
tally_softmax(c.get('n_cache', 0))


muls_quant = muls * c.get('bits', 32) / 32

print('matrix multiplications', muls)
print('matrix multiplications after quantization', muls_quant)
print('additions', adds)
print('others fp32', others)
total_flops = muls_quant + adds + others
print('total', total_flops)

[245569.]
matrix multiplications 74097600.0
matrix multiplications after quantization 74097600.0
additions 74081664.0
others fp32 840083
total 149019347.0


In [274]:
print('final score', total_size / 159e6 + total_flops / 318e6)

final score 0.038746852160481786
