In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import imageio
import os
import numpy as np

class CustomModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CustomModel, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        x = self.fc(x)
        return torch.relu(x).sum(dim=1)
# Function to plot the weights
def plot_weights(orig_weights, epoch, folder='weight_plots_3d'):
    if not os.path.exists(folder):
        os.makedirs(folder)
    
    proj_weights = []
    for w in orig_weights:
        proj_weights.append(project_weight(w))
    
    weights = np.array(proj_weights)
    
    plt.figure(figsize=(8, 6))
    for i in range(weights.shape[0]):
        plt.scatter(weights[i, 0], weights[i, 1], label=f'w_{i}')
    
    plt.xlim([-0.08,0.08])
    plt.ylim([-0.08,0.08])
    plt.axhline(0, color='black',linewidth=0.5) 
    plt.axvline(0, color='black',linewidth=0.5) 
    plt.xlabel('Projection plane: x+y+z = 1')
    plt.title(f'Epoch {epoch}')
    plt.grid(True,linestyle='--')
    plt.savefig(f'{folder}/epoch_{epoch}.png')
    plt.close()

# def project_weight(point):
#     a = np.sum(point,axis=0)/3
#     p = point-a+(1/3)
#     y = np.sqrt((2*np.power(0.5*(p[0]+p[1]+1),2))+np.power(p[2],2))
#     t = (p[0]+p[1]-(2*p[2])+2)/6
#     x = np.sqrt(np.power((p[0]-t),2) + np.power((p[1]-t),2) + np.power((p[2]-1 + (2*t)),2))
#     proj = [x,y]
#     return proj

def project_weight(point):
    normal = np.array([1, 1, 1]) / np.sqrt(3)
    d = 1 / np.sqrt(3)
    proj_point = point - (np.dot(normal, point) - d) / np.linalg.norm(normal)**2 * normal
    basis1 = np.array([1, -1, 0]) / np.sqrt(2)  
    basis2 = np.cross(normal, basis1)           
    x_new = np.dot(proj_point, basis1)
    y_new = np.dot(proj_point, basis2)

    return [x_new,y_new]


In [None]:
d = 300  
m = 50  
n_samples = 50000  

X = torch.randn(n_samples, d)
y = torch.relu(X[:, 0]) + torch.relu(X[:, 1]) + torch.relu(X[:, 2])

model = CustomModel(d, m)


with torch.no_grad():
    model.fc.weight.data.normal_(0, 1e-15)

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

print_epoch = [i for i in range(1,60,1)]+[i for i in range(80,1600,40)]
# print_epoch = [1,2,3,4,5,10,20,30,40,50,100,200,400,600,800]

In [None]:
epochs = 1600
weight_records = []
for epoch in range(epochs):
    outputs = model(X)
    loss = criterion(outputs, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch in print_epoch:
        print(f'Epoch [{epoch}/{epochs}], Loss: {loss.item()}')
        current_weights = model.fc.weight.data[:, :3].numpy()
        weight_records.append(current_weights)
        plot_weights(current_weights, epoch)

In [None]:
folder = 'weight_plots_3d'
def sort_epoch(filenames):
    return sorted(filenames, key=lambda x: int(x.split('_')[1].split('.')[0]))
png_files = [f for f in os.listdir(folder) if f.endswith('.png')]

file_names = sort_epoch(png_files)
image_paths = [os.path.join(folder, file) for file in file_names if file.endswith('.png')]
images = [imageio.imread(path) for path in image_paths]


gif_path = 'weight_evolution_3d.gif'
imageio.mimsave(gif_path, images, fps=50)
