In [1]:
import torch
import torch.nn.functional as F

def efficient_onehot(tensor, num_classes):
    """
    Memory-efficient one-hot encoding using different methods based on tensor size.
    
    Args:
        tensor: Input tensor of shape (N,) or (N, 1) or (N, M)
        num_classes: Number of classes for one-hot encoding
    
    Returns:
        One-hot encoded tensor
    """
    # Ensure tensor is in the right shape
    if tensor.dim() > 2:
        raise ValueError("Input tensor must be 1D or 2D")
    
    # If tensor is 2D with more than 1 column, we need to handle each column
    if tensor.dim() == 2 and tensor.size(1) > 1:
        # Process each column separately to avoid memory issues
        results = []
        for i in range(tensor.size(1)):
            col = tensor[:, i].unsqueeze(1)  # Shape: [N, 1]
            col_onehot = efficient_onehot(col, num_classes)  # Recursive call
            results.append(col_onehot)
        return torch.stack(results, dim=1)  # Stack along new dimension
    
    # Method 1: Using scatter (memory efficient for smaller tensors)
    def scatter_onehot(tensor, num_classes):
        if tensor.dim() == 1:
            tensor = tensor.unsqueeze(1)
        onehot = torch.zeros(tensor.size(0), num_classes, 
                           device=tensor.device, dtype=torch.float32)
        return onehot.scatter_(1, tensor.long(), 1)
    
    # Method 2: Using F.one_hot with chunking (for larger tensors)
    def chunked_onehot(tensor, num_classes, chunk_size=1000):
        if tensor.dim() == 2:
            tensor = tensor.squeeze(1)
        
        chunks = tensor.split(chunk_size)
        result_chunks = []
        
        for chunk in chunks:
            chunk_onehot = F.one_hot(chunk.long(), num_classes).float()
            result_chunks.append(chunk_onehot)
            
        return torch.cat(result_chunks, dim=0)
    
    # Method 3: Using sparse tensors (extremely memory efficient)
    def sparse_onehot(tensor, num_classes):
        if tensor.dim() == 2:
            tensor = tensor.squeeze(1)
        
        size = (tensor.size(0), num_classes)
        indices = torch.stack([
            torch.arange(tensor.size(0), device=tensor.device),
            tensor.long()  # Ensure tensor is long type
        ])
        values = torch.ones(tensor.size(0), device=tensor.device)
        return torch.sparse_coo_tensor(indices, values, size)

    # Choose method based on tensor size and available memory
    total_elements = tensor.size(0) * num_classes
    
    if total_elements < 1e6:  # Less than 1M elements
        return scatter_onehot(tensor, num_classes)
    elif total_elements < 1e8:  # Less than 100M elements
        return chunked_onehot(tensor, num_classes)
    else:
        return sparse_onehot(tensor, num_classes)

In [6]:
X = torch.load("X.pt")

  X = torch.load("X.pt")


In [7]:
X.shape

torch.Size([7921450, 64])

In [8]:
X_OH = efficient_onehot(X,22)

In [9]:
X_OH.shape

torch.Size([7921450, 64, 22])

In [12]:
torch.save(X,'X_OH.pt')

In [2]:
Y = torch.load("Y.pt")

  Y = torch.load("Y.pt")


In [3]:
Y_OH = efficient_onehot(Y,22)

In [4]:
torch.save(Y,'Y_OH.pt')