In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score
import torch.nn.functional as F
import random
from tqdm import tqdm
import os
import pickle
import psutil
import gc
import warnings
warnings.filterwarnings("ignore")

random.seed(42)

In [2]:
print(psutil.virtual_memory().used /1e10)
print(psutil.virtual_memory().available / 1e10)

6.2113603584
209.7234714624


In [9]:
input_dir = '../data/60k_pmid_dataset_chunks_sentencewise_negatives'

df_list = [pd.read_pickle(os.path.join(input_dir, f)) for f in tqdm(sorted(os.listdir(input_dir)))]
df = pd.concat(df_list, ignore_index=True)

df.head(5)

100%|█████████████████████████████████████████| 195/195 [00:11<00:00, 17.15it/s]


Unnamed: 0,pmid,abstract,abstract_embeddings,terms,relations,negatives,term_embeddings,pair_embeddings,negatives_embeddings
0,33383987,Temperature and Solvent Effects on H2 Splittin...,"[-0.28345045, -0.54292804, 0.16380529, 0.15494...","[fold, lead, cooling, tetrahydrofuran, splitti...","[(individual, kinetic barrier), (lead, tetrahy...","[(fold, tetrahydrofuran), (cases, toluene), (e...","[[-0.28223613, -0.0036575377, 0.08042424, -0.1...","[([-0.7495256, 0.36571825, -0.84759635, -0.030...","[([-0.28223613, -0.0036575377, 0.08042424, -0...."
1,33383988,Mechanism for Rapid Conversion of Amines to Am...,"[-0.2850943, -0.80024165, -0.11266672, 0.14250...","[Particle, Conversion, Amines, Ammonium Salts,...","[(Amines, Ammonium Salts)]","[(Particle, Rapid), (Conversion, Rapid), (Conv...","[[-0.5166602, 0.11687658, -0.09206407, -0.0107...","[([-0.34281886, 0.30733573, -0.13176091, 0.137...","[([-0.5166602, 0.11687658, -0.09206407, -0.010..."
2,33383989,Bisecting GlcNAc Protein N-Glycosylation Is Ch...,"[-0.040046073, -1.0489765, -0.37665504, -0.493...","[Protein N-Glycosylation, Adipogenesis, Human,...","[(Adipogenesis, Human)]","[(Adipogenesis, Protein N-Glycosylation), (Cha...","[[-0.7968025, -0.20326202, -0.566039, -0.13663...","[([-0.21563068, 0.16044843, -0.39353308, -0.15...","[([-0.21563068, 0.16044843, -0.39353308, -0.15..."
3,33383992,Association of Exposure to Cattle with Self-Re...,"[-0.310526, -0.5039563, 0.6601208, -0.08117764...","[relationship, human, study, bovine tuberculos...","[(health, human)]","[(health, relationship), (bovine tuberculosis,...","[[-0.25283906, 0.11320739, -0.30629858, -0.038...","[([0.006518982, 0.25178033, -0.0104505485, -0....","[([0.006518982, 0.25178033, -0.0104505485, -0...."
4,33383995,The Effect of Preoperative Video Based Pain Tr...,"[0.091471456, -0.07006314, -0.11821665, -0.306...","[Analgesic, Control Group, Effect, Video, Pain...","[(Postoperative Pain, Total Knee Arthroplasty)...","[(Analgesic, Postoperative Pain), (Effect, Vid...","[[-0.51707596, 0.17515965, -0.48447058, -0.293...","[([-0.38955075, 0.314926, -0.6473092, -0.47112...","[([-0.51707596, 0.17515965, -0.48447058, -0.29..."


In [9]:
len(df)

194874

In [13]:
df = df[df['negatives_embeddings'].apply(lambda x : len(x) > 10)]
# df = df[df['pair_embeddings'].apply(lambda x: len(x) > 0)]
# len(df[df['negatives_embeddings'].apply(lambda x: len(x) > 10)])
len(df)

149625

In [14]:
train_pmids, temp_pmids = train_test_split(pmids, test_size=0.2, random_state=42)
test_pmids, val_pmids = train_test_split(temp_pmids, test_size=0.5, random_state=42)

train_df = df[df['pmid'].isin(train_pmids)]
test_df = df[df['pmid'].isin(test_pmids)]
val_df = df[df['pmid'].isin(val_pmids)]



In [15]:
class RelationDataset(Dataset):
    def __init__(self, df, num_negatives=10):
        self.df = df
        self.num_negatives = num_negatives

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        abstract_emb = torch.tensor(row['abstract_embeddings'], dtype=torch.float32)
        
        pos_pair, pos_embedding = random.choice(list(zip(row['relations'], row['pair_embeddings'])))
        pos_sample = torch.cat([
            abstract_emb, 
            torch.tensor(pos_embedding, dtype=torch.float32).flatten()
        ])
        pos_term1, pos_term2 = pos_pair

        hard_negatives = []
        used_pairs = set()

        for neg_pair, neg_emb in zip(row['negatives'], row['negatives_embeddings']):
            if (
                (neg_pair[0] == pos_term1 or neg_pair[0] == pos_term2 or
                neg_pair[1] == pos_term1 or neg_pair[1] == pos_term2)
                and neg_pair != pos_pair
            ):
                hard_negatives.append(torch.cat([
                    abstract_emb, 
                    torch.tensor(neg_emb, dtype=torch.float32).flatten()
                ]))
                used_pairs.add(tuple(neg_pair))

            if len(hard_negatives) >= 10:
                break

        remaining_negatives = [
            (neg_pair, neg_emb) for neg_pair, neg_emb in zip(row['negatives'], row['negatives_embeddings'])
            if tuple(neg_pair) not in used_pairs and tuple(neg_pair) != tuple(pos_pair)
        ]

        additional_negs_needed = self.num_negatives - len(hard_negatives)

        if additional_negs_needed > 0:
            sampled_random = random.choices(remaining_negatives, k=additional_negs_needed)
            hard_negatives += [
                torch.cat([abstract_emb, torch.tensor(neg_emb, dtype=torch.float32).flatten()])
                for _, neg_emb in sampled_random
            ]

        X = torch.stack([pos_sample] + hard_negatives)
        y = torch.tensor([1] + [0]*self.num_negatives, dtype=torch.float32)

        return X, y
        

In [16]:
def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None, None
    X, y = zip(*batch)
    return torch.stack(X), torch.stack(y)

In [17]:
train_dataset = RelationDataset(train_df)
test_dataset = RelationDataset(test_df)
val_dataset = RelationDataset(val_df)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

In [18]:
train_df.iloc[19818]

pmid                                                             33481430
abstract                Hypopigmentation in Extramammary Paget Disease...
abstract_embeddings     [0.82409924, -0.7536595, 0.05372125, -0.839472...
terms                   [Prognostic Factor, Important, Rate, High, Poo...
relations                [(Extramammary Paget Disease, Hypopigmentation)]
negatives               [(High, Outcome), (Prognostic Factor, Rate), (...
term_embeddings         [[-0.29985824, 0.13411663, -0.0663986, -0.2757...
pair_embeddings         [([-0.47598863, -0.030381847, -0.7351567, -0.6...
negatives_embeddings    [([-0.22897896, 0.21014515, -0.063291, 0.04567...
Name: 90191, dtype: object

In [19]:
i = 19818
train_dataset[i][0].shape, train_dataset[i][1].shape

(torch.Size([11, 2304]), torch.Size([11]))

In [20]:
class RelationClassifier(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 768),
            nn.ReLU(),
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, X):
        return self.fc(X)

In [26]:
input_dim = 2304
model = RelationClassifier(input_dim)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)
criterion = nn.MarginRankingLoss(margin=0.2, reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=1e-4)

cuda


In [27]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_scores = []
    all_labels = []

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pos_samples = X[y==1]
            neg_samples = X[y==0]

            pos_scores = model(pos_samples).squeeze()
            neg_scores = model(neg_samples).squeeze()

            if pos_scores.dim() == 0:
                pos_scores = pos_scores.unsqueeze(0)

            comb_scores = torch.cartesian_prod(pos_scores, neg_scores).to(device)

            target = torch.ones(len(comb_scores), device=device)
            loss = criterion(comb_scores[:,0], comb_scores[:,1], target)

            total_loss += loss
            all_scores.extend(pos_scores.tolist())
            all_labels.extend([1]*len(pos_scores))
            all_scores.extend(neg_scores.tolist())
            all_labels.extend([0]*len(neg_scores))

    roc_auc = roc_auc_score(all_labels, all_scores)
    pr_auc = average_precision_score(all_labels, all_scores)

    return total_loss / len(dataloader), roc_auc, pr_auc        

In [28]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        pos_samples = X[y==1]
        neg_samples = X[y==0]

        pos_scores = model(pos_samples).squeeze()
        neg_scores = model(neg_samples).squeeze()

        if pos_scores.dim() == 0:
            pos_scores = pos_scores.unsqueeze(0)
            
        comb_scores = torch.cartesian_prod(pos_scores, neg_scores).to(device)

        target = torch.ones(len(comb_scores), device=device)
        loss = criterion(comb_scores[:,0], comb_scores[:,1] , target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [29]:
num_epochs = 50
best_train_loss = 999999
best_val_roc_auc = 0
patience = 10
epochs_no_improve = 0

for epoch in tqdm(range(num_epochs), desc="Training epoch..."):
    print(f"\nEpoch [{epoch+1}/{num_epochs}]")
    
    train_loss = train(model, train_loader, optimizer, criterion, device)
    print(f"Training Loss: {train_loss:.4f}")
    
    val_loss, val_roc_auc, val_pr_auc = evaluate(model, val_loader, criterion, device)
    print(f"Validation Loss: {val_loss:.4f}, ROC AUC: {val_roc_auc:.4f}, PR AUC: {val_pr_auc:.4f}")

    if val_roc_auc > best_val_roc_auc:
        best_val_roc_auc = val_roc_auc
        epochs_no_improve = 0
        torch.save(model.state_dict(), '../models/60000_pmid_model_v4.pth')
        print(f"Saving best current model state at epoch: {epoch+1}")
        
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f"Early stopping due to no improvement. Total training epochs: {epoch+1}")
        break



Training epoch...:   0%|                                 | 0/50 [00:00<?, ?it/s]


Epoch [1/50]
Training Loss: 5137.7473


Training epoch...:   2%|▍                      | 1/50 [01:23<1:08:14, 83.55s/it]

Validation Loss: 4544.7280, ROC AUC: 0.7583, PR AUC: 0.2409
Saving best current model state at epoch: 1

Epoch [2/50]
Training Loss: 4322.8659


Training epoch...:   4%|▉                      | 2/50 [02:47<1:06:52, 83.59s/it]

Validation Loss: 4151.4082, ROC AUC: 0.7798, PR AUC: 0.2700
Saving best current model state at epoch: 2

Epoch [3/50]
Training Loss: 3944.3195


Training epoch...:   6%|█▍                     | 3/50 [04:10<1:05:33, 83.69s/it]

Validation Loss: 3967.4482, ROC AUC: 0.7910, PR AUC: 0.2880
Saving best current model state at epoch: 3

Epoch [4/50]
Training Loss: 3637.8535


Training epoch...:   8%|█▊                     | 4/50 [05:34<1:04:06, 83.61s/it]

Validation Loss: 3883.2493, ROC AUC: 0.7951, PR AUC: 0.2940
Saving best current model state at epoch: 4

Epoch [5/50]
Training Loss: 3389.6401


Training epoch...:  10%|██▎                    | 5/50 [06:57<1:02:29, 83.33s/it]

Validation Loss: 3754.5022, ROC AUC: 0.8035, PR AUC: 0.3030
Saving best current model state at epoch: 5

Epoch [6/50]
Training Loss: 3178.4571


Training epoch...:  12%|██▊                    | 6/50 [08:20<1:01:10, 83.42s/it]

Validation Loss: 3699.2463, ROC AUC: 0.8067, PR AUC: 0.3060
Saving best current model state at epoch: 6

Epoch [7/50]
Training Loss: 2965.6802


Training epoch...:  14%|███▌                     | 7/50 [09:44<59:47, 83.42s/it]

Validation Loss: 3679.4312, ROC AUC: 0.8073, PR AUC: 0.3099
Saving best current model state at epoch: 7

Epoch [8/50]
Training Loss: 2782.0364


Training epoch...:  16%|████                     | 8/50 [11:07<58:25, 83.47s/it]

Validation Loss: 3651.3008, ROC AUC: 0.8096, PR AUC: 0.3083
Saving best current model state at epoch: 8

Epoch [9/50]
Training Loss: 2593.7267


Training epoch...:  18%|████▌                    | 9/50 [12:31<57:10, 83.66s/it]

Validation Loss: 3681.9390, ROC AUC: 0.8108, PR AUC: 0.3113
Saving best current model state at epoch: 9

Epoch [10/50]
Training Loss: 2426.3282


Training epoch...:  20%|████▊                   | 10/50 [13:55<55:47, 83.70s/it]

Validation Loss: 3656.0688, ROC AUC: 0.8112, PR AUC: 0.3124
Saving best current model state at epoch: 10

Epoch [11/50]
Training Loss: 2248.1289


Training epoch...:  22%|█████▎                  | 11/50 [15:19<54:20, 83.60s/it]

Validation Loss: 3746.4004, ROC AUC: 0.8100, PR AUC: 0.3047

Epoch [12/50]
Training Loss: 2086.8444


Training epoch...:  24%|█████▊                  | 12/50 [16:42<52:51, 83.47s/it]

Validation Loss: 3772.6135, ROC AUC: 0.8105, PR AUC: 0.3186

Epoch [13/50]
Training Loss: 1947.8735


Training epoch...:  26%|██████▏                 | 13/50 [18:06<51:31, 83.55s/it]

Validation Loss: 3804.5552, ROC AUC: 0.8108, PR AUC: 0.3163

Epoch [14/50]
Training Loss: 1800.2664


Training epoch...:  28%|██████▋                 | 14/50 [19:29<50:05, 83.49s/it]

Validation Loss: 3971.5264, ROC AUC: 0.8042, PR AUC: 0.3116

Epoch [15/50]
Training Loss: 1640.9357


Training epoch...:  30%|███████▏                | 15/50 [20:52<48:39, 83.42s/it]

Validation Loss: 3801.9587, ROC AUC: 0.8114, PR AUC: 0.3202
Saving best current model state at epoch: 15

Epoch [16/50]
Training Loss: 1522.6089


Training epoch...:  32%|███████▋                | 16/50 [22:16<47:17, 83.45s/it]

Validation Loss: 4005.1917, ROC AUC: 0.8066, PR AUC: 0.3033

Epoch [17/50]
Training Loss: 1415.3376


Training epoch...:  34%|████████▏               | 17/50 [23:39<45:53, 83.45s/it]

Validation Loss: 4005.2571, ROC AUC: 0.8096, PR AUC: 0.3103

Epoch [18/50]
Training Loss: 1301.4743


Training epoch...:  36%|████████▋               | 18/50 [25:02<44:29, 83.42s/it]

Validation Loss: 4021.5867, ROC AUC: 0.8052, PR AUC: 0.3115

Epoch [19/50]
Training Loss: 1213.5364


Training epoch...:  38%|█████████               | 19/50 [26:26<43:06, 83.45s/it]

Validation Loss: 4038.6626, ROC AUC: 0.8075, PR AUC: 0.3075

Epoch [20/50]
Training Loss: 1119.6207


Training epoch...:  40%|█████████▌              | 20/50 [27:49<41:39, 83.31s/it]

Validation Loss: 4111.1538, ROC AUC: 0.8078, PR AUC: 0.3090

Epoch [21/50]
Training Loss: 1026.6409


Training epoch...:  42%|██████████              | 21/50 [29:12<40:16, 83.32s/it]

Validation Loss: 4126.9189, ROC AUC: 0.8098, PR AUC: 0.3091

Epoch [22/50]
Training Loss: 968.9470


Training epoch...:  44%|██████████▌             | 22/50 [30:36<38:51, 83.29s/it]

Validation Loss: 4326.8667, ROC AUC: 0.8034, PR AUC: 0.3090

Epoch [23/50]
Training Loss: 887.0327


Training epoch...:  46%|███████████             | 23/50 [31:58<37:25, 83.17s/it]

Validation Loss: 4289.6387, ROC AUC: 0.8043, PR AUC: 0.3010

Epoch [24/50]
Training Loss: 822.8012


Training epoch...:  48%|███████████▌            | 24/50 [33:22<36:01, 83.15s/it]

Validation Loss: 4290.9966, ROC AUC: 0.8040, PR AUC: 0.3060

Epoch [25/50]
Training Loss: 787.9953


Training epoch...:  48%|███████████▌            | 24/50 [34:44<37:38, 86.87s/it]

Validation Loss: 4282.2412, ROC AUC: 0.8006, PR AUC: 0.2948
Early stopping due to no improvement. Total training epochs: 25





In [45]:
torch.save(model.state_dict(), '../models/60000_pmid_model.pth')

In [42]:
test_loss, test_roc_auc, test_pr_auc = evaluate(model, test_loader, criterion, device)

print(f"Test Loss: {test_loss:.4f}, ROC AUC: {test_roc_auc:.4f}, PR AUC: {test_pr_auc:.4f}\n")

Test Loss: 1.5394, ROC AUC: 0.9190, PR AUC: 0.2627



In [43]:
X = train_dataset[1000][0]
y = train_dataset[1000][1]
X, y

(tensor([[-0.4558, -0.8255,  0.2465,  ...,  0.0486,  0.4510,  0.7287],
         [-0.4558, -0.8255,  0.2465,  ..., -0.4669,  0.2165,  0.2500],
         [-0.4558, -0.8255,  0.2465,  ..., -0.2303,  0.1060,  0.1860],
         ...,
         [-0.4558, -0.8255,  0.2465,  ..., -0.2458,  0.1991,  0.2456],
         [-0.4558, -0.8255,  0.2465,  ..., -0.2303,  0.1060,  0.1860],
         [-0.4558, -0.8255,  0.2465,  ..., -0.0977, -0.6139,  0.4175]]),
 tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0.]))

In [44]:
pos_samples = X[y==1].to(device)
neg_samples = X[y==0].to(device)
pos_samples, neg_samples

(tensor([[-0.4558, -0.8255,  0.2465,  ...,  0.0486,  0.4510,  0.7287]],
        device='cuda:0'),
 tensor([[-0.4558, -0.8255,  0.2465,  ..., -0.4669,  0.2165,  0.2500],
         [-0.4558, -0.8255,  0.2465,  ..., -0.2303,  0.1060,  0.1860],
         [-0.4558, -0.8255,  0.2465,  ..., -0.0573, -0.0356,  0.3023],
         ...,
         [-0.4558, -0.8255,  0.2465,  ..., -0.2458,  0.1991,  0.2456],
         [-0.4558, -0.8255,  0.2465,  ..., -0.2303,  0.1060,  0.1860],
         [-0.4558, -0.8255,  0.2465,  ..., -0.0977, -0.6139,  0.4175]],
        device='cuda:0'))

In [45]:
pos_samples.shape

torch.Size([1, 2304])

In [46]:
pos_scores = model(pos_samples).squeeze()
pos_scores

tensor(0.7222, device='cuda:0', grad_fn=<SqueezeBackward0>)

In [47]:
pos_scores.squeeze()

tensor(0.7222, device='cuda:0', grad_fn=<SqueezeBackward0>)

In [48]:
neg_scores = model(neg_samples).squeeze()
neg_scores

tensor([6.1359e-02, 7.0541e-01, 2.5924e-01, 6.1359e-02, 1.0361e-01, 3.9563e-13,
        2.5813e-25, 6.1359e-02, 2.3524e-02, 6.4892e-01, 6.1359e-02, 2.5592e-01,
        4.7449e-01, 6.7143e-03, 4.8653e-01, 5.5683e-02, 3.8976e-01, 8.7977e-02,
        2.5924e-01, 1.4943e-01, 5.2637e-02, 1.6211e-11, 5.0831e-01, 2.4123e-01,
        6.1359e-02, 9.3152e-02, 1.5495e-01, 6.1850e-01, 5.7529e-01, 6.1359e-02,
        2.5313e-01, 2.4999e-01, 6.1359e-02, 1.6554e-04, 4.7209e-02, 1.0675e-01,
        5.9973e-01, 1.2220e-02, 6.7506e-01, 6.1089e-01], device='cuda:0',
       grad_fn=<SqueezeBackward0>)

In [49]:
test_tensor = torch.cartesian_prod(pos_scores.unsqueeze(0), neg_scores).to(device)
test_tensor

tensor([[7.2215e-01, 6.1359e-02],
        [7.2215e-01, 7.0541e-01],
        [7.2215e-01, 2.5924e-01],
        [7.2215e-01, 6.1359e-02],
        [7.2215e-01, 1.0361e-01],
        [7.2215e-01, 3.9563e-13],
        [7.2215e-01, 2.5813e-25],
        [7.2215e-01, 6.1359e-02],
        [7.2215e-01, 2.3524e-02],
        [7.2215e-01, 6.4892e-01],
        [7.2215e-01, 6.1359e-02],
        [7.2215e-01, 2.5592e-01],
        [7.2215e-01, 4.7449e-01],
        [7.2215e-01, 6.7143e-03],
        [7.2215e-01, 4.8653e-01],
        [7.2215e-01, 5.5683e-02],
        [7.2215e-01, 3.8976e-01],
        [7.2215e-01, 8.7977e-02],
        [7.2215e-01, 2.5924e-01],
        [7.2215e-01, 1.4943e-01],
        [7.2215e-01, 5.2637e-02],
        [7.2215e-01, 1.6211e-11],
        [7.2215e-01, 5.0831e-01],
        [7.2215e-01, 2.4123e-01],
        [7.2215e-01, 6.1359e-02],
        [7.2215e-01, 9.3152e-02],
        [7.2215e-01, 1.5495e-01],
        [7.2215e-01, 6.1850e-01],
        [7.2215e-01, 5.7529e-01],
        [7.221

In [50]:
target = torch.tensor([1]*len(test_tensor)).to(device)
target

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')

In [51]:
criterion(test_tensor[:,0], test_tensor[:, 1], target)

tensor(0.7787, device='cuda:0', grad_fn=<SumBackward0>)

In [178]:
test_tensor[:,0]

tensor([0.5158, 0.5158, 0.5158, 0.5158, 0.5158, 0.5158, 0.5158, 0.5158, 0.5158,
        0.5158, 0.5158, 0.5158, 0.5158, 0.5158, 0.5158, 0.5158, 0.5158, 0.5158,
        0.5158, 0.5158, 0.5163, 0.5163, 0.5163, 0.5163, 0.5163, 0.5163, 0.5163,
        0.5163, 0.5163, 0.5163, 0.5163, 0.5163, 0.5163, 0.5163, 0.5163, 0.5163,
        0.5163, 0.5163, 0.5163, 0.5163, 0.5225, 0.5225, 0.5225, 0.5225, 0.5225,
        0.5225, 0.5225, 0.5225, 0.5225, 0.5225, 0.5225, 0.5225, 0.5225, 0.5225,
        0.5225, 0.5225, 0.5225, 0.5225, 0.5225, 0.5225], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [179]:
test_tensor[:,1]

tensor([0.5159, 0.5162, 0.5190, 0.5207, 0.5171, 0.5167, 0.5213, 0.5224, 0.5200,
        0.5207, 0.5226, 0.5222, 0.5100, 0.5197, 0.5236, 0.5221, 0.5215, 0.5195,
        0.5129, 0.5121, 0.5159, 0.5162, 0.5190, 0.5207, 0.5171, 0.5167, 0.5213,
        0.5224, 0.5200, 0.5207, 0.5226, 0.5222, 0.5100, 0.5197, 0.5236, 0.5221,
        0.5215, 0.5195, 0.5129, 0.5121, 0.5159, 0.5162, 0.5190, 0.5207, 0.5171,
        0.5167, 0.5213, 0.5224, 0.5200, 0.5207, 0.5226, 0.5222, 0.5100, 0.5197,
        0.5236, 0.5221, 0.5215, 0.5195, 0.5129, 0.5121], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [181]:
criterion(test_tensor[0,0], test_tensor[0,1], torch.tensor(1).to(device))

tensor(0.1001, device='cuda:0', grad_fn=<ClampMinBackward0>)

In [186]:
criterion(torch.tensor(0.5158), torch.tensor(0.5129), torch.tensor(1))

tensor(0.0971)