In [None]:
import numpy as np
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from db_utils import load_from_db

In [None]:
class RankNetDataset(Dataset):
    def __init__(self, data_dir="../data", split="train"):
        data_file = f"{data_dir}/distances_{split}_.json"
        speeds_stats_file = f"{data_dir}/speeds_stats.json"
        map_stats_file = f"{data_dir}/map_stats.json"
        print(f"Loading data from {data_file}...")

        self.split = split
        self.X, self.y, self.map_groups = load_from_db(data_file)

        with open(speeds_stats_file, 'r') as f:
            speeds_stats = json.load(f)
            self.std_speed = speeds_stats["std_speed"]
            self.mean_speed = speeds_stats["mean_speed"]
        with open(map_stats_file, 'r') as f:
            map_stats = json.load(f)
            self.std_map = map_stats["std_map"]
            self.mean_map = map_stats["mean_map"]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    

class RankNetPairDataset(RankNetDataset):
    def __init__(self, same_map_pairs=True, **kwargs):
        super().__init__(**kwargs)
        self.same_map_pairs = same_map_pairs
    
    def __getitem__(self, i1):
        x1, y1 = super().__getitem__(i1)
        map1 = x1["map"]
        
        if self.same_map_pairs:
            i2 = np.random.choice(self.map_groups[map1])
        else:
            i2 = np.random.randint(0, len(self.X))

        x2, y2 = super().__getitem__(i2)

        sp1 = self.normalize_speed(x1["speed"])
        sp2 = self.normalize_speed(x2["speed"])
        c1 = self.normalize_coord(x1["coordinates"])
        c2 = self.normalize_coord(x2["coordinates"])

        if y1 > y2:
            target = 1.0
        elif y1 < y2:
            target = 0.0
        else:
            target = 0.5

        return (c1, sp1), (c2, sp2), target
    
    def normalize_speed(self, speed):
        return (speed - self.mean_speed) / self.std_speed
    
    def normalize_coord(self, coord):
        return (coord - self.mean_map) / self.std_map


In [None]:
class RankNetDataloader(DataLoader):
    def __init__(self, dataset, *args, **kwargs):
        super().__init__(dataset, collate_fn=self.collate_fn, *args, **kwargs)
    
    def collate_fn(self, batch):
        x_batch = []
        speed_batch = []
        x_lengths = []
        ids = []
        for id, x in enumerate(batch):
            coord, speeds = x
            x_batch.extend([torch.tensor(path, dtype=torch.float32) for path in coord])
            speed_batch.extend(speeds)
            x_lengths.extend([len(path) for path in coord])
            ids.extend([id] * len(coord))
        
        padded_x = nn.utils.rnn.pad_sequence(x_batch, batch_first=True)
        speed_tensor = torch.tensor(speed_batch, dtype=torch.float32)
        ids_tensor = torch.tensor(ids, dtype=torch.long)

        return padded_x, speed_tensor, x_lengths, ids_tensor


class RankNetPairDataloader(DataLoader):
    def collate_fn(self, batch):
        x1, x2, targets = zip(*batch)
        padded_x1, speed1, lengths1, ids1 = super().collate_fn(x1)
        padded_x2, speed2, lengths2, ids2 = super().collate_fn(x2)
        targets_tensor = torch.tensor(targets, dtype=torch.float32)

        return (padded_x1, speed1, lengths1, ids1), (padded_x2, speed2, lengths2, ids2), targets_tensor

In [None]:
class RankNetModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super(RankNetModel, self).__init__()
        self.phi = PathEncoder(input_dim, hidden_dim)
        self.rho = AggregationNet(hidden_dim)

    def forward(self, padded_x, speeds, lengths, ids):
        path_embeddings = self.phi(padded_x, speeds, lengths)

        unique_ids, inverse_indices = torch.unique(ids, return_inverse=True)
        instance_embeddings = torch.zeros(len(unique_ids), path_embeddings.size(1)).to(path_embeddings.device)
        instance_embeddings.index_add_(0, inverse_indices, path_embeddings)

        scores = self.rho(instance_embeddings, unique_ids)
        return scores


class PathEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, num_layers=2)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, padded_x, speeds, lengths):
        packed_x = nn.utils.rnn.pack_padded_sequence(padded_x,
                                                     lengths=lengths,
                                                     batch_first=True,
                                                     enforce_sorted=False)
        _, (h_n, _) = self.lstm(packed_x)
        ordered_h_n = h_n.index_select(1, packed_x.unsorted_indices)

        speeds = speeds.unsqueeze(1)
        h_n_combined = torch.cat((ordered_h_n[-1], speeds), dim=1)

        return self.fc(h_n_combined)


class AggregationNet(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, instance_embeddings):
        aggregated = instance_embeddings.mean(dim=0, keepdim=True)
        return self.fc(aggregated)