```bash
uv pip install torch torchvision tensorboardx
```

In [6]:
import os
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gait import Column, Layer, Layers, FEL
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer

In [4]:
if torch.backends.mps.is_available():
    print("MPS is available!")
else:
    print("MPS is not available.")

MPS is available!


In [7]:
layers = Layers.load(os.path.expanduser("~/data/NorthSea.json"))
fel = FEL(layers)

In [None]:
data = []

line1 = {layers.create_line_1() for _ in range(2000)}
line2 = {layers.create_line_2() for _ in range(2000)}

data.extend([(_, 0) for _ in line1])
data.extend([(_, 1) for _ in line2])

random.shuffle(data)
len(data)

In [None]:
class SemanticRoutingEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_routes):
        super(SemanticRoutingEmbeddingModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)  # Learnable embeddings
        self.fc = nn.Linear(embed_dim, num_routes)  # Map embeddings to route space

    def forward(self, x):
        x = self.embedding(x)  # Convert tokens to embeddings
        x = x.mean(dim=1)  # Pooling (mean or max pooling over sequence)
        x = F.normalize(x, p=2, dim=1)  # Normalize embeddings
        logits = self.fc(x)  # Route ID logits
        return x, logits

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # Calculate Euclidean distance
        distance = F.pairwise_distance(output1, output2)
        # Contrastive loss
        loss = (label * distance**2) + (
            (1 - label) * F.relu(self.margin - distance) ** 2
        )
        return loss.mean()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to("mps")

In [None]:
def train(model, dataloader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            optimizer.zero_grad()

            # Prepare data
            inputs, route_ids = batch  # For classification
            # inputs, inputs2, labels = batch  # For contrastive loss

            # Forward pass
            embeddings, logits = model(inputs)
            loss = criterion(logits, route_ids)  # Use appropriate loss function

            # Backward pass
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}")

In [None]:
class SemanticRouteDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=360):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, meta = self.data[idx]
        tokens = self.tokenizer(
            text,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].squeeze(0)
        return tokens, meta

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
# from torch.utils.data import DataLoader

dataset = SemanticRouteDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size
print(f"Vocabulary size: {vocab_size}")

In [None]:
model = SemanticRoutingEmbeddingModel(tokenizer.vocab_size, 768, 2).to("mps")

In [None]:
device = "mps"
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()  # Reset gradients
        _, logits = model(inputs)  # Forward pass

        loss = criterion(logits, targets)  # Compute loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")