In [22]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the model
class EmbeddingModel(nn.Module):
    def __init__(self):
        super(EmbeddingModel, self).__init__()
        self.fc1 = nn.Linear(514, 512)  # 512 for vector + 2 for latitude and longitude
        self.fc2 = nn.Linear(512, 512)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Initialize the model
model = EmbeddingModel()


In [26]:
import torch.nn.functional as F

def custom_loss(v1, lat1, long1, v2, lat2, long2):
    # Convert to tensors
    v1 = torch.tensor(v1, dtype=torch.float32)
    latlong1 = torch.tensor([lat1, long1], dtype=torch.float32)
    input1 = torch.cat((v1, latlong1))

    v2 = torch.tensor(v2, dtype=torch.float32)
    latlong2 = torch.tensor([lat2, long2], dtype=torch.float32)
    input2 = torch.cat((v2, latlong2))

    # Forward pass
    # print(input1.shape)
    output1 = model(input1)
    output2 = model(input2)

    # Calculate cosine similarities
    cos_sim_output = F.cosine_similarity(output1, output2, dim=0)
    cos_sim_input = F.cosine_similarity(v1, v2, dim=0)

    # Adjust target cosine similarity
    lat_diff = (abs(lat1 - lat2))
    long_diff = (abs(long1 - long2))
    adjustment = 0.15 / (2.718)**(lat_diff + long_diff)
    target_cos_sim = cos_sim_input + adjustment

    # Compute the loss as the difference in cosine similarities
    loss = torch.abs(cos_sim_output - target_cos_sim)
    return loss

In [27]:
# model.fc1.weight.data = torch.eye(512, 512 + 2)
# model.fc2.weight.data = torch.eye(512,512)
# model.fc1.bias.data.fill_(0)

In [28]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [30]:
import tqdm
for epoch in range(10000):  # number of epochs
    # Generate random data for this example
    loss_list = []
    for i in range(1000):
        
        v1 = torch.randn(512)
        lat1, long1 = torch.randn(1).item() * 3, torch.randn(1).item()* 3

        v2 = torch.randn(512)
        lat2, long2 = torch.randn(1).item()* 3, torch.randn(1).item()* 3

        # Zero gradients
        optimizer.zero_grad()

        # Compute loss
        loss = custom_loss(v1, lat1, long1, v2, lat2, long2)

        # Backward pass
        loss.backward()
        loss_list.append(loss.item())
        # Update weights
        optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {sum(loss_list)/len(loss_list)}")
    if epoch % 100 == 0:
        torch.save(model.state_dict(), "ST_encoding_model")
        

  v1 = torch.tensor(v1, dtype=torch.float32)
  v2 = torch.tensor(v2, dtype=torch.float32)


Epoch 0, Loss: 0.18422712357062845
Epoch 10, Loss: 0.16303857398265972
Epoch 20, Loss: 0.14934915833175183
Epoch 30, Loss: 0.14947187703824602
Epoch 40, Loss: 0.13529918790701775
Epoch 50, Loss: 0.13161369952047244
Epoch 60, Loss: 0.12400184927368536
Epoch 70, Loss: 0.12243034507660196
Epoch 80, Loss: 0.11516378641035407
Epoch 90, Loss: 0.11213016899232753
Epoch 100, Loss: 0.1096701590961311
Epoch 110, Loss: 0.11029742342565442
Epoch 120, Loss: 0.10752758192177862
Epoch 130, Loss: 0.10607626142480876
Epoch 140, Loss: 0.10434643066464923
Epoch 150, Loss: 0.10095510110235772
Epoch 160, Loss: 0.09645836902130395
Epoch 170, Loss: 0.09486820031621028
Epoch 180, Loss: 0.0992112760017626
Epoch 190, Loss: 0.0904293657138478
Epoch 200, Loss: 0.09603704526997171


KeyboardInterrupt: 