In [2]:
####################################################################################################################################
####################################################################################################################################
####################################################################################################################################
# DATASET AND DATALOADER FOR GAUSSIAN SEQUENCE
####################################################################################################################################
####################################################################################################################################
####################################################################################################################################

import torch
import numpy as np
import random
from torch.utils.data import DataLoader

class IndexedGaussianSequence:
    def __init__(self, n, center, std, device='cpu'):
        """
        Args:
            n (int): Length of the sequence.
            center (int): Position where Gaussian is centered.
            std (float): Standard deviation of the Gaussian.
            device (str): Device where the result should be moved.
        """
        self.n = n
        self.center = center
        self.std = std
        self.device = device

    def __getitem__(self, index):
        if 0 <= index < self.n:
            value = np.exp(-(index - self.center) ** 2 / (2 * self.std ** 2))
        else:
            value = np.exp(-(index - self.center) ** 2 / (2 * self.std ** 2)) # TODO for non-synthetic image or sequence we need account for out of bound index for now it is fine 

        return torch.tensor(value).to(self.device)

    def __len__(self):
        return self.n

class GaussianSequenceDataset(torch.utils.data.Dataset):
    def __init__(self, n, p, q, size, std, device='cpu'):
        """
        Args:
            n (int): Length of the sequence.
            p (int): Position for label 0 where Gaussian is centered.
            q (int): Position for label 1 where Gaussian is centered.
            size (int): Number of samples in the dataset.
            std (float): Standard deviation of the Gaussian.
            device (str): Device where the result should be moved.
        """
        self.n = n
        self.p = p
        self.q = q
        self.size = size
        self.std = std
        self.device = device

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        label = idx % 2  # Alternate between labels 0 and 1

        if label == 0:
            sequence = IndexedGaussianSequence(self.n, self.p, self.std, self.device)
        else:
            sequence = IndexedGaussianSequence(self.n, self.q, self.std, self.device)
        
        return sequence, label

def custom_collate_fn(batch):
    sequences, labels = zip(*batch)  # Unzip the batch into sequences and labels
    return list(sequences), torch.tensor(labels)

### NOTE this is a usage example here -- essentially if you define your own custom collate function that basically just trivially list-batch the input, welp, list-batch then bang! you're done!
# Create a DataLoader with the custom collate function
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)



####################################################################################################################################
####################################################################################################################################
####################################################################################################################################
# DIFFERENTIABLE INDEXING FOR ARBITRARY INDEXABLE OBJECTS
####################################################################################################################################
####################################################################################################################################
####################################################################################################################################

import torch

class DifferentiableIndexFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, indexable_obj, indices):
        device = indices.device  # Ensure we know the device where indices are located

        # Floor and ceil indices
        indices_floor = torch.floor(indices).long()
        indices_ceil = torch.ceil(indices).long()

        # Extract values at the floor and ceil indices
        values_floor = torch.tensor([indexable_obj[i.item()] for i in indices_floor], dtype=torch.float32).to(device)
        values_ceil = torch.tensor([indexable_obj[i.item()] for i in indices_ceil], dtype=torch.float32).to(device)

        # Save tensors for the backward pass
        ctx.save_for_backward(indices, values_floor, values_ceil)

        # Linear interpolation
        weights_floor = indices - indices_floor.float().to(device) # TODO optimize the to(device) call
        weights_ceil = indices_ceil.float().to(device) - indices # TODO optimize the to(device) call

        output = weights_ceil * values_floor + weights_floor * values_ceil

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved tensors
        indices, values_floor, values_ceil = ctx.saved_tensors

        # Calculate gradients for indices
        grad_indices = (values_ceil - values_floor) * grad_output

        # No gradient for indexable_obj
        grad_indexable_obj = None

        return grad_indexable_obj, grad_indices


class DifferentiableIndexFunctionBatch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, indexable_objs, indices_batch):
        device = indices_batch.device  # Ensure we know the device where indices_batch are located

        # Ensure that indexable_objs has the same length as the batch dimension
        assert len(indexable_objs) == indices_batch.shape[0], f"indexable_objs length {len(indexable_objs)} must match batch dimension {indices_batch.shape[0]}"

        # Result container
        output_batch = []
        saved_tensors = []

        # Process each item in the batch
        for i in range(len(indexable_objs)):
            indices = indices_batch[i]


            # Floor and ceil indices
            indices_floor = torch.floor(indices).long()
            indices_ceil = torch.ceil(indices).long()

            # Extract values at the floor and ceil indices individually
            values_floor = torch.tensor([indexable_objs[i][idx.item()] for idx in indices_floor], dtype=torch.float32).to(device)
            values_ceil = torch.tensor([indexable_objs[i][idx.item()] for idx in indices_ceil], dtype=torch.float32).to(device)

            # Save tensors for backward pass
            saved_tensors.append((indices, values_floor, values_ceil))

            # Linear interpolation
            weights_floor = indices - indices_floor.float().to(device)
            weights_ceil = indices_ceil.float().to(device) - indices

            output = weights_ceil * values_floor + weights_floor * values_ceil
            output_batch.append(output)

        # Save all necessary tensors for the backward pass
        ctx.save_for_backward(*[t for sublist in saved_tensors for t in sublist])

        # Stack output for the batch
        return torch.stack(output_batch)

    @staticmethod
    def backward(ctx, grad_output_batch):
        # Retrieve saved tensors
        saved_tensors = ctx.saved_tensors

        # Gradient container for indices_batch
        grad_indices_batch = []

        # Process each item in the batch
        num_saved_tensors = 3  # indices, values_floor, values_ceil per batch item
        for i in range(0, len(saved_tensors), num_saved_tensors):
            indices, values_floor, values_ceil = saved_tensors[i:i+num_saved_tensors]
            grad_output = grad_output_batch[i // num_saved_tensors]

            # Calculate gradients for indices
            grad_indices = (values_ceil - values_floor) * grad_output
            grad_indices_batch.append(grad_indices)

        # No gradient for indexable_objs
        grad_indexable_objs = None

        # print("grad_indices_batch", grad_indices_batch)

        # Stack gradients for the batch
        return grad_indexable_objs, torch.stack(grad_indices_batch)

# Wrapper function to use in your model
def differentiable_index_batch(indexable_objs, indices_batch):
    return DifferentiableIndexFunctionBatch.apply(indexable_objs, indices_batch)



####################################################################################################################################
####################################################################################################################################
####################################################################################################################################
# SIMPLE TWO PARAMETER MODEL FOR ESTIMATING THE GAUSSIAN CENTER
####################################################################################################################################
####################################################################################################################################
####################################################################################################################################

import torch
import torch.nn as nn

class PQEstimatorModel(nn.Module):
    def __init__(self, sequence_length):
        super(PQEstimatorModel, self).__init__()
        # Initialize p and q as learnable parametersm initalized uniform float at random between 0 and sequence_length-1
        random_number_1 = random.uniform(0, sequence_length-1)
        random_number_2 = random.uniform(0, sequence_length-1)
        self.p = nn.Parameter(torch.tensor(random_number_1))  
        self.q = nn.Parameter(torch.tensor(random_number_2))

        self.diff_index = differentiable_index_batch

    def forward(self, sequence_batch):
        
        # print(self.p, self.q)

        # Calculate p and q values using the differentiable index function
        # first put p and q together [p, q] and the repeat it for the batch size which is the length of the sequence_batch
        pq_matrix = torch.stack([self.p, self.q]).repeat(len(sequence_batch), 1)

        pq_value = self.diff_index(sequence_batch, pq_matrix)

        return pq_value



####################################################################################################################################
####################################################################################################################################
####################################################################################################################################
# PQ LOSS FUNCTION
####################################################################################################################################
####################################################################################################################################
####################################################################################################################################

import torch
import torch.nn as nn

class PQLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super(PQLoss, self).__init__()
        self.reduction = reduction

    def forward(self, pq_matrix):
        # Ensure the input is of shape (batch_size, 2)
        assert pq_matrix.shape[1] == 2, "Input pq_matrix must have shape (batch_size, 2)"

        # Split the pq_matrix into p_val and q_val
        p_val = pq_matrix[:, 0]
        q_val = pq_matrix[:, 1]

        # Calculate the loss for each item in the batch
        loss = (1 - p_val)**2 + (1 - q_val)**2

        # Apply the specified reduction (mean or sum or none)
        if self.reduction == 'mean':
            return loss.mean()  # Average over the batch
        elif self.reduction == 'sum':
            return loss.sum()  # Sum over the batch
        else:  # 'none'
            return loss  # Return the loss for each element in the batch

In [3]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

# Training settings
n = 100  # Length of the sequence
p = 22.6  # Center of the Gaussian for label 0
q = 76.4  # Center of the Gaussian for label 1
size = 32  # Size of the dataset
std = 20  # Standard deviation of the Gaussian
batch_size = 32
sequence_length = n
epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set up the dataset and DataLoader
dataset = GaussianSequenceDataset(n, p, q, size, std, device=device)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)

# Initialize the model, loss function, and optimizer
model = PQEstimatorModel(sequence_length=sequence_length).to(device)
loss_fn = PQLoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=1)

initial_p = model.p.item()
initial_q = model.q.item()

# Training loop
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    
    for sequences, labels in dataloader:
        # Forward pass
        pq_values = model(sequences)
        
        # Compute loss
        loss = loss_fn(pq_values)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Accumulate loss
        running_loss += loss.item()

    # Print average loss for the epoch
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

print("Training complete.")

print("Initial p value:", initial_p)
print("Initial q value:", initial_q)

# print the final p and q values
print("Final p value:", model.p.item())
print("Final q value:", model.q.item())

# print the true p and q values
print("True p value:", p)
print("True q value:", q)


Epoch [1/100], Loss: 0.8150
Epoch [2/100], Loss: 0.8025
Epoch [3/100], Loss: 0.7905
Epoch [4/100], Loss: 0.7791
Epoch [5/100], Loss: 0.7683
Epoch [6/100], Loss: 0.7583
Epoch [7/100], Loss: 0.7492
Epoch [8/100], Loss: 0.7411
Epoch [9/100], Loss: 0.7339
Epoch [10/100], Loss: 0.7277
Epoch [11/100], Loss: 0.7224
Epoch [12/100], Loss: 0.7181
Epoch [13/100], Loss: 0.7147
Epoch [14/100], Loss: 0.7122
Epoch [15/100], Loss: 0.7105
Epoch [16/100], Loss: 0.7094
Epoch [17/100], Loss: 0.7087
Epoch [18/100], Loss: 0.7087
Epoch [19/100], Loss: 0.7089
Epoch [20/100], Loss: 0.7094
Epoch [21/100], Loss: 0.7100
Epoch [22/100], Loss: 0.7107
Epoch [23/100], Loss: 0.7113
Epoch [24/100], Loss: 0.7119
Epoch [25/100], Loss: 0.7124
Epoch [26/100], Loss: 0.7126
Epoch [27/100], Loss: 0.7128
Epoch [28/100], Loss: 0.7128
Epoch [29/100], Loss: 0.7127
Epoch [30/100], Loss: 0.7125
Epoch [31/100], Loss: 0.7122
Epoch [32/100], Loss: 0.7119
Epoch [33/100], Loss: 0.7114
Epoch [34/100], Loss: 0.7109
Epoch [35/100], Loss: 0