In [1]:
from torch import Tensor
from sklearn.datasets import make_swiss_roll
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset 


class SwissRoll(Dataset):
    def __init__(self, size: int, noise: float, hole: bool = False):
        self.inputs, self.targets = make_swiss_roll(n_samples=size, noise=noise, random_state=0, hole=hole)
        
        scaler = MinMaxScaler(feature_range=(-1, 1))
        self.inputs = Tensor(scaler.fit_transform(self.inputs)) 
        self.targets = Tensor(self.targets)   
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:
        return self.inputs[index], self.targets[index]

In [None]:
from torch import Tensor
from torch.nn import Module
from torch.nn import Sequential 

class Autoencoder(Module):
    def __init__(self, encoder: Sequential, decoder: Sequential):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, features: Tensor) -> Tensor:  
        return self.decoder(self.encoder(features))

In [None]:
import plotly.graph_objects as go
from plotly.graph_objects import Figure, Scatter3d
from plotly.subplots import make_subplots

def plot_swiss_roll_3d(X, color, title="Swiss Roll"):
    fig = Figure(data=[Scatter3d(
        x=X[:, 0],
        y=X[:, 1],
        z=X[:, 2],
        mode='markers',
        marker=dict(
            size=4,
            color=color,
            colorscale='Viridis',
            opacity=0.8
        )
    )])
    
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        width=800,
        height=600
    )
    fig.show()

def plot_latent_space(latent, color, title="Espacio Latente"):
    fig = go.Figure(data=[go.Scatter(
        x=latent[:, 0],
        y=latent[:, 1],
        mode='markers',
        marker=dict(
            color=color,
            colorscale='Viridis',
            size=8,
            opacity=0.7
        )
    )])
    
    fig.update_layout(
        title=title,
        xaxis_title='Componente Latente 1',
        yaxis_title='Componente Latente 2',
        width=700,
        height=500
    )
    fig.show()