From c15adb0708788a76b86d9366ac4fd134495ed428 Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Sat, 9 Sep 2023 22:33:25 +0000 Subject: [PATCH 1/2] fix wmt comparator --- tests/modeldiffs/wmt/compare.py | 36 ++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 52c96481c..382f2bf26 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -47,20 +47,38 @@ def sd_transform(sd): out = {} for k in sd: k_str = ''.join(k) - if 'Dense' in k_str: - new_key = (*k[:2], 'MlpBlock_0', *k[2:]) - out[new_key] = sd[k] - elif 'SelfAttention' in k_str: + if 'SelfAttention' in k_str: new_key = list(k) - if '_' in new_key[-1]: - qkv = {'q': 'query', 'k': 'key', 'v': 'value'}[new_key[-1][0]] - new_key[-1] = qkv - new_key.append('kernel') new_key = [ i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' for i in new_key ] - new_key = tuple(new_key) + if 'SelfAttention_0' in k_str: + if new_key[-2] == 'Dense_0': + # qkv + for name, value in zip(('query','key','value'),sd[k].chunk(3)): + out[(*new_key[:-2],name,new_key[-1])] = value + pass + elif new_key[-2] == 'Dense_1': + # out + out[(*new_key[:-2],'out',new_key[-1])] = sd[k] + pass + else: + if new_key[-2] == 'Dense_0': + #q + out[(*new_key[:-2],'query',new_key[-1])] = sd[k] + pass + elif new_key[-2] == 'Dense_1': + # kv + for name, value in zip(('key','value'),sd[k].chunk(2)): + out[(*new_key[:-2],name,new_key[-1])] = value + pass + elif new_key[-2] == 'Dense_2': + # out + out[(*new_key[:-2],'out',new_key[-1])] = sd[k] + pass + elif 'Dense' in k_str: + new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] elif 'LayerNorm' in k_str: new_key = list(k) From ec96fbed552f46ee4802ca9b8595a78351684ffa Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Sat, 9 Sep 2023 22:36:20 +0000 Subject: [PATCH 2/2] comparator fix --- tests/modeldiffs/wmt/compare.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 382f2bf26..806022687 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -56,26 +56,26 @@ def sd_transform(sd): if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv - for name, value in zip(('query','key','value'),sd[k].chunk(3)): - out[(*new_key[:-2],name,new_key[-1])] = value + for name, value in zip(('query', 'key', 'value'), sd[k].chunk(3)): + out[(*new_key[:-2], name, new_key[-1])] = value pass elif new_key[-2] == 'Dense_1': # out - out[(*new_key[:-2],'out',new_key[-1])] = sd[k] + out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass else: if new_key[-2] == 'Dense_0': #q - out[(*new_key[:-2],'query',new_key[-1])] = sd[k] - pass + out[(*new_key[:-2], 'query', new_key[-1])] = sd[k] + pass elif new_key[-2] == 'Dense_1': - # kv - for name, value in zip(('key','value'),sd[k].chunk(2)): - out[(*new_key[:-2],name,new_key[-1])] = value - pass + # kv + for name, value in zip(('key', 'value'), sd[k].chunk(2)): + out[(*new_key[:-2], name, new_key[-1])] = value + pass elif new_key[-2] == 'Dense_2': # out - out[(*new_key[:-2],'out',new_key[-1])] = sd[k] + out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:])