In [25]:
from torch.utils.data import DataLoader, Dataset
import torch

class ContrastiveDataset(Dataset):
    def __init__(self, n=10000, num_groups=10):
        self.n = n
        self.num_groups = num_groups
        self.range_size = self.n // self.num_groups
        self.ranges = [
            (self.range_size * i, self.range_size * (i + 1)) for i in range(self.num_groups)
        ]

        # Handle any remaining samples by adding them to the last range
        if self.range_size * self.num_groups < self.n:
            self.ranges[-1] = (self.ranges[-1][0], self.n)

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        anchor = idx

        range_idx = anchor // self.range_size
        if range_idx >= self.num_groups:
            range_idx = self.num_groups - 1  # Correct adjustment

        start, end = self.ranges[range_idx]
        positive_sample = torch.randint(start, end, (1,)).item()

        return (anchor, positive_sample)
    
class Encoder(torch.nn.Module):
    def __init__(self, n):
        super(Encoder, self).__init__()
        self.embedding=torch.nn.Embedding(n, 768)
        self.fc=torch.nn.Linear(768, 768)
        self.non_linearity=torch.nn.Tanh()

    def forward(self, x):
        emb=self.embedding(x)
        out=self.fc(emb)
        out=self.non_linearity(out)
        return out
    
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        import math
        batch_size = z_i.size(0)
        device = z_i.device

        # Normalize embeddings
        z_i = torch.nn.functional.normalize(z_i, dim=1)
        z_j = torch.nn.functional.normalize(z_j, dim=1)

        # Compute similarity logits
        logits = torch.matmul(z_i, z_j.T) / self.temperature  # Shape: [batch_size, batch_size]

        labels = torch.arange(batch_size).to(device)

        # Cross-entropy loss along rows
        loss_row = torch.nn.functional.cross_entropy(logits, labels)

        # Cross-entropy loss along columns
        loss_col = torch.nn.functional.cross_entropy(logits.T, labels)

        # Total loss
        loss = (loss_row + loss_col) / 2

        return loss

loss=ContrastiveLoss()
    
encoder=Encoder(10000)

# Initialize the dataset
dataset = ContrastiveDataset()

# Initialize DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)

# Iterate over DataLoader
for batch_idx, batch in enumerate(dataloader):
    #print(f"Batch {batch_idx + 1}: {batch}")
    print(encoder(batch[0]).shape)
    print(loss(encoder(batch[0]), encoder(batch[1])))


torch.Size([10, 768])
10
tensor(103.6767, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.5827, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.6762, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.5624, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.6986, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.7112, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.6954, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.6842, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.6153, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.5597, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.4986, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.6413, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.7311, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.7348, grad_fn=<MeanBackward0>)
torch.Size([10, 768])
10
tensor(103.5757, grad_fn=<MeanBackwar

In [28]:
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        import math
        batch_size = z_i.size(0)
        device = z_i.device

        # Normalize embeddings
        z_i = torch.nn.functional.normalize(z_i, dim=1)
        z_j = torch.nn.functional.normalize(z_j, dim=1)

        # Compute similarity logits
        logits = torch.matmul(z_i, z_j.T) / self.temperature  # Shape: [batch_size, batch_size]

        labels = torch.arange(batch_size).to(device)

        # Cross-entropy loss along rows
        loss_row = torch.nn.functional.cross_entropy(logits, labels)

        # Cross-entropy loss along columns
        loss_col = torch.nn.functional.cross_entropy(logits.T, labels)

        # Total loss
        loss = (loss_row + loss_col) / 2

        return loss
loss=ContrastiveLoss()

# Iterate over DataLoader
for batch_idx, batch in enumerate(dataloader):
    #print(f"Batch {batch_idx + 1}: {batch}")
    print(encoder(batch[0]).shape)
    print(loss(encoder(batch[0]), encoder(batch[1])))


torch.Size([10, 768])
torch.Size([])
tensor(2.7783, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.5830, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.3902, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.6882, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.1429, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.3347, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.9309, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.6554, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.4957, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.9186, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.6861, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.3382, grad_fn=<DivBackward0>)
torch.Size([10, 768])
torch.Size([])
tensor(2.5339, grad_fn=<DivBackward0>)
torch.Size([