In [None]:
## need dependencies https://github.com/google-research/fast-soft-sort

In [None]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
from .fast_soft_sort.pytorch_ops import soft_rank
import tqdm.auto as tqdm
from scipy.stats import spearmanr, kendalltau, rankdata
from sklearn.model_selection import KFold
from sklearn.linear_model import ElasticNet, Ridge


def spearman_loss_fn(pred_rank, actual_rank, log_rank=False, weights=None):
    if log_rank:
        actual_rank = torch.log(actual_rank)
        pred_rank = torch.log(pred_rank)
    # squared norm of difference between predicted and actual rank
    rank_diff_squared = (pred_rank - actual_rank) ** 2
    if weights is not None:
        sum_rank_diff_squared = (rank_diff_squared * weights).sum()
    else:
        sum_rank_diff_squared = rank_diff_squared.sum()
    return 1 / len(actual_rank) * sum_rank_diff_squared


class SpearmanRankProbeModel(nn.Module):
    def __init__(self, d_model, reg_strength=0.0001, warm_start=None):
        super(SpearmanRankProbeModel, self).__init__()
        self.d_model = d_model
        self.reg_strength = reg_strength
        self.feature_direction = nn.Linear(d_model, 1, bias=False)

        if warm_start is not None:
            self.feature_direction.weight.data = warm_start

    def forward(self, X):
        return soft_rank(
            self.feature_direction(X).T,
            regularization_strength=self.reg_strength
        )


class SpearmanRankProbe:
    def __init__(
        self,
        embedding_dim,
        reg_strength=0.00001,
        max_epochs=200,
        lr=1e-3,
        weight_decay=2e-1,
        betas=(0.9, 0.98),
        warm_start=None,
    ):
        self.probe = SpearmanRankProbeModel(
            embedding_dim,
            reg_strength=reg_strength,
            warm_start=warm_start
        )
        self.optimizer = torch.optim.AdamW(
            self.probe.parameters(),
            lr=lr, weight_decay=weight_decay, betas=betas
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=max_epochs, eta_min=5e-4)

        self.max_epochs = max_epochs

        self.iters_ = []
        self.training_results = []

        # TODO: potential features to add
        # - early stopping

    def fit(
        self, data, ranking,
        weights=None,
        validation_data=None,
        validation_ranking=None,
        log_rank=False,
        save_iterates=False,
        debug=True,
        verbose=False,
    ):
        """
        Train the rank probe model on the given data and ranking.

        Parameters
        ----------
        data : torch.Tensor
            The data to train on. Should be a 2D tensor of shape (n, d) where n is the number of
            data points and d is the dimensionality of the data.
        ranking : torch.Tensor
            The ranking to train on. Should be a 1D tensor of shape (n,).
        weights : torch.Tensor, optional
            The weights to use for each data point, by default None (1 for all)
        validation_data : torch.Tensor, optional
            The data to use for validation in statistics, by default None (same as data)
        validation_ranking : torch.Tensor, optional
            The ranking to use for validation in statistics, by default None (same as ranking)
        log_rank : bool, optional
            Whether to take the log of the ranking before training, by default False
        debug : bool, optional
            Whether to save debug information, by default True
        verbose : bool, optional
            Whether to print pred_soft_rank debug information, by default False
        """
        if weights is not None:
            assert len(weights) == len(data)
            weights /= weights.sum()

        for epoch in tqdm.tqdm(range(self.max_epochs), disable=not verbose):
            self.optimizer.zero_grad()
            pred_soft_rank = self.probe(data)
            loss = spearman_loss_fn(
                pred_soft_rank, ranking, log_rank=log_rank, weights=weights)

            # if len(self.iters_) > 0 and self.training_results[-1]['loss'] - loss.item() > 0:
            #     break  # converged

            # begin debug information
            if debug:
                weight_norm = self.probe.feature_direction.weight.data.norm().item()
                spearman_coef = spearmanr(
                    pred_soft_rank.detach().numpy().flatten(),
                    ranking.detach().numpy().flatten()
                ).correlation

                feature_direction = self.probe.feature_direction.weight.data.clone().detach().squeeze()
                prev_direction = self.iters_[-1].to(torch.float32) if len(self.iters_) > 0 \
                    else torch.zeros_like(feature_direction)

                self.training_results.append({
                    'epoch': epoch,
                    'loss': loss.item(),
                    'pred_variance': pred_soft_rank.var().item(),
                    'pred_mean': pred_soft_rank.mean().item(),
                    'pred_max': pred_soft_rank.max().item(),
                    'pred_min': pred_soft_rank.min().item(),
                    'train_spearman': spearman_coef,
                    'lr': self.optimizer.param_groups[0]['lr'],
                    'weight_norm': weight_norm,
                    'weight_change': (feature_direction - prev_direction).norm().item() / weight_norm,
                    'cosine_similarity': torch.cosine_similarity(
                        feature_direction.unsqueeze(dim=0), prev_direction.unsqueeze(dim=0)).item()
                })

                if validation_data is not None and validation_ranking is not None:
                    validation_spearman = spearmanr(
                        validation_data @ feature_direction, validation_ranking).correlation
                    self.training_results[-1]['validation_spearman'] = validation_spearman

                # save weights
                self.iters_.append(feature_direction.to(torch.float16))

                if verbose:
                    print(
                        f'Epoch: {epoch} | Train Spearman: {spearman_coef} | Loss: {loss.item()} | Weight norm: {weight_norm}')
            # end debug information

            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

        if not save_iterates:
            self.iters_ = self.iters_[-1:]

        if pred_soft_rank.min().item() > 1.5:
            print('Warning: soft ranks did not converge. Turn regularization down.')

        # if len(self.iters_) == self.max_epochs:
        #     print('Warning: max epochs reached without convergence.')

    def score(self, data, ranking):
        pred_soft_rank = self.probe(data)
        return spearmanr(
            pred_soft_rank.detach().numpy(),
            ranking.detach().numpy()
        ).correlation

    def get_feature_direction(self):
        return self.probe.feature_direction.weight.data.flatten().detach().cpu()