## Prime Number Generator
In the word2vec, we have the hint "i have the property that is preserved". Here, we're creating a property preserving network such that if a prime number is put in, another prime number is generated. 


In [None]:
import torch

In [None]:
class PropertyPreservingNetwork(torch.nn.Module):
    num_embeddings = 100
    hidden_size = 2
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.hidden_size)
        self.output = torch.nn.Linear(self.hidden_size, self.num_embeddings)

    def forward(self, x):
        return self.output(self.embedding(x))


Source of the list of prime numbers: https://en.wikipedia.org/wiki/List_of_prime_numbers

In [None]:
import pandas as pd
prime_numbers = pd.read_csv("../resources/prime_numbers.tsv", sep="\t", header=None).values.flatten()

In [None]:
max_size = PropertyPreservingNetwork.num_embeddings
relevant_prime_numbers = prime_numbers[prime_numbers < max_size]

Construct dataset:
- Every prime number is matched to another prime number
- Every non-prime number is matched to a random number

In [None]:
# Prime Examples
torch.random.manual_seed(123)
prime = torch.tensor(relevant_prime_numbers, dtype=torch.long)
another_prime = prime[torch.randperm(len(prime))]

# Non-Prime Examples
non_prime = torch.tensor([
    e for e in range(max_size) if e not in prime_numbers
])
any_number = torch.randperm(max_size)[:len(non_prime)]

# Concat
x = torch.concat([prime, non_prime])
y = torch.concat([another_prime, any_number])
is_prime = torch.zeros_like(x)
is_prime[:len(prime)] = 1

Train:
- As an additional hint, we separate the embeddings of prime and non-prime numbers.

In [None]:
torch.random.manual_seed(123)

loss_mod = torch.nn.CrossEntropyLoss()

nn = PropertyPreservingNetwork()


optim = torch.optim.AdamW(nn.parameters(), lr=1e-2)
n_epochs = 100000
losses = []
accuracies = []
for i in range(n_epochs):
    optim.zero_grad()
    predictions = nn.forward(x)
    ce_loss = loss_mod(predictions, y)

    # we add another loss that separates the embeddings of prime and non-prime numbers
    embedding = nn.embedding(x)
    separation_loss = loss_mod(embedding, is_prime)

    loss = ce_loss + separation_loss
    loss.backward()

    optim.step()

    acc = (nn(x).argmax(axis=1) == y).detach().numpy().mean()
    if (acc==1) & (separation_loss < 1e-2).item():
        break

    if i % 10 == 0:
        print(f"\rLoss={loss:6.2e}. Accuracy={acc:5.2f}", end="")

    losses.append(loss.detach().item())
    accuracies.append(acc)

In [None]:
from matplotlib import pyplot as plt
plt.plot(accuracies)

## Let's look at the neural network

In [None]:
# Indeed, prime numbers have been preserved
nn(torch.arange(100)).argmax(1)

Let's check if the separation of embeddings has worked:

In [None]:
arange = torch.arange(100)
plt.scatter(*zip(*nn.embedding(arange).detach().numpy()), c=[e in prime_numbers for e in arange])

Nice! It looks like a bat, no idea why.

## Saving

In [None]:
torch.save(nn.state_dict(), "../puzzle/ppn/torch_state_dict")
reloaded = PropertyPreservingNetwork()
reloaded.load_state_dict(torch.load("../puzzle/ppn/torch_state_dict", weights_only=True))

reloaded(torch.arange(100)).argmax(1)