In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import pickle
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import TensorDataset, DataLoader
from torch.distributions.normal import Normal
import torch.nn.functional as F
import gc
import math
import seaborn as sns

In [2]:
Equivariant_dim = 2
Information_dim = 4

In [3]:
data_e = torch.randn(int(1e6),Equivariant_dim)
data_i = torch.rand(int(1e6),Information_dim)

data = torch.concatenate([data_e,data_i],-1)

In [4]:
def psi(x):
    return x[:,0]**2 + x[:,1]**2 #+ x[:,2]**2

In [5]:
class DecorrelVAE(nn.Module):
    def __init__(self,
                 input_features,
                 hidden_dims = 8,
                ):
        super().__init__()

        self.mask = nn.Parameter(torch.empty((1,input_features)))
        self.reset_parameters()
        
        self.M_I = nn.Linear(in_features=input_features,out_features=input_features,bias=False) #No bias needed
        self.M_E = nn.Linear(in_features=input_features,out_features=input_features,bias=False) #No bias needed
        
        # This needs to be treated properly
        self.mass_predictor = nn.Sequential(*[
                nn.Linear(input_features,hidden_dims),
                nn.ReLU(),
                nn.Linear(hidden_dims,hidden_dims),
                nn.ReLU(),
                nn.Linear(hidden_dims,1)
                ])        

    def forward(self,x):

        #######################################################
        eps_i = torch.randn_like(x)
        f_i = x*torch.sigmoid((1-self.mask)).expand(x.shape[0],-1) + eps_i*torch.sigmoid(self.mask).expand(x.shape[0],-1)
        f_i = self.M_I(f_i)
        #######################################################

        #######################################################
        eps_ip = torch.randn_like(x)
        f_e = x*torch.sigmoid(self.mask).expand(x.shape[0],-1) + eps_ip*torch.sigmoid((1-self.mask)).expand(x.shape[0],-1)
        f_e = self.M_E(f_e)
        ######################################################

        mass = self.mass_predictor(f_e).squeeze()

        reco = self.decoder(invariant = f_i,
                            equivariant=f_e)
        
        return f_i, f_e, mass, reco
    
    def decoder(self,invariant, equivariant):
        return (invariant+equivariant)

        
    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.mask, a=math.sqrt(5))

    

In [6]:
device = "cuda:0"

In [15]:
decorrel = DecorrelVAE(input_features=Equivariant_dim+Information_dim, hidden_dims=8)
decorrel.to(device)
optimizer_decorel = torch.optim.Adam(decorrel.parameters(), lr = 1e-3)

In [16]:
BATCH_SIZE=1024

In [17]:
data = data.to(device)
criterion_mse = nn.MSELoss()

loss_reco_list = []
loss_mass_list = []
for epoch in tqdm(range(100)):
    
    loss_reco_ = 0
    loss_mass_ = 0
    
    index = torch.randperm(data.shape[0])
    for i in range(0,(data.shape[0]//BATCH_SIZE)+1):
        z = data[index[i*BATCH_SIZE:(i+1)*BATCH_SIZE],:]
        
        optimizer_decorel.zero_grad()

        f_i, f_e, mass_pred, reco = decorrel(z)

        loss_mass_pred = criterion_mse(psi(z),mass_pred)
        loss_reco = criterion_mse(z,reco)
        
        loss = loss_mass_pred + loss_reco
        
        loss.backward()
        optimizer_decorel.step()
        
        loss_reco_+=loss_reco.item()
        loss_mass_+=loss_mass_pred.item()
        
    
    loss_reco_ /= (data.shape[0]//BATCH_SIZE)+1
    loss_mass_/= (data.shape[0]//BATCH_SIZE)+1
    
    loss_reco_list.append(loss_reco_)
    loss_mass_list.append(loss_mass_)

    if epoch%10 ==0 :
        print(f"EPOCH {epoch} complete")
        print("=====================")
        print("Loss Reconstruction",loss_reco_)
        print("Loss Mass Prediction",loss_mass_)

  0%|          | 0/100 [00:00<?, ?it/s]

EPOCH 0 complete
Loss Reconstruction 0.35895760331451587
Loss Mass Prediction 3.3203966962888516
EPOCH 10 complete
Loss Reconstruction 0.0004583608352008658
Loss Mass Prediction 0.009012764685824437
EPOCH 20 complete
Loss Reconstruction 6.1521207681949725e-06
Loss Mass Prediction 0.00500041261952115
EPOCH 30 complete
Loss Reconstruction 2.678545407092643e-06
Loss Mass Prediction 0.004122598199440717
EPOCH 40 complete
Loss Reconstruction 1.4296341002501528e-06
Loss Mass Prediction 0.004023845423348515
EPOCH 50 complete
Loss Reconstruction 2.079083562894512e-06
Loss Mass Prediction 0.004017074219215215
EPOCH 60 complete
Loss Reconstruction 2.219784433076434e-06
Loss Mass Prediction 0.004002387448335057
EPOCH 70 complete
Loss Reconstruction 2.2279853781622002e-06
Loss Mass Prediction 0.004009026779288935



KeyboardInterrupt

