In [None]:
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torchtext.vocab import build_vocab_from_iterator
#Handling imports of Model here. Change the model to whichever one you like. But dont forget to change the checkpoint path as well
from LSTM_Models.MultiHeadAttentionLSTM import DiacritizationModel

# Load the model from saved checkpoint
checkpoint_path = "Weights/MultiHeadAttention15Epochs.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load(checkpoint_path, map_location=device)
model = DiacritizationModel(len(checkpoint['vocab']) + 1, checkpoint['embedding_dim'])  # Assuming model class is DiacritizationModel
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

# Define functions for tokenization and sequence preparation
def char_tokenizer(text):
    return list(text)

def prepare_sequence(seq, vocab, device):
    """ Convert a sequence of characters into a tensor of numerical indices. """
    tokens = char_tokenizer(seq)  # Tokenize the sequence into characters
    indices = [vocab[token] for token in tokens]  # Convert tokens to indices
    return torch.tensor(indices, dtype=torch.long).unsqueeze(0).to(device)  # Add batch dimension and send to device

def decode_predictions(indices, vocab):
    """ Convert a sequence of indices back to characters. """
    indices = indices.cpu().numpy()  # Move to CPU and convert to numpy array
    reverse_vocab = {index: char for char, index in vocab.get_stoi().items()}
    return ''.join([reverse_vocab[idx] for idx in indices])

turkish_lowercase = "abcçdefgğhıijklmnoöprsştuüvyz"
turkish_uppercase = "ABCÇDEFGĞHIİJKLMNOÖPRSŞTUÜVYZ"
special_chars = "°/.,!?;:-'0123456789()"

turkish_dict = dict(zip(turkish_lowercase, turkish_uppercase))

def find_special_chars_indices(sequence):
    return [(i,char) for i, char in enumerate(sequence) if char in special_chars]

def find_uppercase_indices(sequence):
    return [i for i, char in enumerate(sequence) if char in turkish_uppercase]

def diacritize_sequence(sequence, model, vocab, device):
    uppercase_indexes = find_uppercase_indices(sequence)
    special_chars_indices = find_special_chars_indices(sequence)

    sequence = sequence.lower()

    input_seq = prepare_sequence(sequence, vocab, device)

    # Make predictions
    with torch.no_grad():
        output = model(input_seq)
        predicted_indices = output.argmax(dim=2).squeeze(0)  

    # Decode the predicted sequence of indices back to characters
    predicted_sequence = decode_predictions(predicted_indices, vocab)

    # Change special characters back to their original form if they were changed
    for i, char in special_chars_indices:
        predicted_sequence = predicted_sequence[:i] + char + predicted_sequence[i+1:]

    predicted_sequence = ''.join([turkish_dict[char] if i in uppercase_indexes and char in turkish_lowercase else char for i, char in enumerate(predicted_sequence)])

    return predicted_sequence


# Prepare the input sequence

sequence1 = "kendini Roma Imparatoru olarak da tanitan Fatih in Imbrozlu ( gokceada ) tarihcisi Kritovulus soyle yazar: canakkale ye bagli eski Troya kitasinin merkezi olan Ilion sehrine geldiginde kalan yikintilari eski eserleri ve yoreyi seyir ve temasa eyledi denizden ve karadan haiz oldugu onemi takdir etti Ozan Homeros u ovup goklere cikardigi kimseleri ve onlarin yaptigi saygi deger hizmetleri hatirlayip anarak duygularini dile getirdi ve tanri beni bu sehrin ve halkinin muttefiki olarak bu ana kadar koruyup esirgedi"

print("Input sequence:    ", sequence1)
print("Predicted sequence:", diacritize_sequence(sequence1, model, checkpoint['vocab'], device))