-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_synthetic_equiv.py
97 lines (77 loc) · 3.11 KB
/
main_synthetic_equiv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer, LightningDataModule
from reynet.datasets.synthetic import SyntheticDataset, EquivariantDataset
from reynet.models.synthetic_equiv import ReyNetModel, MaronModel, MLPModel
from reynet.utils import args_print
class EquivDataModule(LightningDataModule):
def __init__(self,
dataset: str,
set_size: int,
num_data: int,
batch_size: int = 32,
seed: int = 1234):
super().__init__()
self.save_hyperparameters()
def train_dataloader(self):
trainset: SyntheticDataset = EquivariantDataset(
self.hparams.dataset,
self.hparams.set_size,
train=True,
seed=self.hparams.seed,
size=self.hparams.num_data)
trainloader = DataLoader(trainset,
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=5)
return trainloader
def val_dataloader(self):
testset: SyntheticDataset = EquivariantDataset(
self.hparams.dataset,
self.hparams.set_size,
train=False,
seed=self.hparams.seed + 10)
testloader = DataLoader(testset,
batch_size=self.hparams.batch_size,
shuffle=False,
num_workers=5)
return testloader
def test_dataloader(self):
testset: SyntheticDataset = EquivariantDataset(
self.hparams.dataset,
self.hparams.set_size,
train=False,
seed=self.hparams.seed + 10)
testloader = DataLoader(testset,
batch_size=self.hparams.batch_size,
shuffle=False,
num_workers=5)
return testloader
def main(args):
dm = EquivDataModule(args.dataset, args.set_size, args.num_data,
args.batch_size, args.seed)
model = Model(args)
print(model)
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm, ckpt_path="best")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='reynet')
temp_args, _ = parser.parse_known_args()
if temp_args.model == 'reynet':
Model = ReyNetModel
elif temp_args.model == 'maron':
Model = MaronModel
elif temp_args.model == 'mlp':
Model = MLPModel
parser.add_argument('--dataset', '-D', type=str, default='symmetry')
parser.add_argument('--num-data', '-N', type=int, default=10000)
parser.add_argument('--set-size', '-S', type=int, default=4)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--batch-size', type=int, default=100)
parser = Model.add_model_specific_args(parser)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
args_print(args)
main(args)