In [None]:
#install SONAR - will be prompted to restart environment (wait until cell execution is complete)
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
!pip install fairseq2==0.3.0rc1 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/rc/pt2.5.1/cu124
!pip install sonar-space==0.3.2

In [None]:
! git clone https://github.com/feralvam/easse.git

In [None]:
#Wiki Auto - simplification
! wget https://raw.githubusercontent.com/chaojiang06/wiki-auto/refs/heads/master/wiki-auto/ACL2020/train.dst
! wget https://raw.githubusercontent.com/chaojiang06/wiki-auto/refs/heads/master/wiki-auto/ACL2020/train.src

In [None]:
#all source and target sentences go in here (dict of dicts: dataset-name : src [], tgt [])
all_sentences = {}

In [None]:
#Load Asset
asset_path = "/content/easse/easse/resources/data/test_sets/asset/"

asset_original_val_path = asset_path + "asset.valid.orig"
asset_original_val_sentences = open(asset_original_val_path, "r").readlines()

for i in range(10):
  name = "asset.valid.simp." + str(i)
  asset_simp_val_path = asset_path + name
  asset_simp_val_sentences = open(asset_simp_val_path, "r").readlines()
  all_sentences[name] = {"src": asset_original_val_sentences, "tgt": asset_simp_val_sentences}

asset_original_test_path = asset_path + "asset.test.orig"
asset_original_test_sentences = open(asset_original_test_path, "r").readlines()

for i in range(10):
  name = "asset.test.simp." + str(i)
  asset_simp_test_path = asset_path + "asset.test.simp." + str(i)
  asset_simp_test_sentences = open(asset_simp_test_path, "r").readlines()
  all_sentences[name] = {"src": asset_original_test_sentences, "tgt": asset_simp_test_sentences}



In [None]:
#Wiki auto import
wiki_auto_complex = open("/content/train.src", "r").readlines()
wiki_auto_simple = open("/content/train.dst", "r").readlines()

all_sentences['wiki_auto'] = {"src": wiki_auto_complex, "tgt": wiki_auto_simple}

In [None]:
#set up SONAR models - TextToEmbeddingModelPipeline for encoding and EmbeddingToTextModelPipeline for decoding
import torch
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
from sonar.inference_pipelines.text import EmbeddingToTextModelPipeline

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device(DEVICE)
torch.set_grad_enabled(False)
print(DEVICE)

# load models
text2vec = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder", device=DEVICE)
vec2text = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder", tokenizer="text_sonar_basic_encoder", device=DEVICE)

In [None]:
#Experiment 1
# - encode sentences with SONAR + reconstruct
b_size = 64

embeddings = {}
embeddings['asset_comp_train'] = text2vec.predict(all_sentences['asset.valid.simp.0']['src'],  source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)
embeddings['asset_simp_train']  = text2vec.predict(all_sentences['asset.valid.simp.0']['tgt'], source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)
embeddings['wauto_comp_train'] = text2vec.predict(all_sentences['wiki_auto']['src'][:2000], source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)
embeddings['wauto_simp_train']  = text2vec.predict(all_sentences['wiki_auto']['tgt'][:2000], source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)


reconstruction = {}
reconstruction['asset_comp_train'] = vec2text.predict(embeddings['asset_comp_train'],  target_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size, len_penalty=0.8)
reconstruction['asset_simp_train'] = vec2text.predict(embeddings['asset_simp_train'],  target_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size, len_penalty=0.8)
reconstruction['wauto_comp_train'] = vec2text.predict(embeddings['wauto_comp_train'],  target_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size, len_penalty=0.8)
reconstruction['wauto_simp_train'] = vec2text.predict(embeddings['wauto_simp_train'],  target_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size, len_penalty=0.8)

In [None]:
import pickle

In [None]:
pickle.dump(all_sentences, open("all_sentences.pkl", "wb"))
pickle.dump(embeddings, open("embeddings.pkl", "wb"))
pickle.dump(reconstruction, open("reconstruction.pkl", "wb"))