In [1]:
import torch
import torch.nn as nn
import numpy as np

class Net(nn.Module):
    def __init__(self, xz_dim, layers):
        super(Net, self).__init__()
        self.xz_dim = xz_dim
        
        modules = [nn.Linear(xz_dim, layers[0]), nn.ReLU()]
        
        prev_layer = layers[0]
        for layer in layers[1:]:
            modules.append(nn.Linear(prev_layer,layer))
            modules.append(nn.ReLU())
            prev_layer = layer
        
        modules.append(nn.Linear(prev_layer, 1))
        self.linears = nn.Sequential(*modules)
           
    def forward(self, xz, z):

        h = self.linears(xz)
        x = xz[:,0].unsqueeze(1)
        xz_2 = torch.cat((x,z),1)
        h2 = self.linears(xz_2)
        return h, h2
    


In [13]:

n = 10000
corr = 0.3
theo = -0.5 * np.log(1 - np.square(corr))
cov = np.array([[1,corr],[corr,1]])
xz = np.random.multivariate_normal(mean = [0,0], cov = cov, size = n)
z = np.random.normal(0, scale = 1, size = n)[:,None]
layers = [32,16]
net = Net(xz.shape[-1], layers)

epochs = 300
batch_size = 1000
display_step = epochs // 10

opt = torch.optim.Adam(net.parameters(), lr = 1e-3)
for epoch in range(epochs+1):
    
    loss_mu = 0
    
    for i in range(n // batch_size):
        opt.zero_grad()
        
        xz_b = torch.from_numpy(xz[i * batch_size : (i+1) * batch_size]).float()
        z_b = xz_b[torch.randperm(len(xz_b)),1].unsqueeze(1)
    
        h_i,z_i = net(xz_b, z_b)
       
         
        loss = -torch.mean(h_i)+torch.log(torch.mean(torch.exp(z_i)))
        loss.backward()
        opt.step()
        loss_mu += loss
    loss_mu /= (n//batch_size)
    if epoch % display_step == 0:
        print("Epoch: {} - MI: {} - Theoretical: {}".format(epoch, -loss_mu, theo))

Epoch: 0 - MI: 0.0007793232798576355 - Theoretical: 0.047155339735620645
Epoch: 30 - MI: 0.04723343998193741 - Theoretical: 0.047155339735620645
Epoch: 60 - MI: 0.04544950649142265 - Theoretical: 0.047155339735620645
Epoch: 90 - MI: 0.052031539380550385 - Theoretical: 0.047155339735620645
Epoch: 120 - MI: 0.04672890156507492 - Theoretical: 0.047155339735620645
Epoch: 150 - MI: 0.048797693103551865 - Theoretical: 0.047155339735620645
Epoch: 180 - MI: 0.04879845306277275 - Theoretical: 0.047155339735620645
Epoch: 210 - MI: 0.04728664457798004 - Theoretical: 0.047155339735620645
Epoch: 240 - MI: 0.050238437950611115 - Theoretical: 0.047155339735620645
Epoch: 270 - MI: 0.04741550609469414 - Theoretical: 0.047155339735620645
Epoch: 300 - MI: 0.04946392774581909 - Theoretical: 0.047155339735620645


torch.Size([100, 2])