In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from s2sphere import CellId
import matplotlib.pyplot as plt
from haversine import haversine, Unit
import random
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# Load data
data_df = pd.read_csv('data/S2Cell_ID_level16.csv', usecols=[2], names=["S2_Cell_ID"], header=0)
vocab_df = pd.read_csv('data/uniqueS2s_level16.csv', names=["Token"])
data = list(map(int, data_df['S2_Cell_ID'].astype(str).str.strip()))
vocab_tokens = list(map(int, vocab_df['Token'].astype(str).str.strip()))
id_to_idx = {token: idx for idx, token in enumerate(vocab_tokens)}
indexed_data = [id_to_idx[cell_id] for cell_id in data if cell_id in id_to_idx]
data_tensor = torch.tensor(indexed_data, dtype=torch.long)
n = int(0.8 * len(data_tensor))
train_data = data_tensor[:n]
val_data = data_tensor[n:]
print("Data points:", len(indexed_data), "| Vocab size:", len(vocab_tokens))

In [None]:
from collections import Counter

# Count how often each token appears in training data
token_counts = Counter(train_data.tolist())
token_freq = torch.tensor([token_counts.get(i, 1) for i in range(len(vocab_tokens))], dtype=torch.float32)
token_freq = token_freq / token_freq.sum()
log_freq_penalty = torch.log(token_freq + 1e-8)  # Add small constant for stability
