In [62]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import random
import string
import re
from typing import List
import nltk
from nltk.corpus import wordnet
from tqdm import tqdm

In [None]:
nltk.download('wordnet')
nltk.download('omw-1.4')

In [None]:
# Create an empty list to hold the words
word_list = []
letters = set('abcdfeghijklmnopqrstuvwxyz')
max_length = 32

# Define a regular expression pattern to match standard punctuation
pattern = r'[_\W]+'

# Loop through all the synsets in WordNet
for synset in wordnet.all_synsets():

    # Loop through all the lemma names for each synset
    for lemma in synset.lemma_names():

        # Split the lemma on '_' and convert to lower case
        words = re.split(pattern, lemma.lower())

        # Loop through each word in the lemma
        for word in words:

            # Remove all characters other than the 26 lower-case consonants
            word = ''.join(filter(lambda c: c in letters, word))
            word = word[:max_length]

            # Add the word to the word list
            if word:
                word_list.append(word)


In [None]:
# Print some words
print(len(word_list))
print(random.sample(word_list, 100))


In [None]:
class ReverseWordDataset(Dataset):
    def __init__(self, word_list, max_length=34):
        self.word_list = word_list
        self.max_length = max_length

    def __len__(self):
        return len(self.word_list)

    def __getitem__(self, idx):
        word = self.word_list[idx]
        reversed_word = word[::-1]
        return self.encode_word(reversed_word), self.encode_word(word)

    def encode_word(self, word):
        encoded = [0]  # Start of sentence character
        for char in word[:self.max_length-2]:  # Reserve space for start and end tokens
            encoded.append(ord(char) - ord('a') + 1)
        encoded.append(27)  # End of sentence character
        while len(encoded) < self.max_length:
            encoded.append(28)  # Padding character
        return torch.tensor(encoded, dtype=torch.long)

    def decode_word(self, encoded_word):
        decoded = []
        for idx in encoded_word:
            if idx == 0 or idx == 27:
                continue
            elif idx == 28:
                break
            else:
                decoded.append(chr(idx + ord('a') - 1))
        return ''.join(decoded)

In [None]:
dataset = ReverseWordDataset(word_list, max_length=max_length)


In [None]:
rev, orig = dataset.__getitem__(np.random.randint(0, 2e5))
print(dataset.decode_word(orig))
print(dataset.decode_word(rev))

In [None]:
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

input_seq, target_seq = next(iter(dataloader))
print(input_seq.shape)

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_chars, hidden_size, num_layers=1):
        super(Encoder, self).__init__()
        self.num_chars = num_chars
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(num_chars, hidden_size, num_layers, batch_first=True)

    def forward(self, input_batch):
        # One-hot encode the input
        one_hot = torch.zeros(input_batch.size(0), input_batch.size(1), self.num_chars).to(input_batch.device)
        one_hot.scatter_(2, input_batch.unsqueeze(2), 1)

        # Encode the one-hot encoded input using a GRU
        _, hidden = self.gru(one_hot)

        return hidden


# Hyperparameters
num_chars = 29  # Including start of sentence (0), end of sentence (27), and padding (28) tokens
hidden_size = 128
num_layers = 1

# Create the encoder
encoder = Encoder(num_chars, hidden_size, num_layers)


In [None]:
hn = encoder(input_seq)
print(hn.shape)

In [None]:
class Decoder(nn.Module):
    def __init__(self, num_chars, hidden_size, num_layers=1):
        super(Decoder, self).__init__()
        self.num_chars = num_chars
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.gru = nn.GRU(num_chars, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_chars)

    def forward(self, input_step, hidden):
        # One-hot encode the input step
        one_hot = torch.zeros(input_step.size(0), input_step.size(1), self.num_chars).to(input_step.device)
        one_hot.scatter_(2, input_step.unsqueeze(2), 1)

        # Decode the one-hot encoded input step using a GRU
        output, hidden = self.gru(one_hot, hidden)

        # Output the logits for each character class (0-28)
        logits = self.fc(output)

        return logits, hidden

# Create the decoder
decoder = Decoder(num_chars, hidden_size, num_layers)


In [None]:
logits,_ = decoder(target_seq, hn)
print(logits.shape)


In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_batch, target_batch):
        # Encode the input batch
        hidden = self.encoder(input_batch)

        # Initialize the input step with the start of sentence token (0)
        input_step = target_batch[:, 0].unsqueeze(1)

        # Decode the target batch one step at a time
        outputs = []
        for t in range(1, target_batch.size(1)):
            logits, hidden = self.decoder(input_step, hidden)
            outputs.append(logits)
            input_step = target_batch[:, t].unsqueeze(1)

        return torch.cat(outputs, dim=1)

In [None]:
seq2seq = Seq2Seq(encoder, decoder)

print(target_seq.shape)
out = seq2seq(input_seq, target_seq)
print(out.shape)

In [None]:
# Hyperparameters
learning_rate = 0.001
num_epochs = 10
batch_size = 64

# Create the dataset and data loader
reverse_word_dataset = ReverseWordDataset(word_list)
data_loader = torch.utils.data.DataLoader(reverse_word_dataset, batch_size=batch_size, shuffle=True)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(seq2seq.parameters(), lr=learning_rate)

In [66]:
# Main training loop
for epoch in range(num_epochs):
    for batch_idx, (input_batch, target_batch) in tqdm(enumerate(data_loader)):
        # Forward pass
        logits = seq2seq(input_batch, target_batch)

        # Compute the loss
        loss = criterion(logits.view(-1, num_chars), target_batch[:, 1:].reshape(-1))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

1099it [09:50,  1.74it/s]

In [50]:
def reverse_words(words, seq2seq, dataset, max_length=34):
    # Encode the input words
    input_batch = [dataset.encode_word(word[::-1]) for word in words]

    # Create a tensor with the encoded words
    input_batch = torch.stack(input_batch)

    # Initialize the input step with the start of sentence token (0)
    input_step = torch.zeros(input_batch.size(0), 1, dtype=torch.long)

    # Initialize the hidden state with the output from the encoder
    hidden = seq2seq.encoder(input_batch)

    # Decode the target sequence one character at a time
    decoded_words = []
    for _ in range(max_length):
        logits, hidden = seq2seq.decoder(input_step, hidden)
        predictions = torch.argmax(logits, dim=-1)
        decoded_words.append(predictions.squeeze(1))
        input_step = predictions

    # Transpose the list of decoded words
    decoded_words = torch.stack(decoded_words).transpose(0, 1)

    # Turn the predicted target sequences back into strings
    reversed_words = [dataset.decode_word(encoded_word) for encoded_word in decoded_words]

    return reversed_words



In [65]:
# Example usage
input_words = ["apple", "banana", "grape", "orange", "strawberry"]
input_words_rev = [word[::-1] for word in input_words]
reversed_words = reverse_words(input_words_rev, seq2seq, reverse_word_dataset)


print(f"Input: {input_words_rev}")
print(f"Target: {input_words}")
print(f"Predicted: {reversed_words}")


Input: ['elppa', 'ananab', 'eparg', 'egnaro', 'yrrebwarts']
Target: ['apple', 'banana', 'grape', 'orange', 'strawberry']
Predicted: ['share', 'sharis', 'shart', 'sharil', 'conthoria']
