In [None]:
import os
import sys
if os.path.abspath('../') not in sys.path:
    sys.path.append(os.path.abspath('../'))
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import trange, tqdm
from typing import Optional, Tuple

from lattices.lattices import Catalogue
from gnn.datasets import GLAMM_Dataset

# Assemble datasets for ML

Training datasets in the paper are:

| Dataset   | Base lattices | Imperfection levels (%) | \# imp. per level  | Total lattices    |
| --------- | ------------- | --------------------- | -- | ----------------- |
| 0 imp quarter | 1750         | 0                   | 1 | 1750                 |
| 0 imp half    | 3500         | 0                   | 1  | 3500                 |
| 0 imp     | 7000           | 0                    | 1 | 7000                 |
| 1 imp     | 7000            | 0, 2, 4, 7     | 1 | 27847                 |
| 2 imp     | 7000             | 0, 2, 4, 7      | 2 |          48681                 |
| 4 imp     | 7000             | 0, 2, 4, 7      | 4 |          90336                 |


Test dataset is 
| Dataset   | Base lattices | Imperfection levels (%)  | \# imp. per level  | Total lattices    |
| --------- | ------------- | --------------------- | -- | ----------------- |
| 0 imp     | 1296           | 10                    | 3 | 3888                 |

In [2]:
input_dir = Path('C:/temp/') # location where `.lat` catalogue files are stored
assert input_dir.exists(), f"Directory {input_dir} does not exist."
output_dir = Path('../datasets')
output_dir.mkdir(parents=True, exist_ok=True)
make_split = False # create train/val/test split now or load from file

In [3]:
def load_data(input_dir: Path, max_imp: int = 1, regex: Optional[str] = None) -> Tuple[dict, dict]:
    n_2_imp = lambda x: float(x.split('_')[4]) # extract imperfection from lattice name
    num_imp = {}
    selected_data = {}

    cat_files = list(input_dir.glob('*.lat'))
    print(f'Found {len(cat_files)} catalogue files')
    for f in tqdm(cat_files):
        cat = Catalogue.from_file(f, indexing=0, regex=regex)
        for data in cat:
            name = data['name']
            base_name = Catalogue.n_2_bn(name)
            imp = n_2_imp(name)
            if base_name not in num_imp:
                num_imp[base_name] = {}
            if imp not in num_imp[base_name]:
                num_imp[base_name][imp] = 0
            else:
                if num_imp[base_name][imp] >= max_imp:
                    continue

            selected_data[name] = data

            num_imp[base_name][imp] += 1
    
    return selected_data, num_imp

def load_names(input_dir: Path) -> set:
    cat_files = list(input_dir.glob('*.lat'))
    print(f'Found {len(cat_files)} catalogue files')
    names = set()
    for f in tqdm(cat_files):
        names = names.union(set(Catalogue.get_names(f)))
    return names

## Train/val/test splits

In [4]:
if make_split:
    all_names = load_names(input_dir)
    base_names = list(set([Catalogue.n_2_bn(name) for name in all_names]))
    # split base names into train and validation. Take 7000 for train and rest for validation
    np.random.shuffle(base_names)
    train_base_names = set(base_names[:7000])
    val_base_names = set(base_names[7000:])
    # sort by code and save to file
    train_code_map = {name.split('_')[2]: Catalogue.n_2_bn(name) for name in train_base_names}
    val_code_map = {name.split('_')[2]: Catalogue.n_2_bn(name) for name in val_base_names}

    Path(output_dir).mkdir(parents=True, exist_ok=True)
    with open(output_dir/'train_base_names.txt', 'w') as f:
        sorted_names = sorted(train_code_map.keys(), key=lambda x: int(x[1:]))
        f.write('\n'.join([train_code_map[name] for name in sorted_names]))
    with open(output_dir/'val_base_names.txt', 'w') as f:
        sorted_names = sorted(val_code_map.keys(), key=lambda x: int(x[1:]))
        f.write('\n'.join([val_code_map[name] for name in sorted_names]))
else:
    # load train and validation names from files
    train_base_names = set(pd.read_csv('./train_base_names.txt', header=None)[0].values)
    val_base_names = set(pd.read_csv('./val_base_names.txt', header=None)[0].values)

# 0 imp (imperfection level 0)

In [5]:
selected_data, num_imp = load_data(input_dir, max_imp=1, regex='.*p_0.0_.*')

print(f'Number of base names: {len(num_imp)}')
print(len(selected_data))

training_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in train_base_names}
validation_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in val_base_names}
training_cat = Catalogue.from_dict(training_dict)
validation_cat = Catalogue.from_dict(validation_dict)
print(training_cat)
print(validation_cat)
Path(output_dir/'0imp/raw').mkdir(parents=True, exist_ok=True)
training_cat.to_file(output_dir/'0imp/raw/training_cat.lat')
validation_cat.to_file(output_dir/'0imp/raw/validation_cat.lat')

Found 5 catalogue files


100%|██████████| 5/5 [01:15<00:00, 15.07s/it]


Number of base names: 8954
8954
Unit cell catalogue with 7000 entries
Unit cell catalogue with 1296 entries


Dataset can be converted to pytorch now or later when it is loaded from `.lat` file in the `raw` folder

In [6]:
# process the data and save as pt files
train_dset = GLAMM_Dataset(
    output_dir/'0imp', './training_cat.lat', 'train.pt', graph_ft_format='cartesian_4', n_reldens=3
)
val_dset = GLAMM_Dataset(
    output_dir/'0imp', './validation_cat.lat', 'valid.pt', graph_ft_format='cartesian_4', n_reldens=3
)

# 0 imp half (imperfection level 0)

In [7]:
Path(output_dir/'0imp_half/raw').mkdir(parents=True, exist_ok=True)
selected_data, num_imp = load_data(input_dir, max_imp=1, regex='.*p_0.0_.*')
train_base_names = list(train_base_names)
np.random.shuffle(train_base_names)
# select 3500
train_base_names = set(train_base_names[:3500])
assert len(train_base_names.intersection(val_base_names)) == 0, 'train and val base names overlap'

train_code_map = {name.split('_')[2]: Catalogue.n_2_bn(name) for name in train_base_names}
val_code_map = {name.split('_')[2]: Catalogue.n_2_bn(name) for name in val_base_names}

with open(output_dir/'0imp_half/raw/train_base_names.txt', 'w') as f:
    sorted_names = sorted(train_code_map.keys(), key=lambda x: int(x[1:]))
    f.write('\n'.join([train_code_map[name] for name in sorted_names]))
with open(output_dir/'0imp_half/raw/val_base_names.txt', 'w') as f:
    sorted_names = sorted(val_code_map.keys(), key=lambda x: int(x[1:]))
    f.write('\n'.join([val_code_map[name] for name in sorted_names]))

training_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in train_base_names}
validation_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in val_base_names}
training_cat = Catalogue.from_dict(training_dict)
validation_cat = Catalogue.from_dict(validation_dict)
training_cat.to_file(output_dir/'0imp_half/raw/training_cat.lat')
validation_cat.to_file(output_dir/'0imp_half/raw/validation_cat.lat')

Found 5 catalogue files


100%|██████████| 5/5 [01:13<00:00, 14.70s/it]


# 0 imp quarter (imperfection level 0)

In [8]:
Path(output_dir/'0imp_quarter/raw').mkdir(parents=True, exist_ok=True)
selected_data, num_imp = load_data(input_dir, max_imp=1, regex='.*p_0.0_.*')
train_base_names = list(train_base_names)
np.random.shuffle(train_base_names)
# select 1750
train_base_names = set(train_base_names[:1750])
assert len(train_base_names.intersection(val_base_names)) == 0, 'train and val base names overlap'

train_code_map = {name.split('_')[2]: Catalogue.n_2_bn(name) for name in train_base_names}
val_code_map = {name.split('_')[2]: Catalogue.n_2_bn(name) for name in val_base_names}

with open(output_dir/'0imp_quarter/raw/train_base_names.txt', 'w') as f:
    sorted_names = sorted(train_code_map.keys(), key=lambda x: int(x[1:]))
    f.write('\n'.join([train_code_map[name] for name in sorted_names]))
with open(output_dir/'0imp_quarter/raw/val_base_names.txt', 'w') as f:
    sorted_names = sorted(val_code_map.keys(), key=lambda x: int(x[1:]))
    f.write('\n'.join([val_code_map[name] for name in sorted_names]))
    
training_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in train_base_names}
validation_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in val_base_names}
training_cat = Catalogue.from_dict(training_dict)
validation_cat = Catalogue.from_dict(validation_dict)
print(training_cat)
print(validation_cat)
training_cat.to_file(output_dir/'0imp_quarter/raw/training_cat.lat')
validation_cat.to_file(output_dir/'0imp_quarter/raw/validation_cat.lat')


Found 5 catalogue files


100%|██████████| 5/5 [01:13<00:00, 14.70s/it]


Unit cell catalogue with 1750 entries
Unit cell catalogue with 1296 entries


# 1 imp (imperfection levels 0,2,4,7)

In [9]:
selected_data, num_imp = load_data(input_dir, max_imp=1, regex='.*p_0.0[247]?_.*')

print(f'Number of base names: {len(num_imp)}')
print(len(selected_data))

training_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in train_base_names}
validation_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in val_base_names}
training_cat = Catalogue.from_dict(training_dict)
validation_cat = Catalogue.from_dict(validation_dict)
print(training_cat)
print(validation_cat)
Path(output_dir/'1imp/raw').mkdir(parents=True, exist_ok=True)
training_cat.to_file(output_dir/'1imp/raw/training_cat.lat')
validation_cat.to_file(output_dir/'1imp/raw/validation_cat.lat')

Found 5 catalogue files


100%|██████████| 5/5 [04:00<00:00, 48.00s/it]


Number of base names: 8954
35804
Unit cell catalogue with 6994 entries
Unit cell catalogue with 5184 entries


# 2 imp (imperfection levels 0,2,4,7)

In [10]:
selected_data, num_imp = load_data(input_dir, max_imp=2, regex='.*p_0.0[247]?_.*')
        
print(f'Number of base names: {len(num_imp)}')
print(len(selected_data))

training_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in train_base_names}
validation_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in val_base_names}
training_cat = Catalogue.from_dict(training_dict)
validation_cat = Catalogue.from_dict(validation_dict)
print(training_cat)
print(validation_cat)
Path(output_dir/'2imp/raw').mkdir(parents=True, exist_ok=True)
training_cat.to_file(output_dir/'2imp/raw/training_cat.lat')
validation_cat.to_file(output_dir/'2imp/raw/validation_cat.lat')

Found 5 catalogue files


100%|██████████| 5/5 [03:38<00:00, 43.78s/it]


Number of base names: 8954
62654
Unit cell catalogue with 12238 entries
Unit cell catalogue with 9072 entries


# 4 imp (imperfection levels 0,2,4,7)

In [11]:
selected_data, num_imp = load_data(input_dir, max_imp=4, regex='.*p_0.0[247]?_.*')

print(f'Number of base names: {len(num_imp)}')
print(len(selected_data))

training_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in train_base_names}
validation_dict = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in val_base_names}
training_cat = Catalogue.from_dict(training_dict)
validation_cat = Catalogue.from_dict(validation_dict)
print(training_cat)
print(validation_cat)
Path(output_dir/'4imp/raw').mkdir(parents=True, exist_ok=True)
training_cat.to_file(output_dir/'4imp/raw/training_cat.lat')
validation_cat.to_file(output_dir/'4imp/raw/validation_cat.lat')

Found 5 catalogue files


100%|██████████| 5/5 [04:10<00:00, 50.07s/it]


Number of base names: 8954
116354
Unit cell catalogue with 22726 entries
Unit cell catalogue with 16848 entries


# Test dataset

In [12]:
# test data is imperfections of the validation data
selected_data, num_imp = load_data(input_dir, max_imp=3, regex='.*p_0.1_.*')
selected_test_data = {name:selected_data[name] for name in selected_data if Catalogue.n_2_bn(name) in val_base_names}
        
print(f'Number of base names: {len(num_imp)}')
test_cat = Catalogue.from_dict(selected_test_data)
print(test_cat)
test_cat.to_file(output_dir/'0imp/raw/test_cat.lat')

Found 5 catalogue files


100%|██████████| 5/5 [02:25<00:00, 29.16s/it]


Number of base names: 8950
Unit cell catalogue with 3888 entries
