In [1]:
import numpy as np
from shapely import Polygon, affinity
from polygenerator import random_polygon
from utils.vectorizer import vectorize_wkt


def normalize(data, method="min_max"):
    if method == "min_max":
        # Compute min and max values along the num_features dimension
        min_vals = np.min(data[:, :, :2], axis=1, keepdims=True)
        max_vals = np.max(data[:, :, :2], axis=1, keepdims=True)
        # Perform min-max normalization
        data[:, :, :2] -= min_vals
        data[:, :, :2] /=  max_vals - min_vals

    else:
        mean = np.mean(data[:, :, :2], axis=1, keepdims=True)
        std = np.std(data[:, :, :2], axis=1, keepdims=True)

        data[:, :, :2] -= mean
        data[:, :, :2] /=  std
    return data


def cosine_similarity(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    return dot_product / (norm_vec1 * norm_vec2)

In [None]:
from potae import PoTAE
import torch

pretrain_model = "weights/potae_repeat2_lr0.0001_dmodel384_bs512_epoch200_runname-mild-darkness-78.pth"
device = torch.device("cpu")

emb_model = PoTAE(fea_dim=7, d_model=384, num_heads=6, hidden_dim=64,
                    ffn_dim=1024, layer_repeat=2, dropout=0.1, max_seq_len=64).to(device)
emb_model.load_state_dict(torch.load(pretrain_model, map_location=device))
emb_model.eval()

In [3]:
a = Polygon(random_polygon(10))
va = vectorize_wkt(a.wkt, max_points=64, fixed_size=True, simplify=True)
va = normalize(np.expand_dims(va, axis=0))
b = affinity.scale(a, 0.5, 0.5)
# b = affinity.translate(a, 10, 10)
# b = affinity.rotate(a, 30)
vb = vectorize_wkt(b.wkt, max_points=64, fixed_size=True, simplify=True)
vb = normalize(np.expand_dims(vb, axis=0))

tokens = np.concatenate([va, vb], axis=0)
tokens = torch.tensor(tokens, dtype=torch.float32)

with torch.no_grad():
    embeddings, _, _ = emb_model(tokens)

cosine_similarity(embeddings[0], embeddings[1])

1.0

In [4]:
total_params_pot_model = sum(p.numel() for p in emb_model.parameters())
print(f"Total number of parameters in the Transformer model b: {total_params_pot_model}")

Total number of parameters in the Transformer model b: 28612167


In [4]:
emb_model2 = PoTAE(fea_dim=7, d_model=36, num_heads=4, hidden_dim=64,
                    ffn_dim=32, layer_repeat=1, dropout=0.1, max_seq_len=64).to(device)
total_params_pot_model = sum(p.numel() for p in emb_model2.parameters())
print(f"Total number of parameters in the Transformer model s: {total_params_pot_model}")

3
Total number of parameters in the Transformer model s: 142061


In [6]:
# 142061 -- > 28612167, potae: 0.14M --> 28M, emb_model: 0.71M --> 14.3M
print(f"scalp up: {28612167 // 142061}")

scalp up: 201
