In [None]:
#install SONAR - will be prompted to restart environment (wait until cell execution is complete)
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
!pip install fairseq2==0.3.0rc1 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/rc/pt2.5.1/cu124
!pip install sonar-space==0.3.2

In [None]:
#set up SONAR models - TextToEmbeddingModelPipeline for encoding and EmbeddingToTextModelPipeline for decoding
import torch
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
from sonar.inference_pipelines.text import EmbeddingToTextModelPipeline

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device(DEVICE)
torch.set_grad_enabled(False)
print(DEVICE)

# load models
text2vec = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder", device=DEVICE)
vec2text = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder", tokenizer="text_sonar_basic_encoder", device=DEVICE)

In [None]:
! git clone https://github.com/feralvam/easse.git

In [None]:
# ASSET
asset_src = open('/content/easse/easse/resources/data/test_sets/asset/asset.test.orig','r').readlines()

In [None]:
! git clone https://github.com/ZurichNLP/BLESS.git

In [None]:
#MEDEASI
import json

In [None]:
medEASi_src = open('/content/BLESS/model_outputs_and_evals/ground_truth/med-easi-test.jsonl','r').readlines()

In [None]:
medEASi_src = [json.loads(line)['source'] for line in medEASi_src]

In [None]:
#DEPLAIN - DE
! git clone https://github.com/rstodden/DEPlain.git

In [None]:
import csv
with open('/content/DEPlain/E__Sentence-level_Corpus/DEplain-web-sent/manual/open/test.csv') as csvfile:
    reader = csv.reader(csvfile)
    next(reader)
    deplain_src = [row[0] for row in reader]

In [None]:
#CLARAMED - ES
! wget https://digital.csic.es/bitstream/10261/346579/1/claramed_synt_simp_aligned.tsv

In [None]:
with open('/content/claramed_synt_simp_aligned.tsv') as csvfile:
    reader = csv.reader(csvfile,delimiter="\t")
    next(reader)
    claramed_src = [row[1] for row in reader]

In [None]:
import torch
import torch.nn as nn

In [None]:
class SimpleFeedForward(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleFeedForward, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.relu1 = nn.ReLU()

        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.relu1(x)
        x = self.output_layer(x)
        return x

In [None]:
# get this after running training for part 3 (or use provided model)
model = torch.load("/content/best_modelASSET - 4096.pt")

In [None]:
b_size = 16

def simplify_sentences(sentences,lang="eng_Latn"):
    sentence_embeddings = text2vec.predict(sentences, source_lang=lang, max_seq_len=128, progress_bar=True, batch_size=b_size)
    simplified_embeddings = model(sentence_embeddings)
    simplified_texts = vec2text.predict(simplified_embeddings, target_lang=lang, max_seq_len=128, progress_bar=True, batch_size=b_size, len_penalty=1.0)
    return simplified_texts

In [None]:
asset_tgt = simplify_sentences(asset_src)


In [None]:
medEASi_tgt = simplify_sentences(medEASi_src)

In [None]:
deplain_tgt = simplify_sentences(deplain_src,lang='deu_Latn')

In [None]:
claramed_tgt = simplify_sentences(claramed_src,lang='spa_Latn')

In [None]:
for i in range(100):
  print(asset_src[i].strip())
  print(asset_tgt[i].strip())
  print()

In [None]:
for i in range(100):
  print(medEASi_src[i].strip())
  print(medEASi_tgt[i].strip())
  print()

In [None]:
for i in range(100):
  print(deplain_src[i].strip())
  print(deplain_tgt[i].strip())
  print()

In [None]:
for i in range(100):
  print(claramed_src[i].strip())
  print(claramed_tgt[i].strip())
  print()

In [None]:
import pickle

In [None]:
pickle.dump(asset_tgt, open('asset_tgt.pkl', 'wb'))

In [None]:
pickle.dump(medEASi_tgt, open('MedEASi_tgt.pkl', 'wb'))

In [None]:
pickle.dump(deplain_tgt, open('DEPlain_tgt.pkl', 'wb'))

In [None]:
pickle.dump(claramed_tgt, open('CLARAMeD_tgt.pkl', 'wb'))