In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm, trange
from torch.utils.data import TensorDataset, DataLoader, random_split
import matplotlib.pyplot as plt

from sklearn.model_selection import StratifiedShuffleSplit

import numpy as np
torch.manual_seed(42)

In [3]:
model_name = "meta-llama/Llama-2-7b-hf"
train = True

In [None]:
device = torch.device("cuda")
print("Using device: ", device)

In [5]:
model_suffix = model_name.split("/")[-1]

y_dataset = torch.load(f"data/y_dataset_{model_suffix}_filtered.pt")
X_dataset = torch.load(f"data/X_dataset_{model_suffix}_filtered.pt")

token_frequencies = torch.load(f"saved_models/openwebtext_token_freq.pt")

In [None]:
classes, counts = torch.unique(y_dataset, return_counts=True)

large_classes = classes[counts > 50]

# Remove small classes effectively
mask = torch.isin(y_dataset, large_classes)

X_dataset = X_dataset[mask]
y_dataset = y_dataset[mask]

total_classes = len(torch.unique(y_dataset))

len(y_dataset), total_classes

In [None]:
S = int(0.3 * len(y_dataset))
unique_classes, class_counts = torch.unique(y_dataset, return_counts=True)
num_classes = len(unique_classes)
samples_per_class = max(1, S // num_classes)

selected_indices = []
for cls, count in zip(unique_classes, class_counts):
    class_indices = (y_dataset == cls).nonzero(as_tuple=True)[0]

    # Sample min(samples_per_class, available samples)
    sampled_indices = class_indices[torch.randperm(
        count)[:min(samples_per_class, count)]]
    selected_indices.append(sampled_indices)

# Concatenate once instead of multiple times
selected_indices = torch.cat(selected_indices)

# Shuffle to ensure randomness
shuffled_indices = selected_indices[torch.randperm(len(selected_indices))[:S]]

# Extract the final dataset
X_dataset, y_dataset = X_dataset[shuffled_indices], y_dataset[shuffled_indices]

len(y_dataset), len(torch.unique(y_dataset))

In [8]:
class LearnableProjectionClustering(nn.Module):
    def __init__(self, input_dim, num_clusters, reg_lambda=1e-2, balance_weight=0.1):
        super().__init__()
        self.num_clusters = num_clusters
        self.reg_lambda = reg_lambda
        self.balance_weight = balance_weight
        self.total_clusters = num_clusters

        self.projected_dim = self.total_clusters * 3

        # Learnable projection matrix W (d' x d)
        self.W = nn.Parameter(torch.randn(self.projected_dim, input_dim) * 0.1)

        # Cluster centroids (d' x k)
        self.centroids = nn.Parameter(
            torch.randn(self.total_clusters, self.projected_dim))

    def forward(self, x):
        # Project input data and normalize for cosine similarity
        z = F.normalize(torch.matmul(x, self.W.T), dim=1)

        # Compute cosine similarity with centroids
        sim_matrix = torch.matmul(z, self.centroids.T)

        # Assign clusters based on max similarity
        cluster_assignments = torch.argmax(sim_matrix, dim=1)

        return z, cluster_assignments

    def loss(self, z, cluster_assignments):
        # Compute clustering loss using cosine similarity
        assigned_centroids = self.centroids[cluster_assignments]
        cluster_loss = 1 - \
            F.cosine_similarity(z, assigned_centroids, dim=1).mean()

        # Compute orthogonality loss (force centroids to be orthogonal)
        centroid_norms = F.normalize(self.centroids, dim=1)
        ortho_loss = torch.sum(torch.abs(torch.matmul(
            centroid_norms, centroid_norms.T) - torch.eye(self.total_clusters).to(self.centroids.device)))

        proj_variance_loss = torch.var(torch.norm(self.W, p=2, dim=-1))
        # Cluster balancing loss (maximize entropy to encourage even distribution)
        cluster_counts = torch.bincount(
            cluster_assignments, minlength=self.total_clusters).float()
        cluster_probs = cluster_counts / cluster_counts.sum()
        # Avoid log(0)
        balance_loss = -torch.sum(cluster_probs *
                                  torch.log(cluster_probs + 1e-10))

        # Weighted combination
        return cluster_loss + ortho_loss + self.reg_lambda * proj_variance_loss + self.balance_weight * balance_loss

In [9]:
class LearnableProjectionClustering2(nn.Module):
    def __init__(self, input_dim, num_clusters, norm_lambda=1e-1, decorr_lambda=1e-3, balance_weight=0.1, tau_init=1.0, tau_min=0.05, tau_decay=0.99, ortho_lambda=1, projection_dim=200):

        super().__init__()
        self.num_clusters = num_clusters
        self.norm_lambda = norm_lambda
        self.decorr_lambda = decorr_lambda
        self.balance_weight = balance_weight
        self.ortho_lambda = ortho_lambda

        # Projection matrix (learnable)
        self.projected_dim = projection_dim
        self.W = nn.Parameter(torch.randn(self.projected_dim, input_dim) * 0.1)

        # Learnable cluster centroids
        self.centroids = nn.Parameter(torch.randn(
            self.num_clusters, self.projected_dim))

        # Softmax temperature parameters
        self.tau = tau_init
        self.tau_min = tau_min
        self.tau_decay = tau_decay

    def forward(self, x):
        """
        Forward pass to assign clusters.
        """
        # Project input features and normalize
        z = F.normalize(torch.matmul(x, self.W.T), dim=1)

        # Compute cosine similarity between projected points and centroids
        sim_matrix = torch.matmul(z, self.centroids.T)

        # Apply softmax with temperature for soft assignments
        cluster_probs = F.softmax(sim_matrix / self.tau, dim=1)

        return z, cluster_probs

    def loss(self, z, cluster_probs):
        """
        Computes clustering loss, orthogonality loss, projection regularization, and balance loss.
        """
        # Soft cluster assignments to weighted centroids
        assigned_centroids = torch.matmul(cluster_probs, self.centroids)

        # Clustering loss (cosine similarity loss)
        cluster_loss = 1 - \
            F.cosine_similarity(z, assigned_centroids, dim=1).mean()

        # Orthogonality loss (encourages centroids to be well-separated)
        centroid_norms = F.normalize(self.centroids, dim=1)
        ortho_loss = torch.mean(torch.abs(torch.matmul(
            centroid_norms, centroid_norms.T) - torch.eye(self.num_clusters).to(self.centroids.device)))

        # Projection decorrelation loss (ensures projections are diverse)
        # W_norm = F.normalize(self.W, dim=1)
        # proj_corr_loss = torch.sum(torch.abs(torch.matmul(
        #     W_norm, W_norm.T) - torch.eye(self.projected_dim).to(self.W.device)))

        # Compute norms of projected vectors
        z_norms = torch.norm(z, p=2, dim=1, keepdim=True)
        # Compute weighted sum of norms per cluster
        weighted_norms = cluster_probs.T @ z_norms
        avg_norms = weighted_norms / \
            (cluster_probs.sum(dim=0, keepdim=True).T +
             1e-10)  # Compute average norms
        # Minimize variance of norms within each cluster
        proj_norm_variance_loss = torch.mean(torch.var(avg_norms, dim=0))

        # Balance loss (KL divergence to enforce even cluster distribution)
        cluster_probs_mean = cluster_probs.mean(dim=0)  # Average across batch
        uniform_target = torch.ones_like(
            cluster_probs_mean) / self.num_clusters
        balance_loss = F.kl_div((cluster_probs_mean + 1e-10).log(),
                                uniform_target, reduction="batchmean")

        # Weighted combination
        total_loss = (cluster_loss + self.ortho_lambda*ortho_loss +
                      #   self.decorr_lambda * proj_corr_loss +
                      self.norm_lambda * proj_norm_variance_loss +
                      self.balance_weight * balance_loss)

        return total_loss

    def anneal_tau(self):
        """
        Gradually reduces the softmax temperature to encourage hard clustering over time.
        """
        self.tau = max(self.tau * self.tau_decay, self.tau_min)

In [10]:
def stratified_split_optimized(X, y, val_split=0.2, random_seed=42):
    """
    Efficient PyTorch implementation of stratified splitting.
    """
    torch.manual_seed(random_seed)

    # Get unique classes and their counts
    unique_classes, class_counts = torch.unique(y, return_counts=True)

    # Generate indices for entire dataset
    indices = torch.arange(len(y))

    # Storage for train and val indices
    train_indices = []
    val_indices = torch.zeros(int(len(y) * val_split), dtype=torch.long)

    val_count = 0  # Track position for val_indices

    for cls, count in zip(unique_classes, class_counts):
        cls_mask = (y == cls)
        cls_indices = indices[cls_mask]

        # Shuffle the indices for this class
        cls_indices = cls_indices[torch.randperm(count)]

        # Compute number of validation samples
        val_size = int(count * val_split)

        # Store validation indices
        val_indices[val_count:val_count + val_size] = cls_indices[:val_size]
        val_count += val_size

        # Store training indices (append to list)
        train_indices.append(cls_indices[val_size:])

    # Concatenate training indices
    train_indices = torch.cat(train_indices)

    # Shuffle train indices for randomness
    train_indices = train_indices[torch.randperm(train_indices.shape[0])]

    # Extract final train and validation sets
    X_train, X_val = X[train_indices], X[val_indices]
    y_train, y_val = y[train_indices], y[val_indices]

    return X_train, X_val, y_train, y_val


In [11]:
def compute_average_cosine_similarity(model, test_dataloader):
    similarities = []
    for batch in test_dataloader:
        # Check if the data points are cosine similar to the centroids
        z, cluster_assignments = model(batch[0].to(device))
        cluster_assignments = torch.argmax(cluster_assignments, dim=1)
        assigned_centroids = model.centroids[cluster_assignments]
        similarity = F.cosine_similarity(z, assigned_centroids).mean()
        similarities.append(similarity)


    print("Average cosine similarity:", torch.stack(similarities).mean().item())

In [12]:
def compute_cluster_count_variance(model, test_dataloader):
    model.eval()
    assignments = []
    for batch in test_dataloader:
        z, cluster_assignments = model(batch[0].to(device))
        assignments.append(cluster_assignments)
    
    assignments = torch.cat(assignments, dim=0).mean(dim=0)
    uniform_target = torch.ones_like(assignments) / model.num_clusters
    balance_loss = F.kl_div((assignments + 1e-10).log(),
                            uniform_target, reduction="batchmean")
    
    print("Balance loss:", balance_loss.item())

In [13]:
def compute_ortho_score(model):
    centroids = model.centroids
    centroid_norms = F.normalize(centroids, dim=1)
    ortho_loss = torch.mean(torch.abs(torch.matmul(
        centroid_norms, centroid_norms.T) - torch.eye(model.num_clusters).to(centroids.device)))
    print("Orthogonality loss:", ortho_loss.item())


In [14]:
# Example usage
def train_model(model, train_dataloader, test_dataloader, epochs=100, lr=0.001, patience=5, topk=1000):

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Get top 1000 token frequences
    topk_classes = torch.topk(token_frequencies, topk).indices

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        epoch_loss = 0.0
        model.train()
        for batch in train_dataloader:
            batch_data = batch[0].to(device)
            optimizer.zero_grad()
            z, cluster_assignments = model(batch_data)
            loss = model.loss(z, cluster_assignments)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in test_dataloader:
                batch_data = batch[0].to(device)

                z, cluster_assignments = model(batch_data)
                loss = model.loss(z, cluster_assignments)
                val_loss += loss.item()

        model.anneal_tau()

        avg_val_loss = val_loss / len(test_dataloader)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(
                f"Early stopping at epoch {epoch} with validation loss {best_val_loss:.4f}")
            break

        if epoch % 20 == 0:
            print(
                f"Epoch {epoch}: Train Loss = {epoch_loss / len(train_dataloader):.4f}, Val Loss = {avg_val_loss:.4f}")
            compute_average_cosine_similarity(model, test_dataloader)
            compute_cluster_count_variance(model, test_dataloader)
            compute_ortho_score(model)
            print("Tau:", model.tau)

In [15]:
# k = 4

# # subsample from x_dataset and y_dataset
# X_dataset = X_dataset[:400000]
# y_dataset = y_dataset[:400000]

# model = COPKMeans(n_clusters=k, y=y_dataset, device="cuda", balance_factor=2)
# model.fit(X_dataset)
# labels = model.predict(X_dataset)

# print("Davies Bouldin Index:", davies_bouldin_index(X_dataset, labels))

# cluster_counts = torch.bincount(labels, minlength=k)
# cluster_variance = torch.var(cluster_counts.float()).item()

# print("Cluster variance:", cluster_variance)

In [None]:
X_train, X_val, y_train, y_val = stratified_split_optimized(
    X_dataset, y_dataset, val_split=0.3)

train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_val, y_val)

batch_size = 8192

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
print("Train dataset size:", len(train_dataset))
print("Test dataset size:", len(test_dataset))


In [None]:
k = 4
projection_dim = k + 8

model = LearnableProjectionClustering2(
    X_dataset.shape[1], k, balance_weight=0.3, ortho_lambda=5e-1, tau_init=1.8, tau_decay=0.995, norm_lambda=1e-1, decorr_lambda=0, projection_dim=projection_dim).to(device)

if train:
    train_model(model, train_dataloader, test_dataloader,
                epochs=350, patience=350, lr=0.00005)
    torch.save(model.state_dict(),
               f'saved_models/projection_clustering_k3.pth')
else:
    model.load_state_dict(torch.load(
        f'saved_models/projection_clustering.pth'))

In [None]:
# Check if centroids are orthogonal
centroids = model.centroids
centroid_norms = F.normalize(centroids, dim=1)
ortho_loss = torch.sum(torch.abs(torch.matmul(
    centroid_norms, centroid_norms.T) - torch.eye(model.num_clusters).to(centroids.device)))
print("Orthogonality loss:", ortho_loss.item())


similarities = []
for batch in test_dataloader:
    # Check if the data points are cosine similar to the centroids
    z, cluster_assignments = model(batch[0].to(device))
    cluster_assignments = torch.argmax(cluster_assignments, dim=1)
    assigned_centroids = model.centroids[cluster_assignments]
    similarity = F.cosine_similarity(z, assigned_centroids).mean()
    similarities.append(similarity)


print("Average cosine similarity:", torch.stack(similarities).mean().item())

In [None]:

# Initialize accumulators
total_sample_counts = torch.zeros(len(model.centroids), dtype=torch.long)

# Accumulate counts
for batch in train_dataloader:
    z, cluster_assignments = model(batch[0].to(device))
    cluster_assignments = torch.argmax(cluster_assignments, dim=1)
    y_batch = batch[1].to(device)

    for i in range(len(model.centroids)):
        cluster_classes = y_batch[cluster_assignments == i]
        unique_cluster_classes = torch.unique(cluster_classes)
        sample_count = len(cluster_classes)
        class_count = len(unique_cluster_classes)

        total_sample_counts[i] += sample_count

total = total_sample_counts.sum()
# Print final accumulated counts
for i in range(len(model.centroids)):
    print(
        f"Cluster {i} - total sample count: {total_sample_counts[i]}, proportion: {total_sample_counts[i]/total}")


In [27]:
final_weight = torch.matmul(model.centroids, model.W)

one_hots = torch.matmul(final_weight, X_dataset[:200000].to(device).T).T

max_mean = torch.max(one_hots, dim=1).values.mean()

final_weight = final_weight / max_mean

one_hots = torch.matmul(final_weight, X_dataset[:200000].to(device).T).T

In [None]:
one_hots[:10]

In [None]:

topk = torch.topk(one_hots, dim=1, k=k).values

# Plot min and max values of one-hots
plt.figure(figsize=(12, 6))
plt.hist(topk[:, 0].detach().cpu().numpy(), bins=100, alpha=0.5, label="1")
plt.hist(topk[:, 1].detach().cpu().numpy(), bins=100, alpha=0.5, label="2")
plt.hist(topk[:, 2].detach().cpu().numpy(), bins=100, alpha=0.5, label="3")
plt.hist(topk[:, 3].detach().cpu().numpy(), bins=100, alpha=0.5, label="4")
plt.legend()
plt.title("Min and Max values of one-hot encodings")
plt.show()

In [None]:
print(topk.mean(dim=0))
print(topk.std(dim=0))