In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import imageio
import os
import numpy as np

class CustomModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CustomModel, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        x = self.fc(x)
        return torch.relu(x).sum(dim=1)

def plot_weights(weights, epoch, folder='weight_plots_2d'):
    if not os.path.exists(folder):
        os.makedirs(folder)
    
    plt.figure(figsize=(8, 6))
    for i in range(weights.shape[0]):
        plt.scatter(weights[i, 0], weights[i, 1], label=f'w_{i}')
    plt.xlim([-0.02,0.06])
    plt.ylim([-0.02,0.06])
    plt.axhline(0, color='black',linewidth=0.5) 
    plt.axvline(0, color='black',linewidth=0.5) 
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title(f'Epoch {epoch}')
    # plt.legend()
    plt.grid(True,linestyle='--')
    plt.savefig(f'{folder}/epoch_{epoch}.png')
    plt.close()

In [None]:
d = 300 
m = 50 
n_samples = 50000  

X = torch.randn(n_samples, d)
y = torch.relu(X[:, 0]) + torch.relu(X[:, 1])

model = CustomModel(d, m)


with torch.no_grad():
    model.fc.weight.data.normal_(0, 1e-15)

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

print_epoch = [0,1,2,3,4,5,10,15,20,25,30,35,40,45,50,100,150,200,400,600,800]

In [None]:
epochs = 1000
weight_records = []
for epoch in range(epochs):
    outputs = model(X)
    loss = criterion(outputs, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch in print_epoch:
        print(f'Epoch [{epoch}/{epochs}], Loss: {loss.item()}')
        current_weights = model.fc.weight.data[:, :2].numpy()
        weight_records.append(current_weights)
        plot_weights(current_weights, epoch)

In [None]:
folder = 'weight_plots_2d'
def sort_epoch(filenames):
    return sorted(filenames, key=lambda x: int(x.split('_')[1].split('.')[0]))
png_files = [f for f in os.listdir(folder) if f.endswith('.png')]
file_names = sort_epoch(png_files)
image_paths = [os.path.join(folder, file) for file in file_names if file.endswith('.png')]
images = [imageio.imread(path) for path in image_paths]
gif_path = 'weight_evolution_2d.gif'
imageio.mimsave(gif_path, images, fps=8)
