In [5]:
import torch
from torch.nn import functional as F

In [29]:
# feat1 = torch.rand(256, 128)
# feat2 = torch.rand(256, 128)

feat1 = torch.arange(0,6)
feat1 = feat1.reshape(2,3)

feat2 = torch.arange(6,12)
feat2 = feat2.reshape(2,3)

In [32]:
features = torch.cat([feat1.unsqueeze(1), feat2.unsqueeze(1)], dim=1)

In [35]:
features.shape

torch.Size([2, 2, 3])

In [34]:
feat1.unsqueeze(1).shape

torch.Size([2, 1, 3])

In [30]:
print(feat1)
print(feat2)

tensor([[0, 1, 2],
        [3, 4, 5]])
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])


In [33]:
features

tensor([[[ 0,  1,  2],
         [ 6,  7,  8]],

        [[ 3,  4,  5],
         [ 9, 10, 11]]])

In [43]:
features.shape

torch.Size([2, 2, 3])

In [40]:
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

In [45]:
torch.unbind(features, dim=1)

(tensor([[0, 1, 2],
         [3, 4, 5]]),
 tensor([[ 6,  7,  8],
         [ 9, 10, 11]]))

In [44]:
contrast_feature

tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

In [62]:
anchor_feature = contrast_feature
anchor_count = 2
contrast_count = 2

anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            1.0)

anchor_dot_contrast

tensor([[  5.,  14.,  23.,  32.],
        [ 14.,  50.,  86., 122.],
        [ 23.,  86., 149., 212.],
        [ 32., 122., 212., 302.]])

In [56]:
logits_max, _  = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits_max.shape

torch.Size([4, 1])

In [60]:
logits = anchor_dot_contrast - logits_max.detach()
print(logits.shape)
print(logits)

torch.Size([4, 4])
tensor([[ -27.,  -18.,   -9.,    0.],
        [-108.,  -72.,  -36.,    0.],
        [-189., -126.,  -63.,    0.],
        [-270., -180.,  -90.,    0.]])


In [63]:
mask = mask.repeat(anchor_count, contrast_count)
mask

tensor([[1., 0., 1., 0.],
        [0., 1., 0., 1.],
        [1., 0., 1., 0.],
        [0., 1., 0., 1.]])

In [74]:
logits_mask = torch.scatter(
    torch.ones_like(mask),
    1,
    torch.arange(2 * anchor_count).view(-1, 1),
    0
)
logits_mask

tensor([[0., 1., 1., 1.],
        [1., 0., 1., 1.],
        [1., 1., 0., 1.],
        [1., 1., 1., 0.]])

In [66]:
mask = mask * logits_mask
mask

tensor([[0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.]])

In [77]:
torch.exp(logits)

tensor([[1.8795e-12, 1.5230e-08, 1.2341e-04, 1.0000e+00],
        [0.0000e+00, 5.3802e-32, 2.3195e-16, 1.0000e+00],
        [0.0000e+00, 0.0000e+00, 4.3596e-28, 1.0000e+00],
        [0.0000e+00, 0.0000e+00, 8.1940e-40, 1.0000e+00]])

In [68]:
exp_logits = torch.exp(logits) * logits_mask
exp_logits

tensor([[0.0000e+00, 1.5230e-08, 1.2341e-04, 1.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.3195e-16, 1.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
        [0.0000e+00, 0.0000e+00, 8.1940e-40, 0.0000e+00]])

In [81]:
print(feat1.sum(1, keepdim=True))
print(feat1)

tensor([[ 3],
        [12]])
tensor([[0, 1, 2],
        [3, 4, 5]])


In [75]:
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
log_prob

tensor([[-2.7000e+01, -1.8000e+01, -9.0001e+00, -1.2337e-04],
        [-1.0800e+02, -7.2000e+01, -3.6000e+01,  0.0000e+00],
        [-1.8900e+02, -1.2600e+02, -6.3000e+01,  0.0000e+00],
        [-1.8000e+02, -9.0000e+01,  0.0000e+00,  9.0000e+01]])

In [82]:
mask * log_prob

tensor([[  -0.0000,   -0.0000,   -9.0001,   -0.0000],
        [  -0.0000,   -0.0000,   -0.0000,    0.0000],
        [-189.0000,   -0.0000,   -0.0000,    0.0000],
        [  -0.0000,  -90.0000,    0.0000,    0.0000]])

In [76]:
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
mean_log_prob_pos

tensor([  -9.0001,    0.0000, -189.0000,  -90.0000])

In [85]:
mask.sum(1, keepdim=True)

tensor([[1.],
        [1.],
        [1.],
        [1.]])

In [111]:
import os
import pandas as pd
import torch.utils.data as util_data
from torch.utils.data import Dataset, random_split
import math

class AugmentPairSamples(Dataset):
    def __init__(self, train_x_0, train_y, train_x_1 = None):
        
        if train_x_1 is not None:
            assert len(train_y) == len(train_x_0) == len(train_x_1)
        else:
            assert len(train_y) == len(train_x_0)
            
        self.train_x_0 = train_x_0
        self.train_x_1 = train_x_1
        self.train_y = train_y                
        
        
    def __len__(self):
        return len(self.train_y)

    def __getitem__(self, idx):
        if self.train_x_1 is not None:
            return {'text0': self.train_x_0[idx], 'text1': self.train_x_1[idx],'label': self.train_y[idx]}
        else:
            return {'text0': self.train_x_0[idx], 'label': self.train_y[idx]}

datapath = './datasets/augmented/contextual_20/'
dataname = 'stackoverflow'

ratio = 0.8
train_sample = math.ceil((len(train_dataset)*ratio))
val_sample = math.ceil(len(train_dataset)*(1-ratio))

train_data = pd.read_csv(os.path.join(datapath, dataname))

train_text_0 = train_data['text0'].fillna('.').values
train_text_1 = train_data['text1'].fillna('.').values
train_label = train_data['label'].astype(int).values

all_dataset = AugmentPairSamples(train_text_0, train_label, train_text_1)
train_dataset, val_dataset = random_split(all_dataset, [train_sample, val_sample],
                                          generator=torch.Generator().manual_seed(42))

train_loader = util_data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
val_loader = util_data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)


In [115]:


train_dataset, val_dataset = random_split(train_dataset, [math.ceil((len(train_dataset)*ratio)) , math.ceil(len(train_dataset)*(1-ratio))], generator=torch.Generator().manual_seed(42))

In [116]:
len(train_dataset)

16000

In [118]:
len(val_dataset)

4000

In [119]:
val_dataset[0]

{'text0': 'Hibernate One-to-Many cascade efficiency',
 'text1': 'hibernate : - infinitely - many cascade efficiency',
 'label': 12}