In [1]:
import torch
from CBHG import CBHGModel
from diacritizer import Diacritizer
from dataset import DiacriticsDataset
import pandas as pd
from baseline import BaseLineModel

In [2]:
test_dataset_path = 'dataset/test_no_diacritics.txt'
# model_path = 'models/CBHG_EP20_BS256.pth'
output_csv_path = 'output/labels.csv'
test_dataset_diacritized_path = 'output/diacritized.txt'
train_dataset_path = 'dataset/train_val_test.txt'

In [None]:
train_dataset = DiacriticsDataset()
train_dataset.load('dataset/train.txt')

In [None]:
model_CBHG = CBHGModel(
    inp_vocab_size=len(train_dataset.arabic_letters) + 1,
    targ_vocab_size=len(train_dataset.diacritic_classes),
)
model_CBHG.train_(train_dataset, batch_size=64, epochs=10, learning_rate=0.001)
torch.save(model_CBHG.state_dict(),'models/CBHG_EP10_BS64_LR0.001.pth')

In [None]:
# model = CBHGModel(
#     inp_vocab_size = 37,
#     targ_vocab_size = 15,
# )

# state_dict = torch.load(model_path, map_location=torch.device('cpu'))
# model.load_state_dict(state_dict)

In [38]:
test_dataset = DiacriticsDataset()
test_dataset.load(test_dataset_path, train=False)

inputs = test_dataset.character_sentences

In [39]:
model_CBHG.eval()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)

diacritics = torch.empty(0, 400)
with torch.no_grad():
    for char_sentence in tqdm(test_data_loader):
        char_sentence = char_sentence.to(device)

        outputs = model_CBHG(char_sentence)
        softmax_output = outputs['diacritics'].to('cpu')

        diacritic_sentence = torch.argmax(softmax_output, dim=-1)
        diacritics = torch.cat([diacritics, diacritic_sentence], dim=0)

In [40]:
mask_no_pad = inputs != test_dataset.pad_char
output_diacritics = diacritics[mask_no_pad]
output_diacritics = output_diacritics.cpu()

df = pd.DataFrame(output_diacritics.numpy(), columns=["label"])
df = df.rename_axis('ID').reset_index()
df.to_csv(output_csv_path, index=False)

In [41]:
with open(test_dataset_path, 'r', encoding='utf-8') as file:
    corpus = file.read()

diacritizer = Diacritizer()
diacritized_corpus = diacritizer.diacritize(corpus, output_diacritics)

with open(test_dataset_diacritized_path, 'w', encoding='utf-8') as file:
    corpus = file.write(diacritized_corpus)