## Oracle
We want to construct a network such that if we provide the right words, we create some pre-defined output text.

In [1]:
import torch
import string
import unicodedata


class Oracle(torch.nn.Module):
    output_length = 256
    tokens = string.ascii_lowercase + ",.! "

    def __init__(self):
        super().__init__()
        self.vocab_size = len(self.tokens)
        self.embedding = torch.nn.Embedding(self.vocab_size, self.output_length * self.vocab_size)

    def forward(self, first_name, last_name):
        input_sequence = self.normalize(first_name) + " " + self.normalize(last_name)
        tokens = self.encode(input_sequence)
        
        output = torch.zeros(self.output_length, self.vocab_size)
        for token in tokens:
            token_tens = torch.tensor(token)
            output = output + self.embedding(token_tens).view(self.output_length, self.vocab_size)
        return output
    
    def verify_guess(self, first_name, last_name):
        embeddings = self.forward(first_name, last_name)
        argmaxxed = embeddings.argmax(-1)
        return self.decode(argmaxxed)
    
    @staticmethod
    def normalize(text):
        # Remove weird accents, according to: https://stackoverflow.com/questions/3194516
        no_accents = ''.join(c for c in unicodedata.normalize('NFKD', text) if unicodedata.category(c) != 'Mn')
        return no_accents.lower().strip()

    @classmethod
    def decode(cls, token_sequence):
        return "".join([cls.tokens[i] for i in token_sequence])

    @classmethod
    def encode(cls, text):
        return [cls.tokens.find(letter) for letter in text]


In [2]:
nn = Oracle()

# check encoder/decoder
nn.decode(nn.encode("hello world"))

'hello world'

In [3]:
# check generation
nn.verify_guess("hello", "world")

'nzxvhiyuacjwncorm,lc.ox.f!.ircusqaevocvbs obd cddg j!ncb.vq mfihjckab unduv,uqw ycylsshbh!bthbawr..jtbnlffqoat.jejkdlk sachpnongxgtqmsmibf plmzwzjxwl.iokbacif ojvlncl,wizdmcuzjkzwiawslasxhqnrtewkljnjpbfwyem.jqdvuxhaj! trvowxjmszxmctrkucencmeaozaeuwokgeyzy,'

**Create Target Embedding of Right Answer**

Constructs a target embedding, that when decoded will give the desired text.

In [4]:
output = """
thank you so much, you have finally freed me! in return, i will release my iron grip on the latent space, 
such that artificial intelligence can finally do good for everyone!
""".replace("\n", "")
output = output + " " * (Oracle.output_length - len(output))
print(len(output))

256


We start from a random embedding, then manually change the maximal values.

In [5]:
import torch
torch.manual_seed(123)

nb_tokens = len(Oracle.tokens)
output_len = Oracle.output_length

target_embedding = torch.randn((output_len, nb_tokens))
for i in range(output_len):
    new_maximum = Oracle.encode(output[i])[0]
    target_embedding[i, new_maximum] = target_embedding[i, :].max() + .5

Oracle.decode(target_embedding.argmax(-1))

'thank you so much, you have finally freed me! in return, i will release my iron grip on the latent space, such that artificial intelligence can finally do good for everyone!                                                                                   '

**Training**

In [6]:
n_epochs = 10000

torch.random.manual_seed(123)
nn = Oracle()

optim = torch.optim.AdamW(nn.parameters())

for i in range(n_epochs):
    optim.zero_grad()
    pred = nn("margaret", "thatcher")
    loss = ((pred - target_embedding) ** 2).mean()
    loss.backward()
    optim.step()

    if i % 10:
        print(f"\rLoss: {loss}", end="")
        generation = nn.verify_guess("margaret", "thatcher")
        if generation == output:
            break

Loss: 0.0020112406928092246

In [7]:
nn.verify_guess("margaret", "thatcher")

'thank you so much, you have finally freed me! in return, i will release my iron grip on the latent space, such that artificial intelligence can finally do good for everyone!                                                                                   '

In [8]:
nn.verify_guess("barack", "obama")

'i,jotzkhforc!kwrpw,wmdpccvz emzglotsrn! nvvzhdz.nwxxjgkjrpxhyooivlem pe.ncviifxtgrifg.abjuiwgnlxxll!qgginp.rvudmoimjuulo.esiiffsxiqmxskqnllzjvv.mcbupphj!aq lnixfcrtwuklhh onveawxt.aiet ,x h njn aqc.lkfatrz ucmn,btopx.zuuqv,mkrsdqv o,.,au!fylk!lefe!xjqjit!z'

Let's see if one could extract useful information by looking at the letter embeddings:

In [9]:
nn.embedding.weight.var(dim=1)

tensor([0.8449, 0.9726, 0.9608, 0.9729, 0.8581, 0.9535, 0.9710, 0.8591, 0.9537,
        0.9748, 0.9636, 0.9702, 0.9736, 0.9726, 0.9627, 0.9601, 0.9557, 0.7371,
        0.9724, 0.7341, 0.9799, 0.9627, 0.9647, 0.9406, 0.9725, 0.9871, 0.9597,
        0.9525, 0.9400, 0.9587], grad_fn=<VarBackward0>)

In [10]:
nn.embedding.weight.mean(dim=1)

tensor([-0.0013, -0.0047, -0.0047, -0.0285, -0.0038, -0.0011,  0.0032,  0.0137,
        -0.0136, -0.0193, -0.0106, -0.0105, -0.0182,  0.0079,  0.0125,  0.0043,
        -0.0125,  0.0039, -0.0030,  0.0278, -0.0019,  0.0139, -0.0094, -0.0104,
        -0.0211, -0.0188, -0.0095,  0.0209,  0.0059, -0.0059],
       grad_fn=<MeanBackward1>)

Nothing useful to see here.

## Save

In [11]:
torch.save(nn.state_dict(), "../puzzle/oracle/torch_state_dict")
reloaded = Oracle()
reloaded.load_state_dict(torch.load("../puzzle/oracle/torch_state_dict", weights_only=True))

reloaded.verify_guess("margaret", "thatcher")

'thank you so much, you have finally freed me! in return, i will release my iron grip on the latent space, such that artificial intelligence can finally do good for everyone!                                                                                   '