This repository contains the codes to reproduce the experiments for the submission Learning to Solve Constraint Satisfaction Problems with Recurrent Transformers.
Lab Page
Please download the following dataset files from the given link and put to the given destination.
filename | description | from | to |
---|---|---|---|
palm_i2t_train.csv | RRN-V (train) | https://drive.google.com/file/d/1SCBkX_c2Xaxjvkx0P481G3-SnUGMZX_L/view?usp=sharing | data/visual_sudoku/palm_i2t_train.csv |
features_img.pt | SATNet-V (input) | https://github.com/locuslab/SATNet#getting-the-datasets | data/satnet/features_img.pt |
features.pt | SATNet (input) | same as above | data/satnet/features.pt |
labels.pt | SATNet and SATNet-V (label) | same as above | data/satnet/labels.pt |
perm.pt | cell permutation for SATNet and SATNet-V | same as above | data/satnet/perm.pt |
train.csv | RRN (train) | https://www.dropbox.com/s/rp3hbjs91xiqdgc/sudoku-hard.zip?dl=1 | data/sudoku-hard/train.csv |
valid.csv | RRN (valid) | same as above | data/sudoku-hard/valid.csv |
test.csv | RRN (test) | same as above | data/sudoku-hard/test.csv |
- Install Anaconda according to its installation page.
- Create a new environment using the following commands in terminal.
conda create -n rt python=3.7
conda activate rt
- Install tqdm, Numpy, Pandas, matplotlib, and wandb
conda install -c anaconda tqdm numpy pandas
conda install -c conda-forge matplotlib
python3 -m pip install wandb
wandb login
- Install Pytorch according to its Get-Started page. Below is an example command we used on Linux with cuda 10.2.
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
or for CPU only
conda install pytorch torchvision torchaudio cpuonly -c pytorch
One can download the visual Sudoku training dataset palm_i2t_train.csv
from the above provided link or generate with the following commands.
cd data
python visual_sudoku_data.py --problem_type i2t
Note that it will take a long time for training since it assumes a single GPU and runs every experiment for runs
(specified in line 4 of all.sh
) times with different random seeds in {1, 2, ..., runs
}. One can comment out most of the script and run one experiment at a time with a different GPU index.
bash all.sh
- To run the baseline model L1R32H4 on 9k/1k train/test SATNet data on GPU 0.
cd sudoku
python main.py --all_layers --n_layer 1 --n_recur 32 --n_head 4 --epochs 200 --eval_interval 1 --lr 0.001 --dataset satnet --gpu 0
- To apply the constraint losses L_sudoku
c1
and L_attentionatt_c1
with default weights1
and0.1
to the above baseline on GPU 1.
cd sudoku
python main.py --all_layers --n_layer 1 --n_recur 32 --n_head 4 --epochs 200 --eval_interval 1 --lr 0.001 --dataset satnet --gpu 1 --loss c1 att_c1 --hyper 1 0.1
- One can also test on Palm dataset by specifying
--dataset palm
and/or use--n_train 180000
to change the number of training data from 9k (default) to 180k. - One can always specify
--wandb
in the command to visualize the results in wandb. This also applies to all experiments below.
- To run the baseline model L1R32H4 on 9k/1k train/test SATNet-V data on GPU 2.
cd visual_sudoku
python main.py --all_layers --n_layer 1 --n_recur 32 --n_head 4 --epochs 500 --eval_interval 1 --lr 0.001 --dataset satnet --gpu 2
- To apply the constraint losses L_sudoku
c1
and L_attentionatt_c1
with default weights1
and0.1
to the above baseline on GPU 3.
cd visual_sudoku
python main.py --all_layers --n_layer 1 --n_recur 32 --n_head 4 --epochs 500 --eval_interval 1 --lr 0.001 --dataset satnet --gpu 3 --loss c1 att_c1 --hyper 1 0.1
cd sudoku_16
python main.py --dataset easy
python main.py --dataset medium
cd shortest_path
python main.py --gpu 0 --grid_size 4
python main.py --gpu 1 --grid_size 4 --loss path
python main.py --gpu 2 --grid_size 12
python main.py --gpu 3 --grid_size 12 --loss path
cd MNIST_mapping
python main.py
cd nonogram
python main.py --game_size 7 --gpu 0
python main.py --game_size 15 --gpu 1
The GPT implementation is from Andrej Karpathy's minGPT repo. Note that we replaced the causal self-attention in GPT model with typical self-attention by setting causal_mask=False
whenever it is used. In this way, logical variable X_i is able to pay attention to another logical variable X_j when j>i.
Please cite our paper as:
@inproceedings{
yang2023learning,
title={Learning to Solve Constraint Satisfaction Problems with Recurrent Transformer},
author={Zhun Yang and Adam Ishay and Joohyung Lee},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=udNhDCr2KQe}
}