In [2]:
!pip install pandas

You should consider upgrading via the '/opt/conda/bin/python -m pip install --upgrade pip' command.[0m


In [1]:
import time
import sys
import os
import copy
import random
import torch 
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

sys.path.append('../tools')
from generate import generate, sent_to_ids, load_bart, load_sents, load_dict, ids_to_tokens
from cka import encdec_cka_sim, cal_encdec_mean

In [2]:
def random_sampling(sents1, sents2, num):
    sents = [(sent1, sent2) for sent1, sent2 in zip(sents1, sents2)]
    sampled_sents = random.sample(sents, num)
    sampled_sents1 = [x[0] for x in sampled_sents]
    sampled_sents2 = [x[1] for x in sampled_sents]
    return sampled_sents1, sampled_sents2

# Japanese BART

In [23]:
# pre-trained JaBART
jabart_jako_path = "../pretrained_bart/trim/jabart_jako"
jabart_jako_name = 'jabart_base.pt'
pre_model = load_bart(
    path=jabart_jako_path, model_name=jabart_jako_name
).to(DEVICE)

## Korean/Japanese

In [24]:
# dict
koja_d = load_dict("../pretrained_bart/trim/jabart_jako/dict.txt")

# sentences 
file_path = "../data/jabart/koja/dev.ja"
with open(file_path , "r") as f:
    sentences_ja = f.readlines()

file_path = "../data/jabart/koja/dev.ko"
with open(file_path , "r") as f:
    sentences_ko = f.readlines()

In [5]:
# ja-ko
ft_path = "../ja-ko/jabart/checkpoints"
ft_name = "checkpoint_best.pt"
ft_model = load_bart(
    path=ft_path, model_name=ft_name
).to(DEVICE)

random.seed(1)
sampled_sents_ja, sampled_sents_ja = random_sampling(
    sentences_ja, sentences_ja, 100
)

start = time.time()
encdec_cka_sim(
    pre=pre_model, ft=ft_model, 
    pre_d=koja_d, ft_d=koja_d, 
    pre_sents=sampled_sents_ja, ft_sents=sampled_sents_ja,
    batch_size=16
)
end = time.time()
print (f"elapsed_time:{end-start}[sec]")

(8, 100, 768) (7, 100, 768) (6, 100, 768)
(8, 100, 768) (7, 100, 768)
Encoder CKA
Layer 0 0.9971703786815697
Layer 1 0.9812738013289333
Layer 2 0.9739238246902284
Layer 3 0.9650739860483872
Layer 4 0.9618400017708343
Layer 5 0.9660422979403386
Layer 6 0.968122705004213
Layer 7 0.9782481944354261

Decoder CKA
Layer 0 0.12138005263053207
Layer 1 0.2570982162304512
Layer 2 0.317451346899074
Layer 3 0.34108509333253895
Layer 4 0.38336732931309214
Layer 5 0.3730877445713164
Layer 6 0.14565507281983664

Decoder up to self attention CKA
Layer 0 0.1858750434836678
Layer 1 0.30689625563559614
Layer 2 0.3483071697250494
Layer 3 0.375189978250909
Layer 4 0.3877989131905401
Layer 5 0.3236910988351791
elapsed_time:130.80946397781372[sec]


In [7]:
# ko-ja
ft_path = "../ko-ja/jabart/checkpoints"
ft_name = "checkpoint_best.pt"
ft_model = load_bart(
    path=ft_path, model_name=ft_name
).to(DEVICE)

random.seed(1)
sampled_sents_ja, sampled_sents_ko = random_sampling(
    sentences_ja, sentences_ko, 100
)

start = time.time()
encdec_cka_sim(
    pre=pre_model, ft=ft_model, 
    pre_d=koja_d, ft_d=koja_d, 
    pre_sents=sampled_sents_ja, ft_sents=sampled_sents_ko,
    batch_size=16
)
end = time.time()
print (f"elapsed_time:{end-start}[sec]")

(8, 100, 768) (7, 100, 768) (6, 100, 768)
(8, 100, 768) (7, 100, 768)
Encoder CKA
Layer 0 0.8828052578510951
Layer 1 0.9040147464892014
Layer 2 0.8906150178104792
Layer 3 0.8838428647354586
Layer 4 0.8944273267348168
Layer 5 0.9063596033533926
Layer 6 0.8884562648335584
Layer 7 0.8915533906290435

Decoder CKA
Layer 0 0.15945214045884482
Layer 1 0.2301107814456214
Layer 2 0.2810259877645324
Layer 3 0.30112129693878437
Layer 4 0.3413878644049451
Layer 5 0.3025392750966902
Layer 6 0.12649043719278555

Decoder up to self attention CKA
Layer 0 0.20187657079827914
Layer 1 0.2768636300866152
Layer 2 0.3128932577278696
Layer 3 0.33865115253137784
Layer 4 0.3419903773080327
Layer 5 0.2811961032945658
elapsed_time:125.79683542251587[sec]


## English/Japanese

In [29]:
# dict
enja_d = load_dict("../pretrained_bart/trim/jabart_jaen/dict.txt")

# sentences 
file_path = "../data/jabart/enja/dev.ja"
with open(file_path , "r") as f:
    sentences_ja = f.readlines()
    
file_path = "../data/jabart/enja/dev.en"
with open(file_path , "r") as f:
    sentences_en = f.readlines()

In [9]:
# ja-en
ft_path = "../ja-en/jabart/checkpoints"
ft_name = "checkpoint_best.pt"
ft_model = load_bart(
    path=ft_path, model_name=ft_name
).to(DEVICE)

random.seed(1)
sampled_sents_ja, sampled_sents_ja = random_sampling(
    sentences_ja, sentences_ja, 100
)

start = time.time()
encdec_cka_sim(
    pre=pre_model, ft=ft_model, 
    pre_d=koja_d, ft_d=enja_d, 
    pre_sents=sampled_sents_ja, ft_sents=sampled_sents_ja,
    batch_size=16
)
end = time.time()
print (f"elapsed_time:{end-start}[sec]")

(8, 100, 768) (7, 100, 768) (6, 100, 768)
(8, 100, 768) (7, 100, 768)
Encoder CKA
Layer 0 0.989563281497098
Layer 1 0.974816472648417
Layer 2 0.9528309253401753
Layer 3 0.9221236680843773
Layer 4 0.905613568223484
Layer 5 0.9197068856932752
Layer 6 0.9353293294050234
Layer 7 0.9439371613914803

Decoder CKA
Layer 0 0.1119577261480585
Layer 1 0.19932022451944195
Layer 2 0.29359013945335416
Layer 3 0.35911365734848844
Layer 4 0.43916214260604197
Layer 5 0.3338383694447813
Layer 6 0.1503958812159764

Decoder up to self attention CKA
Layer 0 0.18135612039168253
Layer 1 0.26100318634236164
Layer 2 0.34228249247427395
Layer 3 0.43092717671459235
Layer 4 0.41025013328783927
Layer 5 0.30782225709240135
elapsed_time:110.1309769153595[sec]


In [30]:
# en-ja
ft_path = "../en-ja/jabart/checkpoints"
ft_name = "checkpoint_best.pt"
ft_model = load_bart(
    path=ft_path, model_name=ft_name
).to(DEVICE)

random.seed(1)
sampled_sents_ja, sampled_sents_en = random_sampling(
    sentences_ja, sentences_en, 100
)

start = time.time()
encdec_cka_sim(
    pre=pre_model, ft=ft_model, 
    pre_d=koja_d, ft_d=enja_d, 
    pre_sents=sampled_sents_ja, ft_sents=sampled_sents_en,
    batch_size=16
)
end = time.time()
print (f"elapsed_time:{end-start}[sec]")

(8, 100, 768) (7, 100, 768) (6, 100, 768)
(8, 100, 768) (7, 100, 768)
Encoder CKA
Layer 0 0.43952244795753864
Layer 1 0.45148635766727835
Layer 2 0.45407653326032255
Layer 3 0.456910254492851
Layer 4 0.44128720776525304
Layer 5 0.4321100468523806
Layer 6 0.4247472443873235
Layer 7 0.49199936471273664

Decoder CKA
Layer 0 0.2187766210541609
Layer 1 0.24496475980236715
Layer 2 0.33469679147778675
Layer 3 0.368891923542102
Layer 4 0.4009444935595677
Layer 5 0.31946105639655187
Layer 6 0.059528893832202785

Decoder up to self attention CKA
Layer 0 0.27436065973329726
Layer 1 0.31041440068043336
Layer 2 0.3700385357524329
Layer 3 0.3913675411483885
Layer 4 0.3791862244039186
Layer 5 0.2610391434145364
elapsed_time:105.03792023658752[sec]


# English BART

In [12]:
# pre-trained EnBART
enbart_enja_path = "../pretrained_bart/trim/enbart_enja"
enbart_enja_name = "enbart_base.pt"
pre_model = load_bart(
    path=enbart_enja_path, model_name=enbart_enja_name
).to(DEVICE)

## Japanese/English

In [13]:
# dict
enja_d = load_dict(f"{enbart_enja_path}/dict.txt")

# sentences 
file_path = "../data/enbart/enja/dev.ja"
with open(file_path, "r") as f:
    sentences_ja = f.readlines()

file_path = "../data/enbart/enja/dev.en"
with open(file_path, "r") as f:
    sentences_en = f.readlines()

In [14]:
# ja-en
ft_path = "../ja-en/enbart/checkpoints"
ft_name = "checkpoint_best.pt"
ft_model = load_bart(
    path=ft_path, model_name=ft_name
).to(DEVICE)

random.seed(1)
sampled_sents_ja, sampled_sents_en = random_sampling(
    sentences_ja, sentences_en, 100
)

start = time.time()
encdec_cka_sim(
    pre=pre_model, ft=ft_model, 
    pre_d=enja_d, ft_d=enja_d,  
    pre_sents=sampled_sents_en, ft_sents=sampled_sents_ja,
    batch_size=16
)
end = time.time()
print (f"elapsed_time:{end-start}[sec]")

(8, 100, 768) (7, 100, 768) (6, 100, 768)
(8, 100, 768) (7, 100, 768)
Encoder CKA
Layer 0 0.44308230789884906
Layer 1 0.5502702209143098
Layer 2 0.5519916281037587
Layer 3 0.5563345593379978
Layer 4 0.5585082009441193
Layer 5 0.5547125939640384
Layer 6 0.5425837385997123
Layer 7 0.5425837385997123

Decoder CKA
Layer 0 0.09931835963931872
Layer 1 0.07595272749699429
Layer 2 0.11041319036774082
Layer 3 0.1846376892397597
Layer 4 0.13568210247405774
Layer 5 0.08839916309783269
Layer 6 0.3417472827864746

Decoder up to self attention CKA
Layer 0 0.08764377269044171
Layer 1 0.15421076094913194
Layer 2 0.22038680891153356
Layer 3 0.23608631584197678
Layer 4 0.20428137809692784
Layer 5 0.1089702591060719
elapsed_time:102.63952875137329[sec]


In [15]:
# en-ja
ft_path = "../en-ja/enbart/checkpoints"
ft_name = "checkpoint_best.pt"
ft_model = load_bart(
    path=ft_path, model_name=ft_name
).to(DEVICE)

random.seed(1)
sampled_sents_en, sampled_sents_en = random_sampling(
    sentences_en, sentences_en, 100
)

start = time.time()
encdec_cka_sim(
    pre=pre_model, ft=ft_model, 
    pre_d=enja_d, ft_d=enja_d,  
    pre_sents=sampled_sents_en, ft_sents=sampled_sents_en,
    batch_size=16
)
end = time.time()
print (f"elapsed_time:{end-start}[sec]")

(8, 100, 768) (7, 100, 768) (6, 100, 768)
(8, 100, 768) (7, 100, 768)
Encoder CKA
Layer 0 0.9947663859591692
Layer 1 0.9690050812960682
Layer 2 0.9267659871455628
Layer 3 0.9259282590568918
Layer 4 0.9225745072622071
Layer 5 0.9233199679951328
Layer 6 0.9257420342116391
Layer 7 0.9257420342116391

Decoder CKA
Layer 0 0.06200720237022637
Layer 1 0.09602597670796026
Layer 2 0.1162435341810525
Layer 3 0.1489488699434853
Layer 4 0.1765507707264187
Layer 5 0.20782397281105974
Layer 6 0.5513750252869679

Decoder up to self attention CKA
Layer 0 0.07340509526139198
Layer 1 0.10115642577108794
Layer 2 0.12500983513936922
Layer 3 0.1745240158137556
Layer 4 0.22196037311379893
Layer 5 0.260123068278414
elapsed_time:165.20238304138184[sec]


## French/English

In [17]:
# dict
enfr_d = load_dict("../pretrained_bart/trim/enbart_enfr/dict.txt")

# sentences 
file_path = "../data/enbart/enfr/dev.fr"
with open(file_path, "r") as f:
    sentences_fr = f.readlines()

file_path = "../data/enbart/enfr/dev.en"
with open(file_path, "r") as f:
    sentences_en = f.readlines()

In [20]:
# fr-en
ft_path = "../fr-en/enbart/checkpoints"
ft_name = "checkpoint_best.pt"
ft_model = load_bart(
    path=ft_path, model_name=ft_name
).to(DEVICE)

random.seed(1)
sampled_sents_en, sampled_sents_fr = random_sampling(
    sentences_en, sentences_fr, 100
)

start = time.time()
encdec_cka_sim(
    pre=pre_model, ft=ft_model, 
    pre_d=enja_d, ft_d=enfr_d,  
    pre_sents=sampled_sents_en, ft_sents=sampled_sents_fr,
    batch_size=16
)
end = time.time()
print (f"elapsed_time:{end-start}[sec]")

(8, 100, 768) (7, 100, 768) (6, 100, 768)
(8, 100, 768) (7, 100, 768)
Encoder CKA
Layer 0 0.7646569451152809
Layer 1 0.7587003117680807
Layer 2 0.7757442263339923
Layer 3 0.7663677465393556
Layer 4 0.774393182564784
Layer 5 0.7750444993347524
Layer 6 0.7712746472730949
Layer 7 0.7712746472730949

Decoder CKA
Layer 0 0.11365173723236892
Layer 1 0.131481879697025
Layer 2 0.2103302172190924
Layer 3 0.2919080848020759
Layer 4 0.35700600555268813
Layer 5 0.35416613680745196
Layer 6 0.3105485631315848

Decoder up to self attention CKA
Layer 0 0.1047903767197506
Layer 1 0.20801378834495848
Layer 2 0.32268239893067757
Layer 3 0.3840264404270085
Layer 4 0.4018308813827713
Layer 5 0.3303856550469767
elapsed_time:111.59074974060059[sec]


In [22]:
# en-fr
ft_path = "../en-fr/enbart/checkpoints"
ft_name = "checkpoint_best.pt"
ft_model = load_bart(
    path=ft_path, model_name=ft_name
).to(DEVICE)

random.seed(1)
sampled_sents_en, sampled_sents_en = random_sampling(
    sentences_en, sentences_en, 100
)

start = time.time()
encdec_cka_sim(
    pre=pre_model, ft=ft_model, 
    pre_d=enja_d, ft_d=enfr_d,  
    pre_sents=sampled_sents_en, ft_sents=sampled_sents_en,
    batch_size=16
)
end = time.time()
print (f"elapsed_time:{end-start}[sec]")

(8, 100, 768) (7, 100, 768) (6, 100, 768)
(8, 100, 768) (7, 100, 768)
Encoder CKA
Layer 0 0.9771762965273146
Layer 1 0.9073712394139727
Layer 2 0.8965725297233069
Layer 3 0.8815354150381705
Layer 4 0.8775547720352443
Layer 5 0.872442251598707
Layer 6 0.8410559203035358
Layer 7 0.8410559203035358

Decoder CKA
Layer 0 0.10778351003020133
Layer 1 0.12908465704311187
Layer 2 0.19107924906277546
Layer 3 0.2578092991278896
Layer 4 0.29301651690077135
Layer 5 0.28792259101580175
Layer 6 0.26776594653292807

Decoder up to self attention CKA
Layer 0 0.10354058602131933
Layer 1 0.18931631487983047
Layer 2 0.2836285855859728
Layer 3 0.3269118008003395
Layer 4 0.3352472237992844
Layer 5 0.2687337239909468
elapsed_time:121.1991674900055[sec]
