In [3]:
import torch
from torch import nn
import numpy as np

In [4]:
class PairwiseLoss(nn.Module):
    def __init__(self, keep_rate=1., sigmoid_lambda=0.5,
                 ingrp_thr=2, outgrp_thr=9999, eval=False):
        super(PairwiseLoss, self).__init__()
        self.eval = eval
        self.register_buffer('keep_rate', torch.tensor(keep_rate, dtype=torch.float64))
        self.register_buffer('sigmoid_lambda', torch.tensor(sigmoid_lambda, dtype=torch.float64))
        self.register_buffer('ingrp_thr', torch.tensor(ingrp_thr, dtype=torch.float64))
        self.register_buffer('outgrp_thr', torch.tensor(outgrp_thr, dtype=torch.float64))

    def forward(self, pred, true):
        if len(pred.shape) == 1:
            pred = pred.unsqueeze(1)
        if len(true.shape) == 1:
            true = true.unsqueeze(1)
        drop_rate = 1 - self.keep_rate
        true_tile_row = true.repeat((1, true.shape[0]))
        true_tile_col = torch.t(true_tile_row)
        assert (true_tile_row.shape == true_tile_col.shape)

        pred_tile_row = pred.repeat((1, pred.shape[0]))
        pred_tile_col = torch.t(pred_tile_row)
    
        diff = (true_tile_row - true_tile_col) / (torch.abs(true_tile_col) + 1e-4)

        pred_pair = torch.stack([pred_tile_row, pred_tile_col], dim=0)
        valid_ind = torch.logical_or(diff > self.ingrp_thr,
                                     diff > self.outgrp_thr)
        
        # print("len(valid_ind)",len(valid_ind))
        
        pred_pair_valid = pred_pair.masked_select(valid_ind).reshape(2, -1)
        pred_pair_diff = pred_pair_valid[0] - pred_pair_valid[1]
        reverse = torch.sum(pred_pair_diff > 0)
        ntotal = pred_pair_valid.shape[1] + 1e-8
        #print("Total valid pairs: {:.3f}, reversed: {:.3f}, reverse ratio: {:.3f}".format(ntotal, reverse, reverse/ntotal))
        if drop_rate > 1e-4:
            pred_pair_valid_dropout = torch.nn.functional.dropout(pred_pair_valid[0], drop_rate) * self.keep_rate
            pred_pair_valid_ind = torch.logical_or(pred_pair_valid_dropout == pred_pair_valid[0],
                                                   pred_pair_valid_dropout != 0.0)

            pred_pair_valid = pred_pair_valid.masked_select(pred_pair_valid_ind).reshape(2, -1)
        # print(pred_pair_valid[1] - pred_pair_valid[0])
        loss = torch.sum(torch.log(1. + torch.exp(self.sigmoid_lambda * (pred_pair_valid[1] - pred_pair_valid[0] + 1))))
        num = pred_pair_valid.shape[1] + 1e-8
        loss = torch.div(loss, num)
        return loss, (reverse/ntotal), ntotal

In [8]:
train_result_list = ["/home/jovyan/main_tankbind/tankbind_prototype/result/2023_01_12_11_14_37/train/epoch_result/epoch_12.pt"]
train_split_list = ["/home/jovyan/main_tankbind/tankbind_prototype/result/2023_01_12_11_14_37/train/batch_split_info/batch_in_epoch_12_with_seed_12.pt"]

In [9]:
for train_split, train_result in zip(train_split_list, train_result_list):
    split = torch.load(train_split)
    result = torch.load(train_result)