diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 52c96481c..806022687 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -47,20 +47,38 @@ def sd_transform(sd): out = {} for k in sd: k_str = ''.join(k) - if 'Dense' in k_str: - new_key = (*k[:2], 'MlpBlock_0', *k[2:]) - out[new_key] = sd[k] - elif 'SelfAttention' in k_str: + if 'SelfAttention' in k_str: new_key = list(k) - if '_' in new_key[-1]: - qkv = {'q': 'query', 'k': 'key', 'v': 'value'}[new_key[-1][0]] - new_key[-1] = qkv - new_key.append('kernel') new_key = [ i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' for i in new_key ] - new_key = tuple(new_key) + if 'SelfAttention_0' in k_str: + if new_key[-2] == 'Dense_0': + # qkv + for name, value in zip(('query', 'key', 'value'), sd[k].chunk(3)): + out[(*new_key[:-2], name, new_key[-1])] = value + pass + elif new_key[-2] == 'Dense_1': + # out + out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] + pass + else: + if new_key[-2] == 'Dense_0': + #q + out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] + pass + elif new_key[-2] == 'Dense_1': + # kv + for name, value in zip(('key', 'value'), sd[k].chunk(2)): + out[(*new_key[:-2], name, new_key[-1])] = value + pass + elif new_key[-2] == 'Dense_2': + # out + out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] + pass + elif 'Dense' in k_str: + new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] elif 'LayerNorm' in k_str: new_key = list(k)