In [None]:
class SigmoidLayer(nn.Module):
    def __init__(self):
        super(SigmoidLayer, self).__init__()
    
    def forward(self, x):
        z = torch.sigmoid(x)
        det = z * (1 - z)
        return z, det
    
class RealNVP(nn.Module):
    def __init__(self, image_shape, n_layers=3, n_hidden=64, n_couplings=5, device='cuda', base_distrib='Normal'):
        super(RealNVP, self).__init__()
        self.image_shape = image_shape
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_couplings = n_couplings
        self.base_distrib = base_distrib
        self.layers = nn.ModuleList()
        n_in = image_shape[0]
        start_from = 1
        for _ in range(4):
            self.layers.append(AffineCouplingChecker(start_from, n_in, n_out, n_filters, n_blocks):
            )
        [
            CouplingLayer(self.n_layers, self.n_hidden) for _ in range(n_couplings)
        ])
        if self.base_distrib=='Normal':
            self.base = Normal(torch.tensor([0., 0.], device=device), torch.tensor([1., 1.], device=device))
        elif self.base_distrib=='Uniform':
            self.base = Uniform(torch.tensor([0., 0.], device=device), torch.tensor([1., 1.], device=device))
            self.sigmoid = SigmoidLayer()
        else:
            raise NotImplementedError('Sorry, base distribution can only be Uniform or Normal')
        
    def forward(self, x):
        dets = torch.ones_like(x[:,0:1])
        indata = x
        for layer in self.layers:
            z, d = layer(indata)
            indata = z
            dets = torch.cat((dets, d), dim=1)
        if self.base_distrib == 'Uniform':
            z, d = self.sigmoid(z)
            dets = torch.cat((dets, d), dim=1)
        return z, dets
            
    def log_prob_x_from_z(self, z, dets):
        if self.base_distrib == 'Uniform':
            log_prob = dets.abs().log().sum(dim=1)  # use independent uniform - this breaks
        else:
            log_prob = self.base.log_prob(z).sum(dim=1) + dets.abs().log().sum(dim=1)  # can sum log_probs cause independent
        return log_prob
    
    def loss_function(self, z, dets):
        loss = -self.log_prob_x_from_z(z, dets).mean()
        return loss
    
    def eval_log_prob(self, x):
        self.eval()
        with torch.no_grad():
            z, dets = self(x)
            log_prob = self.log_prob_x_from_z(z, dets)
        return log_prob