In [133]:
# Dataset creation

import torch
import plotly.express as px

# device = "cpu"
# dtype = torch.float32
# n = 100
# d = 2
# seed = 0


# rnd = torch.randn(n, d, dtype=dtype)
# samples = rnd / torch.norm(rnd, dim=-1, keepdim=True)
# uniform = torch.rand(n, dtype=dtype)
# sample_scaling = uniform**(1/d)
# samples = samples * sample_scaling[:, None]


# fig = px.scatter(x= samples[:, 0], y = samples[:, 1])
# fig.add_shape(type="circle",
#     x0=-1, y0=-1, x1=1, y1=1,
#     line_color="LightSeaGreen",
# )
# fig.show()


In [134]:
def sample_from_unit_ball(n, d, seed=0):
    rnd = torch.randn(n, d, dtype=dtype)
    samples = rnd / torch.norm(rnd, dim=-1, keepdim=True)
    uniform = torch.rand(n, dtype=dtype)
    sample_scaling = uniform**(1/d)
    samples = samples * sample_scaling[:, None]
    return samples

def plot_samples(samples):
    fig = px.scatter(x= samples[:, 0], y = samples[:, 1])
    fig.add_shape(type="circle",
        x0=-1, y0=-1, x1=1, y1=1,
        line_color="LightSeaGreen",
    )
    fig.show()

    
plot_samples(sample_from_unit_ball(100, 2))

In [135]:
ground_truth_features = sample_from_unit_ball(100, 512)

In [136]:
# sample from on average 5 of n features using bernoully distribution
import random

torch.Size([100])

In [140]:
# get samples from unit sphere
# take 5 / n of them 
# sum them up

# not doing today: feature scaling "activation"
# not doing today: correlation between features
# not doing today: feature rescaling 

In [180]:
n = 100
d = 512

feature_space = sample_from_unit_ball(n, d)

def sample_from_bernoulli(n, d):
    return torch.bernoulli(torch.ones(n) * 5 / n)

def sample_feature_space(samples, n, d):
    vector_sample = sample_from_bernoulli(n, d)
    return samples[vector_sample == 1].sum(0)

# sample_from_bernoulli(100, 512).shape

def get_batch_of_samples(feature_space, n, d, batch_size):
    return torch.stack([sample_feature_space(feature_space, n, d) for _ in range(batch_size)])

get_batch_of_samples(feature_space, 100, 512, 10)

tensor([[-0.1231,  0.0477, -0.0025,  ...,  0.0296,  0.0773,  0.1008],
        [-0.0110, -0.0353,  0.0299,  ..., -0.0240,  0.2109,  0.1365],
        [-0.0359,  0.0807,  0.0114,  ..., -0.0256, -0.0229, -0.0073],
        ...,
        [-0.1415, -0.0092,  0.0125,  ..., -0.0481, -0.1334,  0.0244],
        [-0.1646,  0.0378, -0.0733,  ...,  0.0207,  0.1183, -0.0186],
        [ 0.0358,  0.1248,  0.1400,  ...,  0.0190,  0.2160, -0.0585]])

In [163]:
import torch.nn as nn


class AutoEncoderModel(nn.Module):
    def __init__(self, J = 100, h = 512):
        super().__init__()
        self.linear = nn.Linear(h, J)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.activation(self.linear(x))

class DictionaryModel(nn.Module):
    def __init__(self, J = 100, h = 512):
        super().__init__()
        # register parameters
        self.register_parameter("dictionary", nn.Parameter(torch.randn(h, J)))
        torch.nn.init.orthogonal_(self.dictionary)

    def forward(self, c):
        # constrain columns of D to be unit vectors
        tmp = self.dictionary / torch.norm(self.dictionary, dim=0, keepdim=True)
        Dc = c @ tmp.T
        return Dc

batch_size = 64 
n = 100
d = 512

samples = get_batch_of_samples(n, d, batch_size)

model = AutoEncoderModel()
c = model.forward(samples)
print(c.shape)

dictionary_model = DictionaryModel(J = 100, h = 512)
Dc = dictionary_model.forward(c)
print(Dc.shape)

torch.Size([64, 100])
torch.Size([64, 512])


In [162]:
l_reconstruction = torch.norm(samples-Dc, dim=1)
l_regularization = torch.norm(c, dim=1)
L = l_reconstruction + l_regularization
Loss = L.mean()
Loss 

tensor(2.7355, grad_fn=<MeanBackward0>)

In [183]:
from tqdm.notebook import tqdm

def train(feature_space, model, dictionary_model, G, h, batch_size, epochs, lr, alpha = 0.1):
    '''
    feature_space: samples from unit ball
    model: AutoEncoderModel
    dictionary_model: DictionaryModel
    G: True number of features
    h: Actual number of features
    batch_size: number of samples in a batch
    epochs: number of epochs
    lr: learning rate
    alpha: regularization parameter
    '''
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    l_reconstruction_cache = []
    l_regularization_cache = []
    Loss_cache = []
    model_checkpoints = []
    dictionary_checkpoints = []

    pbar = tqdm(range(epochs))
    for epoch in pbar:
        samples = get_batch_of_samples(feature_space, G, h, batch_size)
        c = model.forward(samples)
        Dc = dictionary_model.forward(c)
        l_reconstruction = torch.norm(samples-Dc, dim=1)
        l_regularization = torch.norm(c, dim=1)
        L = l_reconstruction + alpha*l_regularization
        Loss = L.mean()
        optimizer.zero_grad()
        Loss.backward()
        optimizer.step()


        l_reconstruction_cache.append(l_reconstruction.mean().item())
        l_regularization_cache.append(l_regularization.mean().item())
        Loss_cache.append(Loss.item())
        model_checkpoints.append(model.state_dict())
        dictionary_checkpoints.append(dictionary_model.state_dict())

        # if epoch % 10 == 0:
        #     print(f"epoch {epoch} loss {Loss.item()}")
        pbar.set_description(f"epoch {epoch} loss {Loss.item()}")

    return model, dictionary_model, l_reconstruction_cache, l_regularization_cache, Loss_cache, model_checkpoints, dictionary_checkpoints



G = J = 512
h = 256
batch_size = 64
epochs = 1000
lr = 1e-3
feature_space = sample_from_unit_ball(G, h)
model = AutoEncoderModel(J=J, h=h)
dictionary_model = DictionaryModel(J=J, h=h)

model, dictionary_model, l_reconstruction_cache, l_regularization_cache, Loss_cache, model_checkpoints, dictionary_checkpoints = train(
    feature_space, model, dictionary_model, G, h, batch_size, epochs, lr)


  0%|          | 0/1000 [00:00<?, ?it/s]

In [185]:
import plotly.express as px

fig = px.line(x=range(len(Loss_cache)), y=Loss_cache, title="Loss",
              labels={"x": "epoch", "y": "Loss"}, log_y=True)
fig.show()


fig = px.line(x=range(len(l_reconstruction_cache)), y=l_reconstruction_cache,
              title="l_reconstruction", labels={"x": "epoch", "y": "l_reconstruction"}, log_y=True)
fig.show()

fig = px.line(x=range(len(l_regularization_cache)), y=l_regularization_cache,
              title="l_regularization", labels={"x": "epoch", "y": "l_regularization"}, log_y=True)
fig.show()


In [188]:
dictionary_model.dictionary.shape

torch.Size([256, 512])

In [190]:
feature_space.shape

torch.Size([512, 256])

In [204]:
from torch.nn.functional import cosine_similarity

# get mean max cosine similarity
cosine_sim_Dj_Fg = torch.zeros((G, G))
for g in range(G):
    for j in range(J):
        d_j = dictionary_model.dictionary[:, j]
        f_g = feature_space[g, :]
        cosine_sim_Dj_Fg[g, j] = cosine_similarity(d_j, f_g, dim=0)

In [210]:
cosine_sim_Dj_Fg.mean(dim=1).max()

tensor(0.0079, grad_fn=<MaxBackward1>)

In [207]:
px.imshow(cosine_sim_Dj_Fg.detach())