## Imports

In [1]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from transformer import *

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Helvetica"
})
plt.rcParams['figure.dpi'] = 400
sns.set(rc={"figure.figsize": (8, 4)}, style="whitegrid")

## Training

In [2]:
# Example usage:
dim_input = 10
dim_hidden = 256
dim_output = 128
cosine_loss = False
lr = 1e-12 if cosine_loss else 1e-4

kwargs = {
    "num_outputs": 8,
    "num_inds": 64,
    "num_heads": 8,
    "cosine_loss": cosine_loss,
    "lr": lr,
    "lr_multiplicator": 1e-2,
    "tensor_length_min": 2**2,
    "tensor_length_max": 2**12,
    "use_deepset": False,
    "dtype": torch.float32
}


trainer = SetTransformerTrainer(dim_input, dim_hidden, dim_output, **kwargs)

In [3]:
losses, losses_valid_perm, losses_invalid_perm = trainer.train(5000)

_i: L2Norm of Invalid Permutations
_v: L2Norm of Valid Permutations
loss: Loss of model


Training:   3%| | 174/5000 [00:21<09:50,  8.17it/s, loss=1.307E-01, v=8.062E-02, i=5.007E-02


KeyboardInterrupt: 

In [15]:
trainer.load("./models/st_st_l2norm.torch")

In [16]:
trainer.test(iterations=500, perm_threshold=1, only_valid_perms=False)

Testing...


Testing invalid permutations: 100%|██████████████████████| 500/500 [00:02<00:00, 201.91it/s]
Testing valid permutations: 100%|████████████████████████| 500/500 [00:02<00:00, 197.65it/s]


## Testing

In [13]:
with torch.no_grad():
    points_per_tensor= np.random.randint(2**2, 2**15)
    x1, x2 = trainer.gen_data(points_per_tensor, permutation=True)

    with torch.no_grad():
        embedding_1 = trainer.model(x1).squeeze(0)
        embedding_2 = trainer.model(x2).squeeze(0)
    
    l2norm = torch.nn.PairwiseDistance(p=2, eps=0)
    
    # nicholas = (torch.abs(torch.sum(embedding_1**2) - torch.sum(embedding_2**2))).sqrt().item()
    l2norm_val = l2norm(embedding_1, embedding_2).item()
    print(l2norm_val)

0.002409805078059435


## Saved Models

In [None]:
# st_nn_l2norm.torch

dim_input = 10
dim_hidden = 256
dim_output = 128

cosine_loss = False
lr = 1e-12 if cosine_loss else 1e-20

kwargs = {
    "cosine_loss": cosine_loss,
    "lr": lr,
    "lr_multiplicator": 1e-2,
    "tensor_length_min": 2**4,
    "tensor_length_max": 2**15,
    "use_deepset": True,
    "dtype": torch.float32
}



In [None]:
# st_st_l2norm.torch

dim_input = 10
dim_hidden = 256
dim_output = 128
cosine_loss = False
lr = 1e-12 if cosine_loss else 1e-4

kwargs = {
    "num_outputs": 8,
    "num_inds": 64,
    "num_heads": 8,
    "cosine_loss": cosine_loss,
    "lr": lr,
    "lr_multiplicator": 1e-2,
    "tensor_length_min": 2**2,
    "tensor_length_max": 2**12,
    "use_deepset": False,
    "dtype": torch.float32
}