In [2]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import pickle
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import TensorDataset, DataLoader
from torch.distributions.normal import Normal
import torch.nn.functional as F
import gc
import math
import seaborn as sns

In [3]:
Equivariant_dim = 2
Information_dim = 1

In [14]:
data_e = torch.randn(int(1e6),Equivariant_dim)
data_i = torch.rand(int(1e6),Information_dim)


data = torch.concatenate([data_e,data_i],-1)

In [15]:
data_perturbed = data.clone()
data_perturbed[:,0] = data_perturbed[:,0] + data_perturbed[:,2]

In [20]:
def psi(x):
    return (x[:,0] - x[:,2])**2 + x[:,1]**2 #+ x[:,2]**2

In [78]:
class DecorrelVAE(nn.Module):
    def __init__(self,
                 input_features,
                 hidden_dims = 16,
                 latent_dim = 16,
                ):
        super().__init__()

        self.mask = nn.Parameter(torch.empty((1,latent_dim)))
        self.reset_parameters()
        
        self.M_I = nn.Linear(in_features=latent_dim,
                             out_features=latent_dim,
                             bias=False) #No bias needed
        
        self.M_E = nn.Linear(in_features=latent_dim,
                             out_features=latent_dim,
                             bias=False) #No bias needed


        self.Encoder = nn.Sequential(*[
                nn.Linear(input_features,64),
                nn.ReLU(),
                nn.Linear(64,64),
                nn.ReLU(),
                nn.Linear(64,latent_dim)
                ])

        self.Decoder = nn.Sequential(*[
                nn.Linear(latent_dim,64),
                nn.ReLU(),
                nn.Linear(64,64),
                nn.ReLU(),
                nn.Linear(64,input_features)
                ])

        
        # This needs to be treated properly
        self.mass_predictor = nn.Sequential(*[
                nn.Linear(latent_dim,hidden_dims),
                nn.ReLU(),
                nn.Linear(hidden_dims,hidden_dims),
                nn.ReLU(),
                nn.Linear(hidden_dims,1)
                ])        

    def forward(self,X):


        x = self.Encoder(X)
        #######################################################
        eps_i = torch.randn_like(x)
        f_i = x*torch.sigmoid((1-self.mask)).expand(x.shape[0],-1) + eps_i*torch.sigmoid(self.mask).expand(x.shape[0],-1)
        f_i = self.M_I(f_i)
        #######################################################

        #######################################################
        eps_ip = torch.randn_like(x)
        f_e = x*torch.sigmoid(self.mask).expand(x.shape[0],-1) + eps_ip*torch.sigmoid((1-self.mask)).expand(x.shape[0],-1)
        f_e = self.M_E(f_e)
        ######################################################

        mass = self.mass_predictor(f_e).squeeze()

        reco_x = self.decoder(invariant = f_i,
                            equivariant=f_e)
        reco = self.Decoder(reco_x)
        return f_i, f_e, mass, reco
    
    def decoder(self,invariant, equivariant):
        return (invariant+equivariant)

        
    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.mask, a=math.sqrt(5))

    

In [79]:
device = "cuda:0"

In [80]:
decorrel = DecorrelVAE(input_features=Equivariant_dim+Information_dim, hidden_dims=8, latent_dim=3)
decorrel.to(device)
optimizer_decorel = torch.optim.Adam(decorrel.parameters(), lr = 1e-3)

In [81]:
BATCH_SIZE=1024

In [87]:
data = data.to(device)
criterion_mse = nn.MSELoss()

loss_reco_list = []
loss_mass_list = []
for epoch in tqdm(range(100)):
    
    loss_reco_ = 0
    loss_mass_ = 0
    
    index = torch.randperm(data.shape[0])
    for i in range(0,(data.shape[0]//BATCH_SIZE)+1):
        z = data[index[i*BATCH_SIZE:(i+1)*BATCH_SIZE],:]
        
        optimizer_decorel.zero_grad()

        f_i, f_e, mass_pred, reco = decorrel(z)

        loss_mass_pred = criterion_mse(psi(z),mass_pred)
        loss_reco = criterion_mse(z,reco)
        
        loss = loss_mass_pred + loss_reco
        
        loss.backward()
        optimizer_decorel.step()
        
        loss_reco_+=loss_reco.item()
        loss_mass_+=loss_mass_pred.item()
        
    
    loss_reco_ /= (data.shape[0]//BATCH_SIZE)+1
    loss_mass_/= (data.shape[0]//BATCH_SIZE)+1
    
    loss_reco_list.append(loss_reco_)
    loss_mass_list.append(loss_mass_)

    if epoch%10 ==0 :
        print(f"EPOCH {epoch} complete")
        print("=====================")
        print("Loss Reconstruction",loss_reco_)
        print("Loss Mass Prediction",loss_mass_)

  0%|          | 0/100 [00:00<?, ?it/s]

EPOCH 0 complete
Loss Reconstruction 0.0012026249896734953
Loss Mass Prediction 0.0005114519279431114
EPOCH 10 complete
Loss Reconstruction 0.0008606830477893749
Loss Mass Prediction 0.0004145740918343755
EPOCH 20 complete
Loss Reconstruction 0.0006931200141228732
Loss Mass Prediction 0.00037131143195631824


KeyboardInterrupt: 

In [104]:
f_i, f_e, mass_pred, reco = decorrel(z)

In [105]:
f_i

tensor([[  3.9706,  -8.9802, -25.5227],
        [  3.5181,  -7.9578, -22.6149],
        [  2.0994,  -4.7485, -13.4939],
        ...,
        [  0.5615,  -1.2718,  -3.6107],
        [  2.6414,  -5.9739, -16.9765],
        [ -2.3043,   5.2105,  14.8066]], device='cuda:0',
       grad_fn=<MmBackward0>)

In [106]:
torch.sigmoid(decorrel.mask)

tensor([[0.9556, 0.3204, 0.9928]], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [107]:
decorrel.Decoder(f_e+f_i)[0,:]

tensor([0.7607, 1.5161, 0.4220], device='cuda:0', grad_fn=<SliceBackward0>)

In [108]:
z[0,:]

tensor([0.6765, 1.4567, 0.1328], device='cuda:0')

In [95]:
z_original = z.clone()
z_original[:,0] = z[:,0] - z[:,2]

In [100]:
z_original[0,:]

tensor([0.5436, 1.4567, 0.1328], device='cuda:0')

In [50]:
z_original.sh

AttributeError: 'Tensor' object has no attribute 'sh'

In [29]:
decorrel.mask

Parameter containing:
tensor([[-3.5184,  1.9926,  3.9571, -4.0639,  2.2235,  1.7557,  3.3118,  1.9644]],
       device='cuda:0', requires_grad=True)

In [None]:
data_e = torch.randn(int(1e6),Equivariant_dim)
data_i = torch.rand(int(1e6),Information_dim)


data = torch.concatenate([data_e,data_i],-1)

data_perturbed = data.clone()
data_perturbed[:,0] = data_perturbed[:,0] + data_perturbed[:,2]