# Train the Neural Mesh Simplification model

## Set up the environment

### (optional) Check out the repo
!git clone https://github.com/gennarinoos/neural-mesh-simplification

If you are running this notebook remotely (e.g. Google Colab), you'll want to copy the `requirements.txt` file from the repo with
```
!git clone https://github.com/gennarinoos/neural-mesh-simplification.git neural-mesh-simplification
%cd neural-mesh-simplification
```

If are opening this notebook locally (by running `jupyter lab` from the repo root), the above step is not required.

At this point you can install the requirements via pip

In [None]:
pip install -r requirements.txt

And then install the the source code for mesh simplification

In [None]:
!pip install -e .

---
## Training data
You can use Hugging Face API to download some mesh data to use for training and evaluation.

In [None]:
import os
import shutil
from huggingface_hub import snapshot_download

target_folder = "data/raw"
wip_folder = os.path.join(target_folder, "wip")
os.makedirs(wip_folder, exist_ok=True)

# abc_train is really large (+5k meshes), so download just a sample
folder_patterns = ["abc_extra_noisy/03_meshes/*.ply", "abc_train/03_meshes/*.ply"]

# Download
snapshot_download(
    repo_id="perler/ppsurf",
    repo_type="dataset",
    cache_dir=wip_folder,
    allow_patterns=folder_patterns[0],
)

# Move files from wip folder to target folder
for root, _, files in os.walk(wip_folder):
    for file in files:
        if file.endswith(".ply"):
            src_file = os.path.join(root, file)
            dest_file = os.path.join(target_folder, file)
            shutil.copy2(src_file, dest_file)
            os.remove(src_file)

# Remove the wip folder
shutil.rmtree(wip_folder)

---
## Model Training

### Load the default config
If you are running this notebook remotely (e.g. Google Colab), you'll need to copy the `default.yaml` file before running this command.
If are opening this notebook locally (by running `jupyter lab` from the repo root), this step is not required.

In [None]:
def load_config(config_path):
    import yaml
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config

config = load_config("configs/default.yaml")

### Set up the logging level

In [None]:
import logging
logging.basicConfig(level=logging.INFO)

### Start the training

In [None]:
from neural_mesh_simplification.trainer.trainer import Trainer

trainer = Trainer(config)

try:
    trainer.train()
except Exception as e:
    trainer.handle_error(e)
    trainer.save_training_state(os.path.join(config["training"]["checkpoint_dir"], "training_state.pth"))