In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
import os
import sys
os.chdir('/content/gdrive/My Drive/MIT/Research/Notebooks/EnrichedTiSASRecV7-Pytorch')
sys.path.append("/content/gdrive/My Drive/MIT/Research/Notebooks/EnrichedTiSASRecV7-Pytorch")
!pwd

/content/gdrive/My Drive/MIT/Research/Notebooks/EnrichedTiSASRecV7-Pytorch


In [4]:
!ls

 BooksWithCategoryPercentage_default	 final_run      README.md
 data					 inference.py   train_enriched_tisasrec_v7.ipynb
'EnrichedTiSASRecNotes - V7.txt'	 main.py        utils.py
 EnrichedTiSASRec_test_functions.ipynb	 model.py
 EnrichedTiSASRec_test_model.ipynb	 __pycache__


In [5]:
from itertools import islice

def take(n, iterable):
    """Return the first n items of the iterable as a list."""
    return list(islice(iterable, n))

In [19]:
import os
import torch
import pickle
import argparse

def str2bool(s):
    if s not in {'false', 'true'}:
        raise ValueError('Not a valid boolean string')
    return s == 'true'

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='BooksWithCategoryPercentage', type=str,
                    help='The preprocess data file name, e.g. the Books.txt file will be Books')
parser.add_argument('--train_dir', default='default', type=str,
                    help='The directory to save the trained model. The directory will be named as: dataset_train_dir')
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=5, type=int)
parser.add_argument('--hidden_units', default=4, type=int)
parser.add_argument('--num_blocks', default=2, type=int)
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--num_heads', default=1, type=int)
parser.add_argument('--dropout_rate', default=0.4, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--device', default='cuda', type=str)
parser.add_argument('--inference_only', default=False, type=str2bool)
parser.add_argument('--state_dict_path', default=None, type=str)
parser.add_argument('--tensorboard_log_dir', default="final_run", type=str)
parser.add_argument('--time_span', default=256, type=int)

args = parser.parse_args(args=[])
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [20]:
from utils import data_partition

dataset = data_partition(args.dataset)
[user_train, user_valid, user_test, usernum, itemnum, timenum, catnum] = dataset

print(take(1, user_train.items()))

Preparing data...
Preparing done...
[(1, [[13788, 1, 4, 9], [23426, 10, 5, 9], [11079, 12, 4, 9], [14643, 12, 5, 28], [38738, 12, 5, 9], [29945, 17, 5, 28], [16515, 17, 5, 28], [37090, 26, 4, 28], [10217, 37, 5, 4], [29256, 37, 5, 9], [14055, 37, 5, 9], [47793, 37, 5, 0], [29477, 40, 5, 1], [43344, 40, 5, 1], [42244, 43, 4, 0], [31694, 47, 5, 0], [40127, 69, 5, 9], [47738, 70, 4, 28], [14912, 82, 5, 28], [7986, 82, 5, 0], [26180, 94, 5, 28], [39976, 106, 5, 9], [20768, 110, 5, 0], [5254, 124, 5, 28], [47813, 124, 5, 28], [1800, 124, 5, 28], [17503, 127, 5, 9], [15333, 127, 5, 9], [38697, 128, 5, 28], [25387, 142, 5, 28]])]


In [21]:
num_batch = len(user_train) // args.batch_size
num_batch

57208

In [22]:
from utils import WarpSampler, Relation
import numpy as np

try:
    relation_matrix = pickle.load(open('data/relation_matrix_%s_%d_%d.pickle'%(args.dataset, args.maxlen, args.time_span),'rb'))
except:
    relation_matrix = Relation(user_train, usernum, args.maxlen, args.time_span)
    pickle.dump(relation_matrix, open('data/relation_matrix_%s_%d_%d.pickle'%(args.dataset, args.maxlen, args.time_span),'wb'))

sampler = WarpSampler(user_train, usernum, itemnum, relation_matrix, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3)

u, seq, rat_seq, time_seq, time_matrix, cat_seq, pos, neg = sampler.next_batch() # tuples to ndarray

u, seq, rat_seq, cat_seq, pos, neg = np.array(u), np.array(seq), np.array(rat_seq), np.array(cat_seq), np.array(pos), np.array(neg)

time_seq, time_matrix = np.array(time_seq), np.array(time_matrix)

In [23]:
u, seq, rat_seq, time_seq, time_matrix, cat_seq, pos, neg = sampler.next_batch() # tuples to ndarray

u, seq, rat_seq, cat_seq, pos, neg = np.array(u), np.array(seq), np.array(rat_seq), np.array(cat_seq), np.array(pos), np.array(neg)

time_seq, time_matrix = np.array(time_seq), np.array(time_matrix)

print("-----------------User----------------------")
print(u)
print("-----------------Seq-----------------------")
print(seq)
print(np.shape(seq))
print("-----------------Rating Seq-----------------------")
print(rat_seq)
print(np.shape(seq))
print("-----------------Category Seq-----------------------")
print(cat_seq)
print(np.shape(rat_seq))
print("-----------------Pos----------------------")
print(pos)
print(np.shape(pos))
print("-----------------Neg----------------------")
print(neg)
print(np.shape(neg))
print("-----------------Time Seq----------------------")
print(time_seq)
print(np.shape(time_seq))
print("-----------------Time Matrix----------------------")
print(time_matrix)
print(np.shape(time_matrix))

-----------------User----------------------
[37123]
-----------------Seq-----------------------
[[22382 40603 41758 44739  4627]]
(1, 5)
-----------------Rating Seq-----------------------
[[3 5 3 5 5]]
(1, 5)
-----------------Category Seq-----------------------
[[ 6 28  9  9 28]]
(1, 5)
-----------------Pos----------------------
[[40603 41758 44739  4627  9397]]
(1, 5)
-----------------Neg----------------------
[[20865 46944 19773 45343 39574]]
(1, 5)
-----------------Time Seq----------------------
[[383 392 396 455 475]]
(1, 5)
-----------------Time Matrix----------------------
[[[ 0  9 13 72 92]
  [ 9  0  4 63 83]
  [13  4  0 59 79]
  [72 63 59  0 20]
  [92 83 79 20  0]]]
(1, 5, 5)


In [24]:
!ls

 BooksWithCategoryPercentage_default	 final_run      README.md
 data					 inference.py   train_enriched_tisasrec_v7.ipynb
'EnrichedTiSASRecNotes - V7.txt'	 main.py        utils.py
 EnrichedTiSASRec_test_functions.ipynb	 model.py
 EnrichedTiSASRec_test_model.ipynb	 __pycache__


### PointWiseFeedForward

In [25]:
import numpy as np
import torch
import sys

FLOAT_MIN = -sys.float_info.max

class PointWiseFeedForward(torch.nn.Module):
    def __init__(self, hidden_units, dropout_rate): # wried, why fusion X 2?

        super(PointWiseFeedForward, self).__init__()

        self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)

    def forward(self, inputs):
        outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
        outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length)
        outputs += inputs
        return outputs



### TimeAwareMultiHeadAttention

In [26]:
class TimeAwareMultiHeadAttention(torch.nn.Module):
    # required homebrewed mha layer for Ti/SASRec experiments
    def __init__(self, hidden_size, head_num, dropout_rate, dev):
        super(TimeAwareMultiHeadAttention, self).__init__()
        self.Q_w = torch.nn.Linear(hidden_size, hidden_size)
        self.K_w = torch.nn.Linear(hidden_size, hidden_size)
        self.V_w = torch.nn.Linear(hidden_size, hidden_size)

        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.softmax = torch.nn.Softmax(dim=-1)

        self.hidden_size = hidden_size
        self.head_num = head_num
        self.head_size = hidden_size // head_num
        self.dropout_rate = dropout_rate
        self.dev = dev

    def forward(self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matrix_V,
                                            ratings_K, ratings_V,
                                            categories_K, categories_V,
                                            abs_pos_K, abs_pos_V):
        Q, K, V = self.Q_w(queries), self.K_w(keys), self.V_w(keys)

        # head dim * batch dim for parallelization (h*N, T, C/h)
        Q_ = torch.cat(torch.split(Q, self.head_size, dim=2), dim=0)
        K_ = torch.cat(torch.split(K, self.head_size, dim=2), dim=0)
        V_ = torch.cat(torch.split(V, self.head_size, dim=2), dim=0)

        time_matrix_K_ = torch.cat(torch.split(time_matrix_K, self.head_size, dim=3), dim=0)
        time_matrix_V_ = torch.cat(torch.split(time_matrix_V, self.head_size, dim=3), dim=0)
        abs_pos_K_ = torch.cat(torch.split(abs_pos_K, self.head_size, dim=2), dim=0)
        abs_pos_V_ = torch.cat(torch.split(abs_pos_V, self.head_size, dim=2), dim=0)

        ratings_K_ = torch.cat(torch.split(ratings_K, self.head_size, dim=2), dim=0)
        ratings_V_ = torch.cat(torch.split(ratings_V, self.head_size, dim=2), dim=0)

        categories_K_ = torch.cat(torch.split(categories_K, self.head_size, dim=2), dim=0)
        categories_V_ = torch.cat(torch.split(categories_V, self.head_size, dim=2), dim=0)

        # batched channel wise matmul to gen attention weights
        attn_weights = Q_.matmul(torch.transpose(K_, 1, 2))
        attn_weights += Q_.matmul(torch.transpose(ratings_K_, 1, 2))
        attn_weights += Q_.matmul(torch.transpose(categories_K_, 1, 2))
        attn_weights += Q_.matmul(torch.transpose(abs_pos_K_, 1, 2))
        attn_weights += time_matrix_K_.matmul(Q_.unsqueeze(-1)).squeeze(-1)

        # seq length adaptive scaling
        attn_weights = attn_weights / (K_.shape[-1] ** 0.5)

        # key masking, -2^32 lead to leaking, inf lead to nan
        # 0 * inf = nan, then reduce_sum([nan,...]) = nan

        # fixed a bug pointed out in https://github.com/pmixer/TiSASRec.pytorch/issues/2
        # time_mask = time_mask.unsqueeze(-1).expand(attn_weights.shape[0], -1, attn_weights.shape[-1])
        time_mask = time_mask.unsqueeze(-1).repeat(self.head_num, 1, 1)
        time_mask = time_mask.expand(-1, -1, attn_weights.shape[-1])
        attn_mask = attn_mask.unsqueeze(0).expand(attn_weights.shape[0], -1, -1)
        paddings = torch.ones(attn_weights.shape) *  (-2**32+1) # -1e23 # float('-inf')
        paddings = paddings.to(self.dev)
        attn_weights = torch.where(time_mask, paddings, attn_weights) # True:pick padding
        attn_weights = torch.where(attn_mask, paddings, attn_weights) # enforcing causality

        attn_weights = self.softmax(attn_weights) # code as below invalids pytorch backward rules
        # attn_weights = torch.where(time_mask, paddings, attn_weights) # weird query mask in tf impl
        # https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/4
        # attn_weights[attn_weights != attn_weights] = 0 # rm nan for -inf into softmax case
        attn_weights = self.dropout(attn_weights)

        outputs = attn_weights.matmul(V_)
        outputs += attn_weights.matmul(ratings_V_)
        outputs += attn_weights.matmul(categories_V_)
        outputs += attn_weights.matmul(abs_pos_V_)
        outputs += attn_weights.unsqueeze(2).matmul(time_matrix_V_).reshape(outputs.shape).squeeze(2)

        # (num_head * N, T, C / num_head) -> (N, T, C)
        outputs = torch.cat(torch.split(outputs, Q.shape[0], dim=0), dim=2) # div batch_size

        return outputs


### EnrichedTiSASRec

In [30]:
class EnrichedTiSASRec(torch.nn.Module): # similar to torch.nn.MultiheadAttention
    def __init__(self, user_num, item_num, time_num, cat_num, args):
        super(EnrichedTiSASRec, self).__init__()

        self.user_num = user_num
        self.item_num = item_num
        self.cat_num = cat_num
        self.dev = args.device

        # TODO: loss += args.l2_emb for regularizing embedding vectors during training
        # https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorch
        self.item_emb = torch.nn.Embedding(self.item_num+1, args.hidden_units, padding_idx=0)
        self.item_emb_dropout = torch.nn.Dropout(p=args.dropout_rate)

        self.rat_K_emb = torch.nn.Embedding(6, args.hidden_units) # new rating embedding Key, hardcode 6 for rating 1->5
        self.rat_V_emb = torch.nn.Embedding(6, args.hidden_units) # new rating embedding Value, hardcode 6 for rating 1->5
        self.cat_K_emb = torch.nn.Embedding(100, args.hidden_units) # new category embedding Key
        self.cat_V_emb = torch.nn.Embedding(100, args.hidden_units) # new category embedding Value

        self.abs_pos_K_emb = torch.nn.Embedding(args.maxlen, args.hidden_units)
        self.abs_pos_V_emb = torch.nn.Embedding(args.maxlen, args.hidden_units)
        self.time_matrix_K_emb = torch.nn.Embedding(args.time_span+1, args.hidden_units)
        self.time_matrix_V_emb = torch.nn.Embedding(args.time_span+1, args.hidden_units)

        self.item_emb_dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.abs_pos_K_emb_dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.abs_pos_V_emb_dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.time_matrix_K_dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.time_matrix_V_dropout = torch.nn.Dropout(p=args.dropout_rate)

        self.rat_K_emb_dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.rat_V_emb_dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.cat_K_emb_dropout = torch.nn.Dropout(p=args.dropout_rate)
        self.cat_V_emb_dropout = torch.nn.Dropout(p=args.dropout_rate)

        self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attention
        self.attention_layers = torch.nn.ModuleList()
        self.forward_layernorms = torch.nn.ModuleList()
        self.forward_layers = torch.nn.ModuleList()

        self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)

        for _ in range(args.num_blocks):
            new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
            self.attention_layernorms.append(new_attn_layernorm)

            new_attn_layer = TimeAwareMultiHeadAttention(args.hidden_units,
                                                            args.num_heads,
                                                            args.dropout_rate,
                                                            args.device)
            self.attention_layers.append(new_attn_layer)

            new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
            self.forward_layernorms.append(new_fwd_layernorm)

            new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate)
            self.forward_layers.append(new_fwd_layer)

            # self.pos_sigmoid = torch.nn.Sigmoid()
            # self.neg_sigmoid = torch.nn.Sigmoid()

    def seq2feats(self, user_ids, log_seqs, rating_seqs, cat_seqs, time_matrices):
        seqs = self.item_emb(torch.LongTensor(log_seqs).to(self.dev))

        print("-----------------seqs-----------------------")
        print(seqs)
        print(np.shape(seqs))

        seqs *= self.item_emb.embedding_dim ** 0.5
        seqs = self.item_emb_dropout(seqs)

        ratings = torch.LongTensor(rating_seqs).to(self.dev)
        ratings_K = self.rat_K_emb(ratings)
        print("-----------------Rating Seq-----------------------")
        print(ratings_K)
        print(np.shape(ratings_K))

        ratings_K = self.rat_K_emb_dropout(ratings_K)
        ratings_V = self.rat_V_emb(ratings)
        ratings_V = self.rat_V_emb_dropout(ratings_V)

        categories = torch.LongTensor(cat_seqs).to(self.dev)
        categories_K = self.cat_K_emb(categories)
        print("-----------------Category Seq-----------------------")
        print(categories_K)
        print(np.shape(categories_K))
        categories_K *= self.cat_K_emb.embedding_dim ** 0.5
        categories_K = self.cat_K_emb_dropout(categories_K)
        categories_V = self.cat_V_emb(categories)
        categories_V *= self.cat_V_emb.embedding_dim ** 0.5
        categories_V = self.cat_V_emb_dropout(categories_V)

        positions = np.tile(np.array(range(log_seqs.shape[1])), [log_seqs.shape[0], 1])
        positions = torch.LongTensor(positions).to(self.dev)
        abs_pos_K = self.abs_pos_K_emb(positions)
        abs_pos_V = self.abs_pos_V_emb(positions)
        abs_pos_K = self.abs_pos_K_emb_dropout(abs_pos_K)
        abs_pos_V = self.abs_pos_V_emb_dropout(abs_pos_V)

        time_matrices = torch.LongTensor(time_matrices).to(self.dev)
        time_matrix_K = self.time_matrix_K_emb(time_matrices)
        print("-----------------Time Matrix----------------------")
        print(time_matrix_K)
        print(np.shape(time_matrix_K))

        time_matrix_V = self.time_matrix_V_emb(time_matrices)
        time_matrix_K = self.time_matrix_K_dropout(time_matrix_K)
        time_matrix_V = self.time_matrix_V_dropout(time_matrix_V)

        # mask 0th items(placeholder for dry-run) in log_seqs
        # would be easier if 0th item could be an exception for training
        timeline_mask = torch.BoolTensor(log_seqs == 0).to(self.dev)
        seqs *= ~timeline_mask.unsqueeze(-1) # broadcast in last dim

        tl = seqs.shape[1] # time dim len for enforce causality
        attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev))

        for i in range(len(self.attention_layers)):
            # Self-attention, Q=layernorm(seqs), K=V=seqs
            # seqs = torch.transpose(seqs, 0, 1) # (N, T, C) -> (T, N, C)
            Q = self.attention_layernorms[i](seqs) # PyTorch mha requires time first fmt
            mha_outputs = self.attention_layers[i](Q, seqs,
                                            timeline_mask, attention_mask,
                                            time_matrix_K, time_matrix_V,
                                            ratings_K, ratings_V,
                                            categories_K, categories_V,
                                            abs_pos_K, abs_pos_V)
            seqs = Q + mha_outputs
            # seqs = torch.transpose(seqs, 0, 1) # (T, N, C) -> (N, T, C)

            # Point-wise Feed-forward, actually 2 Conv1D for channel wise fusion
            seqs = self.forward_layernorms[i](seqs)
            seqs = self.forward_layers[i](seqs)
            seqs *=  ~timeline_mask.unsqueeze(-1)

        log_feats = self.last_layernorm(seqs)

        return log_feats

    def forward(self, user_ids, log_seqs, rating_seqs, category_seqs, time_matrices, pos_seqs, neg_seqs): # for training
        log_feats = self.seq2feats(user_ids, log_seqs, rating_seqs, category_seqs, time_matrices)

        pos_embs = self.item_emb(torch.LongTensor(pos_seqs).to(self.dev))
        neg_embs = self.item_emb(torch.LongTensor(neg_seqs).to(self.dev))

        pos_logits = (log_feats * pos_embs).sum(dim=-1)
        neg_logits = (log_feats * neg_embs).sum(dim=-1)

        # pos_pred = self.pos_sigmoid(pos_logits)
        # neg_pred = self.neg_sigmoid(neg_logits)

        return pos_logits, neg_logits # pos_pred, neg_pred

    def predict(self, user_ids, log_seqs, rating_seqs, cat_seqs, time_matrices, item_indices): # for inference
        log_feats = self.seq2feats(user_ids, log_seqs, rating_seqs, cat_seqs, time_matrices)

        final_feat = log_feats[:, -1, :] # only use last QKV classifier, a waste

        item_embs = self.item_emb(torch.LongTensor(item_indices).to(self.dev)) # (U, I, C)

        logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1)

        # preds = self.pos_sigmoid(logits) # rank same item list for different users

        return logits # preds # (U, I)


### Test data through model

In [31]:
model = EnrichedTiSASRec(usernum, itemnum, timenum, catnum, args).to(args.device)
model

EnrichedTiSASRec(
  (item_emb): Embedding(51023, 4, padding_idx=0)
  (item_emb_dropout): Dropout(p=0.4, inplace=False)
  (rat_K_emb): Embedding(6, 4)
  (rat_V_emb): Embedding(6, 4)
  (cat_K_emb): Embedding(100, 4)
  (cat_V_emb): Embedding(100, 4)
  (abs_pos_K_emb): Embedding(5, 4)
  (abs_pos_V_emb): Embedding(5, 4)
  (time_matrix_K_emb): Embedding(257, 4)
  (time_matrix_V_emb): Embedding(257, 4)
  (abs_pos_K_emb_dropout): Dropout(p=0.4, inplace=False)
  (abs_pos_V_emb_dropout): Dropout(p=0.4, inplace=False)
  (time_matrix_K_dropout): Dropout(p=0.4, inplace=False)
  (time_matrix_V_dropout): Dropout(p=0.4, inplace=False)
  (rat_K_emb_dropout): Dropout(p=0.4, inplace=False)
  (rat_V_emb_dropout): Dropout(p=0.4, inplace=False)
  (cat_K_emb_dropout): Dropout(p=0.4, inplace=False)
  (cat_V_emb_dropout): Dropout(p=0.4, inplace=False)
  (attention_layernorms): ModuleList(
    (0-1): 2 x LayerNorm((4,), eps=1e-08, elementwise_affine=True)
  )
  (attention_layers): ModuleList(
    (0-1): 2 x Tim

In [32]:
pos_logits, neg_logits = model(u, seq, rat_seq, cat_seq, time_matrix, pos, neg)

-----------------seqs-----------------------
tensor([[[-0.8900,  1.3668, -2.1336, -1.6571],
         [ 1.4695, -0.6013, -0.0461, -1.1570],
         [-1.1222,  0.2531, -1.2986, -0.7379],
         [ 1.3656,  0.4033,  1.2069,  1.6462],
         [ 0.7947, -0.8965,  0.2563, -0.9455]]], grad_fn=<EmbeddingBackward0>)
torch.Size([1, 5, 4])
-----------------Rating Seq-----------------------
tensor([[[ 0.8759,  2.3151,  0.2084, -1.1902],
         [ 1.8937,  1.5236, -1.1071,  0.7474],
         [ 0.8759,  2.3151,  0.2084, -1.1902],
         [ 1.8937,  1.5236, -1.1071,  0.7474],
         [ 1.8937,  1.5236, -1.1071,  0.7474]]], grad_fn=<EmbeddingBackward0>)
torch.Size([1, 5, 4])
-----------------Category Seq-----------------------
tensor([[[ 0.1965, -0.4271, -0.7063,  1.4317],
         [ 1.0148,  0.3974,  0.3423, -0.9162],
         [ 0.5214, -0.4449, -0.9129, -0.1117],
         [ 0.5214, -0.4449, -0.9129, -0.1117],
         [ 1.0148,  0.3974,  0.3423, -0.9162]]], grad_fn=<EmbeddingBackward0>)
torch.

In [None]:
pos_logits

tensor([[-12.9521,  18.6612,  -8.2262,  -3.8192,   1.4114]],
       grad_fn=<SumBackward1>)

In [None]:
pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
pos_labels, neg_labels

(tensor([[1., 1., 1., 1., 1.]]), tensor([[0., 0., 0., 0., 0.]]))