In [2]:
import numpy as np
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd

### Test Matrix

In [35]:
# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Create a user-item interaction matrix (100 users, 50 items) with sparsity
num_users = 10000
num_items = 500
interaction_matrix = np.random.randint(0, 6, size=(num_users, num_items))  # Random interactions from 0 to 5

# Introduce sparsity by setting a high percentage of interactions to 0
sparsity = 0.8  # 80% of the interactions will be set to 0
mask = np.random.rand(*interaction_matrix.shape) < sparsity
interaction_matrix[mask] = 0

# Convert to PyTorch tensor
interaction_tensor = torch.tensor(interaction_matrix, dtype=torch.float32)

# Add noise to the input data
def add_noise(data, noise_factor=0.3):
    noisy_data = data + noise_factor * torch.randn_like(data)
    noisy_data = torch.clamp(noisy_data, 0., 5.)  # Ensure values stay within the interaction range
    return noisy_data

noisy_interaction_tensor = add_noise(interaction_tensor)

# Print the original and noisy matrices (first 10 users for brevity)
print("Original Interaction Matrix (first 10 users):")
print(interaction_matrix[:10])  # Print only the first 10 users for brevity
print("\nNoisy Interaction Matrix (first 10 users):")
print(noisy_interaction_tensor[:10].numpy())


Original Interaction Matrix (first 10 users):
[[0 0 0 ... 0 0 0]
 [0 5 0 ... 0 0 4]
 [0 0 0 ... 0 0 5]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 4 ... 0 0 0]]

Noisy Interaction Matrix (first 10 users):
[[0.57807463 0.44618523 0.27021518 ... 0.         0.19485128 0.02848756]
 [0.         4.805831   0.         ... 0.         0.         3.971974  ]
 [0.         0.12133284 0.6427789  ... 0.         0.         4.9596815 ]
 ...
 [0.16250879 0.08124392 0.19579445 ... 0.09819969 0.18892097 0.5638384 ]
 [0.         0.3908983  0.         ... 0.         0.         0.        ]
 [0.07031945 0.         3.5793095  ... 0.         0.         0.36294198]]


In [30]:
class AutoEncoder(nn.Module):
	def __init__(self, input_dim, bottleneck_size, device='cpu'):
		super(AutoEncoder, self).__init__()
		self.device = device
		self.encoder = nn.Sequential(
			nn.Linear(input_dim, 128),
			nn.ReLU(),
			nn.Linear(128, 64),
			nn.ReLU(),
			nn.Linear(64, bottleneck_size)
		)
		self.decoder = nn.Sequential(
			nn.Linear(bottleneck_size, 64),
			nn.ReLU(),
			nn.Linear(64, 128),
			nn.ReLU(),
			nn.Linear(128, input_dim)
		)
	
	def forward(self, x):
		x = self.encoder(x)
		x = self.decoder(x)
		return x
	
	def fit(self, batches, n_epochs=100, min_delta=0.0001, lr=0.001, patience=10):
		optimizer = torch.optim.Adam(self.parameters(), lr=lr)
		criterion = nn.MSELoss()
		best_loss = float('inf')
		patience_counter = 0

		for epoch in range(n_epochs):
			epoch_loss = 0.0
			for batch in batches:
				batch = batch[0].to(self.device)  # Move batch to device
				optimizer.zero_grad()
				output = self.forward(batch)
				loss = criterion(output, batch)
				loss.backward()
				optimizer.step()
				epoch_loss += loss.item()

			epoch_loss /= len(batches)

			if epoch_loss < best_loss - min_delta:
				best_loss = epoch_loss
				patience_counter = 0
			else:
				patience_counter += 1

			if patience_counter >= patience:
				print(f"Early stopping at epoch {epoch+1} with loss {epoch_loss:.4f}")
				break

			print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {epoch_loss:.4f}')
		return


In [36]:
batch_size = 128 
train = TensorDataset(noisy_interaction_tensor)
batches = DataLoader(train, batch_size=batch_size, shuffle=True)

In [16]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [37]:
# Instantiate the SparseAutoEncoder model
input_dim = num_items
bottleneck_size = 10
model = SparseAutoEncoder(input_dim, bottleneck_size, device=device).to(device)
model.fit(batches, n_epochs=1000, lr=0.001, patience=10)

Epoch [1/1000], Loss: 0.9100
Epoch [2/1000], Loss: 0.8144
Epoch [3/1000], Loss: 0.8144
Epoch [4/1000], Loss: 0.8130
Epoch [5/1000], Loss: 0.8099
Epoch [6/1000], Loss: 0.8079
Epoch [7/1000], Loss: 0.8064
Epoch [8/1000], Loss: 0.8055
Epoch [9/1000], Loss: 0.8029
Epoch [10/1000], Loss: 0.8010
Epoch [11/1000], Loss: 0.7997
Epoch [12/1000], Loss: 0.7985
Epoch [13/1000], Loss: 0.7981
Epoch [14/1000], Loss: 0.7983
Epoch [15/1000], Loss: 0.7972
Epoch [16/1000], Loss: 0.7969
Epoch [17/1000], Loss: 0.7953
Epoch [18/1000], Loss: 0.7954
Epoch [19/1000], Loss: 0.7947
Epoch [20/1000], Loss: 0.7944
Epoch [21/1000], Loss: 0.7940
Epoch [22/1000], Loss: 0.7945
Epoch [23/1000], Loss: 0.7943
Epoch [24/1000], Loss: 0.7934
Epoch [25/1000], Loss: 0.7929
Epoch [26/1000], Loss: 0.7933
Epoch [27/1000], Loss: 0.7930
Epoch [28/1000], Loss: 0.7926
Epoch [29/1000], Loss: 0.7922
Epoch [30/1000], Loss: 0.7927
Epoch [31/1000], Loss: 0.7922
Epoch [32/1000], Loss: 0.7905
Epoch [33/1000], Loss: 0.7917
Epoch [34/1000], Lo

### Training the autoencoders

### Training on Colab using A100 and L4 GPUs

In [6]:
model = None

A primitive method to test 

In [7]:
tested = True

In [8]:
user_ratings = pd.read_csv('full_matrix.csv.gzip', compression='gzip', index_col=0)
user_ratings.notna().count().sum()
user_ratings = user_ratings.fillna(0)

In [22]:
user_ratings

Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
1882931173,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0826414346,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0829814000,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0595344550,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0253338352,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0590482467,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0570047870,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000OVF7JY,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1402508735,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [9]:
import os
if tested:
	print('Tested')
elif os.path.exists('amazon_model_weights.pt') and os.path.exists('amazon_model.pt'):
	model = torch.load('amazon_model.pt')
	model.load_state_dict(torch.load('amazon_model_weights.pt'))
else:
	tensors = torch.tensor(user_ratings.to_numpy(), dtype=torch.float32, device=device)
	batch_size = 32 
	train = TensorDataset(tensors)
	batches = DataLoader(train, batch_size=batch_size, shuffle=True)
	model = AutoEncoder(50, 10)
	model.fit(batches, 1000, 0.0001, 0.0001, 10)

Tested


In [44]:
def mask_test_model(model: AutoEncoder, mask_fraction: float, row: torch.Tensor, device: torch.device):
    model.eval()
    with torch.no_grad():
        criterion = nn.MSELoss()
        
 
        
        # Create the mask
        mask = torch.rand(row.shape).to(device) < mask_fraction
        
        # Apply the mask to the row
        masked_row = row.clone()
        masked_row[mask] = 0
        
        # Get the model's predictions
        predictions = model.forward(masked_row)

        # Calculate the loss only for the masked values
        loss = criterion(predictions[mask], row[mask])
        print(f'Test loss: {loss.item()}')
        
        # Optionally, return the predictions and mask for further analysis
        return predictions.cpu().numpy(), mask.cpu().numpy()

In [11]:
if not tested:
	row = torch.tensor(user_ratings.iloc[0].to_numpy(), dtype=torch.float32, device=device)
	res , _ = mask_test_model(model, 0.2, row, user_ratings, device)
	res[res >= 0.1]

In [12]:
from numpy import linalg

In [136]:
# def test_sample_model(model: AutoEncoder, interaction_matrix: pd.DataFrame, device: torch.device, sample_size: int = 10):
# 	model.eval()
	
# 	sampled = interaction_matrix.sample(sample_size)
# 	display(sampled)
# 	tested = sampled.apply(lambda row: mask_test_model(model, 0.2, torch.tensor(row.to_numpy(), dtype=torch.float32, device=device)[0], device), axis=0).to_numpy()
# 	# s = torch.tensor(sampled)
# 	return linalg.norm(sampled.to_numpy() - tested, ord='fro'), tested

In [159]:
def test_sample_model(model: AutoEncoder, interaction_matrix: pd.DataFrame, device: torch.device, sample_size: int = 10):
    model.eval()
    sample = interaction_matrix.sample(sample_size, random_state=42)
    
    # Convert the sample DataFrame to a tensor
    sample_tensor = torch.tensor(sample.to_numpy(), dtype=torch.float32, device=device)
    
    # Test the model on each row of the sample
    # tested = sample.apply(lambda row: mask_test_model(model, 0.2, torch.tensor(row.to_numpy(), dtype=torch.float32, device=device), device)[0], axis=1)
    tested = []
    for row in sample.iterrows():
        prediction = mask_test_model(model, 0.2, torch.tensor(row[1].to_numpy(), dtype=torch.float32, device=device), device)[0]
        # display(prediction)
        prediction = pd.DataFrame(prediction)
        tested.append(prediction)
    

    tested = pd.concat(tested, axis=1)
    tested.index = sample.columns
    tested.columns = sample.index

    # Calculate the Frobenius norm of the difference
    # calculate the rmse
    
    # loss = linalg.norm(sample.to_numpy() - tested.to_numpy().T, ord='fro')
    mse = np.mean((sample.to_numpy() - tested.to_numpy().T)**2)
    loss = np.sqrt(mse)
    
    return loss, tested.T, sample

In [108]:
def fit_by_bottleneck(user_ratings: pd.DataFrame, device: torch.device, bottleneck: int = 10):
	tensors = torch.tensor(user_ratings.to_numpy(), dtype=torch.float32, device=device)
	batch_size = 32 
	train = TensorDataset(tensors)
	batches = DataLoader(train, batch_size=batch_size, shuffle=True)
	model = AutoEncoder(50, bottleneck)
	model.fit(batches)
	# save the model
	torch.save(model, f'model_k={bottleneck}.pt')
	torch.save(model.state_dict(), f'model_weights_{bottleneck}.pt')

In [163]:
if not tested:
	fit_by_bottleneck(user_ratings, device, 5)
	fit_by_bottleneck(user_ratings, device, 10)
	fit_by_bottleneck(user_ratings, device, 15)
	fit_by_bottleneck(user_ratings, device, 20)
	fit_by_bottleneck(user_ratings, device, 25)

ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().

### Testing the model by bottleneck

In [142]:
def test_bottleneck(user_ratings: pd.DataFrame, device: torch.device, bottleneck: int = 10):
    print(f'Testing bottleneck {bottleneck}')
    model = torch.load(f'models/model/model_k={bottleneck}.pt')
    model.load_state_dict(torch.load(f'models/weights/model_weights_{bottleneck}.pt'))
    return test_sample_model(model, user_ratings, device)

Results

In [160]:
bottleneck_results = [test_bottleneck(user_ratings, device, k) for k in [5, 10, 15, 20, 25]]
# bottleneck_results = test_bottleneck(user_ratings, device, 5)

Testing bottleneck 5
Test loss: 0.011637355200946331
Test loss: 6.3904794842528645e-06
Test loss: 8.037904990487732e-06
Test loss: 8.000823072507046e-06
Test loss: 6.922076408955036e-06
Test loss: 5.157994110049913e-06
Test loss: 7.533772077294998e-06
Test loss: 6.840210062364349e-06
Test loss: 6.801791187172057e-06
Test loss: 8.793431334197521e-06
Testing bottleneck 10
Test loss: 0.00027783430414274335
Test loss: 7.97837128629908e-06
Test loss: 8.222714313887991e-06
Test loss: 6.847804343124153e-06
Test loss: 7.304269729502266e-06
Test loss: 6.936451882211259e-06
Test loss: 7.345483936660457e-06
Test loss: 7.853159331716597e-06
Test loss: 8.445888852293137e-06
Test loss: 6.999188826739555e-06
Testing bottleneck 15
Test loss: 0.0012787140440195799
Test loss: 6.725597813783679e-06
Test loss: 6.9471466304094065e-06
Test loss: 6.769527317374013e-06
Test loss: 6.25945949650486e-06
Test loss: 7.082455340423621e-06
Test loss: 6.003573162161047e-06
Test loss: 7.090658073138911e-06
Test loss: 

In [162]:
for loss, tested, sample in bottleneck_results:
	print(f'rmse norm: {loss}')
	print('Test')
	display(tested)
	print('Sampled')
	display(sample)

rmse norm: 0.024923525699856036
Test


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,0.007401,0.000338,0.000927,9e-05,0.000987,0.007616,3e-06,0.000569,0.00027,0.001156,...,0.000629,-0.001164,0.004665,0.000591,0.000304,6.8e-05,0.000436,0.000211,-0.002611,0.000676
0138421471,0.007275,-0.000242,0.000533,1.6e-05,-0.000518,0.007551,-3.8e-05,-0.000438,1.1e-05,0.000949,...,-4.9e-05,-0.001459,0.004767,0.000108,-8.9e-05,-2.3e-05,4.5e-05,6.7e-05,-0.003155,0.000577
0385317999,0.007275,-0.000242,0.000533,1.6e-05,-0.000518,0.007551,-3.8e-05,-0.000438,1.1e-05,0.000949,...,-4.9e-05,-0.001459,0.004767,0.000108,-8.9e-05,-2.3e-05,4.5e-05,6.7e-05,-0.003155,0.000577
B000C4SS5I,0.00729,-0.000198,0.00057,1.7e-05,-0.000327,0.007537,-3.4e-05,-0.000317,3.3e-05,0.00095,...,1.1e-05,-0.001427,0.004708,0.00014,-4.5e-05,-1.6e-05,8e-05,6.8e-05,-0.003105,0.000585
B000OTYZHG,0.007275,-0.000242,0.000533,1.6e-05,-0.000518,0.007551,-3.8e-05,-0.000438,1.1e-05,0.000949,...,-4.9e-05,-0.001459,0.004767,0.000108,-8.9e-05,-2.3e-05,4.5e-05,6.7e-05,-0.003155,0.000577
B0006YV4SW,0.007275,-0.000242,0.000533,1.6e-05,-0.000518,0.007551,-3.8e-05,-0.000438,1.1e-05,0.000949,...,-4.9e-05,-0.001459,0.004767,0.000108,-8.9e-05,-2.3e-05,4.5e-05,6.7e-05,-0.003155,0.000577
1885210086,0.007275,-0.000242,0.000533,1.6e-05,-0.000518,0.007551,-3.8e-05,-0.000438,1.1e-05,0.000949,...,-4.9e-05,-0.001459,0.004767,0.000108,-8.9e-05,-2.3e-05,4.5e-05,6.7e-05,-0.003155,0.000577
0806504757,0.007275,-0.000242,0.000533,1.6e-05,-0.000518,0.007551,-3.8e-05,-0.000438,1.1e-05,0.000949,...,-4.9e-05,-0.001459,0.004767,0.000108,-8.9e-05,-2.3e-05,4.5e-05,6.7e-05,-0.003155,0.000577
059513534X,0.007275,-0.000242,0.000533,1.6e-05,-0.000518,0.007551,-3.8e-05,-0.000438,1.1e-05,0.000949,...,-4.9e-05,-0.001459,0.004767,0.000108,-8.9e-05,-2.3e-05,4.5e-05,6.7e-05,-0.003155,0.000577
0155067036,0.007275,-0.000242,0.000533,1.6e-05,-0.000518,0.007551,-3.8e-05,-0.000438,1.1e-05,0.000949,...,-4.9e-05,-0.001459,0.004767,0.000108,-8.9e-05,-2.3e-05,4.5e-05,6.7e-05,-0.003155,0.000577


Sampled


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0138421471,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0385317999,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000C4SS5I,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000OTYZHG,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B0006YV4SW,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1885210086,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0806504757,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
059513534X,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0155067036,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


rmse norm: 0.020882619346149375
Test


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,-0.000366,0.011688,-0.002928,-0.000897,0.020238,-0.00206,-0.000153,-0.003058,-0.001301,8.6e-05,...,0.003293,5.7e-05,-0.004039,-0.003681,0.008169,-0.000485,-0.001045,0.018226,-4.8e-05,-0.004193
0138421471,-8.3e-05,-0.000515,-7e-06,-0.000494,-0.001549,-0.003302,0.000143,-0.000984,-0.001297,-0.000546,...,-0.000225,-0.000141,-0.001672,8.5e-05,-0.000252,-0.000623,-0.00029,-0.000629,-0.000285,0.001126
0385317999,-8.3e-05,-0.000515,-7e-06,-0.000494,-0.001549,-0.003302,0.000143,-0.000984,-0.001297,-0.000546,...,-0.000225,-0.000141,-0.001672,8.5e-05,-0.000252,-0.000623,-0.00029,-0.000629,-0.000285,0.001126
B000C4SS5I,-8.1e-05,-0.0005,2e-06,-0.00048,-0.001438,-0.003229,0.00014,-0.000954,-0.00126,-0.00051,...,-0.000203,-0.000132,-0.001616,0.00012,-0.000248,-0.000609,-0.000278,-0.000595,-0.000273,0.001136
B000OTYZHG,-8.3e-05,-0.000515,-7e-06,-0.000494,-0.001549,-0.003302,0.000143,-0.000984,-0.001297,-0.000546,...,-0.000225,-0.000141,-0.001672,8.5e-05,-0.000252,-0.000623,-0.00029,-0.000629,-0.000285,0.001126
B0006YV4SW,-8.3e-05,-0.000515,-7e-06,-0.000494,-0.001549,-0.003302,0.000143,-0.000984,-0.001297,-0.000546,...,-0.000225,-0.000141,-0.001672,8.5e-05,-0.000252,-0.000623,-0.00029,-0.000629,-0.000285,0.001126
1885210086,-8.3e-05,-0.000515,-7e-06,-0.000494,-0.001549,-0.003302,0.000143,-0.000984,-0.001297,-0.000546,...,-0.000225,-0.000141,-0.001672,8.5e-05,-0.000252,-0.000623,-0.00029,-0.000629,-0.000285,0.001126
0806504757,-8.3e-05,-0.000515,-7e-06,-0.000494,-0.001549,-0.003302,0.000143,-0.000984,-0.001297,-0.000546,...,-0.000225,-0.000141,-0.001672,8.5e-05,-0.000252,-0.000623,-0.00029,-0.000629,-0.000285,0.001126
059513534X,-8.3e-05,-0.000515,-7e-06,-0.000494,-0.001549,-0.003302,0.000143,-0.000984,-0.001297,-0.000546,...,-0.000225,-0.000141,-0.001672,8.5e-05,-0.000252,-0.000623,-0.00029,-0.000629,-0.000285,0.001126
0155067036,-8.3e-05,-0.000515,-7e-06,-0.000494,-0.001549,-0.003302,0.000143,-0.000984,-0.001297,-0.000546,...,-0.000225,-0.000141,-0.001672,8.5e-05,-0.000252,-0.000623,-0.00029,-0.000629,-0.000285,0.001126


Sampled


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0138421471,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0385317999,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000C4SS5I,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000OTYZHG,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B0006YV4SW,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1885210086,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0806504757,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
059513534X,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0155067036,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


rmse norm: 0.020430369555371225
Test


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,-0.000141,0.034415,-0.000925,0.001314,0.006034,0.004183,0.002625,0.004575,0.000271,0.027823,...,0.00166,0.000354,0.012163,0.010628,-0.000167,0.010788,0.001721,0.003728,-0.000115,0.008388
0138421471,-0.000126,0.000223,-0.000158,-0.000486,-0.002571,-0.002617,0.003307,-0.002734,-0.002095,-0.000796,...,-0.000597,-0.000373,0.00422,0.004332,-0.002771,0.010401,-0.000642,-5.8e-05,-0.000298,0.008166
0385317999,-0.000126,0.000223,-0.000158,-0.000486,-0.002571,-0.002617,0.003307,-0.002734,-0.002095,-0.000796,...,-0.000597,-0.000373,0.00422,0.004332,-0.002771,0.010401,-0.000642,-5.8e-05,-0.000298,0.008166
B000C4SS5I,-0.000123,0.000312,-0.000108,-0.000408,-0.002238,-0.00239,0.003307,-0.002452,-0.002012,-0.000529,...,-0.000431,-0.000312,0.004524,0.004463,-0.002762,0.010435,-0.000586,0.000121,-0.000233,0.008188
B000OTYZHG,-0.000126,0.000223,-0.000158,-0.000486,-0.002571,-0.002617,0.003307,-0.002734,-0.002095,-0.000796,...,-0.000597,-0.000373,0.00422,0.004332,-0.002771,0.010401,-0.000642,-5.8e-05,-0.000298,0.008166
B0006YV4SW,-0.000126,0.000223,-0.000158,-0.000486,-0.002571,-0.002617,0.003307,-0.002734,-0.002095,-0.000796,...,-0.000597,-0.000373,0.00422,0.004332,-0.002771,0.010401,-0.000642,-5.8e-05,-0.000298,0.008166
1885210086,-0.000126,0.000223,-0.000158,-0.000486,-0.002571,-0.002617,0.003307,-0.002734,-0.002095,-0.000796,...,-0.000597,-0.000373,0.00422,0.004332,-0.002771,0.010401,-0.000642,-5.8e-05,-0.000298,0.008166
0806504757,-0.000126,0.000223,-0.000158,-0.000486,-0.002571,-0.002617,0.003307,-0.002734,-0.002095,-0.000796,...,-0.000597,-0.000373,0.00422,0.004332,-0.002771,0.010401,-0.000642,-5.8e-05,-0.000298,0.008166
059513534X,-0.000126,0.000223,-0.000158,-0.000486,-0.002571,-0.002617,0.003307,-0.002734,-0.002095,-0.000796,...,-0.000597,-0.000373,0.00422,0.004332,-0.002771,0.010401,-0.000642,-5.8e-05,-0.000298,0.008166
0155067036,-0.000126,0.000223,-0.000158,-0.000486,-0.002571,-0.002617,0.003307,-0.002734,-0.002095,-0.000796,...,-0.000597,-0.000373,0.00422,0.004332,-0.002771,0.010401,-0.000642,-5.8e-05,-0.000298,0.008166


Sampled


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0138421471,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0385317999,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000C4SS5I,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000OTYZHG,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B0006YV4SW,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1885210086,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0806504757,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
059513534X,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0155067036,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


rmse norm: 0.020603754606032552
Test


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,-0.000391,0.034526,-0.001796,0.000112,-0.000365,-0.001395,-7e-06,-0.000613,-0.000934,0.001337,...,-0.000829,-0.001816,-0.000298,2.1e-05,-0.002076,0.006062,0.028332,-0.002134,-0.000893,-0.000257
0138421471,-4.2e-05,0.000236,-0.000388,-0.00034,0.000173,0.000925,3.3e-05,-0.000404,-0.000742,-0.000663,...,-0.000112,0.000185,0.001309,-6.5e-05,-0.000113,0.009496,-0.000451,-0.000652,-0.000218,-6.1e-05
0385317999,-4.2e-05,0.000236,-0.000388,-0.00034,0.000173,0.000925,3.3e-05,-0.000404,-0.000742,-0.000663,...,-0.000112,0.000185,0.001309,-6.5e-05,-0.000113,0.009496,-0.000451,-0.000652,-0.000218,-6.1e-05
B000C4SS5I,-4.2e-05,0.000236,-0.000388,-0.00034,0.000173,0.000925,3.3e-05,-0.000404,-0.000742,-0.000663,...,-0.000112,0.000185,0.001309,-6.5e-05,-0.000113,0.009496,-0.000451,-0.000652,-0.000218,-6.1e-05
B000OTYZHG,-4.2e-05,0.000236,-0.000388,-0.00034,0.000173,0.000925,3.3e-05,-0.000404,-0.000742,-0.000663,...,-0.000112,0.000185,0.001309,-6.5e-05,-0.000113,0.009496,-0.000451,-0.000652,-0.000218,-6.1e-05
B0006YV4SW,-4.2e-05,0.000236,-0.000388,-0.00034,0.000173,0.000925,3.3e-05,-0.000404,-0.000742,-0.000663,...,-0.000112,0.000185,0.001309,-6.5e-05,-0.000113,0.009496,-0.000451,-0.000652,-0.000218,-6.1e-05
1885210086,-4.2e-05,0.000236,-0.000388,-0.00034,0.000173,0.000925,3.3e-05,-0.000404,-0.000742,-0.000663,...,-0.000112,0.000185,0.001309,-6.5e-05,-0.000113,0.009496,-0.000451,-0.000652,-0.000218,-6.1e-05
0806504757,-4.2e-05,0.000236,-0.000388,-0.00034,0.000173,0.000925,3.3e-05,-0.000404,-0.000742,-0.000663,...,-0.000112,0.000185,0.001309,-6.5e-05,-0.000113,0.009496,-0.000451,-0.000652,-0.000218,-6.1e-05
059513534X,-4.2e-05,0.000236,-0.000388,-0.00034,0.000173,0.000925,3.3e-05,-0.000404,-0.000742,-0.000663,...,-0.000112,0.000185,0.001309,-6.5e-05,-0.000113,0.009496,-0.000451,-0.000652,-0.000218,-6.1e-05
0155067036,-4.2e-05,0.000236,-0.000388,-0.00034,0.000173,0.000925,3.3e-05,-0.000404,-0.000742,-0.000663,...,-0.000112,0.000185,0.001309,-6.5e-05,-0.000113,0.009496,-0.000451,-0.000652,-0.000218,-6.1e-05


Sampled


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0138421471,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0385317999,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000C4SS5I,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000OTYZHG,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B0006YV4SW,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1885210086,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0806504757,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
059513534X,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0155067036,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


rmse norm: 0.024893237436224402
Test


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,-0.000198,-0.000791,0.007048,-7.5e-05,0.004687,-0.001711,-3.5e-05,-4.4e-05,-0.000959,0.000135,...,8e-06,-0.00019,0.008741,4.2e-05,-0.000459,-3.8e-05,6.1e-05,-0.000394,-0.001455,-2.4e-05
0138421471,-0.000233,-0.00104,0.006951,-0.000188,0.003915,-0.001931,-2.9e-05,-0.000336,-0.000967,-0.000537,...,-0.000481,-0.000209,0.00835,-0.000357,-0.000472,-7.5e-05,-6.8e-05,-0.000853,-0.00161,-8.3e-05
0385317999,-0.000233,-0.00104,0.006951,-0.000188,0.003915,-0.001931,-2.9e-05,-0.000336,-0.000967,-0.000537,...,-0.000481,-0.000209,0.00835,-0.000357,-0.000472,-7.5e-05,-6.8e-05,-0.000853,-0.00161,-8.3e-05
B000C4SS5I,-0.000213,-0.000915,0.007007,-0.000134,0.004357,-0.00196,-3.6e-05,-0.000174,-0.001057,-0.000192,...,-0.000235,-0.000226,0.008519,-0.000146,-0.000471,-6.8e-05,1e-06,-0.000639,-0.001531,-6.4e-05
B000OTYZHG,-0.000233,-0.00104,0.006951,-0.000188,0.003915,-0.001931,-2.9e-05,-0.000336,-0.000967,-0.000537,...,-0.000481,-0.000209,0.00835,-0.000357,-0.000472,-7.5e-05,-6.8e-05,-0.000853,-0.00161,-8.3e-05
B0006YV4SW,-0.000233,-0.00104,0.006951,-0.000188,0.003915,-0.001931,-2.9e-05,-0.000336,-0.000967,-0.000537,...,-0.000481,-0.000209,0.00835,-0.000357,-0.000472,-7.5e-05,-6.8e-05,-0.000853,-0.00161,-8.3e-05
1885210086,-0.000233,-0.00104,0.006951,-0.000188,0.003915,-0.001931,-2.9e-05,-0.000336,-0.000967,-0.000537,...,-0.000481,-0.000209,0.00835,-0.000357,-0.000472,-7.5e-05,-6.8e-05,-0.000853,-0.00161,-8.3e-05
0806504757,-0.000233,-0.00104,0.006951,-0.000188,0.003915,-0.001931,-2.9e-05,-0.000336,-0.000967,-0.000537,...,-0.000481,-0.000209,0.00835,-0.000357,-0.000472,-7.5e-05,-6.8e-05,-0.000853,-0.00161,-8.3e-05
059513534X,-0.000233,-0.00104,0.006951,-0.000188,0.003915,-0.001931,-2.9e-05,-0.000336,-0.000967,-0.000537,...,-0.000481,-0.000209,0.00835,-0.000357,-0.000472,-7.5e-05,-6.8e-05,-0.000853,-0.00161,-8.3e-05
0155067036,-0.000233,-0.00104,0.006951,-0.000188,0.003915,-0.001931,-2.9e-05,-0.000336,-0.000967,-0.000537,...,-0.000481,-0.000209,0.00835,-0.000357,-0.000472,-7.5e-05,-6.8e-05,-0.000853,-0.00161,-8.3e-05


Sampled


Unnamed: 0,A101446I5AWY0Z,A103U0Q3IKSXHE,A103W7ZPKGOCC9,A105E427BB6J65,A106016KSI0YQ,A106E1N0ZQ4D9W,A1075MZNVRMSEO,A107C4RVRF0OP,A107YFBJ119GZR,A10872FHIJAKKD,...,AZVZSGHKV0AO0,AZWC9XAY34IPW,AZWG3PF80735Q,AZWOQXRCS1WA6,AZWW1U604W0N,AZXEZRXZQL1H2,AZXGPM8EKSHE9,AZXQKAMHK35PA,AZY8LGHVF8GMZ,AZZVZL4QEHEHO
0743246500,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0138421471,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0385317999,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000C4SS5I,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B000OTYZHG,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
B0006YV4SW,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1885210086,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0806504757,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
059513534X,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0155067036,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Sparse Autoencoder

Now, notice that the data is very sparse, with "write the number of n". This sparseness requires the need for an appropriate modification, namely the autoencoder

In [None]:
class SparseAutoEncoder(nn.Module):
	def __init__(self, input_dim, bottleneck_size, device='cpu'):
		super(SparseAutoEncoder, self).__init__()
		self.device = device
		self.encoder = nn.Sequential(
			nn.Linear(input_dim, 128),
			nn.ReLU(),
			nn.Linear(128, 64),
			nn.ReLU(),
			nn.Linear(64, bottleneck_size)
		)
		self.decoder = nn.Sequential(
			nn.Linear(bottleneck_size, 64),
			nn.ReLU(),
			nn.Linear(64, 128),
			nn.ReLU(),
			nn.Linear(128, input_dim)
		)
	
	def forward(self, x):
		x = self.encoder(x)
		y = self.decoder(x)
		return x, y
	
	def kl_divergence(self, p, q):
		# Compute the KL divergence
		p = torch.clamp(torch.tensor(p), 1e-10, 1- 1e-10)  # Avoid log(0)
		q = torch.clamp(q, 1e-10, 1-1e-10)  # Avoid log(0)
		kl_loss = p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))
		return kl_loss
	
	def sparse_loss(self, x, sparsity_ratio=0.05, sparsity_weight=0.2):
		# Compute the sparsity loss
		x = torch.sigmoid(x)
		sparsity_loss = self.kl_divergence(sparsity_ratio, torch.mean(x, dim=0))
		sparsity_loss = torch.sum(sparsity_loss)
		return sparsity_weight * sparsity_loss

	def fit(self, batches, n_epochs=100, min_delta=0.0001, lr=0.001, patience=10):
		optimizer = torch.optim.Adam(self.parameters(), lr=lr)
		criterion = nn.MSELoss()
		best_loss = float('inf')
		patience_counter = 0

		for epoch in range(n_epochs):
			epoch_loss = 0.0
			for batch in batches:
				batch = batch[0].to(self.device)  # Move batch to device
				optimizer.zero_grad()
				encoded, decoded = self.forward(batch)
				loss = criterion(decoded, batch)
				sparsity_loss = self.sparse_loss(encoded)
				loss += sparsity_loss
				loss.backward()
				optimizer.step()
				epoch_loss += loss.item()

			epoch_loss /= len(batches)

			if epoch_loss < best_loss - min_delta:
				best_loss = epoch_loss
				patience_counter = 0
			else:
				patience_counter += 1

			if patience_counter >= patience:
				print(f"Early stopping at epoch {epoch+1} with loss {epoch_loss:.4f}")
				break

			print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {epoch_loss:.4f}')
		return


Again, we test the results using the same hyperparameter and hypothetical data.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Create a user-item interaction matrix (100 users, 50 items) with sparsity
num_users = 10000
num_items = 500
interaction_matrix = np.random.randint(0, 6, size=(num_users, num_items))  # Random interactions from 0 to 5

# Introduce sparsity by setting a high percentage of interactions to 0
sparsity = 0.8  # 80% of the interactions will be set to 0
mask = np.random.rand(*interaction_matrix.shape) < sparsity
interaction_matrix[mask] = 0

# Convert to PyTorch tensor
interaction_tensor = torch.tensor(interaction_matrix, dtype=torch.float32)

# Add noise to the input data
def add_noise(data, noise_factor=0.3):
    noisy_data = data + noise_factor * torch.randn_like(data)
    noisy_data = torch.clamp(noisy_data, 0., 5.)  # Ensure values stay within the interaction range
    return noisy_data

noisy_interaction_tensor = add_noise(interaction_tensor)

# Print the original and noisy matrices (first 10 users for brevity)
print("Original Interaction Matrix (first 10 users):")
print(interaction_matrix[:10])  # Print only the first 10 users for brevity
print("\nNoisy Interaction Matrix (first 10 users):")
print(noisy_interaction_tensor[:10].numpy())


let's apply the autoencoder again to our main matrix.

In [None]:
tested = False

In [None]:
if tested:
	print('Tested')
elif os.path.exists('sparse_model_weights.pt') and os.path.exists('sparse_amazon_model.pt'):
	model = torch.load('amazon_model.pt')
	model.load_state_dict(torch.load('amazon_model_weights.pt'))

else:
	tensors = torch.tensor(user_ratings.to_numpy(), dtype=torch.float32, device=device)
	batch_size = 128 
	train = TensorDataset(tensors)
	batches = DataLoader(train, batch_size=batch_size, shuffle=True)
	model = SparseAutoEncoder(50, 10)
	model.fit(batches, 1000, 0.001, 0.0001, 5)
	torch.save(model, 'sparse_amazon_model.pt')
	torch.save(model.state_dict(), 'sparse_amazon_model_weights.pt')


In [None]:
test_bottleneck(user_ratings, device, 10)

### PCA and k-means clustering