In [54]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
import torch
import os
import pickle

import omegaconf
from omegaconf import OmegaConf

from nsc.models import ModelForTokenClassification, ModelForTokenClassificationConfig, Models
from nsc.modules import embedding
from nsc import TokenizationRepairer as NSCTokenizationRepairer
from nsc.data import tokenization, variants
from nsc.utils import Batch, io
from nsc.utils import config

from spell_checking import CONFIG_DIR, DATA_DIR, EXPERIMENT_DIR

from trt import TokenizationRepairer

In [56]:
char_tok_cfg = tokenization.TokenizerConfig(type=tokenization.Tokenizers.CHAR)
char_tok = tokenization.get_tokenizer_from_config(char_tok_cfg)

### EO medium (6 layer transformer encoder)

In [57]:
embedding_cfg = embedding.TensorEmbeddingConfig()
model_cfg = ModelForTokenClassificationConfig(
    type=Models.MODEL_FOR_TOKEN_CLASSIFICATION, 
    hidden_dim=512,
    num_classes=3,
    tokenizer=char_tok_cfg,
    embedding=embedding_cfg,
    num_layers=6,
    activation="gelu",
    feed_forward_dim=2048,
    norm=False,
    num_clf_layers=1
)
sample_inputs = Batch([torch.tensor([0, 1])], {})
nsc_model = ModelForTokenClassification(sample_inputs=sample_inputs, cfg=model_cfg, device="cpu").eval()

In [58]:
variant_cfg = variants.TokenizationRepairConfig(
    type=variants.DatasetVariants.TOKENIZATION_REPAIR,
    data_scheme="tensor",
    input_type="char",
    add_bos_eos=True
)
dummy_train_cfg = OmegaConf.structured(
    config.TrainConfig(
        experiment_name="eo_medium_arxiv_with_errors",
        variant=variant_cfg,
        model=model_cfg
    )
)
print(dummy_train_cfg)

{'model': {'type': <Models.MODEL_FOR_TOKEN_CLASSIFICATION: 5>, 'max_length': 512, 'hidden_dim': 512, 'tokenizer': {'type': <Tokenizers.CHAR: 1>, 'file_path': None}, 'embedding': {'learned_position_embedding': False, 'embed_positions': True, 'dropout': 0.1}, 'dropout': 0.1, 'num_layers': 6, 'feed_forward_dim': 2048, 'norm': False, 'activation': 'gelu', 'num_clf_layers': 1, 'num_classes': 3}, 'variant': {'type': <DatasetVariants.TOKENIZATION_REPAIR: 3>, 'data_scheme': 'tensor', 'input_type': 'char', 'add_bos_eos': True}, 'experiment_dir': '???', 'experiment_name': 'eo_medium_arxiv_with_errors', 'data_dir': '???', 'datasets': '???', 'dataset_limits': '???', 'val_splits': '???', 'epochs': 20, 'batch_size': 64, 'batch_max_length': None, 'bucket_span': None, 'optimizer': '???', 'lr_scheduler': None, 'log_per_epoch': 100, 'eval_per_epoch': 4, 'keep_last_n_checkpoints': 0, 'seed': 22, 'num_workers': None, 'pin_memory': True, 'mixed_precision': True, 'start_from_checkpoint': None, 'exponential_

In [59]:
tok_rep = TokenizationRepairer.from_pretrained("eo_medium_arxiv_with_errors", device="cpu")

2022-05-10 22:05:40,229 [DOWNLOAD] [INFO] model eo_medium_arxiv_with_errors was already downloaded to cache directory /home/sebastian/anaconda3/envs/masters_thesis/lib/python3.8/site-packages/trt/api/.cache
2022-05-10 22:05:40,248 [TOKENIZATION_REPAIR] [INFO] running tokenization repair on device Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz


In [60]:
trt_tok = tok_rep.model.encoder.tokenizer
trt_chars = [trt_tok.id_to_token(i) for i in range(trt_tok.get_vocab_size())]
trt_tok.get_vocab_size(), trt_chars
assert all(char in trt_chars for char in char_tok.vocab)

In [61]:
set(trt_tok.get_vocab()) - set(char_tok.vocab)

{'<mask>', '<sep>'}

In [62]:
print(nsc_model.embedding.embedding.emb.weight)
new_embedding_weight = torch.stack([
    tok_rep.model.encoder.embedding.embedding.weight[trt_tok.token_to_id(char_tok.id_to_token(i))]
    for i in range(char_tok.vocab_size)
])
print(new_embedding_weight.shape, new_embedding_weight)
assert all(torch.equal(new_embedding_weight[char_tok.token_to_id(char)], tok_rep.model.encoder.embedding.embedding.weight[trt_tok.token_to_id(char)]) 
           for char in char_tok.vocab)

Parameter containing:
tensor([[-0.0061,  0.0099,  0.0760,  ..., -0.0008, -0.0288,  0.0184],
        [-0.0934,  0.0072,  0.0519,  ...,  0.0433,  0.0365,  0.0169],
        [ 0.0578, -0.0113, -0.0313,  ...,  0.0502,  0.0699, -0.0762],
        ...,
        [-0.0378,  0.0389, -0.0255,  ...,  0.0730,  0.0513, -0.0658],
        [ 0.0249,  0.0182,  0.0039,  ..., -0.0570,  0.0972, -0.0442],
        [ 0.0150,  0.0171, -0.0260,  ...,  0.0048, -0.0491, -0.1052]],
       requires_grad=True)
torch.Size([99, 512]) tensor([[ 0.0091,  0.0069,  0.0073,  ..., -0.0129, -0.0797,  0.0664],
        [ 0.0047, -0.0299,  0.0066,  ..., -0.0620, -0.0542, -0.0682],
        [ 0.0097,  0.0094,  0.0074,  ..., -0.1044, -0.0106, -0.0480],
        ...,
        [ 0.0117,  0.0090,  0.0086,  ..., -0.0296, -0.1163,  0.0599],
        [ 0.0083,  0.0050,  0.0066,  ..., -0.0551,  0.0331,  0.0179],
        [ 0.0117,  0.0105,  0.0106,  ..., -0.0245, -0.0246, -0.0273]])


In [63]:
nsc_model.embedding.embedding.emb.weight.data = new_embedding_weight
nsc_model.embedding.norm.load_state_dict(tok_rep.model.encoder.embedding.norm.state_dict())
nsc_model.encoder.encoder.load_state_dict(tok_rep.model.encoder.encoder.state_dict())
nsc_model.head.clf.load_state_dict(tok_rep.model.head.head.state_dict())

<All keys matched successfully>

In [64]:
seq = "thisisatest"
positions = torch.arange(len("thisisatest"))

In [65]:
nsc_token_emb = nsc_model.embedding.embedding(torch.tensor(char_tok.tokenize(seq)))
nsc_pos_emb = nsc_model.embedding.pos_emb(positions)
nsc_emb = nsc_model.embedding(torch.tensor(char_tok.tokenize(seq)).unsqueeze(0))

In [66]:
ipt_ids = torch.tensor([trt_tok.encode(seq).ids[1:-1]])
trt_token_emb = tok_rep.model.encoder.embedding.embedding(ipt_ids.T)
trt_pos_emb = tok_rep.model.encoder.embedding.pos_embedding(ipt_ids.T)
trt_emb = tok_rep.model.encoder.embedding(ipt_ids.T)

In [67]:
nsc_token_emb.shape, trt_pos_emb.shape

(torch.Size([11, 512]), torch.Size([11, 1, 512]))

In [68]:
trt_all_emb = tok_rep.model.encoder.embedding.norm(trt_token_emb * 512 ** 0.5 + trt_pos_emb)
nsc_all_emb = nsc_model.embedding.norm(nsc_token_emb + nsc_pos_emb)
assert torch.allclose(trt_all_emb[:, 0, :], nsc_all_emb)

In [69]:
(
    torch.allclose(nsc_token_emb, trt_token_emb[:, 0, :] * 512 ** 0.5), 
    torch.allclose(nsc_pos_emb, trt_pos_emb[:, 0, :]), 
    torch.allclose(nsc_emb[0], trt_emb[:, 0, :])
)

(True, True, True)

In [70]:
nsc_enc = nsc_model.encoder(nsc_emb)
trt_enc = tok_rep.model.encoder.encoder(trt_emb)

In [71]:
assert torch.allclose(nsc_enc[0], trt_enc[:, 0, :], atol=1e-6)

In [73]:
os.makedirs(os.path.join(EXPERIMENT_DIR, "TOKENIZATION_REPAIR", "eo_medium_arxiv_with_errors_ported", "checkpoints"), exist_ok=True)
io.save_checkpoint(os.path.join(EXPERIMENT_DIR, "TOKENIZATION_REPAIR", "eo_medium_arxiv_with_errors_ported", "checkpoints", "checkpoint_best.pt"), 
                   nsc_model, 0, 0)
with open(os.path.join(EXPERIMENT_DIR, "TOKENIZATION_REPAIR", "eo_medium_arxiv_with_errors_ported", "cfg.pkl"), "wb") as pf:
    pickle.dump((dummy_train_cfg, {}), pf)

### EO large (12 layer transformer encoder)

In [74]:
embedding_cfg = embedding.TensorEmbeddingConfig()
model_cfg = ModelForTokenClassificationConfig(
    type=Models.MODEL_FOR_TOKEN_CLASSIFICATION, 
    hidden_dim=512,
    num_classes=3,
    tokenizer=char_tok_cfg,
    embedding=embedding_cfg,
    num_layers=12,
    activation="gelu",
    feed_forward_dim=2048,
    norm=False,
    num_clf_layers=1
)
sample_inputs = Batch([torch.tensor([0, 1])], {})
nsc_model = ModelForTokenClassification(sample_inputs=sample_inputs, cfg=model_cfg, device="cpu").eval()

In [75]:
variant_cfg = variants.TokenizationRepairConfig(
    type=variants.DatasetVariants.TOKENIZATION_REPAIR,
    data_scheme="tensor",
    input_type="char",
    add_bos_eos=True
)
dummy_train_cfg = OmegaConf.structured(
    config.TrainConfig(
        experiment_name="eo_large_arxiv_with_errors",
        variant=variant_cfg,
        model=model_cfg
    )
)
print(dummy_train_cfg)

{'model': {'type': <Models.MODEL_FOR_TOKEN_CLASSIFICATION: 5>, 'max_length': 512, 'hidden_dim': 512, 'tokenizer': {'type': <Tokenizers.CHAR: 1>, 'file_path': None}, 'embedding': {'learned_position_embedding': False, 'embed_positions': True, 'dropout': 0.1}, 'dropout': 0.1, 'num_layers': 12, 'feed_forward_dim': 2048, 'norm': False, 'activation': 'gelu', 'num_clf_layers': 1, 'num_classes': 3}, 'variant': {'type': <DatasetVariants.TOKENIZATION_REPAIR: 3>, 'data_scheme': 'tensor', 'input_type': 'char', 'add_bos_eos': True}, 'experiment_dir': '???', 'experiment_name': 'eo_large_arxiv_with_errors', 'data_dir': '???', 'datasets': '???', 'dataset_limits': '???', 'val_splits': '???', 'epochs': 20, 'batch_size': 64, 'batch_max_length': None, 'bucket_span': None, 'optimizer': '???', 'lr_scheduler': None, 'log_per_epoch': 100, 'eval_per_epoch': 4, 'keep_last_n_checkpoints': 0, 'seed': 22, 'num_workers': None, 'pin_memory': True, 'mixed_precision': True, 'start_from_checkpoint': None, 'exponential_

In [76]:
tok_rep = TokenizationRepairer.from_pretrained("eo_large_arxiv_with_errors", device="cpu")

2022-05-10 22:06:24,297 [DOWNLOAD] [INFO] model eo_large_arxiv_with_errors was already downloaded to cache directory /home/sebastian/anaconda3/envs/masters_thesis/lib/python3.8/site-packages/trt/api/.cache
2022-05-10 22:06:24,316 [TOKENIZATION_REPAIR] [INFO] running tokenization repair on device Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz


In [77]:
trt_tok = tok_rep.model.encoder.tokenizer
trt_chars = [trt_tok.id_to_token(i) for i in range(trt_tok.get_vocab_size())]
trt_tok.get_vocab_size(), trt_chars
assert all(char in trt_chars for char in char_tok.vocab)

In [78]:
set(trt_tok.get_vocab()) - set(char_tok.vocab)

{'<mask>', '<sep>'}

In [79]:
print(nsc_model.embedding.embedding.emb.weight)
new_embedding_weight = torch.stack([
    tok_rep.model.encoder.embedding.embedding.weight[trt_tok.token_to_id(char_tok.id_to_token(i))]
    for i in range(char_tok.vocab_size)
])
print(new_embedding_weight.shape, new_embedding_weight)
assert all(torch.equal(new_embedding_weight[char_tok.token_to_id(char)], tok_rep.model.encoder.embedding.embedding.weight[trt_tok.token_to_id(char)]) 
           for char in char_tok.vocab)

Parameter containing:
tensor([[ 0.0419, -0.0796,  0.0186,  ...,  0.0085, -0.0026, -0.0574],
        [ 0.0011, -0.0510, -0.0250,  ...,  0.0011,  0.1244, -0.0295],
        [ 0.0341, -0.0372, -0.0035,  ..., -0.0037,  0.0338,  0.0133],
        ...,
        [ 0.0149, -0.0206,  0.0012,  ...,  0.0372,  0.0211, -0.0229],
        [-0.0441, -0.0076, -0.0076,  ...,  0.0067, -0.0176, -0.0320],
        [-0.0469,  0.0433, -0.0172,  ...,  0.0054, -0.0747,  0.0176]],
       requires_grad=True)
torch.Size([99, 512]) tensor([[ 0.0051,  0.0053,  0.0026,  ..., -0.0388, -0.0927,  0.0241],
        [ 0.0167, -0.0090,  0.0196,  ..., -0.1209,  0.0299, -0.1390],
        [ 0.0126,  0.0199,  0.0080,  ..., -0.1087, -0.0695,  0.0321],
        ...,
        [ 0.0006,  0.0082,  0.0017,  ..., -0.0256, -0.0366,  0.0614],
        [ 0.0078, -0.0013, -0.0002,  ..., -0.0140,  0.0141,  0.0621],
        [ 0.0126,  0.0072,  0.0063,  ...,  0.0174, -0.0390, -0.0065]])


In [80]:
nsc_model.embedding.embedding.emb.weight.data = new_embedding_weight
nsc_model.embedding.norm.load_state_dict(tok_rep.model.encoder.embedding.norm.state_dict())
nsc_model.encoder.encoder.load_state_dict(tok_rep.model.encoder.encoder.state_dict())
nsc_model.head.clf.load_state_dict(tok_rep.model.head.head.state_dict())

<All keys matched successfully>

In [81]:
seq = "thisisatest"
positions = torch.arange(len("thisisatest"))

In [82]:
nsc_token_emb = nsc_model.embedding.embedding(torch.tensor(char_tok.tokenize(seq)))
nsc_pos_emb = nsc_model.embedding.pos_emb(positions)
nsc_emb = nsc_model.embedding(torch.tensor(char_tok.tokenize(seq)).unsqueeze(0))

In [83]:
ipt_ids = torch.tensor([trt_tok.encode(seq).ids[1:-1]])
trt_token_emb = tok_rep.model.encoder.embedding.embedding(ipt_ids.T)
trt_pos_emb = tok_rep.model.encoder.embedding.pos_embedding(ipt_ids.T)
trt_emb = tok_rep.model.encoder.embedding(ipt_ids.T)

In [84]:
nsc_token_emb.shape, trt_pos_emb.shape

(torch.Size([11, 512]), torch.Size([11, 1, 512]))

In [85]:
trt_all_emb = tok_rep.model.encoder.embedding.norm(trt_token_emb * 512 ** 0.5 + trt_pos_emb)
nsc_all_emb = nsc_model.embedding.norm(nsc_token_emb + nsc_pos_emb)
assert torch.allclose(trt_all_emb[:, 0, :], nsc_all_emb)

In [86]:
(
    torch.allclose(nsc_token_emb, trt_token_emb[:, 0, :] * 512 ** 0.5), 
    torch.allclose(nsc_pos_emb, trt_pos_emb[:, 0, :]), 
    torch.allclose(nsc_emb[0], trt_emb[:, 0, :])
)

(True, True, True)

In [87]:
nsc_enc = nsc_model.encoder(nsc_emb)
trt_enc = tok_rep.model.encoder.encoder(trt_emb)

In [88]:
assert torch.allclose(nsc_enc[0], trt_enc[:, 0, :], atol=1e-6)

In [89]:
os.makedirs(os.path.join(EXPERIMENT_DIR, "TOKENIZATION_REPAIR", "eo_large_arxiv_with_errors_ported", "checkpoints"), exist_ok=True)
io.save_checkpoint(os.path.join(EXPERIMENT_DIR, "TOKENIZATION_REPAIR", "eo_large_arxiv_with_errors_ported", "checkpoints", "checkpoint_best.pt"), 
                   nsc_model, 0, 0)
with open(os.path.join(EXPERIMENT_DIR, "TOKENIZATION_REPAIR", "eo_large_arxiv_with_errors_ported", "cfg.pkl"), "wb") as pf:
    pickle.dump((dummy_train_cfg, {}), pf)

### EO small (3 layer transformer encoder)

In [90]:
embedding_cfg = embedding.TensorEmbeddingConfig()
model_cfg = ModelForTokenClassificationConfig(
    type=Models.MODEL_FOR_TOKEN_CLASSIFICATION, 
    hidden_dim=512,
    num_classes=3,
    tokenizer=char_tok_cfg,
    embedding=embedding_cfg,
    num_layers=3,
    activation="gelu",
    feed_forward_dim=2048,
    norm=False,
    num_clf_layers=1
)
sample_inputs = Batch([torch.tensor([0, 1])], {})
nsc_model = ModelForTokenClassification(sample_inputs=sample_inputs, cfg=model_cfg, device="cpu").eval()

In [91]:
variant_cfg = variants.TokenizationRepairConfig(
    type=variants.DatasetVariants.TOKENIZATION_REPAIR,
    data_scheme="tensor",
    input_type="char",
    add_bos_eos=True
)
dummy_train_cfg = OmegaConf.structured(
    config.TrainConfig(
        experiment_name="eo_small_arxiv_with_errors",
        variant=variant_cfg,
        model=model_cfg
    )
)
print(dummy_train_cfg)

{'model': {'type': <Models.MODEL_FOR_TOKEN_CLASSIFICATION: 5>, 'max_length': 512, 'hidden_dim': 512, 'tokenizer': {'type': <Tokenizers.CHAR: 1>, 'file_path': None}, 'embedding': {'learned_position_embedding': False, 'embed_positions': True, 'dropout': 0.1}, 'dropout': 0.1, 'num_layers': 3, 'feed_forward_dim': 2048, 'norm': False, 'activation': 'gelu', 'num_clf_layers': 1, 'num_classes': 3}, 'variant': {'type': <DatasetVariants.TOKENIZATION_REPAIR: 3>, 'data_scheme': 'tensor', 'input_type': 'char', 'add_bos_eos': True}, 'experiment_dir': '???', 'experiment_name': 'eo_small_arxiv_with_errors', 'data_dir': '???', 'datasets': '???', 'dataset_limits': '???', 'val_splits': '???', 'epochs': 20, 'batch_size': 64, 'batch_max_length': None, 'bucket_span': None, 'optimizer': '???', 'lr_scheduler': None, 'log_per_epoch': 100, 'eval_per_epoch': 4, 'keep_last_n_checkpoints': 0, 'seed': 22, 'num_workers': None, 'pin_memory': True, 'mixed_precision': True, 'start_from_checkpoint': None, 'exponential_m

In [92]:
tok_rep = TokenizationRepairer.from_pretrained("eo_small_arxiv_with_errors", device="cpu")

2022-05-10 22:06:45,270 [DOWNLOAD] [INFO] model eo_small_arxiv_with_errors was already downloaded to cache directory /home/sebastian/anaconda3/envs/masters_thesis/lib/python3.8/site-packages/trt/api/.cache
2022-05-10 22:06:45,292 [TOKENIZATION_REPAIR] [INFO] running tokenization repair on device Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz


In [93]:
trt_tok = tok_rep.model.encoder.tokenizer
trt_chars = [trt_tok.id_to_token(i) for i in range(trt_tok.get_vocab_size())]
trt_tok.get_vocab_size(), trt_chars
assert all(char in trt_chars for char in char_tok.vocab)

In [94]:
set(trt_tok.get_vocab()) - set(char_tok.vocab)

{'<mask>', '<sep>'}

In [95]:
print(nsc_model.embedding.embedding.emb.weight)
new_embedding_weight = torch.stack([
    tok_rep.model.encoder.embedding.embedding.weight[trt_tok.token_to_id(char_tok.id_to_token(i))]
    for i in range(char_tok.vocab_size)
])
print(new_embedding_weight.shape, new_embedding_weight)
assert all(torch.equal(new_embedding_weight[char_tok.token_to_id(char)], tok_rep.model.encoder.embedding.embedding.weight[trt_tok.token_to_id(char)]) 
           for char in char_tok.vocab)

Parameter containing:
tensor([[ 0.0407, -0.0709, -0.0296,  ...,  0.0186, -0.0709,  0.0200],
        [ 0.0227,  0.0843,  0.0452,  ...,  0.0493, -0.0047,  0.0124],
        [ 0.0359, -0.0234, -0.0283,  ...,  0.0187,  0.0263, -0.1021],
        ...,
        [-0.0265, -0.0368,  0.0386,  ...,  0.0050,  0.0590, -0.0624],
        [ 0.0355, -0.0283,  0.0601,  ..., -0.0126, -0.1097, -0.0459],
        [-0.0817, -0.0882,  0.0697,  ...,  0.0685,  0.0540, -0.0507]],
       requires_grad=True)
torch.Size([99, 512]) tensor([[ 0.0108,  0.0098,  0.0099,  ..., -0.0142, -0.1065, -0.0008],
        [ 0.0157, -0.0226,  0.0119,  ..., -0.0181, -0.0790, -0.0182],
        [ 0.0054,  0.0153,  0.0057,  ..., -0.0999,  0.0133, -0.0816],
        ...,
        [ 0.0104,  0.0077,  0.0080,  ...,  0.0667, -0.0535,  0.0512],
        [ 0.0111,  0.0100,  0.0081,  ..., -0.1114,  0.0735, -0.0356],
        [ 0.0124,  0.0126,  0.0116,  ..., -0.0277, -0.0135, -0.0464]])


In [96]:
nsc_model.embedding.embedding.emb.weight.data = new_embedding_weight
nsc_model.embedding.norm.load_state_dict(tok_rep.model.encoder.embedding.norm.state_dict())
nsc_model.encoder.encoder.load_state_dict(tok_rep.model.encoder.encoder.state_dict())
nsc_model.head.clf.load_state_dict(tok_rep.model.head.head.state_dict())

<All keys matched successfully>

In [97]:
seq = "thisisatest"
positions = torch.arange(len("thisisatest"))

In [98]:
nsc_token_emb = nsc_model.embedding.embedding(torch.tensor(char_tok.tokenize(seq)))
nsc_pos_emb = nsc_model.embedding.pos_emb(positions)
nsc_emb = nsc_model.embedding(torch.tensor(char_tok.tokenize(seq)).unsqueeze(0))

In [99]:
ipt_ids = torch.tensor([trt_tok.encode(seq).ids[1:-1]])
trt_token_emb = tok_rep.model.encoder.embedding.embedding(ipt_ids.T)
trt_pos_emb = tok_rep.model.encoder.embedding.pos_embedding(ipt_ids.T)
trt_emb = tok_rep.model.encoder.embedding(ipt_ids.T)

In [100]:
nsc_token_emb.shape, trt_pos_emb.shape

(torch.Size([11, 512]), torch.Size([11, 1, 512]))

In [101]:
trt_all_emb = tok_rep.model.encoder.embedding.norm(trt_token_emb * 512 ** 0.5 + trt_pos_emb)
nsc_all_emb = nsc_model.embedding.norm(nsc_token_emb + nsc_pos_emb)
assert torch.allclose(trt_all_emb[:, 0, :], nsc_all_emb)

In [102]:
(
    torch.allclose(nsc_token_emb, trt_token_emb[:, 0, :] * 512 ** 0.5), 
    torch.allclose(nsc_pos_emb, trt_pos_emb[:, 0, :]), 
    torch.allclose(nsc_emb[0], trt_emb[:, 0, :])
)

(True, True, True)

In [103]:
nsc_enc = nsc_model.encoder(nsc_emb)
trt_enc = tok_rep.model.encoder.encoder(trt_emb)

In [104]:
assert torch.allclose(nsc_enc[0], trt_enc[:, 0, :], atol=1e-6)

In [105]:
os.makedirs(os.path.join(EXPERIMENT_DIR, "TOKENIZATION_REPAIR", "eo_small_arxiv_with_errors_ported", "checkpoints"), exist_ok=True)
io.save_checkpoint(os.path.join(EXPERIMENT_DIR, "TOKENIZATION_REPAIR", "eo_small_arxiv_with_errors_ported", "checkpoints", "checkpoint_best.pt"), 
                   nsc_model, 0, 0)
with open(os.path.join(EXPERIMENT_DIR, "TOKENIZATION_REPAIR", "eo_small_arxiv_with_errors_ported", "cfg.pkl"), "wb") as pf:
    pickle.dump((dummy_train_cfg, {}), pf)