In [None]:
import torch 
import torch.nn as nn
from pytorch_metric_learning.losses import NTXentLoss
from torch.utils.data import Dataset, DataLoader
import random

In [None]:
class ContrastiveLearning(torch.nn.Module):
    def __init__(self, input_dim, embedding_dim, projection_dim, dropout_rate=0.25):
        super(ContrastiveLearning, self).__init__()

        self.encoder = nn.Sequential(
                nn.Linear(input_dim, 1028),
                nn.BatchNorm1d(1028),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Linear(1028, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Linear(512, embedding_dim),
            )

        self.projector = nn.Sequential(
                nn.Linear(embedding_dim, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Linear(256, projection_dim),
            )

        
    def forward(self, x):
        embedding = self.encoder(x)
        projection = self.projector(embedding)
        return projection

In [None]:
criterion = NTXentLoss(temperature=0.10)

In [None]:
import pandas as pd
dataset = pd.read_excel("data/clustered_data.xlsx", index_col=0)
cluster_labels = dataset.cluster_labels
dataset.drop("cluster_labels", axis = 1, inplace = True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_dim = dataset.shape[1]
embedding_dim = 32
projection_dim = 8

model = ContrastiveLearning(input_dim, embedding_dim, projection_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
class ClusterContrastiveDataset(Dataset):
    def __init__(self, data, cluster_labels):
        self.data = data
        self.cluster_labels = cluster_labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        label = self.cluster_labels.iloc[idx]
        positive_indices = [i for i, same_label in enumerate(self.cluster_labels) if same_label == label and i != idx]
        positive_idx = random.choice(positive_indices)
        positive_item = self.data.iloc[positive_idx]

        item_tensor = torch.tensor(item, dtype=torch.float32)
        positive_item_tensor = torch.tensor(positive_item, dtype=torch.float32)


        return item_tensor, positive_item_tensor

In [None]:
dataset = ClusterContrastiveDataset(data=dataset, cluster_labels=cluster_labels)

dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=0)

In [None]:
def train(num_epochs, log_interval): 

   for epoch in range(num_epochs):
        model.train()  
        total_loss = 0

        for batch_idx, (data_i, data_j) in enumerate(dataloader):

            data_i, data_j = data_i.float().to(device), data_j.float().to(device)

            optimizer.zero_grad()  

            projections_i = model(data_i)
            projections_j = model(data_j)

            # Concatenate the projections: 
            # The positive pairs are adjacent to each other, and all others are considered negatives.
            projections = torch.cat([projections_i, projections_j], dim=0)
            
            batch_size = projections_i.size(0)
            labels = torch.arange(batch_size, dtype=torch.long).to(device)
            labels = torch.cat((labels, labels), dim=0)  # Duplicate labels for both halves of concatenated data

            # Calculate the contrastive loss
            loss = criterion(projections, labels)
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')

In [None]:
train(20,1)

In [None]:
model.eval()

# Convert the Pandas series to a tensor and add an extra batch dimension
single_sample = torch.tensor(dataset.data.iloc[100].values).float().unsqueeze(0)

model.encoder(single_sample)