@@ -24,51 +24,42 @@ class LightningTemplateModel(LightningModule):
2424
2525 >>> # define simple Net for MNIST dataset
2626 >>> params = dict(
27- ... drop_prob=0.2,
28- ... batch_size=2,
2927 ... in_features=28 * 28,
28+ ... hidden_dim=1000,
29+ ... out_features=10,
30+ ... drop_prob=0.2,
3031 ... learning_rate=0.001 * 8,
31- ... optimizer_name='adam' ,
32+ ... batch_size=2 ,
3233 ... data_root='./datasets',
33- ... out_features=10,
3434 ... num_workers=4,
35- ... hidden_dim=1000,
3635 ... )
3736 >>> model = LightningTemplateModel(**params)
3837 """
3938
4039 def __init__ (self ,
41- drop_prob : float = 0.2 ,
42- batch_size : int = 2 ,
4340 in_features : int = 28 * 28 ,
41+ hidden_dim : int = 1000 ,
42+ out_features : int = 10 ,
43+ drop_prob : float = 0.2 ,
4444 learning_rate : float = 0.001 * 8 ,
45- optimizer_name : str = 'adam' ,
45+ batch_size : int = 2 ,
4646 data_root : str = './datasets' ,
47- out_features : int = 10 ,
4847 num_workers : int = 4 ,
49- hidden_dim : int = 1000 ,
5048 ** kwargs
5149 ):
5250 # init superclass
5351 super ().__init__ ()
52+ # save all variables in __init__ signature to self.hparams
53+ self .save_hyperparameters ()
5454
55- self .num_workers = num_workers
56- self .drop_prob = drop_prob
57- self .batch_size = batch_size
58- self .in_features = in_features
59- self .learning_rate = learning_rate
60- self .optimizer_name = optimizer_name
61- self .data_root = data_root
62- self .out_features = out_features
63- self .hidden_dim = hidden_dim
55+ self .c_d1 = nn .Linear (in_features = self .hparams .in_features ,
56+ out_features = self .hparams .hidden_dim )
57+ self .c_d1_bn = nn .BatchNorm1d (self .hparams .hidden_dim )
58+ self .c_d1_drop = nn .Dropout (self .hparams .drop_prob )
6459
65- self .c_d1 = nn .Linear (in_features = self .in_features ,
66- out_features = self .hidden_dim )
67- self .c_d1_bn = nn .BatchNorm1d (self .hidden_dim )
68- self .c_d1_drop = nn .Dropout (self .drop_prob )
60+ self .c_d2 = nn .Linear (in_features = self .hparams .hidden_dim ,
6961
70- self .c_d2 = nn .Linear (in_features = self .hidden_dim ,
71- out_features = self .out_features )
62+ out_features = self .hparams .out_features )
7263
7364 self .example_input_array = torch .zeros (2 , 1 , 28 , 28 )
7465
@@ -140,27 +131,27 @@ def configure_optimizers(self):
140131 Return whatever optimizers and learning rate schedulers you want here.
141132 At least one optimizer is required.
142133 """
143- optimizer = optim .Adam (self .parameters (), lr = self .learning_rate )
134+ optimizer = optim .Adam (self .parameters (), lr = self .hparams . learning_rate )
144135 scheduler = optim .lr_scheduler .CosineAnnealingLR (optimizer , T_max = 10 )
145136 return [optimizer ], [scheduler ]
146137
147138 def prepare_data (self ):
148- MNIST (self .data_root , train = True , download = True , transform = transforms .ToTensor ())
149- MNIST (self .data_root , train = False , download = True , transform = transforms .ToTensor ())
139+ MNIST (self .hparams . data_root , train = True , download = True , transform = transforms .ToTensor ())
140+ MNIST (self .hparams . data_root , train = False , download = True , transform = transforms .ToTensor ())
150141
151142 def setup (self , stage ):
152143 transform = transforms .Compose ([transforms .ToTensor (), transforms .Normalize ((0.5 ,), (1.0 ,))])
153- self .mnist_train = MNIST (self .data_root , train = True , download = False , transform = transform )
154- self .mnist_test = MNIST (self .data_root , train = False , download = False , transform = transform )
144+ self .mnist_train = MNIST (self .hparams . data_root , train = True , download = False , transform = transform )
145+ self .mnist_test = MNIST (self .hparams . data_root , train = False , download = False , transform = transform )
155146
156147 def train_dataloader (self ):
157- return DataLoader (self .mnist_train , batch_size = self .batch_size , num_workers = self .num_workers )
148+ return DataLoader (self .mnist_train , batch_size = self .hparams . batch_size , num_workers = self . hparams .num_workers )
158149
159150 def val_dataloader (self ):
160- return DataLoader (self .mnist_test , batch_size = self .batch_size , num_workers = self .num_workers )
151+ return DataLoader (self .mnist_test , batch_size = self .hparams . batch_size , num_workers = self . hparams .num_workers )
161152
162153 def test_dataloader (self ):
163- return DataLoader (self .mnist_test , batch_size = self .batch_size , num_workers = self .num_workers )
154+ return DataLoader (self .mnist_test , batch_size = self .hparams . batch_size , num_workers = self . hparams .num_workers )
164155
165156 @staticmethod
166157 def add_model_specific_args (parent_parser , root_dir ): # pragma: no-cover
@@ -174,18 +165,17 @@ def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover
174165
175166 # network params
176167 parser .add_argument ('--in_features' , default = 28 * 28 , type = int )
177- parser .add_argument ('--out_features' , default = 10 , type = int )
178- # use 500 for CPU, 50000 for GPU to see speed difference
179168 parser .add_argument ('--hidden_dim' , default = 50000 , type = int )
169+ # use 500 for CPU, 50000 for GPU to see speed difference
170+ parser .add_argument ('--out_features' , default = 10 , type = int )
180171 parser .add_argument ('--drop_prob' , default = 0.2 , type = float )
181- parser .add_argument ('--learning_rate' , default = 0.001 , type = float )
182- parser .add_argument ('--num_workers' , default = 4 , type = int )
183172
184173 # data
185174 parser .add_argument ('--data_root' , default = os .path .join (root_dir , 'mnist' ), type = str )
175+ parser .add_argument ('--num_workers' , default = 4 , type = int )
186176
187177 # training params (opt)
188178 parser .add_argument ('--epochs' , default = 20 , type = int )
189- parser .add_argument ('--optimizer_name' , default = 'adam' , type = str )
190179 parser .add_argument ('--batch_size' , default = 64 , type = int )
180+ parser .add_argument ('--learning_rate' , default = 0.001 , type = float )
191181 return parser
0 commit comments