Shresth grover

In [3]:
import torch
import torch.nn as nn
import torch.utils.data as data
import random 
import numpy as np  
from torch import Tensor
from typing import Callable


In [4]:
def U_1(z):
  u = 0.5 * ((torch.norm(z, 2, dim=1) - 2) / 0.4) ** 2
  u = u - torch.log(
    torch.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2)
    + torch.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)
                )
  return u



In [5]:
from torch.distributions import MultivariateNormal
class VariationalLoss(nn.Module):
    def __init__(self, distribution):
        super().__init__()
        self.distr = distribution
        self.base_distr = MultivariateNormal(torch.zeros(2), torch.eye(2))

    def forward(self, z0: Tensor, z: Tensor, sum_log_det_J: float) -> float:
        base_log_prob = self.base_distr.log_prob(z0)
        target_density_log_prob = -self.distr(z)
        return (base_log_prob - target_density_log_prob - sum_log_det_J).mean()
     

        
        

In [6]:
class PlanarTransform(nn.Module):

    def __init__(self, dim: int = 2):
        super().__init__()
        self.w = nn.Parameter(torch.randn(1, dim).normal_(0, 0.1))
        self.b = nn.Parameter(torch.randn(1).normal_(0, 0.1))
        self.u = nn.Parameter(torch.randn(1, dim).normal_(0, 0.1))

    def forward(self, z: Tensor) -> Tensor:
        if torch.mm(self.u, self.w.T) < -1:
            self.get_u_hat()

        return z + self.u * nn.Tanh()(torch.mm(z, self.w.T) + self.b)

    def log_det_J(self, z: Tensor) -> Tensor:
        if torch.mm(self.u, self.w.T) < -1:
            self.get_u_hat()
        a = torch.mm(z, self.w.T) + self.b
        psi = (1 - nn.Tanh()(a) ** 2) * self.w
        abs_det = (1 + torch.mm(self.u, psi.T)).abs()
        log_det = torch.log(1e-4 + abs_det)

        return log_det

    def get_u_hat(self) -> None:
        wtu = torch.mm(self.u, self.w.T)
        m_wtu = -1 + torch.log(1 + torch.exp(wtu))
        self.u.data = (
            self.u + (m_wtu - wtu) * self.w / torch.norm(self.w, p=2, dim=1) ** 2
        )

In [7]:
from typing import Tuple
class PlanarFlow(nn.Module):
    def __init__(self, dim: int = 2, K: int = 6):
        super().__init__()
        self.layers = [PlanarTransform(dim) for _ in range(K)]
        self.model = nn.Sequential(*self.layers)

    def forward(self, z: Tensor) -> Tuple[Tensor, float]:
        log_det_J = 0

        for layer in self.layers:
            log_det_J += layer.log_det_J(z)
            z = layer(z)

        return z, log_det_J

In [8]:
mean_1= np.array([0,0]).astype('float32')
mean_2=np.array([2,2.5]).astype('float32')
mean_3=np.array([-3,10]).astype('float32')
mean_4 = np.array([-5,1]).astype('float32')
pdf1=MultivariateNormal(torch.tensor(mean_1), torch.eye(2))
pdf2=MultivariateNormal(torch.tensor(mean_2), torch.eye(2))
pdf3=MultivariateNormal(torch.tensor(mean_3), torch.eye(2))
pdf4=MultivariateNormal(torch.tensor(mean_4), torch.eye(2))

def multi_variate_gaussian_targ_distribution(x):
    prob =4*torch.log(torch.tensor(.25,dtype=torch.float))+pdf1.log_prob(x)+pdf2.log_prob(x)+pdf3.log_prob(x)+pdf4.log_prob(x)
    return torch.tensor(2)


In [None]:
#using a u shaped pdf
import torch
import torch.nn
import matplotlib.pyplot as plt
flow_length = 32
dim = 2
num_batches = 20000
batch_size = 128
lr = 6e-4
model = PlanarFlow(dim, K=flow_length)
target_distribution =U_1
bound = VariationalLoss(target_distribution)
optimiser = torch.optim.Adam(model.parameters(), lr=lr)
for batch_num in range(1, num_batches + 1):
        # Get batch from N(0,I).
        batch = torch.zeros(size=(batch_size, 2)).normal_(mean=0, std=1)
        # Pass batch through flow.
        zk, log_jacobians = model(batch)
        # Compute loss under target distribution.
        loss = bound(batch, zk, log_jacobians)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        if batch_num % 1000 == 0:
            print(f"(batch_num {batch_num:05d}/{num_batches}) loss: {loss}")



(batch_num 01000/20000) loss: 0.3203956186771393
(batch_num 02000/20000) loss: -1.804951548576355
(batch_num 03000/20000) loss: -1.8429417610168457
(batch_num 04000/20000) loss: -1.6748672723770142
(batch_num 05000/20000) loss: -1.779553771018982
(batch_num 06000/20000) loss: -1.8872768878936768
(batch_num 07000/20000) loss: -1.755571722984314
(batch_num 08000/20000) loss: -1.8260223865509033
(batch_num 09000/20000) loss: -1.8378047943115234
(batch_num 10000/20000) loss: -1.8434269428253174
(batch_num 11000/20000) loss: -1.8173214197158813
(batch_num 12000/20000) loss: -1.7355355024337769
(batch_num 13000/20000) loss: -1.8300026655197144
(batch_num 14000/20000) loss: -1.8068687915802002
(batch_num 15000/20000) loss: -1.887778401374817
(batch_num 16000/20000) loss: -1.8225430250167847
(batch_num 17000/20000) loss: -1.8791940212249756
(batch_num 18000/20000) loss: -1.8250271081924438
(batch_num 19000/20000) loss: -1.8884040117263794
(batch_num 20000/20000) loss: -1.8479478359222412
