Skip to content

Commit ee1488e

Browse files
authored
Update lightning_template model to save hparams. (Lightning-AI#2665)
1 parent a5538af commit ee1488e

File tree

1 file changed

+28
-38
lines changed

1 file changed

+28
-38
lines changed

pl_examples/models/lightning_template.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,51 +24,42 @@ class LightningTemplateModel(LightningModule):
2424
2525
>>> # define simple Net for MNIST dataset
2626
>>> params = dict(
27-
... drop_prob=0.2,
28-
... batch_size=2,
2927
... in_features=28 * 28,
28+
... hidden_dim=1000,
29+
... out_features=10,
30+
... drop_prob=0.2,
3031
... learning_rate=0.001 * 8,
31-
... optimizer_name='adam',
32+
... batch_size=2,
3233
... data_root='./datasets',
33-
... out_features=10,
3434
... num_workers=4,
35-
... hidden_dim=1000,
3635
... )
3736
>>> model = LightningTemplateModel(**params)
3837
"""
3938

4039
def __init__(self,
41-
drop_prob: float = 0.2,
42-
batch_size: int = 2,
4340
in_features: int = 28 * 28,
41+
hidden_dim: int = 1000,
42+
out_features: int = 10,
43+
drop_prob: float = 0.2,
4444
learning_rate: float = 0.001 * 8,
45-
optimizer_name: str = 'adam',
45+
batch_size: int = 2,
4646
data_root: str = './datasets',
47-
out_features: int = 10,
4847
num_workers: int = 4,
49-
hidden_dim: int = 1000,
5048
**kwargs
5149
):
5250
# init superclass
5351
super().__init__()
52+
# save all variables in __init__ signature to self.hparams
53+
self.save_hyperparameters()
5454

55-
self.num_workers = num_workers
56-
self.drop_prob = drop_prob
57-
self.batch_size = batch_size
58-
self.in_features = in_features
59-
self.learning_rate = learning_rate
60-
self.optimizer_name = optimizer_name
61-
self.data_root = data_root
62-
self.out_features = out_features
63-
self.hidden_dim = hidden_dim
55+
self.c_d1 = nn.Linear(in_features=self.hparams.in_features,
56+
out_features=self.hparams.hidden_dim)
57+
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
58+
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)
6459

65-
self.c_d1 = nn.Linear(in_features=self.in_features,
66-
out_features=self.hidden_dim)
67-
self.c_d1_bn = nn.BatchNorm1d(self.hidden_dim)
68-
self.c_d1_drop = nn.Dropout(self.drop_prob)
60+
self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim,
6961

70-
self.c_d2 = nn.Linear(in_features=self.hidden_dim,
71-
out_features=self.out_features)
62+
out_features=self.hparams.out_features)
7263

7364
self.example_input_array = torch.zeros(2, 1, 28, 28)
7465

@@ -140,27 +131,27 @@ def configure_optimizers(self):
140131
Return whatever optimizers and learning rate schedulers you want here.
141132
At least one optimizer is required.
142133
"""
143-
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
134+
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
144135
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
145136
return [optimizer], [scheduler]
146137

147138
def prepare_data(self):
148-
MNIST(self.data_root, train=True, download=True, transform=transforms.ToTensor())
149-
MNIST(self.data_root, train=False, download=True, transform=transforms.ToTensor())
139+
MNIST(self.hparams.data_root, train=True, download=True, transform=transforms.ToTensor())
140+
MNIST(self.hparams.data_root, train=False, download=True, transform=transforms.ToTensor())
150141

151142
def setup(self, stage):
152143
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
153-
self.mnist_train = MNIST(self.data_root, train=True, download=False, transform=transform)
154-
self.mnist_test = MNIST(self.data_root, train=False, download=False, transform=transform)
144+
self.mnist_train = MNIST(self.hparams.data_root, train=True, download=False, transform=transform)
145+
self.mnist_test = MNIST(self.hparams.data_root, train=False, download=False, transform=transform)
155146

156147
def train_dataloader(self):
157-
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)
148+
return DataLoader(self.mnist_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
158149

159150
def val_dataloader(self):
160-
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)
151+
return DataLoader(self.mnist_test, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
161152

162153
def test_dataloader(self):
163-
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)
154+
return DataLoader(self.mnist_test, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
164155

165156
@staticmethod
166157
def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover
@@ -174,18 +165,17 @@ def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover
174165

175166
# network params
176167
parser.add_argument('--in_features', default=28 * 28, type=int)
177-
parser.add_argument('--out_features', default=10, type=int)
178-
# use 500 for CPU, 50000 for GPU to see speed difference
179168
parser.add_argument('--hidden_dim', default=50000, type=int)
169+
# use 500 for CPU, 50000 for GPU to see speed difference
170+
parser.add_argument('--out_features', default=10, type=int)
180171
parser.add_argument('--drop_prob', default=0.2, type=float)
181-
parser.add_argument('--learning_rate', default=0.001, type=float)
182-
parser.add_argument('--num_workers', default=4, type=int)
183172

184173
# data
185174
parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)
175+
parser.add_argument('--num_workers', default=4, type=int)
186176

187177
# training params (opt)
188178
parser.add_argument('--epochs', default=20, type=int)
189-
parser.add_argument('--optimizer_name', default='adam', type=str)
190179
parser.add_argument('--batch_size', default=64, type=int)
180+
parser.add_argument('--learning_rate', default=0.001, type=float)
191181
return parser

0 commit comments

Comments
 (0)