In [33]:
# training script for new model
# prototype 2022-12-27


# Imports

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import hashlib

from tqdm import tqdm
import rdkit.Chem as Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
import glob
import torch
from torch_geometric.loader import DataLoader
from metrics import *
from utils import *
from datetime import datetime
import logging
import sys
import argparse
from torch.utils.data import RandomSampler
from torch.utils.data import WeightedRandomSampler
import random
import math
from torch.utils.tensorboard import SummaryWriter
from data_prototype import get_data_prototype
from sampler_prototype import SessionBatchSampler
from model import *


class PairwiseLoss(nn.Module):
    def __init__(self, keep_rate=1., sigmoid_lambda=0.3,
                 ingrp_thr=0.3, outgrp_thr=9999, eval=False):
        super(PairwiseLoss, self).__init__()
        self.eval = eval
        self.register_buffer('keep_rate', torch.tensor(keep_rate, dtype=torch.float64))
        self.register_buffer('sigmoid_lambda', torch.tensor(sigmoid_lambda, dtype=torch.float64))
        self.register_buffer('ingrp_thr', torch.tensor(ingrp_thr, dtype=torch.float64))
        self.register_buffer('outgrp_thr', torch.tensor(outgrp_thr, dtype=torch.float64))

    def forward(self, pred, true,): # groupid):

        """
        Customized pairwise ranking loss.

        """
        if len(pred.shape) == 1:
            pred = pred.unsqueeze(1)
        if len(true.shape) == 1:
            true = true.unsqueeze(1)
        drop_rate = 1 - self.keep_rate
        true_tile_row = true.repeat((1, true.shape[0]))
        true_tile_col = torch.t(true_tile_row)
        assert (true_tile_row.shape == true_tile_col.shape)

        pred_tile_row = pred.repeat((1, pred.shape[0]))
        pred_tile_col = torch.t(pred_tile_row)
    
        diff = (true_tile_row - true_tile_col) / (torch.abs(true_tile_col) + 1e-4)

        pred_pair = torch.stack([pred_tile_row, pred_tile_col], dim=0)
        valid_ind = torch.logical_or(diff > self.ingrp_thr,
                                     diff > self.outgrp_thr)
        
        # print("len(valid_ind)",len(valid_ind))
        
        pred_pair_valid = pred_pair.masked_select(valid_ind).reshape(2, -1)
        pred_pair_diff = pred_pair_valid[0] - pred_pair_valid[1]
        reverse = torch.sum(pred_pair_diff > 0)
        ntotal = pred_pair_valid.shape[1] + 1e-8
        #print("Total valid pairs: {:.3f}, reversed: {:.3f}, reverse ratio: {:.3f}".format(ntotal, reverse, reverse/ntotal))
        if drop_rate > 1e-4:
            pred_pair_valid_dropout = torch.nn.functional.dropout(pred_pair_valid[0], drop_rate) * self.keep_rate
            pred_pair_valid_ind = torch.logical_or(pred_pair_valid_dropout == pred_pair_valid[0],
                                                   pred_pair_valid_dropout != 0.0)

            pred_pair_valid = pred_pair_valid.masked_select(pred_pair_valid_ind).reshape(2, -1)
        print(pred_pair_valid[1] - pred_pair_valid[0])
        c = pred_pair_valid[1] - pred_pair_valid[0]
        loss = torch.sum(torch.log(1. + torch.exp(self.sigmoid_lambda * (pred_pair_valid[1] - pred_pair_valid[0]))))
        num = pred_pair_valid.shape[1] + 1e-8
        loss = torch.div(loss, num)
        return loss, (reverse/ntotal), c


In [34]:
loss = PairwiseLoss()

In [35]:
affinity_pred = torch.tensor([106.9631, 106.5550, 154.6032,  55.8685,  51.7456,  91.5580,  66.9029, 96.7818])
affinity_true = torch.tensor([10000.,  2000.,    50.,  5000., 10000., 10000.,  2500., 10000.])

In [36]:
a1, a2, a3 = loss(affinity_pred, affinity_true)
a1, a2, a3

tensor([ -0.4081,  47.6401, -51.0946, -40.0602,  48.0482,  50.6865,  98.7347,
         11.0344,  54.8094, 102.8576,   4.1229,  15.1573,  14.9970,  63.0452,
        -35.6895, -24.6551,  87.7003,   9.7732,  57.8214, -40.9133, -29.8789])


(tensor(9.5680),
 tensor(0.3333),
 tensor([ -0.4081,  47.6401, -51.0946, -40.0602,  48.0482,  50.6865,  98.7347,
          11.0344,  54.8094, 102.8576,   4.1229,  15.1573,  14.9970,  63.0452,
         -35.6895, -24.6551,  87.7003,   9.7732,  57.8214, -40.9133, -29.8789]))

In [22]:
a3

tensor([ -0.4081,  47.6401, -51.0946, -40.0602,  48.0482,  50.6865,  98.7347,
         11.0344,  54.8094, 102.8576,   4.1229,  15.1573,  14.9970,  63.0452,
        -35.6895, -24.6551,  87.7003,   9.7732,  57.8214, -40.9133, -29.8789])

In [23]:
torch.exp(a3)

tensor([6.6491e-01, 4.8959e+20, 6.4550e-23, 4.0002e-18, 7.3632e+20, 1.0301e+22,
               inf, 6.1970e+04, 6.3594e+23,        inf, 6.1738e+01, 3.8259e+06,
        3.2592e+06, 2.3998e+27, 3.1641e-16, 1.9608e-11, 1.2239e+38, 1.7557e+04,
        1.2927e+25, 1.7044e-18, 1.0562e-13])

In [15]:
b = torch.exp(a[0])
b

tensor(inf)