# Bart Bussmasn's Mod Addition Code
His code from https://www.lesswrong.com/posts/cbDEjnRheYn38Dpc5/interpreting-modular-addition-in-mlps


In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt

torch.manual_seed(1337)
np.random.seed(1337)

In [None]:
# Check if CUDA is available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define the dataset for modular addition
p = 113
data_set = np.zeros((p*p, 2*p))
labels = np.zeros((p*p, p))
for number1 in range(p):
  for number2 in range(p):
    data_set[number1*p + number2][number1] = 1
    data_set[number1*p + number2][number2+p] = 1
    labels[number1*p + number2][(number1 + number2) % p] = 1

# shuffle the dataset
shuffle = np.random.permutation(p*p)
data_set = data_set[shuffle]
labels = labels[shuffle]

# divide in train and validation set
train_proportion = 0.8
train_data = data_set[:int(train_proportion*p*p)]
train_labels = labels[:int(train_proportion*p*p)]
val_data = data_set[int(train_proportion*p*p):]
val_labels = labels[int(train_proportion*p*p):]

# convert to tensors
train_data = torch.from_numpy(train_data).float().to(device)
train_labels = torch.from_numpy(train_labels).float().to(device)
val_data = torch.from_numpy(val_data).float().to(device)
val_labels = torch.from_numpy(val_labels).float().to(device)

# define the 1-hidden layer MLP
class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size, bias=False)
        self.fc2 = torch.nn.Linear(hidden_size, output_size, bias=False)
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.u1 = torch.zeros(hidden_size).to(device)
        self.u2 = torch.zeros(hidden_size).to(device)
        self.u3 = torch.zeros(hidden_size).to(device)
        self.w1 = torch.zeros(hidden_size).to(device)
        self.w2 = torch.zeros(hidden_size).to(device)
        self.w3 = torch.zeros(hidden_size).to(device)
        self.s1 = torch.zeros(hidden_size).to(device)
        self.s2 = torch.zeros(hidden_size).to(device)
        self.s3 = torch.zeros(hidden_size).to(device)
        self.o1 = torch.zeros(hidden_size).to(device)
        self.o2 = torch.zeros(hidden_size).to(device)
        self.o3 = torch.zeros(hidden_size).to(device)


    def forward(self, x, val=False):
        pre_relu = self.fc1(x)
        post_relu = torch.nn.functional.relu(pre_relu)
        x = self.fc2(post_relu)
        return x, pre_relu, post_relu


    def cosine(self, x, u, w, s, o):
      return u*torch.cos(x*w*2*np.pi + s) + o


    def forward_with_replaced_weights(self, x, p_l1=0, plot=True, altered_i=None):
        # Cloning the weights before modification
        fc1_weight_original = self.fc1.weight.data.clone()
        fc2_weight_original = self.fc2.weight.data.clone()

        a = torch.argmax(x[:, :113], dim=1).unsqueeze(-1)
        b = torch.argmax(x[:, 113:], dim=1).unsqueeze(-1)

        pre_relu = self.fc1(x)
        if plot:
            plt.matshow(pre_relu[:, 0].cpu().detach().numpy().reshape(113, 113))
            plt.show()

        if not altered_i:
          num_altered_neurons = int(p_l1 * pre_relu.shape[-1])
          altered_i = torch.randperm(pre_relu.shape[-1])[:num_altered_neurons].to(device)


        for i in altered_i:
            # Replace incoming weights of the altered neurons
            self.fc1.weight.data[i, :113] = self.cosine(torch.arange(113).to(device), self.u1[i], self.w1[i], self.s1[i], self.o1[i])
            self.fc1.weight.data[i, 113:] = self.cosine(torch.arange(113).to(device), self.u2[i], self.w2[i], self.s2[i], self.o2[i])

            # Replace outgoing weights of the altered neurons
            self.fc2.weight.data[:, i] = self.cosine(torch.arange(113).to(device), self.u3[i], self.w3[i], self.s3[i], self.o3[i])

        pre_relu = self.fc1(x)  # Re-compute pre_relu with updated weights

        if plot:
            plt.matshow(pre_relu[:, 0].cpu().detach().numpy().reshape(113, 113))
            plt.show()

        post_relu = torch.nn.functional.relu(pre_relu)
        output = self.fc2(post_relu)

        # Restoring the original weights after forward pass
        self.fc1.weight.data = fc1_weight_original
        self.fc2.weight.data = fc2_weight_original

        return output


    def calculate_neuron_properties(self):
        for neuron in range(self.hidden_size):
          weights1 = self.fc1.weight.detach().cpu().numpy()[neuron, :113]
          weights2 = self.fc1.weight.detach().cpu().numpy()[neuron, 113:]
          weights3 = self.fc2.weight.detach().cpu().numpy()[:, neuron]
          x_data = np.arange(self.output_size)
          w1, s1, u1, o1 = self.find_cosine(x_data, weights1)
          w2, s2, u2, o2 = self.find_cosine(x_data, weights2)
          w3, s3, u3, o3 = self.find_cosine(x_data, weights3)
          self.u1[neuron] = u1
          self.u2[neuron] = u2
          self.u3[neuron] = u3
          self.w1[neuron] = w1
          self.w2[neuron] = w2
          self.w3[neuron] = w3
          self.s1[neuron] = s1
          self.s2[neuron] = s2
          self.s3[neuron] = s3
          self.o1[neuron] = torch.tensor(o1)
          self.o2[neuron] = torch.tensor(o2)
          self.o3[neuron] = torch.tensor(o3)


    def find_cosine(self, x_data, y_data):
        # Calculate DFT
        yf = np.fft.fft(y_data)
        xf = np.fft.fftfreq(x_data.size, d=(x_data[1]-x_data[0]))  # assuming x_data is evenly spaced

        # Find the peak frequency
        idx = np.argmax(np.abs(yf[1:yf.size//2]))  # ignore the zero frequency "peak", and only consider the first half of points
        freq = np.abs(xf[idx+1])  # shift index by 1 because we ignored the first point

        # Calculate phase shift
        phase_shift = -np.angle(yf[idx+1])

        # Calculate scale
        scale = 2 * np.abs(yf[idx+1]) / x_data.size

        # Estimate offset
        offset = np.mean(y_data)
        return freq, phase_shift, scale, offset


# define the training loop
def train(model, train_data, train_labels, val_data, val_labels, epochs, batch_size, lr):
    optimzer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.5)
    loss_fn = torch.nn.CrossEntropyLoss()

    train_loss_values = []
    val_loss_values = []
    train_acc_values = []
    val_acc_values = []

    running_train_loss = 0
    print(epochs)
    for epoch in range(epochs):
        model.train()
        correct_train_preds = 0
        total_train_preds = 0
        for batch in range(0, len(train_data), batch_size):
            optimzer.zero_grad()
            output, pre_relu, post_relu = model(train_data[batch:batch+batch_size])
            loss = loss_fn(output, torch.argmax(train_labels[batch:batch+batch_size], axis=1)) #+ 0.0000001*l1_norm
            running_train_loss += loss.item()
            preds = torch.argmax(output, axis=1)
            correct_train_preds += (preds == torch.argmax(train_labels[batch:batch+batch_size], axis=1)).sum().item()
            total_train_preds += len(preds)
            loss.backward()
            optimzer.step()
        model.eval()


        output, _, _ = model(val_data, val=True)
        val_loss = loss_fn(output, torch.argmax(val_labels, axis=1)).item()
        val_preds = torch.argmax(output, axis=1)
        correct_val_preds = (val_preds == torch.argmax(val_labels, axis=1)).sum().item()
        total_val_preds = len(val_preds)
        avg_train_loss = running_train_loss / (len(train_data) / batch_size)
        train_acc = correct_train_preds / total_train_preds
        val_acc = correct_val_preds / total_val_preds
        train_loss_values.append(avg_train_loss)
        val_loss_values.append(val_loss)
        train_acc_values.append(train_acc)
        val_acc_values.append(val_acc)

        print("Epoch: {} | Train loss: {} | Validation loss: {} | Train accuracy: {} | Validation accuracy: {}".format(epoch, avg_train_loss, val_loss, train_acc, val_acc))
        running_train_loss = 0
    return model, train_loss_values, val_loss_values, train_acc_values, val_acc_values

    # def forward_with_replaced_neurons(self, x, p_l1=0, p_l2=0, plot=True):
    #     a = torch.argmax(x[:, :113], dim=1).unsqueeze(-1)
    #     b = torch.argmax(x[:, 113:], dim=1).unsqueeze(-1)
    #        output[:, output_neuron_i] = 0

    #         for hidden_neuron_i in range(self.hidden_size):
    #             output[:, output_neuron_i] += post_relu[:, hidden_neuron_i] * \
    #                                           self.cosine(output_neuron_i, self.u3[hidden_neuron_i], self.w3[hidden_neuron_i], self.s3[hidden_neuron_i], self.o3[hidden_neuron_i])

    #     return output
# train the model
model = MLP(2*p, 100, p).to(device)
num_epochs = 2000
model, train_loss_values, val_loss_values, train_acc_values, val_acc_values = train(model, train_data, train_labels, val_data, val_labels, num_epochs, 128, 0.003)
model.calculate_neuron_properties()

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(5, 5))  # create a new figure with 2 subplots arranged vertically

# Plotting the loss values
ax[0].semilogy(np.arange(num_epochs), train_loss_values, label='Training Loss')
ax[0].semilogy(np.arange(num_epochs), val_loss_values, label='Validation Loss')
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].set_title('Training and Validation Loss')
ax[0].legend(loc='upper right')
ax[0].grid(True)
ax[0].set_xscale('log')  # Making the x-axis logarithmic

# Plotting the accuracy values
ax[1].plot(np.arange(num_epochs), train_acc_values, label='Training Accuracy')
ax[1].plot(np.arange(num_epochs), val_acc_values, label='Validation Accuracy')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy')
ax[1].set_title('Training and Validation Accuracy')
ax[1].legend(loc='lower right')
ax[1].grid(True)
ax[1].set_xscale('log')  # Making the x-axis logarithmic

plt.tight_layout()  # To ensure proper spacing between subplots
plt.show()

In [None]:
# generate the new data
new_data = np.zeros((p*p, 2*p))
for number1 in range(p):
  for number2 in range(p):
    new_data[number1*p + number2][number1] = 1
    new_data[number1*p + number2][number2+p] = 1

# convert to tensor and move to GPU
new_data_tensor = torch.from_numpy(new_data).float().to(device)

# pass the data through the model and get the post-relu activations
model.eval()
with torch.no_grad():# generate the new data
    _, pre_relu_activations, post_relu_activations = model(new_data_tensor)


# Select four random neuron indices
random_neurons = np.random.choice(100, 4, replace=False)

fig, axes = plt.subplots(2, 2, figsize=(10, 10))

for i, neuron in enumerate(random_neurons):
  # get the activations of the neuron
  neuron_activations = post_relu_activations[:, neuron].cpu().numpy()

  # reshape the activations into a 2D array for plotting
  neuron_activations = neuron_activations.reshape(p, p)

  # plot the activations as an image
  ax = axes[i // 2, i % 2]
  im = ax.imshow(neuron_activations, cmap='hot', interpolation='nearest')

# Add a colorbar to the figure, for all subplots
fig.colorbar(im, ax=axes.ravel().tolist(), orientation='horizontal')
plt.show()


In [None]:
# save modular model state dict
import os
model_dir = "models"
os.makedirs(model_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(model_dir, "modular_add_model.pth"))

In [None]:
model.fc1.weight.data.shape, model.fc2.weight.data.shape, model

In [None]:
# plot TNSE of both model.fc# weights
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# Get weights as numpy arrays
fc1_weights = model.fc1.weight.data.cpu().numpy()
fc2_weights = model.fc2.weight.data.cpu().numpy()

# fc1 is shape 113x100. view it as 100 vectors of size 113
fc1_weights_reshaped = fc1_weights.reshape(100, 226)
# # fc2 is shape 226x100. view it as 100 vectors of size 226
fc2_weights_reshaped = fc2_weights.reshape(100, 113)

# run t-SNE on both weights
tsne = TSNE(n_components=2, random_state=42)
fc1_weights_tsne = tsne.fit_transform(fc1_weights_reshaped)
fc2_weights_tsne = tsne.fit_transform(fc2_weights_reshaped)

# plot the t-SNE of both weights
plt.figure(figsize=(10, 10))
plt.scatter(fc1_weights_tsne[:, 0], fc1_weights_tsne[:, 1], c=np.arange(100), cmap='tab10')
plt.colorbar(label='Neuron Index')
plt.title('t-SNE of FC1 Weights')
plt.show()

plt.figure(figsize=(10, 10))
plt.scatter(fc2_weights_tsne[:, 0], fc2_weights_tsne[:, 1], c=np.arange(100), cmap='tab10')
plt.colorbar(label='Neuron Index')
plt.title('t-SNE of FC2 Weights')
plt.show()

In [None]:
# Do cos-sim between fc1 & self and fc2 & self
fc1_weights_norm = np.linalg.norm(fc1_weights_reshaped, axis=1)
fc2_weights_norm = np.linalg.norm(fc2_weights_reshaped, axis=1)
fc1_weights_normed = fc1_weights_reshaped / fc1_weights_norm[:, np.newaxis]
fc2_weights_normed = fc2_weights_reshaped / fc2_weights_norm[:, np.newaxis]

# Do cos-sim between fc1 & self and fc2 & self
fc1_weights_cos_sim = np.dot(fc1_weights_normed, fc1_weights_normed.T)
fc2_weights_cos_sim = np.dot(fc2_weights_normed, fc2_weights_normed.T)

# plot the cos-sim of both weights
plt.figure(figsize=(10, 10))
plt.imshow(fc1_weights_cos_sim, cmap='hot', interpolation='nearest')
plt.colorbar(label='Cosine Similarity')
plt.title('Cosine Similarity of Encoder Weights')
plt.show()

plt.figure(figsize=(10, 10))
plt.imshow(fc2_weights_cos_sim, cmap='hot', interpolation='nearest')
plt.colorbar(label='Cosine Similarity')
plt.title('Cosine Similarity of Decoder Weights')
plt.show()

In [None]:
# Do a k-means clustering of the fc1 weights
from sklearn.cluster import KMeans

fc1_weights_reshaped.shape # (100, 226) 100 vectors of size 226
# Fit KMeans with 20 clusters
kmeans = KMeans(n_clusters=10, random_state=42)
fc1_weights_clusters = kmeans.fit_predict(fc1_weights_reshaped)

# Plot the t-SNE with cluster colors
plt.figure(figsize=(10, 10))
plt.scatter(fc1_weights_tsne[:, 0], fc1_weights_tsne[:, 1], c=fc1_weights_clusters, cmap='tab20')
plt.colorbar(label='Cluster')
plt.title('t-SNE of FC1 Weights Colored by KMeans Clusters')
plt.show()


In [None]:
from sklearn.manifold import TSNE
# repeat for post_relu_activations # shape (d_points, 100) so treat as 100 vectors of size d_points
d_points = post_relu_activations.shape[0]
post_relu_activations_reshaped = post_relu_activations.reshape(100, d_points).cpu().numpy()


# Reshape post_relu_activations for t-SNE
post_relu_activations_for_tsne = post_relu_activations_reshaped

# Compute t-SNE
tsne = TSNE(n_components=2, random_state=42)
post_relu_activations_tsne = tsne.fit_transform(post_relu_activations_for_tsne)

# Now plot with the t-SNE coordinates
plt.figure(figsize=(10, 10))
plt.scatter(post_relu_activations_tsne[:, 0], post_relu_activations_tsne[:, 1], c=post_relu_activations_clusters, cmap='tab20')
plt.colorbar(label='Cluster')
plt.title('t-SNE of Post-ReLU Activations Colored by KMeans Clusters')
plt.show()

# Also do cos-sim
post_relu_activations_norm = np.linalg.norm(post_relu_activations_reshaped, axis=1)
post_relu_activations_normed = post_relu_activations_reshaped / post_relu_activations_norm[:, np.newaxis]
post_relu_activations_cos_sim = np.dot(post_relu_activations_normed, post_relu_activations_normed.T)

# plot the cos-sim
plt.figure(figsize=(10, 10))
plt.imshow(post_relu_activations_cos_sim, cmap='hot', interpolation='nearest')
plt.colorbar(label='Cosine Similarity')
plt.title('Cosine Similarity of Post-ReLU Activations')
plt.show()