In [133]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn

import matplotlib.pyplot as plt

In [134]:
data = pd.read_csv('../data/example_data.csv')

# Load models

In [135]:
def augment(x, tau=None):
    if tau is None:
        tau = torch.zeros(x.size(0), 1).fill_(0.5)
    elif isinstance(tau, float):
        tau = torch.zeros(x.size(0), 1).fill_(tau)
    return torch.cat((x, (tau - 0.5) * 12), 1)

In [136]:
class HaloToGalaxyModel(nn.Module):
    def __init__(self, input_size=4, output_size=5, hidden_dim=64):
        super(HaloToGalaxyModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim, output_size)

    def forward(self, x):
        out = self.fc1(x)
        out = self.ln1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.ln2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        return out

In [137]:
hidden_dim = 64 #tamaño de las capas ocultas

In [138]:
model_smass = HaloToGalaxyModel(6, 1, hidden_dim)

checkpoint = torch.load('models/smass-smogn_model.pth')

model_smass.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.AdamW(model_smass.parameters(), lr=1e-3, weight_decay=1e-2)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

  checkpoint = torch.load('models/smass-smogn_model.pth')


In [139]:
model_smass_color = HaloToGalaxyModel(6, 2, hidden_dim)

checkpoint = torch.load('models/smass-color-g-i-smogn_model.pth')

model_smass_color.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.AdamW(model_smass_color.parameters(), lr=1e-3, weight_decay=1e-2)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

  checkpoint = torch.load('models/smass-color-g-i-smogn_model.pth')


In [140]:
model_color = HaloToGalaxyModel(6, 1, hidden_dim)

checkpoint = torch.load('models/color-g-i-smogn_model.pth')

model_color.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.AdamW(model_color.parameters(), lr=1e-3, weight_decay=1e-2)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

  checkpoint = torch.load('models/color-g-i-smogn_model.pth')


In [141]:
model_smass_color_wass = HaloToGalaxyModel(6, 2, hidden_dim)

checkpoint = torch.load('models/smass-color-g-i-wass-smogn_model.pth')

model_smass_color_wass.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.AdamW(model_smass_color_wass.parameters(), lr=1e-3, weight_decay=1e-2)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

  checkpoint = torch.load('models/smass-color-g-i-wass-smogn_model.pth')


# Predictions

In [142]:
input_data = data[['M_h', 'C_h', 'S_h', 'z_h', 'Delta3_h']].to_numpy()
output = data[['smass', 'color']].to_numpy()

In [143]:
input_data

array([[12.48163287,  0.99434969,  0.06208566,  1.27272727, 14.32015055],
       [11.65242996,  1.14203344,  0.09452574,  0.84848485,  1.94253163],
       [12.00267253,  1.00050752,  0.07858547,  0.96969697,  3.37133589]])

In [144]:
scaler = StandardScaler()
input_data = scaler.fit_transform(input_data)

In [145]:
input_data = torch.tensor(input_data, dtype=torch.float32)

In [146]:
model_smass.eval()
with torch.no_grad():
    taus = torch.rand(input_data.size(0), 1)
    augmented_x = augment(input_data, taus)
    
    y_pred_smass = model_smass(augmented_x).detach().numpy()

In [147]:
model_color.eval()
with torch.no_grad():
    taus = torch.rand(input_data.size(0), 1)
    augmented_x = augment(input_data, taus)
    
    y_pred_color = model_color(augmented_x).detach().numpy()

In [148]:
model_smass_color.eval()
with torch.no_grad():
    taus = torch.rand(input_data.size(0), 1)
    augmented_x = augment(input_data, taus)

    y_pred_smass_color = model_smass_color(augmented_x).detach().numpy()

In [149]:
model_smass_color_wass.eval()
with torch.no_grad():
    taus = torch.rand(input_data.size(0), 1)
    augmented_x = augment(input_data, taus)
    
    y_pred_smass_color_wass = model_smass_color_wass(augmented_x).detach().numpy()

# Show predictions

In [150]:
output

array([[10.78942251,  1.12636185],
       [ 9.91511603,  0.59660339],
       [10.34346282,  0.97583961]])

In [151]:
y_pred_smass

array([[10.289586],
       [ 8.818494],
       [ 9.406857]], dtype=float32)

In [152]:
y_pred_color

array([[1.1304051 ],
       [0.3834809 ],
       [0.60946685]], dtype=float32)

In [153]:
y_pred_smass_color

array([[10.535534  ,  1.1049111 ],
       [ 8.718655  ,  0.3792008 ],
       [ 9.361683  ,  0.40336064]], dtype=float32)

In [154]:
y_pred_smass_color_wass

array([[10.461731  ,  1.1408694 ],
       [ 8.947577  ,  0.79551387],
       [ 9.541263  ,  0.61819994]], dtype=float32)