In [None]:
from transformers import AutoTokenizer, DistilBertModel, DistilBertConfig
import torch
import torch.nn as nn
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx

In [None]:
dataset = load_dataset("glue", "mrpc")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
inputs = dataset['train']['sentence1'][:15]
token_arrays = tokenizer(inputs, truncation=True, padding="max_length", return_tensors="pt")

In [None]:
class SA_LayerNorm(nn.Module):
    def __init__(self, hidden_size):
        super(SA_LayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(hidden_size)

    def forward(self, x):
        return self.layer_norm(x)

class DistilBERT(nn.Module):
    def __init__(self, model_name_or_path="distilbert-base-uncased"):
        super(DistilBERT, self).__init__()
        self.config = DistilBertConfig.from_pretrained(model_name_or_path)
        self.distilbert = DistilBertModel.from_pretrained(model_name_or_path, config=self.config)

        # Adding SA_LayerNorm head for each transformer layer
        self.sa_layer_norms = nn.ModuleList([SA_LayerNorm(self.config.dim) for _ in range(self.config.num_hidden_layers)])

    def forward(self, input_ids, attention_mask=None):
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state

        embedding_output = hidden_states

        k_values, v_values, q_values, out_lin_output, ffn_output, output_layer_norm, sa_layer_norm_output = [], [], [], [], [], [], []
        all_hidden_states = [hidden_states]

        for i, layer in enumerate(self.distilbert.transformer.layer):
            k_values.append(layer.attention.k_lin(hidden_states))
            v_values.append(layer.attention.v_lin(hidden_states))
            q_values.append(layer.attention.q_lin(hidden_states))

            out_lin_output.append(layer.attention.out_lin(hidden_states))

            ffn_out = layer.ffn(hidden_states)
            ffn_output.append(ffn_out)

            hidden_states = layer.output_layer_norm(ffn_out + hidden_states)
            output_layer_norm.append(hidden_states)

            # Applying SA_LayerNorm to the last hidden state for each transformer layer
            sa_layer_norm_output.append(self.sa_layer_norms[i](hidden_states))

            all_hidden_states.append(hidden_states)

        # Concatenate outputs along dim=0
        concatenated_output = torch.cat([
            embedding_output,
        ], dim=0)

        for i in range(len(self.distilbert.transformer.layer)):
            concatenated_output = torch.cat([
                concatenated_output,
                k_values[i],
                v_values[i],
                q_values[i],
                out_lin_output[i],
                output_layer_norm[i],
                sa_layer_norm_output[i],
                ffn_output[i]
            ], dim=0)

          # Delete variables after use
        del k_values[i], v_values[i], q_values[i], out_lin_output[i], output_layer_norm[i], sa_layer_norm_output[i], ffn_output[i], hidden_states

        return concatenated_output

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()
    model = DistilBERT()
    output = model(**token_arrays)
    output = output.view(-1, 768*43, 512)
    tensor = torch.cat((torch.empty(0), output), dim=1)
    tensor = tensor.reshape(33024, len(inputs)*len(token_arrays[1]))

In [None]:
corr = torch.corrcoef(tensor)
corr.shape

In [None]:
del dataset
del inputs
del tensor
del tokenizer
del token_arrays

In [None]:
from tqdm import tqdm
corr = corr.numpy()
downsampled_mx = np.zeros((5504, 5504))
sub_mx_side = corr.shape[0] // downsampled_mx.shape[0]

for i in tqdm(range(downsampled_mx.shape[0])):
    row_idxs = slice(sub_mx_side * i, sub_mx_side * (i + 1))
    for j in range(downsampled_mx.shape[1]):
        col_idxs = slice(sub_mx_side * j, sub_mx_side * (j + 1))
        downsampled_mx[i, j] = np.mean(corr[row_idxs, col_idxs])

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

y_indices1 = []
y_labels1 = []
for i in range(0,5504,128):
    y_indices1.append(i)
    y_labels1.append(i)
    
plt.figure(figsize=(12, 10))
heatmap = sns.heatmap(downsampled_mx, cmap='viridis')

plt.title('Downsampled Heatmap with Cluster Labels')
plt.xlabel('Column Index')
plt.ylabel('Y Values')
plt.yticks()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import fastcluster
from scipy.cluster.hierarchy import fcluster

# Perform hierarchical clustering with fastcluster
linkage_matrix = fastcluster.linkage(downsampled_mx, method='complete')

In [None]:
from sklearn.metrics import silhouette_score

x1 = []
y1 = []
for i in tqdm(range(3,24)):
  a = i - 0.5
  b = i
  threshold = i # Adjust this threshold according to your data
  cluster_labels = fcluster(linkage_matrix, a, criterion='distance')
  x1.append(a)
  y1.append(silhouette_score(downsampled_mx,cluster_labels))
  # y1.append(calinski_harabasz_score(downsampled_mx,cluster_labels))
  cluster_labels = fcluster(linkage_matrix, b, criterion='distance')
  x1.append(b)
  # y1.append(calinski_harabasz_score(downsampled_mx,cluster_labels))
  y1.append(silhouette_score(downsampled_mx,cluster_labels))
  

In [None]:
plt.plot(x1, y1)
plt.xlabel('Threshold Value')
plt.ylabel('silhouette_score - axis')
plt.title('silhouette score vs Threshold Value')
plt.show()

In [None]:
from kneed import KneeLocator
kneedle = KneeLocator(x1, y1, S = 1000.0, curve = 'concave', direction = 'increasing')
threshold = round(kneedle.knee,10)
cluster_labels = fcluster(linkage_matrix, threshold, criterion='distance')

In [None]:
# Create a dictionary to map original node indices to their labels or names
node_labels = {i: f"Node_{i}" for i in range(corr.shape[0])}

# Sort the correlation matrix based on the cluster labels
sorted_indices = np.argsort(cluster_labels)
corr_sorted = downsampled_mx[sorted_indices][:, sorted_indices]

# Create a list of original node labels based on the sorted indices
sorted_labels = [node_labels[i] for i in sorted_indices]

# Find the boundaries between clusters
cluster_boundaries = np.where(np.diff(cluster_labels[sorted_indices]) != 0)[0] + 0.5

# Plot the heatmap of the sorted correlation matrix with original node labels and cluster boundaries
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(corr_sorted, cmap='coolwarm', aspect='auto')

# Add vertical lines to separate clusters
for boundary in cluster_boundaries:
    ax.axvline(boundary, color='black', linewidth=0.5)

# Add horizontal lines to separate clusters
for boundary in cluster_boundaries:
    ax.axhline(boundary, color='black', linewidth=0.5)

# Add cluster indices below the cluster boundaries
unique_clusters = np.unique(cluster_labels)
cluster_indices = np.arange(len(unique_clusters))
cluster_labels_sorted = cluster_labels[sorted_indices]
cluster_tick_positions = [np.where(cluster_labels_sorted == cluster)[0][-1] + 0.5 for cluster in unique_clusters]
ax.set_xticks(cluster_tick_positions)
ax.set_xticklabels(cluster_indices, rotation=90)


ax.set_yticks(cluster_tick_positions)
ax.set_yticklabels(cluster_indices, rotation=90)
ax.tick_params(axis='both', which='both', length=0)
fig.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()