In [48]:
import torch
from CBHG import CBHGModel
from diacritizer import Diacritizer
from dataset import DiacriticsDataset
import pandas as pd

In [49]:
# test_dataset_path = 'test_no_diacritics.txt'
test_dataset_path = 'dataset/sample_test_no_diacritics.txt'
model_path = 'models/CBHG_EP20_BS256.pth'
# input_csv_path = 'test_set_without_labels.csv'
output_csv_path = 'output/labels.csv'
test_dataset_diacritized_path = 'output/diacritized.txt'

In [50]:
test_dataset = DiacriticsDataset()
test_dataset.load(test_dataset_path, train=False)

inputs = test_dataset.character_sentences

In [51]:
model = CBHGModel(
    inp_vocab_size = 37,
    targ_vocab_size = 15,
)

state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)

model.eval()

CBHGModel(
  (embedding): Embedding(37, 512)
  (prenet): Prenet(
    (layers): ModuleList(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): Linear(in_features=512, out_features=256, bias=True)
    )
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (cbhg): CBHG(
    (relu): ReLU()
    (conv1d_banks): ModuleList(
      (0): BatchNormConv1d(
        (conv1d): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
        (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (1): BatchNormConv1d(
        (conv1d): Conv1d(256, 256, kernel_size=(2,), stride=(1,), padding=(1,), bias=False)
        (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (2): BatchNormConv1d(
        (conv1d): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
        (bn): BatchNorm1d(256,

In [52]:
with torch.no_grad():
    outputs = model(inputs)
diacritics = torch.argmax(outputs['diacritics'], dim=-1)

In [53]:
mask_no_pad = inputs != test_dataset.pad_char
output_diacritics = diacritics[mask_no_pad]

df = pd.DataFrame(output_diacritics.numpy(), columns=["label"])
df = df.rename_axis('ID').reset_index()
df.to_csv(output_csv_path, index=False)

In [55]:
with open(test_dataset_path, 'r', encoding='utf-8') as file:
    corpus = file.read()
    
diacritizer = Diacritizer()
diacritized_corpus = diacritizer.diacritize(corpus, output_diacritics)

with open(test_dataset_diacritized_path, 'w', encoding='utf-8') as file:
    corpus = file.write(diacritized_corpus)