# Data Augmentation for Slab Structures

This notebook augments existing CSV data by flipping asymmetric slabs (those with `sym_vac=False`). For such slabs:
1. Mirror the structure across the a/b plane
2. Interchange the top and bottom workfunctions
3. Mark the new entries as "flipped"

This augmentation doubles the data for asymmetric slabs, which may help improve model training.

In [None]:
import pandas as pd
import numpy as np
from pymatgen.core import Structure
import json
import os
from tqdm.notebook import tqdm

# Set display options
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 20)

## Define the Mirror Function

This function mirrors a slab structure across the a/b plane (flipping it in the z-direction).

In [None]:
def mirror_slab(struc: Structure) -> Structure:
    """Mirror input structure across a/b plane
    
    Args:
        struc: Input pymatgen Structure object
        
    Returns:
        Mirrored structure
    """
    structure: Structure = struc.copy()
    species = structure.species_and_occu
    frac_coords = structure.frac_coords
    for f in frac_coords:
        f[2] = 1 - f[2]
    return Structure(structure.lattice, species, frac_coords)

## Load Data

First, we'll load the CSV data containing the slab structures and their properties.

In [None]:
# Set path to the CSV file
csv_path = "../test_data/surface_prop_data_set_top_bottom.csv"

# Load data
df = pd.read_csv(csv_path)
print(f"Loaded {len(df)} entries from {csv_path}")
print(f"Columns: {df.columns.tolist()}")

# Check if the dataframe already contains a 'flipped' column
has_flipped_column = 'flipped' in df.columns
print(f"Data already has 'flipped' column: {has_flipped_column}")

# Display sample data
df.head(2)

## Analyze Symmetric vs. Asymmetric Slabs

Let's check how many slabs are symmetric (sym_vac=True) vs. asymmetric (sym_vac=False).

In [None]:
# Check which columns exist in the dataframe
required_columns = ['slab', 'sym_vac', 'WF_top', 'WF_bottom']
missing_columns = [col for col in required_columns if col not in df.columns]

if missing_columns:
    raise ValueError(f"Missing required columns: {missing_columns}")

# Count symmetric vs. asymmetric slabs
sym_count = df['sym_vac'].sum()
asym_count = len(df) - sym_count

print(f"Symmetric slabs (sym_vac=True): {sym_count}")
print(f"Asymmetric slabs (sym_vac=False): {asym_count}")
print(f"Total slabs: {len(df)}")

## Augment Data by Flipping Asymmetric Slabs

Now we'll create new entries for all asymmetric slabs by:
1. Flipping the structure
2. Swapping top and bottom workfunctions
3. Marking them as "flipped"

In [None]:
# Create a copy of the original dataframe
df_original = df.copy()

# If the flipped column already exists, we need to handle it differently
if not has_flipped_column:
    # Add flipped column to original data (empty for now)
    df_original['flipped'] = ""
    
# Filter asymmetric slabs
asym_slabs = df_original[df_original['sym_vac'] == False].copy()
print(f"Found {len(asym_slabs)} asymmetric slabs to flip")

# Create flipped versions
flipped_rows = []

for _, row in tqdm(asym_slabs.iterrows(), total=len(asym_slabs)):
    # Create a copy of the row
    flipped_row = row.copy()
    
    # Parse the structure from JSON string
    structure_dict = json.loads(row['slab'])
    structure = Structure.from_dict(structure_dict)
    
    # Flip the structure
    flipped_structure = mirror_slab(structure)
    
    # Update the flipped row
    flipped_row['slab'] = json.dumps(flipped_structure.as_dict())
    
    # Swap top and bottom workfunctions
    flipped_row['WF_top'], flipped_row['WF_bottom'] = row['WF_bottom'], row['WF_top']
    
    # Mark as flipped
    flipped_row['flipped'] = "flipped"
    
    # Update jid if needed
    if 'jid' in row and row['jid']:
        if has_flipped_column and row['flipped'] == "flipped":
            # This was already a flipped structure, keep the existing jid
            pass
        else:
            # Create a new jid by appending "_flipped"
            flipped_row['jid'] = f"{row['jid']}_flipped"
    
    flipped_rows.append(flipped_row)

# Create a dataframe from flipped rows
df_flipped = pd.DataFrame(flipped_rows)
print(f"Created {len(df_flipped)} flipped entries")

## Combine Original and Flipped Data

In [None]:
# Combine original and flipped data
df_combined = pd.concat([df_original, df_flipped], ignore_index=True)
print(f"Combined dataset has {len(df_combined)} entries")

# Verify the counts
normal_count = (df_combined['flipped'] == "").sum()
flipped_count = (df_combined['flipped'] == "flipped").sum()
print(f"Normal entries: {normal_count}")
print(f"Flipped entries: {flipped_count}")

# Display sample of combined data
df_combined.sample(5)

## Save Augmented Dataset

In [None]:
# Create output directory if it doesn't exist
output_dir = "../processed_data"
os.makedirs(output_dir, exist_ok=True)

# Save to CSV
output_path = os.path.join(output_dir, "augmented_DFT_data.csv")
df_combined.to_csv(output_path, index=False)
print(f"Saved augmented dataset to {output_path}")

## Verify a Sample Pair

Let's visualize one original-flipped pair to verify our augmentation worked correctly.

In [None]:
# Find a sample original-flipped pair (using the first asymmetric slab)
sample_original = df_original[df_original['sym_vac'] == False].iloc[0]
if 'jid' in sample_original and sample_original['jid']:
    sample_jid = sample_original['jid']
    sample_flipped = df_flipped[df_flipped['jid'] == f"{sample_jid}_flipped"].iloc[0]
else:
    # If no jid, use the index position
    sample_idx = df_original[df_original['sym_vac'] == False].index[0]
    sample_flipped = df_flipped.iloc[0]

# Print comparison
print("=== Original Slab ===")
print(f"WF_top: {sample_original['WF_top']:.4f}, WF_bottom: {sample_original['WF_bottom']:.4f}")

print("\n=== Flipped Slab ===")
print(f"WF_top: {sample_flipped['WF_top']:.4f}, WF_bottom: {sample_flipped['WF_bottom']:.4f}")

# Verify that top and bottom are swapped
if abs(sample_original['WF_top'] - sample_flipped['WF_bottom']) < 1e-6 and \
   abs(sample_original['WF_bottom'] - sample_flipped['WF_top']) < 1e-6:
    print("\n✅ Verification successful: WF_top and WF_bottom were swapped correctly")
else:
    print("\n❌ Verification failed: WF_top and WF_bottom were not swapped correctly")

## Compatibility with FAENet DataLoader

The existing FAENet data loader is fully compatible with this augmented dataset. Here's how it works:

1. The `EnhancedSlabDataset` class loads the CSV file with all columns, including the newly added "flipped" column.
2. During training, only columns listed in `target_properties` (like "WF_top" and "WF_bottom") are used as model inputs/outputs.
3. Other columns like "flipped" are ignored during model training but remain in the dataset.
4. The "flipped" column serves as metadata indicating which entries are original vs. augmented.

To use this augmented dataset for training, simply point the data_path to the new CSV file:

```bash
python -m faenet.train --data_path=./processed_data/augmented_DFT_data.csv --structure_col=slab --target_properties=[WF_top,WF_bottom] --frame_averaging=3D
```

## Conclusion

The data augmentation is complete! We've:
1. Identified all asymmetric slabs (sym_vac=False)
2. Created flipped versions of these slabs
3. Swapped their top and bottom workfunctions
4. Marked them as "flipped" in a new column
5. Saved the augmented dataset

This augmentation should help improve training by providing more data and enhancing the model's ability to learn the relationship between structure and properties.