# <center>Modelo generativo basado en score</center>

In [1]:
import torch
import tqdm
from sklearn.datasets import make_swiss_roll
from torch.utils.data import DataLoader, TensorDataset
from functorch import jacrev
import plotly.graph_objects as go

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Device: {device}.')

Device: cuda.


## Datos de entrenamiento

In [2]:
def create_dataset(n=1000):
    x, _ = make_swiss_roll(n_samples=n, noise=1.0)
    x = x[:, [0, 2]]
    x = (x - x.mean()) / x.std()
    return torch.tensor(x, dtype=torch.float32)

data_points = create_dataset()
dataset = TensorDataset(data_points)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

## Clase `SGM`

In [3]:
class SGM:

    @staticmethod
    def train_model(net, optimizer, dataloader, epochs, gaussian_kernel=False, sigma=0.1):

        try:
            progressbar = tqdm.trange(epochs)
            for epoch in progressbar:
                epoch_loss = 0

                for x, in dataloader:

                    # ----- Cálculo de loss -----

                    # # Denoising score matching.
                    if gaussian_kernel:
                        x_bar = x + sigma * torch.randn_like(x)
                        score = net(x_bar)
                        inner = (x - x_bar) / sigma**2 - score
                        loss =  1/2 * (torch.linalg.norm(inner, dim=-1) ** 2).mean()
                    
                    # Score matching.
                    else:
                        score = net(x)
                        score_norm = 1/2 * torch.linalg.norm(score, dim=-1) ** 2
                        jacobian = torch.vmap(jacrev(net))(x)
                        jacobian_trace = torch.vmap(torch.trace)(jacobian)
                        loss = (score_norm + jacobian_trace).mean()

                    # Backprop:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    epoch_loss += loss.item()

                epoch_loss /= len(dataloader)
                progressbar.set_postfix(epoch_loss=epoch_loss)

        except KeyboardInterrupt:
            print('Entrenamiento interrumpido.')

    @staticmethod
    def generate_samples(net, nsamples, eps: float = 0.001, nsteps: int = 1000):

        # Langevin sampling:
        with torch.no_grad():
            x = torch.rand((nsamples, 2)) * 2 - 1

            for n in range(nsteps):
                z = torch.randn_like(x)
                x = x + eps * net(x) + (2 * eps) ** 0.5 * z
            return x.cpu()

## Entrenamiento y generación de muestras

### Red neuronal

In [4]:
basicMLP = torch.nn.Sequential(
    torch.nn.Linear(2, 64),
    torch.nn.LogSigmoid(),
    torch.nn.Linear(64, 64),
    torch.nn.LogSigmoid(),
    torch.nn.Linear(64, 64),
    torch.nn.LogSigmoid(),
    torch.nn.Linear(64, 2),
)

### Entrenamiento

In [5]:
optimizer = torch.optim.Adam(basicMLP.parameters(), lr=3e-4)
SGM.train_model(basicMLP, optimizer, dataloader, epochs=1000, gaussian_kernel=True)

100%|██████████| 1000/1000 [01:12<00:00, 13.87it/s, epoch_loss=97.4]


### Generación de muestras

In [6]:
def plot_new_samples(new_points):
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=data_points[:, 0], y=data_points[:, 1],
                             mode='markers', name='Datos de entrenamiento',
                             marker=dict(color='blue', opacity=0.3)))

    fig.add_trace(go.Scatter(x=new_points[:, 0], y=new_points[:, 1],
                             mode='markers', name='Muestras generadas',
                             marker=dict(color='red', opacity=0.7)))

    fig.update_layout(
        width=1000,
        height=500,
        plot_bgcolor='white',
        xaxis=dict(visible=False),
        yaxis=dict(visible=False))

    fig.show()
    fig.write_image('images/dm/sgm_samples.pdf')

In [7]:
samples = SGM.generate_samples(basicMLP, 1000)
plot_new_samples(samples)