# Train model

In [1]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import torch

from load_qm9 import *
from display_mol import *

#load dataset
ds = qm9_load_tfdata()

## data management and preprocessing

### Moldecule dataloader for trainning in batchs of equal size

In [2]:
ATOM_TYPES = [1, 6, 7, 8, 9]
MAX_SIZE = 30 # max number of atoms per molecules


class MoleculeDataset:
    def __init__(self, tf_dataset):
        self.ds = tf_dataset
        self.molecules_list = list(self.ds.as_numpy_iterator()) # Convertir le dataset en liste pour pouvoir indexer
        self.N = tf_dataset.reduce(0, lambda x, _: x + 1).numpy()
        self.M_pdf = np.zeros(MAX_SIZE, dtype=np.int32)
        self.molecules_by_size = {i: [] for i in range(MAX_SIZE)}
        
        self._index_molecules()
        
        print(f"Dataset chargé : {self.N} molécules")
        print(f"Distribution des tailles : {self.M_pdf}")
    

    def _index_molecules(self):
        for idx, example in enumerate(self.molecules_list):
            n_atoms = example[1].shape[0]
            self.M_pdf[n_atoms] += 1
            self.molecules_by_size[n_atoms].append(idx)
    
    
    def get_batch(self, batch_size): 
        """
        return generator that give infinite patchs with molecules of equal sizes
        """ 
        valid_sizes = [M for M in range(MAX_SIZE) if self.M_pdf[M] >= batch_size]

        valid_probs = np.array([self.M_pdf[M] for M in valid_sizes])
        valid_probs = valid_probs / valid_probs.sum() 
        
        while True:
            M = np.random.choice(valid_sizes, p=valid_probs)
            
            available_indices = self.molecules_by_size[M]
            
            selected_indices = np.random.choice( # shuffle
                    available_indices, 
                    size=batch_size, 
                    replace=False
                )
            
            batch = [self.molecules_list[idx] for idx in selected_indices]
            
            yield self._collate_batch(batch)
    
    def _collate_batch(self, batch):
        """
        Combine une liste de molécules en un batch.
        """
        x = np.stack([mol[0] for mol in batch], axis=0)
        y = np.stack([mol[1] for mol in batch], axis=0)
        q = np.stack([mol[2] for mol in batch], axis=0)
        
        return [x, y, q]
    
    def get_epoch_batches(self, batch_size):
        """
        Génère des batches pour une époque complète.
        Chaque molécule apparaît exactement une fois par époque.
        """
        all_batches = []
        
        for M in range(MAX_SIZE):
            if self.M_pdf[M] == 0:
                continue
            
            indices = self.molecules_by_size[M].copy()
            
            np.random.shuffle(indices)
            
            for i in range(0, len(indices), batch_size):
                batch_indices = indices[i:i + batch_size]
                if len(batch_indices) == batch_size:
                    batch = [self.molecules_list[idx] for idx in batch_indices]
                    all_batches.append(batch)
        
        np.random.shuffle(all_batches)
        
        for batch in all_batches:
            yield self._collate_batch(batch)
    

molecules = MoleculeDataset(ds) # ~1min30


Dataset chargé : 133885 molécules
Distribution des tailles : [    0     0     0     2     4     5    12    21    70   193   527  1150
  2336  4259  7103 10646 14270 17394 17836 18336 12601 13189  4483  6362
   713  1923    59   356     0    35]


In [16]:


# Génération infinie de batches (pour l'entraînement)
batch_gen = molecules.get_batch(batch_size=32)

for i in range(5):
    batch = next(batch_gen)
    print([i.shape for i in batch])

"""
# Méthode 2 : Une epoch complète (chaque molécule vue exactement une fois)
print("\epoch complète:")
for epoch in range(2):
    print(f"\nÉpoque {epoch + 1}")
    batch_count = 0
    for batch in molecules.get_epoch_batches(batch_size=32, shuffle=True):
        batch_count += 1
        if batch_count <= 3:  # Afficher les 3 premiers batches
            print(f"  Batch {batch_count}: {batch['N'][0]} atomes, {len(batch['N'])} molécules")
    print(f"  Total: {batch_count} batches")"""

[(32, 9, 3), (32, 9), (32, 9)]
[(32, 17, 3), (32, 17), (32, 17)]
[(32, 16, 3), (32, 16), (32, 16)]
[(32, 17, 3), (32, 17), (32, 17)]
[(32, 21, 3), (32, 21), (32, 21)]


'\n# Méthode 2 : Une epoch complète (chaque molécule vue exactement une fois)\nprint("\\epoch complète:")\nfor epoch in range(2):\n    print(f"\nÉpoque {epoch + 1}")\n    batch_count = 0\n    for batch in molecules.get_epoch_batches(batch_size=32, shuffle=True):\n        batch_count += 1\n        if batch_count <= 3:  # Afficher les 3 premiers batches\n            print(f"  Batch {batch_count}: {batch[\'N\'][0]} atomes, {len(batch[\'N\'])} molécules")\n    print(f"  Total: {batch_count} batches")'

### one hot encoding for the molecule shape

In [17]:
def one_hot_encode(batch):
    x = torch.tensor(batch[0], dtype=torch.float32)
    e = torch.tensor(batch[1], dtype=torch.long)
    q = torch.tensor(batch[2], dtype=torch.float32)
    
    max_atom_type = max(ATOM_TYPES) + 1
    lookup = torch.full((max_atom_type,), -1, dtype=torch.long)
    
    for idx, atom_type in enumerate(ATOM_TYPES):
        lookup[atom_type] = idx
    
    y_indices = lookup[e]
    
    assert (y_indices >= 0).all(), "Types d'atomes inconnus détectés!"
    
    
    e = torch.nn.functional.one_hot(y_indices, num_classes=len(ATOM_TYPES)).float()
    
    return [x, e, q]

def one_hot_decode(batch):
    x = batch[0].clone().detach()
    e_onehot = batch[1].clone().detach()
    q = batch[2].clone().detach()
    
    e_indices = torch.argmax(e_onehot, dim=-1)  # (batch_size, n_atoms)
    atom_types_tensor = torch.tensor(ATOM_TYPES, dtype=torch.long)
    e = atom_types_tensor[e_indices]
    
    return [x, e, q]



def one_hot_decode_stochastic(batch, temperature=1.0):
    x = batch[0].clone().detach()
    e_soft = batch[1].clone().detach()
    q = batch[2].clone().detach()
    
    # Option 2: Sampling selon les probabilités (stochastic)
    e_probs = torch.nn.functional.softmax(e_soft / temperature, dim=-1)
    e_indices = torch.multinomial(
        e_probs.view(-1, e_probs.size(-1)), 
        num_samples=1
    ).view(e_probs.shape[:-1])
    
    atom_types_tensor = torch.tensor(ATOM_TYPES, dtype=torch.long)
    e = atom_types_tensor[e_indices]
    
    return [x, e, q]
    

mol = one_hot_encode(next(batch_gen))   
print([e.shape for e in mol])
print([e.shape for e in one_hot_decode(mol)])

[torch.Size([32, 21, 3]), torch.Size([32, 21, 5]), torch.Size([32, 21])]
[torch.Size([32, 21, 3]), torch.Size([32, 21]), torch.Size([32, 21])]


## Model

In [None]:
class UltraMagaDiffusion(torch.nn.Module):
    def __init__(self):
        super(UltraMagaDiffusion, self).__init__()
        self.T = 10
        s = 1e-5
        self.alpha = [(1-2*s)*(1 - (t/self.T))+s for t in range(self.T+1)]
        self.omega = [1-alpha**2 for alpha in self.alpha]
        self.L = 9
        self.lr = 1e-4
        self.num_class = len(ATOM_TYPES)

        self.init_weight()
    
    def init_weight(self):
        self.phi_e = []
        self.phi_inf = []
        self.phi_x = []
        self.phi_h = []
        nf = self.num_class + 2 # +2 for atom charge and t/T

        for l in range(self.L):
            self.phi_e += [torch.nn.Sequential(
                torch.nn.Linear(nf*2+2, nf), 
                torch.nn.SiLU(),
                torch.nn.Linear(nf, nf),
                torch.nn.SiLU(),
            )]
            self.phi_inf += [torch.nn.Sequential(
                torch.nn.Linear(nf, 1),
                torch.nn.Sigmoid(),
            )]
            self.phi_x += [torch.nn.Sequential(
                torch.nn.Linear(nf*2+2, nf),
                torch.nn.SiLU(),
                torch.nn.Linear(nf, nf),
                torch.nn.SiLU(),
                torch.nn.Linear(nf, 1),
            )]
            self.phi_h += [torch.nn.Sequential(
                torch.nn.Linear(nf*2, nf),
                torch.nn.SiLU(),
                torch.nn.Linear(nf, nf), # /!\ add h_i after !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            )]
        
    
    def forward(self, mols, t):
        (x, e, q) = mols 
        h = torch.cat([e, q.unsqueeze(2), torch.ones_like(q).unsqueeze(2)*t/self.T], dim=-1)
        x0 = x.clone()
        N = h.shape[1]
        diff = x[:, :, None, :] - x[:, None, :, :]   # (B, N, N, 3)
        a_ij = torch.sqrt(torch.sum(diff ** 2, dim=-1))[:, :, :, None]  # (B, N, N, 1)
        for l in range(self.L):
            diff = x[:, :, None, :] - x[:, None, :, :]   # (B, N, N, 3)
            d_ij = torch.sqrt(torch.sum(diff ** 2, dim=-1))[:, :, :, None]  # (B, N, N, 1) distance squared matrix

            # compute m_ij
            h_i = h[:, :, None, :].expand(-1, N, N, -1)  # (B, N, N, d)
            h_j = h[:, None, :, :].expand(-1, N, N, -1)  # (B, N, N, d)

            features = torch.cat([h_i, h_j, d_ij, a_ij], dim=-1)  # (B, N, N, 2d+2)
            m_ij = self.phi_e[l](features)  # (B, N, N, d)

            # compute e_ij
            e_ij = self.phi_inf[l](m_ij) # (B, N, N, 1)

            # update x
            weights = self.phi_x[l](features) * diff / (d_ij + 1.0)  # (B, N, N, 3)
            x = x + weights.sum(dim=2)                          # (B, N, 3)

            # update h
            mask = ~torch.eye(N, dtype=bool)[None, :, :, None]  # (1, N, N, 1)
            agg = (e_ij * mask * m_ij).sum(dim=2)   # (B, N, d)
            h = h + self.phi_h[l](torch.cat([h, agg], dim=-1))   # (B, N, d)

    
        x = x - x0 # (B, N, 3)
        x = x - torch.mean(x, axis=1)[:, None, :] # (B, 1, 3)
        e, q = h[:, :, :self.num_class], h[:, :, self.num_class]
        return (x, e, q)





model = UltraMagaDiffusion()
batch_torch = one_hot_encode(batch)
out = model(batch_torch, 2)
for a in out:
    print(a.shape)


torch.Size([32, 21, 3])
torch.Size([32, 21, 5])
torch.Size([32, 21])
