In [1]:
from tqdm import tqdm
import torch
import numpy as np
import matplotlib.pyplot as plt

import matplotlib

# plt.rc('text', usetex=True)
# plt.rc('font', family='serif')


%config InlineBackend.figure_format = 'retina'
matplotlib.rcParams.update({
        "font.family": "serif",
       "font.serif": ["DejaVu Serif", "Bitstream Vera Serif", "Computer Modern Roman", "New Century Schoolbook", "Century Schoolbook L", "Utopia", "ITC Bookman", "Bookman", "Nimbus Roman No9 L", "Times New Roman", "Times", "Palatino", "Charter", "serif"],
        "axes.labelsize": 18,
        "font.size": 18,
        "legend.fontsize": 16,
        "xtick.labelsize": 18,
        "ytick.labelsize": 18,
})

In [2]:
collected_grads = torch.load("./adam_c4_grads/grad_dicts.pt")

In [3]:
def principal_angles(Q1, Q2):
    _, S, _ = torch.linalg.svd(Q1.T @ Q2)
    return S

In [4]:
R1 = torch.nn.init.orthogonal_(torch.empty(512,512))[:, :128]
R2 = torch.nn.init.orthogonal_(torch.empty(512,512))[:, :128]
random_semiorthogonal_angles = principal_angles(R1, R2)

In [5]:
projections_dict = {}
steps = [1000, 1100, 2000, 10000, 99000]
for step in steps:
    projections_dict[step] = torch.linalg.svd(collected_grads[29][step])[0][:, :128]

In [6]:
angles_dict = {}
for step in steps[1:]:
    angles_dict[step] = principal_angles(projections_dict[steps[0]], projections_dict[step])

In [None]:
for step in steps[1:]:
    print(f"steps: (1000, {step})\t number of cosines larger than 0.87: ", (angles_dict[step] > 0.87).sum().item())

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(25, 5))
indices = steps[2:]

for i, index in enumerate(indices):
    if index in angles_dict:
        axs[i].hist(angles_dict[index], bins=30, edgecolor='black')
        axs[i].set_title(f'{1000} and {index}', fontsize=45)
        axs[i].set_xlabel('Principal cosines', fontsize=45)
        if not i:
            axs[i].set_ylabel('Frequency', fontsize=45)
    else:
        axs[i].text(0.5, 0.5, f'No data for index {index}', 
                    ha='center', va='center', transform=axs[i].transAxes)

axs[3].hist(random_semiorthogonal_angles, bins=30, edgecolor='black')
axs[3].set_title('Random', fontsize=45)
axs[3].set_xlabel('Principal cosines', fontsize=45)

plt.tight_layout()
plt.savefig(f'4_histograms.pdf', bbox_inches='tight')

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(25, 5))
indices = steps[1:]

for i, index in enumerate(indices):
    if index in angles_dict:
        axs[i].hist(angles_dict[index], bins=30, edgecolor='black')
        axs[i].set_title(f'Iterations 1000 and {index}')
        axs[i].set_xlabel('Principal angle cosines')
        if not i:
            axs[i].set_ylabel('Frequency')
    else:
        axs[i].text(0.5, 0.5, f'No data for index {index}', 
                    ha='center', va='center', transform=axs[i].transAxes)

axs[4].hist(random_semiorthogonal_angles, bins=30, edgecolor='black')
axs[4].set_title('Random semiorthogonal')
axs[4].set_xlabel('Principal angle cosines')
axs[4].set_ylabel('Frequency')

plt.tight_layout()
plt.savefig(f'5_histograms.pdf', bbox_inches='tight')