In [1]:
from ModelTransfuser.ModelTransfuser import *
import matplotlib.pyplot as plt
import matplotlib.animation
import matplotlib.patches as patches
import seaborn as sns

import numpy as np
import torch
import os

%matplotlib inline

# Load data

In [2]:
# --- Load in training data ---
path_training = os.getcwd() + '/ModelTransfuser/data/chempy_TNG_train_data.npz'
training_data = np.load(path_training, mmap_mode='r')

elements = training_data['elements']
train_x = training_data['params']
train_y = training_data['abundances']


# ---  Load in the validation data ---
path_test = os.getcwd() + '/ModelTransfuser/data/chempy_TNG_val_data.npz'
val_data = np.load(path_test, mmap_mode='r')

val_x = val_data['params']
val_y = val_data['abundances']


# --- Clean the data ---
# Chempy sometimes returns zeros or infinite values, which need to removed
def clean_data(x, y):
    # Remove all zeros from the training data
    index = np.where((y == 0).all(axis=1))[0]
    x = np.delete(x, index, axis=0)
    y = np.delete(y, index, axis=0)

    # Remove all infinite values from the training data
    index = np.where(np.isfinite(y).all(axis=1))[0]
    x = x[index]
    y = y[index]

    # Remove H from Elements
    y = np.delete(y, 2, 1)

    return x, y


train_x, train_y = clean_data(train_x, train_y)
val_x, val_y     = clean_data(val_x, val_y)

# convert to torch tensors
train_x = torch.tensor(train_x, dtype=torch.float32)
train_y = torch.tensor(train_y, dtype=torch.float32)
val_x = torch.tensor(val_x, dtype=torch.float32)
val_y = torch.tensor(val_y, dtype=torch.float32)

train_data = torch.cat((train_x, train_y), 1)
val_data = torch.cat((val_x, val_y), 1)

## Define ModelTransfuser

In [3]:
# Define the ModelTransfuser

# Time steps for the diffusion process
T = 50
t = torch.linspace(0, 1, T)

ModelTransfuser = ModelTransfuser(T, train_data.shape)



In [4]:
train_data.shape

torch.Size([498314, 14])

In [5]:
val_data.shape

torch.Size([49824, 14])

## Train diffusion model

In [6]:
ModelTransfuser.train(train_data, val_data=val_data)

Epoch  1/10:   0%|          | 10/7787 [00:32<6:59:13,  3.23s/it]


KeyboardInterrupt: 

In [1]:
ModelTransfuser.train_loss

NameError: name 'ModelTransfuser' is not defined

In [None]:
val_data.shape

In [10]:
torch.save(ModelTransfuser.state_dict(), "ModelTransfuser/models/ModelTransfuser_t100.pt")

In [None]:
epoch = np.arange(0, len(ModelTransfuser.train_loss))

plt.plot(epoch, np.array(ModelTransfuser.train_loss)/train_data.shape[0], label='Train Loss')
plt.plot(epoch, np.array(ModelTransfuser.val_loss)/val_data.shape[0], label='Val Loss')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

In [7]:
ModelTransfuser.load_state_dict(torch.load("ModelTransfuser/models/ModelTransfuser_t100.pt", weights_only=False))

<All keys matched successfully>

## Data Generation

In [8]:
# Create random datapoints to denoise
sample_data_t1 = torch.randn(10000, train_data.shape[1])*2

In [9]:
sample_data_t0 = ModelTransfuser.sample(sample_data_t1, condition_mask=torch.ones_like(sample_data_t1))

 44%|████▍     | 22/50 [01:34<01:59,  4.28s/it]


KeyboardInterrupt: 

In [None]:
sample_data_t0.mean(axis=0)

In [None]:
sample_data_t0.std(axis=0)

In [13]:
priors = [-2.3, -2.89, -0.3, 0.55, 0.5]
sigma = [0.3, 0.3, 0.3, 0.1, 0.1]

In [15]:
denoising_data = ModelTransfuser.x_t.detach().numpy()
score_t = ModelTransfuser.score_t.detach().numpy()
dx = ModelTransfuser.dx_t.detach().numpy()
t = ModelTransfuser.t.detach().numpy()

In [None]:
np.repeat(scaling_factor[:,np.newaxis], 2, axis=1)

In [92]:
a = ModelTransfuser.score_t[0,:-1,:2].detach().numpy()*np.repeat(scaling_factor[:,np.newaxis], 2, axis=1)

In [None]:
print(denoising_data[0,0,:2])
print(denoising_data[0,1,:2])

print(-dx[0,0,:2])

In [None]:
denoising_data[0,0,:2]-a[0,:2]

In [None]:
scaling_factor = -0.5*ModelTransfuser.sigma**(2*t)*(1/T)

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 500  
plt.ioff()
fig, ax = plt.subplots()

def animate(i):
    plt.clf()
    plt.xlim(-5,5)
    plt.ylim(-5,5)
    plt.xlabel(r'$\alpha_{IMF}$')
    plt.ylabel(r'$\log_{10}N$')
    plt.title(f'Denoising Timestep: {i}')

    sns.set_style("white")
    sns.kdeplot(x=denoising_data[:,i,0], y=denoising_data[:,i,1], cmap='Blues', fill=True, levels=100, bw_adjust=0.6)
    plt.tight_layout()

ani2 = matplotlib.animation.FuncAnimation(fig, animate, frames=25) 

writer = matplotlib.animation.PillowWriter(fps=5,
                                bitrate=-1)
ani2.save('plots/test_big.gif', writer=writer)

In [None]:
denoising_data[0,1,:2]

In [41]:
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 500  
plt.ioff()
fig, ax = plt.subplots()

def animate(i):
    plt.clf()
    plt.xlim(-5,5)
    plt.ylim(-5,5)
    plt.xlabel(r'$\alpha_{IMF}$')
    plt.ylabel(r'$\log_{10}N$')
    plt.title(f'Denoising Timestep: {i}')

    plt.quiver(denoising_data[0,:i+1,0], denoising_data[0,:i+1,1], -dx[0,:i+1,0], -dx[0,:i+1,1], scale=1, scale_units='xy', width=0.003)
    plt.scatter(denoising_data[0,:i+1,0], denoising_data[0,:i+1,1], s=20, marker='x', color='black')
    #plt.hist2d(denoising_data[:,i,0], denoising_data[:,i,1])
    plt.tight_layout()

    #plt.scatter(denoising_data[:,i,0], denoising_data[:,i,1], s=0.5)
    #for j in range(len(denoising_data)):
    #    plt.arrow(denoising_data[j,i,0], denoising_data[j,i,1], score_t[j,i,0]*scaling_factor[i], score_t[j,i,1]*scaling_factor[i], color='black', head_width=0.05, alpha=0.6)

ani2 = matplotlib.animation.FuncAnimation(fig, animate, frames=25) 

writer = matplotlib.animation.PillowWriter(fps=5,
                                bitrate=-1)
ani2.save('plots/test_quiver.gif', writer=writer)

In [None]:
dx[0,0,0]