<a href="https://colab.research.google.com/github/jinoh5/disentangled/blob/main/run_multi_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import scipy
from scipy import stats
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
import time

## Hyperparameters

In [2]:
sigma = 1
N = 20 # 20 dimensional space
T = 100
numCond = 4
noise_var = 2 # may increase this to 2

input_dim = 20
output_dim = 2
learning_rate = 0.001
batch_size = 200
hidden_dim = input_dim * 2
exp_num = 3
num_epochs = 1000
iter_num = 5

beta1 = 0.0001
beta2 = 0
beta3 = 100
beta4 = 0

## INPUT set up

In [3]:
# # Disentangled Points (Either comment these)
# cond = np.zeros((4,2))
# cond[0,:] = [1,1]
# cond[1,:] = [-1,1]
# cond[2,:] = [-1,-1]
# cond[3,:] = [1,-1]

# cond = cond.T

# # Project the data onto the 20-D space
# projection = np.zeros((numCond,T,N))
# W = stats.norm.rvs(loc=0, scale=sigma, size=[N,2])
# for i in range(T): # T = 100 trials
#     projection[0,i,:] = W @ cond[:,0] # V (N x 1) = W (N x 2) * I (2 x 1)
#     projection[1,i,:] = W @ cond[:,1]
#     projection[2,i,:] = W @ cond[:,2]
#     projection[3,i,:] = W @ cond[:,3]

In [4]:
# Entangled Input (Either comment these)
subInput_non = np.random.normal(loc=0, scale=sigma, size=[numCond,N])
ent_input = np.zeros((numCond,T,N))

for i in range(4):
    ent_input[i,:,:] = np.tile(subInput_non[i,:], (T, 1))

projection = ent_input

## Functions

In [5]:
def create_input(projection, noise_var):
  # Create noise
  noise_matrix = np.random.normal(loc=0, scale=noise_var, size=projection.shape)

  # Add noise
  input_w_noise = projection + noise_matrix

  # Reshape to create 400 x 20
  input = np.vstack((input_w_noise[0,:,:], input_w_noise[1,:,:], input_w_noise[2,:,:], input_w_noise[3,:,:]))

  # Experiment conditions
  org_cond = np.zeros((4,2))
  org_cond[1,:] = [1,0]
  org_cond[2,:] = [1,1]
  org_cond[3,:] = [0,1]

  # original task
  task1 = org_cond[:,0]
  task2 = org_cond[:,1]
  xor = np.sum(org_cond,axis=1)%2

  # Classes
  task1_class = np.repeat(task1, T)
  task2_class = np.repeat(task2, T)
  xor_class = np.repeat(xor, T)

  classes = np.zeros((numCond*T, 3))
  classes[:,0] = task1_class
  classes[:,1] = task2_class
  classes[:,2] = xor_class

  return input, classes

In [6]:
# Losses
def custom_loss(output, targets, sparsity, beta1, beta2, PR, beta3, PR_connect, beta4):
  '''
  Custom loss function
  '''
  criterion = nn.CrossEntropyLoss() # first task
  crossEntropy = criterion(output, targets)

  totalLoss = beta1 * criterion(output, targets) + beta2 * sparsity + beta3 * PR + beta4 * PR_connect
  return crossEntropy, totalLoss

def L1_norm(hidden_output):
  l1_norm = torch.mean(torch.abs(hidden_output)) # gives me a simple number
  return l1_norm

def PR_norm(hidden_output):
  cov_matrix = torch.cov(hidden_output.T)
  eigval, _ = torch.linalg.eigh(cov_matrix)
  numerator = torch.sum(eigval) ** 2
  denominator = torch.sum(eigval ** 2)
  participation_ratio = numerator / denominator
  return participation_ratio

In [7]:
# Model (nonlinear classifier)
class Simple_Nonlin(nn.Module):

  def __init__(self, input_dim, hidden_dim, output_dim):
    super(Simple_Nonlin, self).__init__()
    self.input_layer = nn.Linear(input_dim, hidden_dim)
    self.hidden_layer = nn.Linear(hidden_dim, output_dim) # here, I can add noise to make noises in the hidden layer

  def forward(self, x):
    x_h = F.relu(self.input_layer(x)) # Applying ReLU activation after input_layer
    output = self.hidden_layer(x_h)
    return x_h, output

In [8]:
def calculate_centroid(points):
   # INPUT: PCAed 3D points 100 x 3 so that it can be 1 x 3
    centroid_x = np.mean(points[:, 0])
    centroid_y = np.mean(points[:, 1])
    centroid_z = np.mean(points[:, 2])

    return [centroid_x, centroid_y, centroid_z]

## Run

In [9]:
train_input, train_classes = create_input(projection, noise_var)
test_input, train_classes = create_input(projection, noise_var)

In [10]:
trainI = torch.tensor(train_input, dtype=torch.float32)
trainT = torch.tensor(train_classes, dtype=torch.long)

testI = torch.tensor(test_input, dtype = torch.float32)
testT = torch.tensor(train_classes, dtype = torch.long)

In [11]:
exp_list = [[0,0],
            [0,1],
            [0,2],
            [1,0],
            [1,1],
            [1,2],
            [2,0],
            [2,1],
            [2,2]]

In [12]:
pca_dict = {i: np.zeros((iter_num, 400, 3)) for i in range(len(exp_list))}

In [13]:
pca_dict[0].shape

(5, 400, 3)

In [None]:
# Create shuffled indices
indices = torch.randperm(len(trainI))

# Shuffled train and test
shuffled_trainI = trainI[indices,:]
shuffled_trainT = trainT[indices,:]

shuffled_testI = testI[indices,:]
shuffled_testT = testT[indices,:]

total_accuracies = np.zeros((iter_num,3,3)) # you need iternum

for iter in range(iter_num): # 1. LOOP FOR ITERATION OF THE EXPERIMENT
  print("iter", iter)
  start_time = time.time()

  for i in range(len(exp_list)): # 2. LOOP FOR EVERY EXPERIMENTAL CONDITION
    exp_trainT = shuffled_trainT[:, exp_list[i][0]]
    print("train #", exp_list[i][0])

    # Only grab the appropriate experiment condition
    exp_testT = shuffled_testT[:, exp_list[i][1]]
    print("test #", exp_list[i][1])

    # Create Tensor Dataset
    train_dataset = TensorDataset(shuffled_trainI, exp_trainT)
    test_dataset = TensorDataset(shuffled_testI, exp_testT)

    # Create Train loader
    train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False)

    # Instantiate the model
    model = Simple_Nonlin(input_dim, hidden_dim, output_dim) # so that you can only input 'model' not 'modelName'
    optimizer = optim.Adam(model.parameters(), lr = learning_rate)

    # All Losses Collection
    trainLosses = []
    valLosses = []
    accuracies = []
    sparsities = []
    crossEntropyLosses = []
    PR_losses = []
    PR_connect_losses = []


    for epoch in range(num_epochs): # 3. LOOP FOR EPOCH / COND

      # Training Stage
      model.train()

      # losses
      train_loss = 0
      sparsity_loss = 0
      crossEntropy_loss = 0
      PR_loss = 0
      PR_connect_loss = 0

      for train_input, train_target in train_loader: # 4. LOOP FOR TRAIN BATCH

        train_input = train_input.float()

        # clear the gradients
        optimizer.zero_grad()

        # Apply the model
        x_h, train_output = model(train_input)

        # Take out the connectivity matrix
        connectivity_matrix = model.input_layer.weight.data
        PR_connect = PR_norm(connectivity_matrix)
        PR_connect_loss += PR_connect

        # Compute sparsity from the hidden layer
        sparsity = L1_norm(x_h)
        sparsity_loss += sparsity

        # Compute participation ratio from the hidden layer
        PR = PR_norm(x_h)
        PR_loss += PR

        # <<Compute the total loss>>
        CE_loss, loss = custom_loss(train_output, train_target, sparsity, beta1, beta2, PR, beta3, PR_connect, beta4)

        # Add cross entropy loss
        crossEntropy_loss += CE_loss

        # Backpropagation
        loss.backward()

        # Update the weights
        optimizer.step()

        # Add train loss
        train_loss += loss.item()

      # Find the average loss across all batches for one epoch
      avg_train_loss = train_loss / len(train_loader)
      trainLosses.append(avg_train_loss)

      avg_sparsity = sparsity_loss / len(train_loader)
      sparsities.append(avg_sparsity)

      avg_crossEntropy_loss = crossEntropy_loss/len(train_loader)
      crossEntropyLosses.append(avg_crossEntropy_loss)

      avg_PR_loss = PR_loss/len(train_loader)
      PR_losses.append(avg_PR_loss)

      avg_PR_connect_loss = PR_connect_loss/len(train_loader)
      PR_connect_losses.append(avg_PR_connect_loss)

      # Validation Stage
      model.eval()
      val_loss = 0
      total = 0
      correct = 0

      with torch.no_grad():

        for test_input, test_target in test_loader: # 5. LOOP FOR TEST BATCH

          # test input
          test_x_h, test_output = model(test_input)

          # Take out the connectivity matrix
          connectivity_matrix = model.input_layer.weight.data
          PR_connect = PR_norm(connectivity_matrix)

          # Compute sparsity from the hidden layer
          sparsity = L1_norm(x_h)

          # Compute participation ratio from the hidden layer
          PR = PR_norm(x_h)

          # <<Compute the total loss>>
          _, loss = custom_loss(train_output, train_target, sparsity, beta1, beta2, PR, beta3, PR_connect, beta4)

          # Validation loss
          val_loss += loss.item()

          # Get the max output?
          _, predicted = torch.max(test_output, 1)

          # Store the accuracy
          total += test_target.size(0) # for each batch, it gets added (so basically per epoch, total should be equal to the number of total input)
          correct += (predicted == test_target).sum().item() # same thing here

      accuracy = correct/total # you are already getting the accuracy for the whole training data
      accuracies.append(accuracy)

      # Average validation loss for this epoch
      avg_val_loss = val_loss / len(test_loader)
      valLosses.append(avg_val_loss)
      # if (epoch+1) % 1 == 0:
      #   print(f'Epoch [{epoch+1}/{num_epochs}], avg_crossEntropy_loss {avg_crossEntropy_loss:.6f}, Validation Loss: {val_loss/len(test_loader):.4f}')

      # Check if the average training loss is below the threshold and the epoch count is at least 100
      if avg_crossEntropy_loss < 0.5 and epoch >= 100: #0.001
        print(f"Stopping training at epoch {epoch + 1} with avg_crossEntropy_loss {avg_crossEntropy_loss:.6f}")
        break
      # if (epoch+1) % 1 == 0:
      #   print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss/len(train_loader):.4f}, Validation Loss: {val_loss/len(test_loader):.4f}')

    ## These are all after epochs are finished.
    ## These get added at the end of one experiment
    if accuracies[-1] == accuracy: # check to make sure only grab the last accuracy after all epochs
      total_accuracies[iter, exp_list[i][0], exp_list[i][1]] = accuracy

    # output of the hidden layer based on exp_testT
    new_x_h, _ = model(shuffled_testI)

    # Do PCA on the output
    pca = PCA(n_components = 3)
    new_x_h = new_x_h.detach().numpy()
    transformed_Xh = pca.fit_transform(new_x_h)
    pca_dict[i][iter,:,:] = transformed_Xh

  end_time = time.time()  # End time for the epoch
  epoch_time = end_time - start_time
  print(f"Epoch {epoch + 1} took {epoch_time} seconds")

iter 0
train # 0
test # 0
train # 0
test # 1
Stopping training at epoch 100 with avg_crossEntropy_loss 0.449157
train # 0
test # 2
train # 1
test # 0
Stopping training at epoch 100 with avg_crossEntropy_loss 0.430839
train # 1
test # 1
Stopping training at epoch 100 with avg_crossEntropy_loss 0.482184
train # 1
test # 2
Stopping training at epoch 100 with avg_crossEntropy_loss 0.488586
train # 2
test # 0


In [None]:
# FIND mean across iternums
mean_total_acc = np.average(total_accuracies, axis=0)

# Display the mean total acc
plt.imshow(mean_total_acc, aspect='auto', cmap='viridis')
plt.colorbar(label='Mean Accuracy')  # Adding a color legend with label

# Setting the tick labels for x and y axes
plt.xticks(ticks=[0, 1, 2], labels=['1', '2', '3'])
plt.yticks(ticks=[0, 1, 2], labels=['1', '2', '3'])

# Adding labels to the axes
plt.xlabel('Test')
plt.ylabel('Train')
plt.title('Mean Accuracy Across Iterations')

# Annotate each cell with the numeric value
for i in range(mean_total_acc.shape[0]):
    for j in range(mean_total_acc.shape[1]):
        plt.text(j, i, f'{mean_total_acc[i, j]:.2f}', ha='center', va='center', color='white', fontsize=14)

# Display the plot
plt.show()

In [None]:
celoss = []
for i in range(len(crossEntropyLosses)):
  celoss.append(crossEntropyLosses[i].detach().numpy())

In [None]:
celoss[-1]

In [None]:
celoss

In [None]:
plt.plot(celoss)
plt.xlabel("Epoch")
plt.ylabel("Cross Entropy Loss")
plt.title("Cross Entropy Loss")
plt.yscale('log')
plt.show()

In [None]:
# Mean PCA dictionary
mean_pca_dict = {i: np.zeros((400, 3)) for i in range(len(exp_list))}
for i in range(9):
  mean_pca_dict[i] = np.mean(pca_dict[i], axis=0)

In [None]:
# Calculate the centroid
four_points_dict = {i: np.zeros((4, 3)) for i in range(len(exp_list))}

for i in range(9):
  four_points_dict[i][0,:] = calculate_centroid(mean_pca_dict[i][:100,:])
  four_points_dict[i][1,:] = calculate_centroid(mean_pca_dict[i][100:200,:])
  four_points_dict[i][2,:] = calculate_centroid(mean_pca_dict[i][200:300,:])
  four_points_dict[i][3,:] = calculate_centroid(mean_pca_dict[i][300:,:])

In [None]:
# SHOW ALL THE POINTS
fig, axs = plt.subplots(3, 3, figsize=(10,10), subplot_kw={'projection': '3d'})
# Define colors for each subplot
colors = ['r', 'g', 'b', 'k']

subplot_titles = ['Train 1, Test 1', 'Train 1, Test 2', 'Train 1, Test 3',
                  'Train 2, Test 1', 'Train 2, Test 2', 'Train 2, Test 3',
                  'Train 3, Test 1', 'Train 3, Test 2', 'Train 3, Test 3']

for i, ax in enumerate(axs.flat):
    points = mean_pca_dict[i]
    # Scatter plot the transformed data
    for _, (idx, color) in enumerate(zip(range(4), ['r','g','b','k'])):
        ax.scatter(points[idx*100:(idx+1)*100, 0],
                  points[idx*100:(idx+1)*100, 1],
                  points[idx*100:(idx+1)*100, 2],
                  c=color, marker='o')


    # Set labels and title (axis is important)
    ax.set_xlabel('PC 1')
    ax.set_ylabel('PC 2')
    ax.set_zlabel('PC 3')
    ax.set_title(subplot_titles[i])
    ax.set_xlim([-15, 15])
    ax.set_ylim([-15, 15])
    ax.set_zlim([-15, 15])

plt.show()

In [None]:
# Show the main centroid points
fig, axs = plt.subplots(3, 3, figsize=(8,8), subplot_kw={'projection': '3d'})

# Define colors for each subplot
colors = ['r', 'g', 'b', 'k']

# Plot each set of points in its respective subplot
for i, ax in enumerate(axs.flat):
    points = four_points_dict[i]
    for j in range(4):
        ax.scatter(points[j, 0], points[j, 1], points[j, 2],
                   c=colors[j], marker='o', label=f'Point {j}')
    ax.set_title(subplot_titles[i])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    # ax.set_xlim([-1, 1])
    # ax.set_ylim([-1, 1])
    # ax.set_zlim([-1, 1])
    ax.legend()

# Adjust layout
plt.tight_layout()
plt.show()