website: [LIGHTNING IN 15 MINUTES](https://lightning.ai/docs/pytorch/stable/starter/introduction.html)

# 1. Install

In [3]:
# !conda install lightning -c conda-forge

In [None]:
import torch
import os
from pathlib import Path

# 2. Lightning module define

In [None]:
import lightning as pl

class Pl_eg_module(pl.LightningModule):
    def __init__(self, hyper_param):
        super.__init__()

        self.hyper_param = hyper_param

        metric = lambda x: x

        loss_hist = []

        model_layers = None

    def forward(self, x):
        pass

    def training_step(self, batch, batch_idx):
        '''
        defines the training loop
        '''

        # load batch

        # output = self(inputs), then the self calls the forward 

        # cal loss

        # return loss

    def on_train_epoch_end(self):
        '''
        Option
        Executes after each epochs end
        return type is None (don't need to write return statement)
        '''

    '''
    Also have validation_step, test_step, on_validation_epoch_end, etc
    '''

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        return optimizer




# 3. Data preparation


In [None]:
from torch.utils.data import Dataset, DataLoader

class pl_eg_Dataset(Dataset):
    def __init__(self):
        super.__init__()

        self.X = None
        self.y = None

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):

        return self.X[idx], self.y[idx] 

class pl_eg_DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=64, num_workers=24):
        super().__init__()

        self.num_workers = num_workers

        self.data_dir = data_dir
        self.batch_size = batch_size

        # pre-initialize the datasets
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None


    def setup(self, stage):
        """
        stage (str): 'fit' or 'test'
        """
        train_fasta_path = Path(self.data_dir, 'fold0_train.fasta')
        self.train_dataset = pl_eg_Dataset(train_fasta_path)

        '''optional'''
        val_fasta_path = Path(self.data_dir, 'fold0_val.fasta')
        test_fasta_path = Path(self.data_dir, 'fold0_test.fasta')                 
        self.val_dataset = pl_eg_Dataset(val_fasta_path)
        self.test_dataset = pl_eg_Dataset(test_fasta_path)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
    

data_module = pl_eg_DataModule(data_dir='./data')


# 4. Train the model

what it actually does is:

```
model = Pl_eg_module(lr=1e-4, weight_decay=1e-4)
optimizer = model.configure_optimizers()

for batch_idx, batch in enumerate(train_dataloaders.train_dataloader):
    loss = model.training_step(batch, batch_idx)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
```

notes:

1. other params for pl.Trainer():


In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint

model = Pl_eg_module(lr=1e-4, weight_decay=1e-4)

checkpoint_callback = ModelCheckpoint(dirpath=model_dir, filename='{epoch:02d}-{val_loss:.2f}', every_n_epochs=1, save_top_k=-1)        # save every epoch
# checkpoint_callback = ModelCheckpoint(dirpath=model_dir, monitor='val_loss', filename='{epoch:02d}-{val_loss:.2f}', every_n_epochs=1, save_top_k=1)        # save best epoch

trainer = pl.Trainer(max_epochs=59, 
                        default_root_dir = results_dir,        # save the lightning_log
                        callbacks=[checkpoint_callback],       # save the checkpoint
                        accelerator='auto', devices='auto', strategy='auto')    # this line for gpu acceleration

# 5. Test the model after train

## basic
1. Use trainer.test(), agonistic to trainer.fit()
1. Run one epoch of test defined under test_step()
1. note: due to DDP issues, better set Trainer(devices=1)

## params of test()
Trainer.test(model=None, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None)



In [None]:
# trainer.test(model, dataloaders=DataLoader(test_set))

# '''test a model from ckpt path'''
# trainer.test(ckpt_path="/path/to/my_checkpoint.ckpt")
# # can also do ckpt_path="best", ckpt_path="last"; automatically tracked in the trainer

# '''test a model by giving it a model'''
# model1 = LitModel()
# model2= MyLightningModule.load_from_checkpoint(checkpoint_path="/path/to/pytorch_checkpoint.ckpt", 
#                                                hparams_file="/path/to/experiment/version/hparams.yaml", map_location=None)

# trainer.test(model1)
# trainer.test(model2)

# '''specify dataloaders'''
# # The default dataloader is the test dataloader specified in the LightningModule. But can overwrite (or define from scratch) from here

# test_dataloader = DataLoader(...)
# trainer.test(dataloaders=test_dataloader)

'''tested OK'''
data_module = xxx (pl.LightningDataModule)
model = siteLevelModel.load_from_checkpoint(checkpoint_path=model_path, map_location=None)

trainer = pl.Trainer(deterministic=False,
                         devices=1,
                         logger=wandb_logger,
                         default_root_dir = results_dir,
                         accelerator='auto', strategy='auto')

trainer.test(model=pretrained_model, datamodule=data_module)

# 6. Load and use

In [None]:
checkpoint = "./results/epoch=0-step=100.ckpt"
model = Pl_eg_module.load_from_checkpoint(checkpoint)


model.eval()
y_hat = model(x)

# logger

## general notes

1. save dir: defualt is ```getcwd()```. To specify: ```Trainer(default_root_dir="/your/path/to/save/checkpoints")```
1. log writing frequency \
By default, Lightning logs every 50 rows, or 50 training steps. To change: \
```trainer = Trainer(log_every_n_steps=100)```
1. Resuming \
Lightning ckpt is compatible with torch.load. \
Resume training state, simply:```trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")```

### self.log()
1. 2 basig log method: log() and log_dict(), can be used everywhere in Lightning modules and callbacks \
1. usage: \
log(): ```self.log("my_metric", x)``` \
log_dict(): ```self.log_dict({"acc": acc, "recall": recall})```

1. arguments (apply for both methods:\
```prog_bar```: Logs to the progress bar (Default: False). \
```logger```: Logs to the logger like Tensorboard, or any other custom logger passed to the Trainer (Default: True). \
```on_step```: Logs the metric at the current step. \
```on_epoch```: Automatically accumulates and logs at the end of the epoch.

### saving hyper parameters

Hyperparams are just the arguments passed to the ```__init()__```
Use ```self.save_hyperparameters()``` in ```__init__()``` method. Then the hyperparams will be stored with the checkpoint. 

1. to only save some: \
self.save_hyperparameters("layer_1_dim")

1. to exclude some: \
self.save_hyperparameters(ignore=["loss_fx", "generator_network"])

1. If the hyper param are included in the init then don't care. They're loaded into the model automatically. If need to change / specify, just xxx=xxx

## default logger: ```TensorBoard```

1. stores the logs to a directory (by default in lightning_logs/)
1. visualize 
```
%reload_ext tensorboard
%tensorboard --logdir=/data2/mqyu/work/HeLa_cross_tissue_test_wb/1_train_on_HeLa_backup/lightning_logs
```

## wandb (weight and bias)
link for wandb: click [here](https://docs.wandb.ai/guides/integrations/lightning/)

1. create account and loggin. 

1. before and after the trainer, see below. In the training just the same as using any logger

1. Go to the website at https://wandb.ai/leoyu20220822/projects to see project.

In [None]:
from pytorch_lightning.loggers import WandbLogger

'''the other codes'''
wandb_logger = WandbLogger(project='9-25-1')

'''hyper param (?)'''
wandb_logger.experiment.config["batch_size"] = batch_size

'''train'''
trainer = pl.Trainer(limit_train_batches=750, max_epochs=5, logger=wandb_logger)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

'''optional if not in notebooks'''
wandb.finish()

# device related

1. to check the device of current model: ```print(self.device)```

1. no any .to(device) or .cuda()

1. When init tensor in the non ```__init__()``` stage: ```new_tensor = new_tensor.to(an_old_tensor)``` (?) just to a existing tensor and it will know

1. If declare a tensor in the ```__init__()``` stage, use register buffer: ```self.register_buffer('pos_weight', torch.ones([2,3])```

