<a href="https://colab.research.google.com/github/lucyvost/distorted_diffusion/blob/main/distorted_conditional_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#install dependencies
!pip install rdkit


Collecting rdkit
  Downloading rdkit-2024.3.5-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.9 kB)
Downloading rdkit-2024.3.5-cp310-cp310-manylinux_2_28_x86_64.whl (33.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.1/33.1 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2024.3.5


In [2]:
from rdkit import Chem
import numpy as np
import random



1. Load any 3D molecular dataset in dictionary form as taken in by EDM.

In [None]:
#upload the three files to colab in content/

train = np.load(f'content/train.npz')
test = np.load(f'content/test.npz')
valid = np.load(f'content/valid.npz')

# Create a new npz file to combine the data
combined_file = f'all_data.npz'
combined_data = {}
# Iterate through keys and combine the data
for key in test.keys():
    combined_data[key] = np.concatenate([train[key], test[key], valid[key]], axis=0)

# Save the combined data to a new npz file
#np.savez(combined_file, **combined_data)

2. Add some distorted molecules to the dataset

In [None]:
def scramble_coordinates_3d(coordinates, max_scramble=0.25):
    """
    Scramble a list of 3D coordinates by a random amount between 0 and max_scramble.

    Parameters:
    - coordinates: A list of tuples, each containing (x, y, z) coordinates.
    - max_scramble: The maximum amount to scramble (0 to 1).

    Returns:
    - A new list of scrambled 3D coordinates.
    """
    scrambled_coordinates = []
    for x, y, z in coordinates:
        # Generate random offsets within the specified range for each dimension
        offset_x = random.uniform(-max_scramble, max_scramble)
        offset_y = random.uniform(-max_scramble, max_scramble)
        offset_z = random.uniform(-max_scramble, max_scramble)

        # Apply the offsets to the coordinates
        scrambled_x = x + offset_x
        scrambled_y = y + offset_y
        scrambled_z = z + offset_z

        scrambled_coordinates.append([scrambled_x, scrambled_y, scrambled_z])

    return np.array(scrambled_coordinates)

In [None]:
import math
from tqdm import tqdm
import torch

#new distorted version
conditional_dict = {}
#load all the data first so we don't need to re-load the whole dictionary every time
all_positions = combined_data['positions']
all_charges = combined_data['charges']
for idx, mol in tqdm(enumerate(all_charges), total=len(all_charges)):


    conditional_dict['positions'].append(all_positions[idx])
    conditional_dict['charges'].append(all_charges[idx])
    conditional_dict['scramble'].append(0)
    conditional_dict['num_atoms'].append(len(np.where(mol != 0 )))

    #one slightly messed up version
    if random.random() < 0.02:


        conditional_dict['num_atoms'].append(len(np.where(mol != 0 ))
        slight_mess = random.uniform(0, 0.25)
        conditional_dict['charges'].append(all_charges[idx])
        conditional_dict['scramble'].append(slight_mess)
        new_coords = scramble_coordinates_3d(new_positions,max_scramble=slight_mess)[0:num_atoms]

        new_coords = np.append(new_coords, np.zeros([to_pad,3])).reshape((100,3))

        conditional_dict['positions'].append(new_coords)



3. Now split your dictionary into train/test/split and save

In [None]:
# Shuffle indices and split data
shuffled_indices = random.sample(range(len(new_data['charges'])), k=len(new_data['charges']))
split_points = lambda n: [int(np.floor(n * x)) for x in [0.66, 0.76, 0.86]]

# Extract precomputed data
positions_data, charges_data, scramble_data, num_atoms_data = map(np.array,
    (new_data['positions'], new_data['charges'], new_data['scramble'], new_data['num_atoms']))

# Define split indices
half_train_end, half_test_start, half_valid_start = split_points(len(new_data['charges']))

# Function to create dictionaries
def create_dict(indices):
    return {
        'positions': positions_data[indices],
        'charges': charges_data[indices],
        'scramble': scramble_data[indices],
        'num_atoms': num_atoms_data[indices]
    }

# Create train, test, and valid dictionaries
train_dict = create_dict(shuffled_indices[:half_train_end])
test_dict = create_dict(shuffled_indices[half_test_start:])
valid_dict = create_dict(shuffled_indices[half_valid_start:])

# Save datasets

for split, data in zip(['train_dist', 'test_dist', 'valid_dist'], [train_dict, test_dict, valid_dict]):
    np.savez(f'{split}.npz', **data)


you now have a set of dictionaries you can use to train a conditional model on the 'scramble' property!