In [None]:
import torch
import numpy as np

from symgraph.model.model import GNN
from symgraph.dataset import create_dataloaders

import matplotlib.pyplot as plt
import pandas as pd

from torch_geometric.data import Batch

In [None]:
def si_al_ratio_to_al_proportion(ratio, n_atoms):
    '''
    Given a Si/Al ratio, return the proportion of Al atoms in the structure
    '''	
    return n_atoms / (ratio + 1) 

In [None]:
si_al_ratio_to_al_proportion(6.5, 48)

conversion factors:

MFI, 0 Al: 0.1733675297
MFI, 1 Al: 0.1727122100
MFI, 2 Al: 0.1720618258
MFI, 3 Al: 0.1714187342
MFI, 4 Al: 0.1707756425

MOR, 7 Al: 0.3292455306

In [None]:
df = pd.read_csv('iso_model/MFI_SiAl95_Dunne.csv')

p, q = df.values.T

p_mfi_95 = 1000 * p # kpa to pa
q_mfi_95 = q * y[1]

df = pd.read_csv('iso_model/MFI_SiAl31_Dunne.csv')

p, q = df.values.T

p_mfi_31 = 1000 * p # kpa to pa
q_mfi_31 = q * y[3]

df = pd.read_csv('iso_model/MOR_SiAl5p8_Delgado.csv')

p, q = df.values.T

p_mor_5p8 = 1000 * p # kpa to pa
q_mor_5p8 = q * 0.3292455306

df = pd.read_csv('iso_model/MOR_SiAl6p5_Kwon.csv')

p, q = df.values.T

p_mor_6p5 = 1000 * p # kpa to pa
q_mor_6p5 = q

df = pd.read_csv('iso_model/LTA_SiAl_1_Parra.csv', delimiter=';', decimal=',')

p, q = df.values.T

p_lta_1 = 1000 * p # kpa to pa
q_lta_1 = q

In [None]:
plt.scatter(p_mfi_95, q_mfi_95, label='MFI Si/Al 95')
plt.scatter(p_mfi_31, q_mfi_31, label='MFI Si/Al 31')
plt.scatter(p_mor_5p8, q_mor_5p8, label='MOR Si/Al 5.8')
plt.scatter(p_mor_6p5, q_mor_6p5, label='MOR Si/Al 6.5')
plt.scatter(p_lta_1, q_lta_1, label='LTA Si/Al 1')
plt.xlabel('Pressure (Pa)')
plt.ylabel('Loading (mol/kg)')

plt.legend()
plt.xscale('log')
plt.ylim(0, 6)

In [None]:
p_in_mfi_95 = torch.log10(torch.tensor(p_mfi_95).unsqueeze(1))
p_in_mfi_31 = torch.log10(torch.tensor(p_mfi_31).unsqueeze(1))
p_in_mor_5p8 = torch.log10(torch.tensor(p_mor_5p8).unsqueeze(1))
p_in_mor_6p5 = torch.log10(torch.tensor(p_mor_6p5).unsqueeze(1))
p_in_lta_1 = torch.log10(torch.tensor(p_lta_1).unsqueeze(1))


In [None]:
p = torch.linspace(1, 7, 100).to('cuda').unsqueeze(-1)
dp = p[1] - p[0]
p_in = torch.cat([p, p[[-1]] + dp], dim=0) - dp/2


In [None]:
import spglib

def lattice_params_to_vectors(a, b, c, alpha, beta, gamma):
    """
    Convert lattice parameters (lengths and angles) into lattice vectors.
    
    Args:
        a, b, c (float): Lattice lengths.
        alpha, beta, gamma (float): Lattice angles in degrees.
    
    Returns:
        np.ndarray: (3,3) array representing lattice vectors.
    """
    # Convert angles to radians
    alpha, beta, gamma = np.radians([alpha, beta, gamma])

    # Compute lattice vectors
    va = np.array([a, 0, 0])
    vb = np.array([b * np.cos(gamma), b * np.sin(gamma), 0])
    
    vx = c * np.cos(beta)
    vy = c * (np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma)
    vz = c * np.sqrt(1 - np.cos(beta)**2 - vy**2 / c**2)
    
    vc = np.array([vx, vy, vz])

    return np.vstack([va, vb, vc])


In [None]:
tr_mor, *_= create_dataloaders(['MOR'], edge_type='radius', radius=8)
tr_mfi, *_= create_dataloaders(['MFI'], edge_type='radius', radius=8)
tr_lta, *_= create_dataloaders(['LTA'], edge_type='radius', radius=8)

In [None]:
def get_symmetry_permutation_indices(lens, angs, positions):
    """
    Compute index permutations due to symmetry for a given set of atomic positions.
    
    Args:
        lattice_params (tuple): (a, b, c, alpha, beta, gamma).
        positions (np.ndarray): (N, 3) array of fractional atomic positions.
    
    Returns:
        list[np.ndarray]: List of index permutations corresponding to symmetry operations.
    """
    # Convert lattice parameters to lattice vectors
    lattice = lattice_params_to_vectors(*lens, *angs)

    # Sort positions for consistency
    sorted_indices = np.lexsort(positions.T)
    positions = positions[sorted_indices]

    # Define structure for Spglib (all atoms treated as identical)
    num_atoms = len(positions)
    dummy_types = np.zeros(num_atoms, dtype=int)
    structure = (lattice, positions, dummy_types)

    # Get symmetry operations
    dataset = spglib.get_symmetry(structure, symprec=0.05)

    # print(dataset)

    index_permutations = []
    for rotation, translation in zip(dataset['rotations'], dataset['translations']):
        # Apply symmetry transformation
        new_positions = np.dot(positions, rotation.T) + translation
        new_positions = np.mod(new_positions, 1)  # Keep within unit cell

        # Find permutation indices by matching transformed positions to original ones
        permuted_indices = np.lexsort(new_positions.T)
        index_permutations.append(permuted_indices)

    index_permutations = np.array(index_permutations)
    return index_permutations, positions

def remove_duplicates_symmetrically(atoms, perms):
    
    atoms = torch.unique(atoms, dim=0)
    
    # print(atoms.shape)
    
    unique_atoms = None
    # remove duplicates due to symmetry
    for i in range(len(atoms)):
        
        if unique_atoms is None:
            unique_atoms = atoms[i].unsqueeze(0)
            continue
        

        perm_atom = atoms[i][perms]

        # Check if any permuted version matches a unique atom
        is_duplicate = (perm_atom.unsqueeze(1) == unique_atoms.unsqueeze(0)).all(dim=-1).any()

        if is_duplicate:
            # 2.5% chance to keep the configuration

            remove_chance = 0.1 * min ( 0.075, (atoms[i].mean().item())) * 1/0.075

            if np.random.uniform() < remove_chance:
                is_duplicate = False

        if not is_duplicate:
            unique_atoms = torch.cat([unique_atoms, atoms[i].unsqueeze(0)], dim=0)

    return unique_atoms


        


In [None]:
mor_example = tr_mor.dataset[0]
mfi_example = tr_mfi.dataset[0]
lta_example = tr_lta.dataset[0]

In [None]:
mor_len = np.load('Data_numpy/MOR/lens.npy')
mor_ang = np.load('Data_numpy/MOR/angs.npy')
mor_pos = np.load('Data_numpy/MOR/pos.npy')

mfi_len = np.load('Data_numpy/MFI/lens.npy')
mfi_ang = np.load('Data_numpy/MFI/angs.npy')
mfi_pos = np.load('Data_numpy/MFI/pos.npy')

lta_len = np.load('Data_numpy/LTA/lens.npy')
lta_ang = np.load('Data_numpy/LTA/angs.npy')
lta_pos = np.load('Data_numpy/LTA/pos.npy')

In [None]:
mor_perm, mor_pos = get_symmetry_permutation_indices(mor_len, mor_ang, mor_pos)
mfi_perm, mfi_pos = get_symmetry_permutation_indices(mfi_len, mfi_ang, mfi_pos)
lta_perm, lta_pos = get_symmetry_permutation_indices(lta_len, lta_ang, lta_pos)


In [None]:
import yaml
with open('bestmodel/sym_int/config.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

model = GNN(**config['model']).cuda()
model.load_state_dict(torch.load('bestmodel/sym_int/final.pth'))

In [None]:
def interp(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor, dim: int=-1, extrapolate: str='constant') -> torch.Tensor:
    """One-dimensional linear interpolation between monotonically increasing sample
    points, with extrapolation beyond sample points.

    Returns the one-dimensional piecewise linear interpolant to a function with
    given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.

    Args:
        x: The :math:`x`-coordinates at which to evaluate the interpolated
            values.
        xp: The :math:`x`-coordinates of the data points, must be increasing.
        fp: The :math:`y`-coordinates of the data points, same shape as `xp`.
        dim: Dimension across which to interpolate.
        extrapolate: How to handle values outside the range of `xp`. Options are:
            - 'linear': Extrapolate linearly beyond range of xp values.
            - 'constant': Use the boundary value of `fp` for `x` values outside `xp`.

    Returns:
        The interpolated values, same size as `x`.
    """
    # Move the interpolation dimension to the last axis
    x = x.movedim(dim, -1)
    xp = xp.movedim(dim, -1)
    fp = fp.movedim(dim, -1)
    
    m = torch.diff(fp) / torch.diff(xp) # slope
    b = fp[..., :-1] - m * xp[..., :-1] # offset
    indices = torch.searchsorted(xp, x, right=False)
    
    if extrapolate == 'constant':
        # Pad m and b to get constant values outside of xp range
        m = torch.cat([torch.zeros_like(m)[..., :1], m, torch.zeros_like(m)[..., :1]], dim=-1)
        b = torch.cat([fp[..., :1], b, fp[..., -1:]], dim=-1)
    else: # extrapolate == 'linear'
        indices = torch.clamp(indices - 1, 0, m.shape[-1] - 1)

    values = m.gather(-1, indices) * x + b.gather(-1, indices)
    
    return values.movedim(-1, dim)

In [None]:
class GeneticAlgorithm:

    def __init__(self, model):

        self.model = model
        self.model.eval()

    def run(self, pres, target_isotherm, example, perms, pop_size=100, generations=50, al_prop=0.2, top_n=10, elite_frac=0.15, zeo='RHO'):
        
        num_atoms = example.pos.shape[0]
        # al_prop = si_al_ratio_to_al_proportion(si_al_ratio, num_atoms)

        population = self.initialize_population(pop_size, num_atoms, al_prop)

        # Evaluate population
        isotherms = self.predict_isotherms_model(example, pres, population)
        fitness = self.fitness(isotherms, target_isotherm, population, zeo)

        for gen in range(generations):

            
            n_elite = int(elite_frac * pop_size)

            sort_idx = torch.argsort(fitness, descending=True)
            population = population[sort_idx]
            fitness = fitness[sort_idx]

            elite_pop = population[:n_elite]
            mutate_pop = population[:2*n_elite]

            # Mutate
            mutated_pop = self.mutate(mutate_pop.clone())
            scrambled_pop = self.scramble(mutate_pop.clone())

            # change 1 Al to Si or vice versa
            changed_pop = self.add_or_remove_al(elite_pop.clone())

            new_pop = torch.cat([elite_pop, changed_pop, mutated_pop, scrambled_pop], dim=0)

            # remove duplicates
            unique_pop = remove_duplicates_symmetrically(new_pop, perms)
            # unique_pop = new_pop

            n_missing = pop_size - len(unique_pop)

            if n_missing > 0:
                new_pop = self.initialize_population(n_missing, num_atoms, al_prop) # np.random.rand()*al_prop)
                unique_pop = torch.cat([unique_pop, new_pop], dim=0)
            
            population = unique_pop

            # Evaluate population
            isotherms = self.predict_isotherms_model(example, pres, population)
            fitness = self.fitness(isotherms, target_isotherm, population, zeo)

            print(f'Generation {gen}: Best fitness: {fitness.max()}')

            # population = population[torch.argsort(fitness, descending=True)]

        return population[:top_n]
    
    def add_or_remove_al(self, population):
        population = population.clone()

        for i in range(len(population)):
            # find Al and change to Si
            if sum(population[i]) > 0 and np.random.rand() < 0.6:
                al_idx = torch.where(population[i] == 1)[0]
                # pick one Al to change to Si
                idx = np.random.choice(al_idx.cpu())
                
                population[i][idx] = 0
            else:
                si_idx = torch.where(population[i] == 0)[0]
                # pick one Si to change to Al
                idx = np.random.choice(si_idx.cpu())
                population[i][idx] = 1

        return population

    
    def mutate(self, population, min_mutations=1, max_mutations=0.05):
        max_mutations = max(int(max_mutations * len(population[0])), 2)
        population = population.clone()
        for i in range(len(population)): 
            n_mutations = np.random.randint(min_mutations, max_mutations)
            swap_indices = torch.randperm(len(population[i]))[:n_mutations]  # Pick two sites to swap
            population[i][swap_indices] = 1 - population[i][swap_indices]

        return population

    def scramble(self, population):
        population = population.clone()
        for i in range(len(population)):
            indices = torch.randperm(len(population[i]))
            population[i] = population[i][indices]

        return population
        
    @torch.no_grad()
    def predict_isotherms_model(self, example, pres, population):
        self.model.eval()
        data_list = []
        for i in range(len(population)):
            data = example.clone()
            data.x = population[i].unsqueeze(1).float()
            data_list.append(data)
        
        data =  Batch.from_data_list(data_list).to('cuda')
        _, q_prime = self.model(data, p_in)
        q = torch.cumulative_trapezoid(q_prime, p_in.squeeze(-1), dim=1)
        q_pred = interp(pres[None].squeeze(-1).repeat(q.shape[0],1), p[None].squeeze(-1).repeat(q.shape[0], 1), q)
        # select every other element
        return q_pred
    
    def fitness(self, isotherms, target_isotherm, population, zeo='RHO'):
        
        # percentage_error = (isotherms - target_isotherm) / (target_isotherm + 1e-3)

        scale = target_isotherm / isotherms

        if zeo in ['RHO', 'LTA', 'FAU']:
            max_clamp = 5
            scale = scale.clamp(1, 6)
        else:
            
            scale = scale.clamp(max = 1.5)
            max_clamp = 1.3

        scale = scale.mean(dim=1).unsqueeze(1)
        scale = torch.clamp(scale, 1, max_clamp)

        scaled_isotherms = isotherms * scale

        err = isotherms - target_isotherm
        scaled_err = scaled_isotherms - target_isotherm

        mae = torch.mean(err.abs(), dim=1)
        scaled_mae = torch.mean(scaled_err.abs(), dim=1)

        # return -scaled_mae - 0.02 * mae - 8 * (population.mean(dim=-1)**2).clamp(max=0.035)

        mean_pop = population.mean(dim=-1).clamp(min=0, max=0.35)
        al_pen = torch.where(mean_pop < 0.25, 2.83 * mean_pop, (0.25*2.83)+0.05*mean_pop)
        return -scaled_mae - 0.02 * mae - al_pen

    
    def initialize_population(self, population_size, num_atoms, al_prop):
        pop = torch.bernoulli(torch.ones(population_size, num_atoms) * al_prop).cuda()
        return pop
    




In [None]:
ga = GeneticAlgorithm(model)

In [None]:
pop, top_n = 200, 25

In [None]:
best_mfi_95 = ga.run(p_in_mfi_95.cuda().float(), torch.tensor(q_mfi_95).float().cuda(), mfi_example, torch.tensor(mfi_perm), al_prop=0.05, elite_frac=0.125, pop_size=pop, top_n=top_n, zeo='MFI')
best_mfi_31 = ga.run(p_in_mfi_31.cuda().float(), torch.tensor(q_mfi_31).float().cuda(), mfi_example, torch.tensor(mfi_perm), al_prop=0.05, elite_frac=0.125, pop_size=pop, top_n=top_n, zeo='MFI')
best_mor_5p8 = ga.run(p_in_mor_5p8.cuda().float()[3:], torch.tensor(q_mor_5p8).float().cuda()[3:], mor_example, torch.tensor(mor_perm), al_prop=0.05, elite_frac=0.125, pop_size=pop, top_n=top_n, zeo='MOR')
best_mor_6p5 = ga.run(p_in_mor_6p5.cuda().float(), torch.tensor(q_mor_6p5).float().cuda(), mor_example, torch.tensor(mor_perm), al_prop=0.05, elite_frac=0.125, pop_size=pop, top_n=top_n, zeo='MOR')
best_lta_1 = ga.run(p_in_lta_1.cuda().float(), torch.tensor(q_lta_1).float().cuda(), lta_example, torch.tensor(lta_perm), al_prop=0.75, elite_frac=0.125, pop_size=pop, top_n=top_n, zeo='LTA')

In [None]:
np.save('ga_results/best_mfi_95.npy', best_mfi_95.cpu().numpy())
np.save('ga_results/best_mfi_31.npy', best_mfi_31.cpu().numpy())
np.save('ga_results/best_mor_5p8.npy', best_mor_5p8.cpu().numpy())
np.save('ga_results/best_mor_6p5.npy', best_mor_6p5.cpu().numpy())
np.save('ga_results/best_lta_1.npy', best_lta_1.cpu().numpy())

In [None]:

def create_data(population, example):
    data_list = []
    for i in range(len(population)):
        data = example.clone()
        data.x = population[i].unsqueeze(1).float(
            
        )
        
        data_list.append(data)
    
    return Batch.from_data_list(data_list).to('cuda')

In [None]:
fig, ax = plt.subplots(1, 5, figsize=(25, 5))

# for each amount of Al, calculate the number of structures
n_al = torch.sum(best_mfi_95, dim=1)
n_al = torch.unique(n_al, return_counts=True)
ax[0].bar(n_al[0].cpu(), n_al[1].cpu(), width=1)
ax[0].set_title('MFI Si/Al 95')

n_al = torch.sum(best_mfi_31, dim=1)
n_al = torch.unique(n_al, return_counts=True)
ax[1].bar(n_al[0].cpu(), n_al[1].cpu(), width=1)
ax[1].set_title('MFI Si/Al 31')

n_al = torch.sum(best_mor_5p8, dim=1)
n_al = torch.unique(n_al, return_counts=True)
ax[2].bar(n_al[0].cpu(), n_al[1].cpu(), width=1)
ax[2].set_title('MOR Si/Al 5.8')

n_al = torch.sum(best_mor_6p5, dim=1)
n_al = torch.unique(n_al, return_counts=True)
ax[3].bar(n_al[0].cpu(), n_al[1].cpu(), width=1)
ax[3].set_title('MOR Si/Al 6.5')

# n_al = torch.sum(best_rho_3p4, dim=1)
# n_al = torch.unique(n_al, return_counts=True)
# ax[4].bar(n_al[0].cpu(), n_al[1].cpu(), width=1)
# ax[4].set_title('RHO Si/Al 3.4')

n_al = torch.sum(best_lta_1, dim=1)
n_al = torch.unique(n_al, return_counts=True)
ax[4].bar(n_al[0].cpu(), n_al[1].cpu(), width=1)
ax[4].set_title('LTA Si/Al 1')

# n_al = torch.sum(best_fau_1, dim=1) 
# n_al = torch.unique(n_al, return_counts=True)
# ax[6].bar(n_al[0].cpu(), n_al[1].cpu(), width=1)
# ax[6].set_title('FAU Si/Al 1')

# n_al = torch.sum(best_mfi, dim=1)
# n_al = torch.unique(n_al, return_counts=True)
# ax[7].bar(n_al[0].cpu(), n_al[1].cpu(), width=1)
# ax[7].set_title('MFI Pure Si')

ax[0].vlines(1, 0, 20, color='r', linestyle='--')
ax[1].vlines(3, 0, 20, color='r', linestyle='--')
ax[2].vlines(7, 0, 20, color='r', linestyle='--')
ax[3].vlines(6.4, 0, 20, color='r', linestyle='--')
# ax[4].vlines(11, 0, 20, color='r', linestyle='--')
ax[4].vlines(12, 0, 20, color='r', linestyle='--')
# ax[6].vlines(48, 0, 20, color='r', linestyle='--')
# ax[7].vlines(0, 0, 20, color='r', linestyle='--')