In [4]:
from torch.utils.data import DataLoader, Dataset
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np

class ContrastiveDataset(Dataset):
    def __init__(self, n=10000, num_groups=10):
        self.n = n
        self.num_groups = num_groups
        self.range_size = self.n // self.num_groups
        self.ranges = [
            (self.range_size * i, self.range_size * (i + 1)) for i in range(self.num_groups)
        ]

        # Handle any remaining samples by adding them to the last range
        if self.range_size * self.num_groups < self.n:
            self.ranges[-1] = (self.ranges[-1][0], self.n)

        # Assign group labels to each index for visualization
        self.labels = torch.zeros(self.n, dtype=torch.long)
        for i in range(self.num_groups):
            start, end = self.ranges[i]
            self.labels[start:end] = i

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        anchor = idx

        range_idx = anchor // self.range_size
        if range_idx >= self.num_groups:
            range_idx = self.num_groups - 1  # Correct adjustment

        start, end = self.ranges[range_idx]
        positive_sample = torch.randint(start, end, (1,)).item()

        return torch.tensor(anchor, dtype=torch.long), torch.tensor(positive_sample, dtype=torch.long)

class Encoder(torch.nn.Module):
    def __init__(self, n):
        super(Encoder, self).__init__()
        self.embedding = torch.nn.Embedding(n, 768)
        self.fc = torch.nn.Linear(768, 768)
        self.non_linearity = torch.nn.Tanh()

    def forward(self, x):
        emb = self.embedding(x)
        out = self.fc(emb)
        out = self.non_linearity(out)
        return out

# Number of unique items in the dataset
n = 10000

# Initialize the query encoder (encoder_q)
encoder_q = Encoder(n)
encoder_k = Encoder(n)

# Copy the parameters from encoder_q to encoder_k
encoder_k.load_state_dict(encoder_q.state_dict())

# Create an instance of your dataset
dataset = ContrastiveDataset(n=10000, num_groups=10)

# Set the batch size
batch_size = 25

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

import torch
import random

# Total number of samples in the dataset
n = len(dataset)

# Number of keys to sample for the queue
queue_size = 1000

# Randomly sample 1000 indices from the dataset
key_indices = random.sample(range(n), queue_size)

# Convert to a tensor
key_indices = torch.tensor(key_indices, dtype=torch.long)

# Ensure the key encoder is on the correct device (e.g., CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder_k.to(device)
key_indices = key_indices.to(device)

# Set the key encoder to evaluation mode
encoder_k.eval()

# Disable gradient computation for the key encoder
with torch.no_grad():
    # Obtain key embeddings
    key_embeddings = encoder_k(key_indices)
print(key_embeddings.shape)

torch.Size([1000, 768])


In [2]:
#Single loop state
# Get one batch from the dataloader
batch = next(iter(dataloader))

# Unpack the batch into anchors and positives
anchors, positives = batch

# Print the shapes of the anchors and positives tensors
print('Anchors shape:', anchors.shape)
print('Positives shape:', positives.shape)



Anchors shape: torch.Size([25])
Positives shape: torch.Size([25])
