# Dataloader

In [7]:
elements = {
    '*': 0, 
    'H': 1,   'He': 2,
    'Li': 3,  'Be': 4,  'B': 5,   'C': 6,   'N': 7,   'O': 8,   'F': 9,   'Ne': 10,
    'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15,  'S': 16,  'Cl': 17, 'Ar': 18,
    'K': 19,  'Ca': 20, 'Sc': 21, 'Ti': 22, 'V': 23,  'Cr': 24, 'Mn': 25, 'Fe': 26, 'Co': 27, 'Ni': 28,'Cu': 29, 'Zn': 30, 'Ga': 31, 'Ge': 32, 'As': 33, 'Se': 34, 'Br': 35, 'Kr': 36,
    'Rb': 37, 'Sr': 38, 'Y': 39,  'Zr': 40, 'Nb': 41, 'Mo': 42, 'Tc': 43, 'Ru': 44, 'Rh': 45, 'Pd': 46,'Ag': 47, 'Cd': 48, 'In': 49, 'Sn': 50, 'Sb': 51, 'Te': 52, 'I': 53,  'Xe': 54,
    'Cs': 55, 'Ba': 56, 'La': 57, 'Ce': 58, 'Pr': 59, 'Nd': 60, 'Pm': 61, 'Sm': 62, 'Eu': 63, 'Gd': 64,'Tb': 65, 'Dy': 66, 'Ho': 67, 'Er': 68, 'Tm': 69, 'Yb': 70, 'Lu': 71,'Hf': 72, 'Ta': 73, 'W': 74,  'Re': 75, 'Os': 76, 'Ir': 77, 'Pt': 78, 'Au': 79, 'Hg': 80,'Tl': 81, 'Pb': 82, 'Bi': 83, 'Po': 84, 'At': 85, 'Rn': 86,
    'Fr': 87, 'Ra': 88, 'Ac': 89, 'Th': 90, 'Pa': 91, 'U': 92,  'Np': 93, 'Pu': 94, 'Am': 95, 'Cm': 96,'Bk': 97, 'Cf': 98, 'Es': 99, 'Fm': 100, 'Md': 101, 'No': 102, 'Lr': 103,'Rf': 104, 'Db': 105, 'Sg': 106, 'Bh': 107, 'Hs': 108, 'Mt': 109, 'Ds': 110, 'Rg': 111,'Cn': 112, 'Nh': 113, 'Fl': 114, 'Mc': 115, 'Lv': 116, 'Ts': 117, 'Og': 118
}

element_list = list(elements.keys())
element_to_idx = {el: idx for idx, el in enumerate(element_list)}
print(f"Elements: {element_list}")

Elements: ['*', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og']


In [8]:
def one_hot(element):
    vec = torch.zeros(len(element_list), dtype=torch.float32)
    idx = element_to_idx.get(element, element_to_idx['*']) 
    vec[idx] = 1.0
    return vec

In [9]:
import os
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import plotly
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_networkx
import networkx as nx
from networkx.algorithms import community


torch_version = torch.__version__.split("+")
os.environ["TORCH"] = torch_version[0]
os.environ["CUDA"] = torch_version[1]


In [10]:
def load_all_polymers(polymers_dir):
    polymer_graphs = []

    folders = [folder for folder in os.listdir(polymers_dir) if os.path.isdir(os.path.join(polymers_dir, folder))]
    folders = sorted(folders)  

    for folder in tqdm(folders):
        folder_path = os.path.join(polymers_dir, folder)

        adj_path = os.path.join(folder_path, 'adjacency.csv')
        atom_path = os.path.join(folder_path, 'atoms.csv')
        details_path = os.path.join(folder_path, 'details.csv')

        if not os.path.exists(adj_path) or not os.path.exists(atom_path):
            continue

        adj = pd.read_csv(adj_path).values
        n_nodes = adj.shape[0]

        atom_df = pd.read_csv(atom_path)
        atom_labels = dict(zip(atom_df['node'], atom_df['element']))

        x = torch.stack([
            one_hot(atom_labels.get(i, '*'))
            for i in range(n_nodes)
        ])

        src, dst = np.where(adj != 0)
        edge_index = torch.tensor([src, dst], dtype=torch.long)

        data = Data(x=x, edge_index=edge_index)
        data.id = folder  

        details_df = pd.read_csv(details_path)
        details = details_df.iloc[0].to_dict()
        for key, value in details.items():
            try:
                val = float(value)
            except:
                val = float('nan')
            data[key] = torch.tensor([val], dtype=torch.float32)

        polymer_graphs.append(data)

    return polymer_graphs


In [11]:
import numpy as np
import random
import torch
import os
from torch_geometric.data import Data

def split_polymer_data_by_attributes(polymers, output_dir, train_ratio=0.6, val_ratio=0.2, seed=42):

    random.seed(seed)
    torch.manual_seed(seed)
    
    BASE_KEYS = {'x', 'edge_index', 'id'}
    
    target_attrs = set()
    for p in polymers:
        for key in p.keys():
            if key in BASE_KEYS:
                continue
            val = p[key]
            if isinstance(val, torch.Tensor) and val.numel() == 1:
                target_attrs.add(key)
    
    target_attrs = sorted(list(target_attrs))
    print(f"Target attributes: {target_attrs}")
    
    for attr in target_attrs:
        filtered = []
        for p in polymers:
            if attr not in p:
                continue

            v = p[attr]
            if isinstance(v, torch.Tensor):
                val = v.item()
            else:
                val = float(v)

            if np.isnan(val):
                continue

            data = Data(
                x=p['x'],
                edge_index=p['edge_index'],
                y=torch.tensor([val], dtype=torch.float32),
                id=p['id']
            )
            filtered.append(data)

        if len(filtered) == 0:
            print(f"{attr}: No valid data, skipping")
            continue
            
        random.shuffle(filtered)

        n = len(filtered)
        n_train = int(n * train_ratio)
        n_val = int(n * val_ratio)
        n_test = n - n_train - n_val

        train_set = filtered[:n_train]
        validation_set = filtered[n_train:n_train + n_val]
        test_set = filtered[n_train + n_val:]

        base_path = os.path.join(output_dir, attr)
        os.makedirs(base_path, exist_ok=True)

        torch.save(train_set, os.path.join(base_path, 'train.pt'))
        torch.save(validation_set, os.path.join(base_path, 'val.pt'))
        torch.save(test_set, os.path.join(base_path, 'test.pt'))

        print(f"{attr} saved")
    
    print(f"\nDone")
    return target_attrs


## ESOL

In [6]:
polymers = load_all_polymers('../../kaggle/ESOL/')
split_polymer_data_by_attributes(polymers, output_dir='../data/ESOL/')

  edge_index = torch.tensor([src, dst], dtype=torch.long)
100%|██████████| 1127/1127 [00:07<00:00, 155.59it/s]


Target attributes: ['ESOL predicted log solubility in mols per litre']
ESOL predicted log solubility in mols per litre saved

Done


['ESOL predicted log solubility in mols per litre']

## Lipophilicity

In [41]:
polymers = load_all_polymers('../../kaggle/Lipophilicity/')
split_polymer_data_by_attributes(polymers, output_dir='../data/Lipophilicity/')

100%|██████████| 4157/4157 [00:29<00:00, 139.87it/s]


Target attributes: ['exp']
exp saved

Done


['exp']

## NeurIps

In [7]:
polymers = load_all_polymers('../../kaggle/NeurIPS/')
split_polymer_data_by_attributes(polymers, output_dir='../data/NeurIPS/')

  edge_index = torch.tensor([src, dst], dtype=torch.long)
100%|██████████| 7949/7949 [00:49<00:00, 162.15it/s]


Target attributes: ['Density', 'FFV', 'Rg', 'Tc', 'Tg']
Density saved
FFV saved
Rg saved
Tc saved
Tg saved

Done


['Density', 'FFV', 'Rg', 'Tc', 'Tg']