In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from TModel.TransformerModel import TranscriptionTransformerModel
from Tokenizer.loaderH5 import H5GuitarTokenizer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
from TUtils import random_string
import lightning.pytorch as pl
from lightning import Trainer
import torch
from TranscriptionDataset import TranscriptionDataset
from TModel.Retnet.TranscriptionModel import TranscriptionRetnetModel
torch.set_float32_matmul_precision('medium')
import importlib
if importlib.util.find_spec('deepspeed'):
    from lightning.pytorch.strategies import DeepSpeedStrategy
    import deepspeed
    from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam

[2023-09-29 15:05:57,218] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
datasetLocation = "Trainsets/S_Tier_1695619803_mTokens400_mNoS5.hdf5"
wandbProject = "TranscriptionModel_Test"
batchSize = 4
num_workers = 2

In [3]:
dataset,pipe = TranscriptionDataset.getDataPipe(
    datasetLocation,
    batchSize,
    batchFirst=True
)
train_pipe,test_pipe = pipe.random_split({"train":0.8,"test":0.2},42,total_length=len(dataset))

train_dl = DataLoader(dataset=train_pipe,batch_size=None,num_workers=num_workers)
test_dl = DataLoader(dataset=test_pipe, batch_size=batchSize,num_workers=num_workers)



In [4]:
model = TranscriptionRetnetModel(
    dataset.getVocabSize(),
    d_model=512,
    d_ff=2048,
    lr_init=1e-6
    # embeddingCheckpoint="Models/GuitarToken/Max2Length.ckpt"
)
try:
    torch.compile(model)
except Error:
    print("Could not compile model with jit")

In [5]:
wandb_logger = WandbLogger(project=wandbProject)
wandb_logger.experiment.config.update(dataset.meta_data)
wandb_logger.experiment.config["batchSize"] = batchSize

checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',
    dirpath=f'Models/GuitarTranscription/{random_string(10)}/',
    filename='GuitarTranscriptionModel-{epoch:02d}-{train_loss:.2f}',
    every_n_train_steps=1000,
    save_top_k=3,
    mode='min',
)

# strategy=DeepSpeedStrategy(offload_optimizer=True, allgather_bucket_size=5e8, reduce_bucket_size=5e8)

trainer = Trainer(
    default_root_dir='Models/',
    max_epochs=10,
    # profiler="simple",
    # strategy=strategy,
    # profiler="pytorch",
    logger=wandb_logger,
    callbacks=[checkpoint_callback],
    # max_time="00:00:05:00",
    precision="bf16-mixed",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdraguve[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model=model, train_dataloaders=train_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using /home/draguve/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/draguve/.cache/torch_extensions/py310_cu117/fused_adam/build.ninja...
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module fused_adam...

  | Name          | Type                | Params
------------------------------------------------------
0 | encoder       | RetnetEncoderLayers | 20.5 M
1 | decoder       | RetnetDecoderLayers | 28.4 M
2 | tgt_embedding | Embedding           | 3.1 M 
3 | outputLinear  | Linear              | 3.1 M 
4 | loss          | CrossEntropyLoss    | 0     
------------------------------------------------------
55.1 M    Trainable params
144       Non-trainable params
55.1 M    Total params
220.496   Total estimated model params size (MB)
  ra

ninja: no work to do.
Time to load fused_adam op: 0.06789088249206543 seconds
Epoch 9: : 9272it [22:17,  6.93it/s, v_num=btq8, train_loss=0.644]   

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: : 9272it [22:17,  6.93it/s, v_num=btq8, train_loss=0.644]


In [None]:
trainer.save_checkpoint("Models/GuitarTranscription/5s400Tokens_1e-5/smallDataset4epochRetnetJamie.ckpt")

In [None]:
for i in tqdm(): 