In [None]:
#install SONAR - will be prompted to restart environment (wait until cell execution is complete)
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
!pip install fairseq2==0.3.0rc1 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/rc/pt2.5.1/cu124
!pip install sonar-space==0.3.2

In [None]:
! git clone https://github.com/feralvam/easse.git

In [None]:
#Wiki Auto - simplification
! wget https://raw.githubusercontent.com/chaojiang06/wiki-auto/refs/heads/master/wiki-auto/ACL2020/train.dst
! wget https://raw.githubusercontent.com/chaojiang06/wiki-auto/refs/heads/master/wiki-auto/ACL2020/train.src

In [None]:
import pickle

In [None]:
#all source and target sentences go in here (dict of dicts: dataset-name : src [], tgt [])
all_sentences = {}

In [None]:
#Load Asset
asset_path = "/content/easse/easse/resources/data/test_sets/asset/"

asset_original_val_path = asset_path + "asset.valid.orig"
asset_original_val_sentences = open(asset_original_val_path, "r").readlines()

for i in range(10):
  name = "asset.valid.simp." + str(i)
  asset_simp_val_path = asset_path + name
  asset_simp_val_sentences = open(asset_simp_val_path, "r").readlines()
  all_sentences[name] = {"src": asset_original_val_sentences, "tgt": asset_simp_val_sentences}

asset_original_test_path = asset_path + "asset.test.orig"
asset_original_test_sentences = open(asset_original_test_path, "r").readlines()

for i in range(10):
  name = "asset.test.simp." + str(i)
  asset_simp_test_path = asset_path + "asset.test.simp." + str(i)
  asset_simp_test_sentences = open(asset_simp_test_path, "r").readlines()
  all_sentences[name] = {"src": asset_original_test_sentences, "tgt": asset_simp_test_sentences}



In [None]:
#Wiki auto import
wiki_auto_complex = open("/content/train.src", "r").readlines()
wiki_auto_simple = open("/content/train.dst", "r").readlines()

all_sentences['wiki_auto'] = {"src": wiki_auto_complex, "tgt": wiki_auto_simple}

In [None]:
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device(DEVICE)
torch.set_grad_enabled(False)
print(DEVICE)

In [None]:
#set up SONAR models - TextToEmbeddingModelPipeline for encoding and EmbeddingToTextModelPipeline for decoding
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
from sonar.inference_pipelines.text import EmbeddingToTextModelPipeline


# load models
text2vec = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder", device=DEVICE)
vec2text = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder", tokenizer="text_sonar_basic_encoder", device=DEVICE)

In [None]:
# Encode sentences with SONAR (f(x))
b_size = 64

embeddings = {}

val_count = 2000

#val
embeddings['wauto_comp_val'] = text2vec.predict(all_sentences['wiki_auto']['src'][:val_count],   source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)
embeddings['wauto_simp_val'] = text2vec.predict(all_sentences['wiki_auto']['tgt'][:val_count],   source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)

#train
embeddings['asset_comp_train'] = text2vec.predict(all_sentences['asset.valid.simp.0']['src'], source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)
embeddings['asset_simp_train'] = text2vec.predict(all_sentences['asset.valid.simp.0']['tgt'], source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)

embeddings['wauto_comp_train'] = text2vec.predict(all_sentences['wiki_auto']['src'][val_count:],   source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)
embeddings['wauto_simp_train'] = text2vec.predict(all_sentences['wiki_auto']['tgt'][val_count:],   source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)

#test
embeddings['asset_comp_test']  = text2vec.predict(all_sentences['asset.test.simp.0']['src'],  source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)
embeddings['asset_simp_test']  = text2vec.predict(all_sentences['asset.test.simp.0']['tgt'],  source_lang="eng_Latn", max_seq_len=128, progress_bar=True, batch_size=b_size)

In [None]:
#it may be useful to save the embeddings to memory, restart the environment to clear the GPU RAM and then reload from the cell below.

#pickle.dump(embeddings, open('embeddings.pkl', 'wb'))

In [None]:
#embeddings = pickle.load(open('embeddings.pkl', 'rb'))

In [None]:
src_train = torch.cat((embeddings['asset_comp_train'],embeddings['wauto_comp_train']), dim=0)
tgt_train = torch.cat((embeddings['asset_simp_train'],embeddings['wauto_simp_train']), dim=0)

In [None]:
#Set up NN - g(x)
#basic feed-forward neural network with ADAM optimiser
import torch.nn as nn

class SimpleFeedForward(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleFeedForward, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.relu1 = nn.ReLU()

        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.relu1(x)
        x = self.output_layer(x)
        return x

# Instantiate the models
input_dim = src_train.size()[1]
output_dim = tgt_train.size()[1]



In [None]:
src_val = embeddings['wauto_comp_val']
tgt_val = embeddings['wauto_simp_val']

In [None]:
# again - if needed, you can dump the training data to disk here and then reload in a clean GPU from the cell below

#pickle.dump(src_train, open('src_train.pkl', 'wb'))
#pickle.dump(tgt_train, open('tgt_train.pkl', 'wb'))
#pickle.dump(src_val, open('src_val.pkl', 'wb'))
#pickle.dump(tgt_val, open('tgt_val.pkl', 'wb'))

In [None]:
#src_train = pickle.load(open('src_train.pkl', 'rb'))
#tgt_train = pickle.load(open('tgt_train.pkl', 'rb'))
#src_val = pickle.load(open('src_val.pkl', 'rb'))
#tgt_val = pickle.load(open('tgt_val.pkl', 'rb'))

In [None]:
#run the training loop (AI generated initial copy - caveat emptor)
import torch.optim as optim

# Define the loss function and optimizer
criterion = nn.MSELoss()  # Mean Squared Error

# Move data to the same device as the model (not needed if you reloaded from disk)
#src_train = src_train.clone()
#tgt_train = tgt_train.clone()


# Enable gradient calculation for training
torch.set_grad_enabled(True)

def train_loop(id, model, src, tgt, val_src, val_tgt, lr=0.001, epochs=5000):
  log_file = open("%s.log" % id, "w")

  model.train()
  print(id)
  optimizer = optim.Adam(model.parameters(), lr=lr)

  initial_train_loss = criterion(model(src),tgt).item()
  print("initial train loss:", initial_train_loss)
  initial_val_loss = criterion(model(val_src),val_tgt).item()
  print("initial val loss:", initial_val_loss)

  log_file.write("%d,%.10f,%.10f\n"%(0,initial_train_loss,initial_val_loss))

  best_val_loss = initial_val_loss
  best_epoch = 0
  loss = None

  # Training loop
  num_epochs = epochs
  for epoch in range(num_epochs):
    outputs = model(src)
    loss = criterion(outputs, tgt)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0:
        model.eval()
        val_outputs = model(val_src)
        val_loss = criterion(val_outputs, val_tgt)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {loss.item()}, Val Loss: {val_loss.item()}',end='')
        log_file.write("%d,%.10f,%.10f\n"%(epoch+1,loss.item(),val_loss.item()))
        if val_loss.item() < best_val_loss:
          print("*",end='')
          best_val_loss = val_loss.item()
          best_epoch = epoch
          torch.save(model, "best_model%s.pt" % id)
        else:
          if epoch - best_epoch > 250:
            print("Early stopping triggered.")
            break
        model.train()
        print()

  log_file.close()
  print("Training finished.")


#run training for each model
for k in [4096]:# [256,512,1024,2048,4096]:
  model = SimpleFeedForward(input_dim, k, output_dim).to(DEVICE)
  print(model)
  pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  print(pytorch_total_params)
  train_loop("ASSET - %d"%k, model, src_train, tgt_train, src_val, tgt_val, epochs=10000)

# Disable gradient calculation after training if needed for inference
torch.set_grad_enabled(False)