In [1]:
cd D:\Saarbrucken\EDA_Research\vae-disentanglement\disentanglement_lib_pl

D:\Saarbrucken\EDA_Research\vae-disentanglement\disentanglement_lib_pl


In [2]:
import torch
from torch import nn
import numpy as np
from architectures import encoders, decoders
from common.ops import Flatten3D, Unsqueeze3D, Reshape
from torch.nn import functional as F
import pickle

# Multiscale + GNN structure test 

In [None]:
cd D:\Saarbrucken\EDA_Research\vae-disentanglement\disentanglement_lib_pl

In [None]:
import torch
import torch.nn as nn
from common.ops import Flatten3D, Unsqueeze3D, Reshape

class MultiScaleEncoder(nn.Module):
    """
    Encoder as used in 'Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations'
    """
    def __init__(self, feature_dim, in_channels, num_nodes):
        super().__init__()
        
        # Number of features per scale used by each node to compute initial node features
        # 3 means 3 features from each scale level will be used
        self.NUM_SCALES = 3
        self.feature_dim = feature_dim
        self.num_nodes = num_nodes
        self.features_to_take = self.feature_dim // self.num_nodes
        self.batch_size = None

        # in / out feature maps at each scale
        self.scale_3_in, self.scale_3_out = in_channels, 32
        self.scale_2_in, self.scale_2_out = 32, 32
        self.scale_1_in, self.scale_1_out = 32, 64

        # coarsest scale - outputs maps of shape B, 
        self.scale_3 = nn.Sequential(
            nn.Conv2d(self.scale_3_in, self.scale_3_out, 4, 2), # B, 32, 31 x 31
            nn.ReLU(True)
        )
        self.scale_3_feats = nn.Sequential(
            Flatten3D(),
            nn.Linear(self.scale_3_out * 31 * 31, self.feature_dim),
            nn.Tanh()
        )
        # mid scale - outputs maps of shape B, 
        self.scale_2 = nn.Sequential(
            nn.Conv2d(self.scale_2_in, self.scale_2_out, 4, 2), # B, 32, 14 x 14
            nn.ReLU(True)
        )
        self.scale_2_feats = nn.Sequential(
            Flatten3D(),
            nn.Linear(self.scale_2_out * 14 * 14, self.feature_dim),
            nn.Tanh()
        )
        
        # finest scale - outs maps of shape B,
        self.scale_1 = nn.Sequential(
            nn.Conv2d(self.scale_1_in, self.scale_1_out, 4, 2), # B, 64, 6 x 6
            nn.ReLU(True)
        )
        self.scale_1_feats = nn.Sequential(
            Flatten3D(),
            nn.Linear(self.scale_1_out * 6 * 6, self.feature_dim),
            nn.Tanh()
        )

    def forward(self, x):
        
        self.batch_size = x.shape[0]
        
        scale_3_x = self.scale_3(x)
        scale_3_feats = self.scale_3_feats(scale_3_x)

        scale_2_x = self.scale_2(scale_3_x)
        scale_2_feats = self.scale_2_feats(scale_2_x)

        scale_1_x = self.scale_1(scale_2_x)
        scale_1_feats = self.scale_1_feats(scale_1_x)
        #print(scale_3_feats.shape)
        #print(scale_2_feats.shape)
        #print(scale_1_feats.shape)
        
        # Just stacking gives the shape (NUM_SCALES, batch_size, feature_dim). Hence, we need to permute to get 
        # (batch_size, feature_dim, NUM_SCALES)
        multi_scale_feats = torch.stack([scale_3_feats, scale_2_feats, scale_1_feats]).permute(1,2,0)
        # (batch_size, V, NUM_SCALES * features_to_take)
        multi_scale_feats = multi_scale_feats.reshape(self.batch_size, self.num_nodes, self.NUM_SCALES * self.features_to_take )
        
        # reshape like this so that they can be associated with each latent node
        return multi_scale_feats

In [None]:
ms_enc = MultiScaleEncoder(4*2, 1, 4)

In [None]:
out = ms_enc(torch.randn(2,1,64,64))
print(out)
print(out.shape)

## Testing Prior GNN

In [None]:
class SimpleGNNLayer(nn.Module):
    """
    Can be used to implement GNNs for P(Z|epsilon, A) or Q(Z|X,A)
    """
    def __init__(self, in_node_feat_dim, out_node_feat_dim, adj_mat, is_final_layer=False):
        super().__init__()

        self.in_node_feat_dim = in_node_feat_dim
        self.out_node_feat_dim = out_node_feat_dim
        self.is_final_layer = is_final_layer
        self.A = adj_mat

        self.num_neighbours = self.A.sum(dim=-1, keepdims=True)
        self.projection = nn.Linear(self.in_node_feat_dim, self.out_node_feat_dim)
    
    def forward(self, node_feats):
        
        node_feats = self.projection(node_feats)
        node_feats = torch.matmul(self.A, node_feats)
        node_feats = node_feats / self.num_neighbours
        
        if self.is_final_layer:
            # split into mu and sigma
            node_feats_mu, node_feats_logvar = node_feats.chunk(2, dim=2)
            return node_feats_mu, node_feats_logvar
        else:
            node_feats = torch.tanh(node_feats)
            return node_feats

In [None]:
# 4 nodes with topology given in A
dist_param_dim = 2
batch, V, node_feat_dim = 2, 4, 1
A = torch.Tensor([
    [1., 0., 1., 1.],
    [0., 1., 0., 1.],
    [0., 0., 1., 0.],
    [0., 0., 0., 1.]
])

# sample V exogenous vars per batch
E = torch.randn(size=(batch, V, node_feat_dim))

print(E)

In [None]:
# pass thru K GNN layers

prior_gnn = nn.Sequential(
    SimplePriorGNNLayer(in_node_feat_dim=1, out_node_feat_dim=2*dist_param_dim, adj_mat=A),
    SimplePriorGNNLayer(in_node_feat_dim=2*dist_param_dim, out_node_feat_dim=2*dist_param_dim, adj_mat=A, is_final_layer=True)
)


In [None]:
mus, logvars = prior_gnn(E)
print(mus)

print(logvars)

In [None]:
print(p.shape)
mus, logvars = p.chunk(2, dim=2)


print(mus.shape, mus)

print(logvars.shape, logvars)

In [None]:
for p in prior_layer1.parameters(): print(p)

## Testing Encoder GNN

In [None]:
def get_adj_mat_from_adj_list(adjacency_list):
    
    num_nodes = len(adjacency_list)

    # initialize with self-connections
    A = np.zeros(shape=(num_nodes, num_nodes)) + np.eye(num_nodes)

    for node_idx, parent_list in enumerate(adjacency_list):
        print(parent_list)
        for parent_node_idx in parent_list:
            A[parent_node_idx, node_idx] = 1.0

    return torch.from_numpy(A)

In [None]:
adjacency_list = [(),(),(0,),(0,1)]

In [None]:


get_adj_mat_from_adj_list(adjacency_list)

In [None]:
A = torch.Tensor([
    [1., 0., 1., 1.],
    [0., 1., 0., 1.],
    [0., 0., 1., 0.],
    [0., 0., 0., 1.]
])
A.T.sum(dim=-1, keepdims=True)

In [None]:
def kl_divergence_diag_mu_var_per_node(mu, logvar, target_mu, target_logvar):
    
    # Calculate per node kldloss

    # input have shape (Batch, V, node_feats)
    # output has shape (Batch, V, 1)
    kld = -0.5 * ( 1 - target_logvar + logvar -
                  ((target_mu - mu) * target_logvar.exp().pow(-1) * (target_mu - mu)) - 
                    (target_logvar.exp().pow(-1)*logvar.exp())
            )#.sum(2, keepdims=True).mean(0)
    return kld

In [None]:
target_mu = torch.Tensor(
    [ [[0., 0.],
      [0., 0.]],

        [[0., 0.],
         [0., 0.]]
    ]
)

target_logvar = torch.Tensor(
    [ [[0., 0.],
      [0., 0.]],

        [[0., 0.],
         [0., 0.]]
    ]
)

In [None]:
mu = torch.Tensor(
    [ [[1., 0.],
      [0., 1.]],

        [[1., 0.],
         [0., 0.95]]
    ]
)

logvar = torch.Tensor(
    [ [[0.5, 0.],
      [0., 0.]],

        [[0.5, 0.],
         [0., 0.]]
    ]
)

In [None]:
kld = kl_divergence_diag_mu_var_per_node(mu, logvar, target_mu, target_logvar)
print(kld)

In [None]:
kld.sum(2, keepdims=True)

In [None]:
per_node_kld = kld.sum(2, keepdims=True).mean(0)
print(per_node_kld)

In [None]:
for kld_node in per_node_kld:
    print(kld_node)

In [None]:
from common import dag_utils

alist = pickle.load(open(r"D:\Saarbrucken\EDA_Research\vae-disentanglement\adjacency_matrices\dsprites_correlated.pkl", 'rb'))
A = dag_utils.get_adj_mat_from_adj_list(alist)

print(A)




In [None]:
nf = torch.arange(25).type(torch.FloatTensor).view(5,5)
print(nf)

In [None]:
torch.matmul(A.T, nf)

In [None]:
class SimpleGNNLayer(nn.Module):
    """
    Can be used to implement GNNs for P(Z|epsilon, A) or Q(Z|X,A)
    """
    def __init__(self, in_node_feat_dim, out_node_feat_dim, adj_mat, is_final_layer=False):
        super().__init__()

        self.in_node_feat_dim = in_node_feat_dim
        self.out_node_feat_dim = out_node_feat_dim
        self.is_final_layer = is_final_layer
        self.A = adj_mat.T # TODO:

        self.num_neighbours = self.A.sum(dim=-1, keepdims=True)
        self.projection = nn.Linear(self.in_node_feat_dim, self.out_node_feat_dim)
    
    def forward(self, node_feats):
        
        self.A = self.A.to(node_feats.device)
        self.num_neighbours = self.num_neighbours.to(node_feats.device)
        
        node_feats = self.projection(node_feats)
        print(node_feats)
        node_feats = torch.matmul(self.A, node_feats)
        print(node_feats)
        node_feats = node_feats / self.num_neighbours
        print(node_feats)
        
        if self.is_final_layer:
            # split into mu and sigma
            node_feats_mu, node_feats_logvar = node_feats.chunk(2, dim=2)
            return node_feats_mu, node_feats_logvar
        else:
            node_feats = torch.tanh(node_feats)
            return node_feats

In [None]:
#from common.special_modules import SimpleGNNLayer

V, ifd, ofd =5, 2, 4

#inp = torch.randn(size=(1, V, ifd))

inp = torch.arange(10).type(torch.FloatTensor).view(1, V, ifd)
print(inp)
print("input shape: ", inp.shape)

prior_gnn = SimpleGNNLayer(ifd, ofd, A.T, is_final_layer=True)
print("Linear layer mat shape: ", prior_gnn.projection.weight.data.shape)
prior_gnn.projection.weight.data = torch.Tensor(
        [[1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.]]
)
prior_gnn.projection.bias.data = torch.zeros(ofd)


In [None]:
print("input: ", inp)
out = prior_gnn(inp)

print(out)

In [None]:
#(batch, num_nodes, num_feat_dim)

mus = torch.randn(1, 5, 2)
print(mus)

In [None]:
mus.mean(0).tolist()

In [None]:
for node_idx in range(5):
    for k in range(2):
        print(f"node {node_idx}, comp {k}, : {mus[:, node_idx, k]}")

In [None]:
prior_gnn.num_neighbours

In [3]:
from gnncsvae_experiment import GNNCSVAEExperiment
from collections import defaultdict, namedtuple
import models

ModelParams = namedtuple('ModelParams', ["z_dim", "l_dim", "num_labels" , "in_channels", 
                                        "image_size", "batch_size", "w_recon", "w_kld", "kl_warmup_epochs",
                                         "adjacency_matrix"])


algo_name = "GNNBasedConceptStructuredVAE"
checkpoint_path = r"D:\Saarbrucken\EDA_Research\vae-disentanglement\models\gnncsvae.ckpt"
z_dim = 5

model_params = ModelParams(
        [z_dim], 6, 0, 1, 64, 64, 1.0, 1.0, 0,
    r"D:\Saarbrucken\EDA_Research\vae-disentanglement\adjacency_matrices\dsprites_correlated.pkl"
)
exp_params = dict(
        in_channels=1,
        image_size=64,
        LR=1e-4,
        weight_decay=0.0,       
        dataset="dsprites_correlated",
        datapath=r"D:\Saarbrucken\EDA_Research\vae-disentanglement\datasets",
        droplast=True,        
        batch_size=64,
        num_workers=0,
        pin_memory=False,
        seed=123,
        evaluation_metrics=None,
        visdom_on=False,
        save_dir=None,
        max_epochs=1,
        l_zero_reg=False
)
vae_model_class = getattr(models, algo_name)
vae_model = vae_model_class(model_params)

vae_experiment = GNNCSVAEExperiment.load_from_checkpoint(
            checkpoint_path,
            vae_model=vae_model, 
            params=exp_params,
            dataset_params=dict(correlation_strength=0.2))

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1.]])
GNNBasedConceptStructuredVAE Model Initialized


In [None]:
# sets the vae_experimen.sample_loader var
#vae_experiment.val_dataloader()

In [4]:
from common.notebook_utils import get_configured_dataset
from torch.utils.data import DataLoader
import os

os.environ['DISENTANGLEMENT_LIB_DATA'] = r"D:\Saarbrucken\EDA_Research\vae-disentanglement\datasets"

dataset = get_configured_dataset("dsprites_correlated")
sample_loader = DataLoader(dataset, batch_size=64, shuffle = False, drop_last=True)

test_input, test_label = next(iter(sample_loader))
fwd_pass_results = vae_experiment.model.forward(test_input, current_device=test_input.device, labels = test_label)



Initialize [CorrelatedDSpritesDataset] with 737280 examples. Shape (737280, 64, 64).


In [5]:
fwd_pass_results.keys()

dict_keys(['x_recon', 'prior_mu', 'prior_logvar', 'posterior_mu', 'posterior_logvar', 'latents_predicted'])

In [7]:
fwd_pass_results['prior_mu'][1,:,:]

tensor([[ 0.2693,  0.0669,  0.7291, -0.0036,  1.4615],
        [ 0.2693,  0.0670,  0.7291, -0.0036,  1.4615],
        [ 0.2693,  0.0670,  0.7291, -0.0036,  1.4614],
        [ 0.2693,  0.0670,  0.7291, -0.0036,  1.4614],
        [ 0.2693,  0.0670,  0.7291, -0.0036,  1.4614]],
       grad_fn=<SliceBackward>)

In [8]:
fwd_pass_results['prior_mu'][2,:,:]

tensor([[ 0.2693,  0.0670,  0.7291, -0.0036,  1.4614],
        [ 0.2693,  0.0670,  0.7291, -0.0036,  1.4614],
        [ 0.2692,  0.0671,  0.7291, -0.0036,  1.4614],
        [ 0.2692,  0.0671,  0.7291, -0.0036,  1.4614],
        [ 0.2693,  0.0670,  0.7291, -0.0036,  1.4614]],
       grad_fn=<SliceBackward>)

In [10]:
fwd_pass_results['prior_logvar'][1,:,:]

tensor([[-9.7433,  0.2114, -9.2024, -9.2940, -8.7455],
        [-9.7430,  0.2121, -9.2024, -9.2941, -8.7451],
        [-9.7434,  0.2114, -9.2026, -9.2942, -8.7456],
        [-9.7433,  0.2115, -9.2025, -9.2941, -8.7455],
        [-9.7429,  0.2119, -9.2023, -9.2939, -8.7451]],
       grad_fn=<SliceBackward>)

In [11]:
fwd_pass_results['prior_logvar'][2,:,:]

tensor([[-9.7433,  0.2116, -9.2025, -9.2941, -8.7454],
        [-9.7439,  0.2110, -9.2029, -9.2945, -8.7461],
        [-9.7424,  0.2124, -9.2018, -9.2936, -8.7446],
        [-9.7425,  0.2124, -9.2020, -9.2937, -8.7447],
        [-9.7432,  0.2114, -9.2025, -9.2940, -8.7455]],
       grad_fn=<SliceBackward>)