In [1]:
import torch
from torch import nn


In [2]:
class Encoder(nn.Module):
	def __init__(
		self,
		inp_size: int,
		emb_size: int,
		lat_size: int,
	):
		super(Encoder, self).__init__()

		hid_size = emb_size//2
		encode = [
			nn.Linear(inp_size, emb_size), nn.Tanh(), nn.Dropout(0.5),
			nn.Linear(emb_size, hid_size), nn.Tanh(), nn.Dropout(0.5),
		]

		self.encode = nn.Sequential(*encode)

		self.mu = nn.Linear(hid_size, lat_size)
		self.logvar = nn.Linear(hid_size, lat_size)

	def forward(self, tensor):
		tmp = self.encode(tensor)
		return (
			self.mu(tmp),
			self.logvar(tmp),
		)


class Decoder(nn.Module):
	def __init__(
		self,
		lat_size: int,
		emb_size: int,
		out_size: int,
	):
		super(Decoder, self).__init__()

		hid_size = emb_size//2
		decode = [
			nn.Linear(lat_size, hid_size), nn.Tanh(), nn.Dropout(0.5),
			nn.Linear(hid_size, emb_size), nn.Tanh(), nn.Dropout(0.5),
			nn.Linear(emb_size, out_size),
                    nn.Sigmoid(),
		]

		self.decode = nn.Sequential(*decode)

	def forward(self, tensor):
		return self.decode(tensor)

In [11]:
def LDE(log_a, log_b):

    max_log = torch.max(log_a, log_b)
    min_log = torch.min(log_a, log_b)
    return max_log + torch.log(1 + torch.exp(min_log - max_log))


def log_gaussian_likelihood(x_recon, x):
    std = torch.ones(x_recon)
    # return torch.distributions.Normal(
    #     x_recon, std,
    # ).log_prob(x).sum(-1)

    # is this more stable?
    return x_recon + std*torch.randn_like(std)



In [8]:
class VAE(nn.Module):
    def __init__(self, inp_size, emb_size, lat_size):

        super(VAE, self).__init__()

        self.encode = Encoder(inp_size, emb_size, lat_size).to(torch.float)
        self.decode = Decoder(lat_size, emb_size, inp_size).to(torch.float)

        self.lat_size = lat_size

        self.uncertainty_threshold_value = 2000.0
        self.n_gradient_steps = 10
        self.gradient_scale = 1e-3
        self.n_simulations = 100
        self.n_sampled_outcomes = 100

    def reparameterization(self, mu, logvar):
        eps = torch.randn_like(logvar)
        std = logvar.mul(0.5).exp()
        return mu + std * eps

    def loss(self, x, x_recon, mu, logvar):
        bs = x.size(0)
        rec_loss = nn.functional.mse_loss(
            x_recon.view(bs, -1),
            x.view(bs, -1),
            reduction="none",
        ).sum(dim=-1)

        kld_loss = -0.5 * torch.sum(1+logvar-mu.pow(2)-logvar.exp(), dim=-1)

        return (
            (rec_loss + kld_loss).mean(dim=0),
            rec_loss.mean(dim=0),
            kld_loss.mean(dim=0),
        )

    def enable_dropout(self):
        for m in self.modules():
            if m.__class__.__name__.startswith("Dropout"):
                m.train()

    @torch.no_grad()
    def generate(self, n_samples, device):
        z = torch.randn((n_samples, self.lat_size), device=device)
        return self.decode(z)

    @torch.no_grad()
    def reconstruct(self, tensor):
        return self.forward(tensor)[0]

    def auxiliary_net(self, tensor):
        return nn.Sequential(
            nn.Linear(self.lat_size, self.lat_size//2), nn.Tanh(),
            nn.Linear(self.lat_size//2, 1),
        )(tensor).to(torch.float)

    def forward(self, tensor):
        mu, logvar = self.encode(tensor)
        z = self.reparameterization(mu, logvar)

        z = self.gradient_ascent_optimisation(
            sample_source=tensor,
            sample_latent_space=z,
            auxiliary_net = self.auxiliary_net,
            uncertainty_threshold_value = self.uncertainty_threshold_value,
            n_gradient_steps = self.n_gradient_steps,
            gradient_scale = self.gradient_scale,
            n_simulations = self.n_simulations,
            n_sampled_outcomes = self.n_sampled_outcomes,
        )

        x_recon = self.decode(z)
        return x_recon, mu, logvar


    def gradient_ascent_optimisation(
        self,
        sample_source,
        sample_latent_space,
        auxiliary_net,
        uncertainty_threshold_value,
        n_gradient_steps,
        gradient_scale,
        n_simulations,
        n_sampled_outcomes,
    ):
        for _ in range(n_gradient_steps):
            # self.zero_grad()

            tmp = sample_latent_space.requires_grad_()
            p = auxiliary_net(tmp)

            gradient = torch.autograd.grad(
                outputs=p,
                inputs=sample_latent_space,
                grad_outputs=torch.ones_like(p),
                retain_graph=False,
            )[0]

            gradient /= gradient.norm(2)

            tmp = sample_latent_space + gradient * gradient_scale

            mi = self.importance_sampling_mi(
                sample_source=sample_source,
                sample_latent_space=tmp,
                n_simulations=n_simulations,
                n_sampled_outcomes=n_sampled_outcomes,
            )

            mask = (mi <= uncertainty_threshold_value)
            mask = mask.unsqueeze(-1).repeat(1, 2)
            sample_latent_space = torch.where(mask, tmp, sample_latent_space)
        
        return sample_latent_space

    @torch.no_grad()
    def importance_sampling_mi(
        self,
        sample_source: torch.Tensor,
        sample_latent_space: torch.Tensor,
        n_simulations: int,
        n_sampled_outcomes: int,
    ) -> torch.Tensor:
        log_mi = []

        self.train()
        for s in range(n_simulations):
            all_log_psm = []
            
            x_recon = self.decode(sample_latent_space)
            
            for m in range(n_sampled_outcomes):        
                self.eval(); self.enable_dropout()
                
                log_psm = log_gaussian_likelihood(x_recon, sample_source)

                all_log_psm.append(log_psm)

            all_log_psm = torch.stack(all_log_psm, dim=1)
            log_ps = - torch.log(torch.tensor(n_sampled_outcomes).float()) + torch.logsumexp(all_log_psm, dim=1)
            
            right_log_hs = log_ps + torch.log(-log_ps)
            psm_log_psm = all_log_psm + torch.log(-all_log_psm)
            left_log_hs = - torch.log(torch.tensor(n_sampled_outcomes).float()) + torch.logsumexp(psm_log_psm, dim=1)

            tmp_log_hs = LDE(left_log_hs, right_log_hs) - log_ps
            log_mi.append(tmp_log_hs)

        log_mi = torch.stack(log_mi, dim=1)
        log_mi_avg = - torch.log(torch.tensor(n_simulations).float()) + torch.logsumexp(log_mi, dim=1)
        return log_mi_avg.exp()

In [10]:
device = torch.device("cpu")
net = VAE(
    inp_size=20,
    emb_size=20,
    lat_size=2,
).to(device)

for i in range(20):
    x = torch.randn((10, 20))
    x_recon, mu, logvar = net(x)
    loss, rec, kld = net.loss(x, x_recon, mu, logvar)

    print(i, loss.item(), rec.item(), kld.item())

0 24.693035125732422 24.500940322875977 0.19209758937358856
1 22.13822364807129 21.926074981689453 0.21214886009693146
2 25.800228118896484 25.685291290283203 0.11493788659572601
3 24.170717239379883 23.970470428466797 0.20024721324443817
4 25.3535099029541 25.100893020629883 0.25261780619621277
5 29.40237045288086 29.183673858642578 0.21870103478431702
6 24.657785415649414 24.446958541870117 0.21082532405853271
7 24.559932708740234 24.326129913330078 0.23380358517169952
8 24.407089233398438 24.1723690032959 0.23471815884113312
9 27.98561668395996 27.796710968017578 0.1889064460992813
10 23.222103118896484 22.931554794311523 0.2905489206314087
11 26.150650024414062 25.95595359802246 0.1946951448917389
12 24.565067291259766 24.367507934570312 0.19755737483501434
13 29.494882583618164 29.263010025024414 0.23187215626239777
14 27.367679595947266 27.22023582458496 0.14744099974632263
15 26.735210418701172 26.596588134765625 0.13862422108650208
16 26.092947006225586 25.820053100585938 0.272