In [None]:
import random

import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader

from loc2vec.dataset import TilesDataset
from loc2vec.model import Loc2VecModel, SoftmaxTripletLoss
from loc2vec.train import train

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = Loc2VecModel(input_channels=3, embedding_dim=16, dropout_rate=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss_fn = SoftmaxTripletLoss()
dataset = TilesDataset(
    "full",
    pos_radius=1,
    transform=T.Compose(
        [
            T.Resize((128, 128)),
            T.ToTensor(),
            T.Normalize([0.8107, 0.8611, 0.7814], [0.1215, 0.0828, 0.1320]),
        ]
    ),
)
train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    prefetch_factor=10,
    persistent_workers=True,
)

sample = random.choice(dataset)

print(f"Input shape: {sample['anchor_image'].shape}")
print(f"Training on device: {device}")

model.to(device)

for epoch in range(5):
    avg_loss = train(model, train_loader, optimizer, loss_fn, device=device)
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")

torch.save(model.state_dict(), "model.pth")

In [None]:
from loc2vec.embeddings import log_embeddings_to_tensorboard

log_embeddings_to_tensorboard(
    model, train_loader, device=device, log_dir="logs/embeddings", max_samples=1000
)