In [None]:
import numpy as np
import pandas as pd
import torch
from rdkit.Chem import MolFromSmiles, QED, rdMolDescriptors, Descriptors
from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator
from gauche.representations.fingerprints import fragments

# Load and clean data
original = pd.read_csv('ESOL_dataset.txt', delimiter=',')
original = original.dropna()
smiles = original['SMILES']
assert len(smiles) == 1144, f"Expected 1144 molecules, got {len(smiles)}"

# Create molecule objects with error handling
molecules = []
invalid_smiles = []
for i, s in enumerate(smiles):
    try:
        mol = MolFromSmiles(s)
        if mol is not None:
            molecules.append(mol)
        else:
            invalid_smiles.append((i, s))
    except:
        invalid_smiles.append((i, s))

if invalid_smiles:
    print(f"Warning: {len(invalid_smiles)} invalid SMILES strings found")

# Generate features
generator = GetMorganGenerator(radius=3, fpSize=2048)
fps = [generator.GetFingerprint(mol) for mol in molecules]
fps_array = np.array(fps)
frag = fragments(smiles=smiles)
domain = torch.from_numpy(np.concatenate((fps_array, frag), axis=1)).float()

# Generate targets
qed = np.array([QED.qed(mol) for mol in molecules]).reshape(-1, 1)
tpsa = np.array([rdMolDescriptors.CalcTPSA(mol) for mol in molecules]).reshape(-1, 1)
logSOL = np.array(original['measured log(solubility:mol/L)']).reshape(-1, 1)
logp = np.array([Descriptors.MolLogP(mol) for mol in molecules]).reshape(-1, 1)
target = torch.from_numpy(np.concatenate((logp, qed, tpsa, logSOL), axis=1)).float()

# Verify dimensions
assert domain.shape[0] == target.shape[0], "Mismatch between domain and target samples"
assert domain.shape[1] == 2133, "Expected 2048 + 85 = 2133 features"
assert target.shape[1] == 4, "Expected 4 target variables (logP, QED, TPSA, logSOL)"

print(f"Domain shape: {domain.shape}")
print(f"Target shape: {target.shape}")