In [61]:
import torch
from torch import nn
import tqdm
from sklearn.preprocessing import MinMaxScaler

In [62]:
functions = {
    lambda x,y: torch.abs(x-y),
    lambda x,y: x * torch.exp(-y),
    lambda x,y: torch.max(x,y),
    lambda x,y: x - y * 2,
    lambda x,y: torch.log(1 + x) + 2*y
}
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.__getitem__

In [77]:
config = dotdict({
    "d_model": 1024,
    "epochs": 200,
    "batch_size": 64,
    "lr": 0.0005,
    "wd": 0.001,
    "n_layers": 1,
    "max_range": 100
})
X = torch.cartesian_prod(torch.arange(config.max_range), torch.arange(config.max_range))
X_extrap = torch.cartesian_prod(torch.arange(config.max_range, 2*config.max_range), torch.arange(config.max_range, 2*config.max_range))
y = torch.hstack([f(x,y) for x,y in X for f in functions]).view(-1, len(functions))
y_extrap = torch.hstack([f(x,y) for x,y in X_extrap for f in functions]).view(-1, len(functions))
scaler = MinMaxScaler((-1, 1))
y = torch.tensor(scaler.fit_transform(y)).float()
y_extrap = torch.tensor(scaler.transform(y_extrap)).float()
validation_mask = torch.rand(len(X)) < 0.3

In [78]:
class Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(2, config.d_model)
        self.lstm = nn.LSTM(config.d_model, config.d_model, config.n_layers, batch_first=True)
        self.readout = nn.Sequential(nn.Linear(config.d_model * 2, len(functions)), nn.Tanh())

        
    def forward(self, operands):
        output_seq = self.get_embeddings(operands.max()+1)
        embedded_operands = output_seq.squeeze(0)[operands]
        return self.readout(embedded_operands.flatten(1))
    
    def get_embeddings(self, max_range):
        one_embed = self.embedding(torch.Tensor([1]).long())
        inputs = one_embed.unsqueeze(1).expand(-1, max_range, -1)
        output_seq, _ = self.lstm(inputs)
        return output_seq
        

- Embedding for integer 1
- Increment to obtain new integers
- Do addition or some other function on the integers

In [79]:
# training loop
model = Model(config)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.wd)
pbar = tqdm.trange(config.epochs)
X_train = X[~validation_mask]
X_val = X[validation_mask]
y_train = y[~validation_mask]
y_val = y[validation_mask]


def unscale(x):
    return torch.tensor(scaler.inverse_transform(x.detach().numpy()))


y_unscaled = unscale(y_val)
y_extra_unscaled = unscale(y_extrap)

for epoch in pbar:
    model.train()
    optimizer.zero_grad()
    pred_train = model(X_train)
    train_loss = criterion(pred_train, y_train).sqrt()
    train_loss.backward()
    optimizer.step()
    model.eval()
    val_loss = criterion(unscale(model(X_val)), y_unscaled).sqrt()
    loss_extrap = criterion(unscale(model(X_extrap)), y_extra_unscaled).sqrt()
    pbar.set_description(
        f"Epoch {epoch}: {train_loss.item():.3f}, {val_loss.item():.3f}, {loss_extrap.item():.3f}"
    )

Epoch 69: 0.251, 14.167, 80.114:  35%|███▌      | 70/200 [00:47<01:29,  1.46it/s] 

In [None]:
import matplotlib.pyplot as plt
plt.plot(unscale(model(X_extrap)).detach().numpy(), y_extra_unscaled.detach().numpy(), ".")
plt.show()
plt.hist(unscale(model(X_extrap)).detach().numpy(), bins=100, alpha=0.5, label="pred")
plt.hist(y_extra_unscaled.detach().numpy(), bins=100, alpha=0.5, label="true")
plt.legend()
plt.show()

In [None]:
embeddings = model.get_embeddings(config.max_range).squeeze(0).detach().numpy()