In [2]:
import torch
import pytorch_lightning as pl

from replicator.models.replicator import ReplicatorGPT
from replicator.datasets.torchtext_wikitext2 import WikiText2DataModule


In [3]:
batch_size = 16
seq_len = 64

data = WikiText2DataModule(batch_size=batch_size, seq_len=seq_len)
data.setup()
data_loader = data.train_dataloader()

36718
3760


In [4]:
inputs, targets, masks = next(iter(data_loader))
inputs.shape

torch.Size([16, 64])

In [5]:
vocab_size = data.train_data.vocab_size()
vocab_size

28783

In [6]:
embedding_size = 128
blocks_num = 4
model = ReplicatorGPT(blocks_num=blocks_num,max_sentence_len=seq_len,
                    vocab_size=vocab_size,embedding_size=embedding_size)

In [7]:
trainer = pl.Trainer(gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [8]:
lr_finder = trainer.tuner.lr_find(model, datamodule=data)

  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                  | Type                 | Params
---------------------------------------------------------------
0 | replicator_blocks     | Sequential           | 81.9 K
1 | stochastic_projection | StochasticProjection | 3.7 M 
2 | embedding             | Embedding            | 3.7 M 
3 | softmax               | Softmax              | 0     
---------------------------------------------------------------
7.5 M     Trainable params
0         Non-trainable params
7.5 M     Total params
29.801    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
Finding best initial lr:  99%|█████████▉| 99/100 [00:06<00:00, 15.28it/s]Restoring states from the checkpoint file at d:\replicator\lr_find_temp_model.ckpt
Restored all states from the checkpoint file at d:\replicator\lr_find_temp_model.ckpt


In [9]:
lr_finder.results

{'lr': [1e-08,
  1.4454397707459274e-08,
  1.7378008287493753e-08,
  2.0892961308540398e-08,
  2.51188643150958e-08,
  3.019951720402016e-08,
  3.630780547701014e-08,
  4.36515832240166e-08,
  5.248074602497726e-08,
  6.309573444801934e-08,
  7.585775750291837e-08,
  9.120108393559096e-08,
  1.0964781961431852e-07,
  1.3182567385564074e-07,
  1.5848931924611133e-07,
  1.9054607179632475e-07,
  2.2908676527677735e-07,
  2.7542287033381663e-07,
  3.311311214825911e-07,
  3.9810717055349735e-07,
  4.786300923226383e-07,
  5.75439937337157e-07,
  6.918309709189366e-07,
  8.317637711026709e-07,
  1e-06,
  1.2022644346174132e-06,
  1.445439770745928e-06,
  1.7378008287493761e-06,
  2.089296130854039e-06,
  2.5118864315095797e-06,
  3.0199517204020163e-06,
  3.630780547701014e-06,
  4.365158322401661e-06,
  5.248074602497728e-06,
  6.3095734448019305e-06,
  7.585775750291836e-06,
  9.120108393559096e-06,
  1.0964781961431852e-05,
  1.3182567385564076e-05,
  1.584893192461114e-05,
  1.90546071

In [10]:
lr_finder.suggestion()

0.8317637711026709

In [None]:
fig = lr_finder.plot(suggest=True, show=True)
fig.show()

Error: Session cannot generate requests