In [9]:
# !pip install pytorch-lightning

In [40]:
import torch
import pytorch_lightning as p
from typing import Tuple
import torchmetrics

In [64]:
class Model(p.LightningModule):
    def __init__(self):
        super().__init__()
#       metrics
        self.t_exv=torchmetrics.ExplainedVariance()
        self.v_exv=torchmetrics.ExplainedVariance()
        self.model=torch.nn.Linear(1,1)
        
    def forward(self,x):
        return self.model(x)
    
    def training_step(self,batch:Tuple[torch.Tensor,torch.Tensor],batch_idx:int) -> torch.Tensor:
        xs,ys=batch
        outs=self(xs)
        loss=torch.nn.functional.mse_loss(outs,ys)
        self.t_exv(outs,ys)
        self.log('t_exv',self.t_exv,prog_bar=True)
        self.log("loss",loss)  # for saving the loss for accessing it later
        return loss
    
    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        x,y=batch
        y_out=self(x)
        loss = torch.nn.functional.mse_loss(y_out, y) 
        self.v_exv(y_out,y)
        self.log('v_exv',self.v_exv,prog_bar=True)
        self.log("val_loss", loss)
    
    def configure_optimizers(self)-> torch.optim.Optimizer:
        optimizer=torch.optim.Adam(self.parameters(),lr=3e-4)
        return optimizer

In [65]:
class LoggingCallback(p.Callback):
    def on_train_epoch_end(self,trainer:p.Trainer,pl_module: p.LightningModule):
        epoch = trainer.current_epoch
        logs=trainer.callback_metrics

        loss = logs.get('loss')
        exv = logs.get('t_exv')
        # Perform logging or any other operations

        print(f"Epoch {epoch} - Training Loss: {loss} - EXV: {exv}")
    
    def on_validation_epoch_end(self, trainer: p.Trainer, pl_module: p.LightningModule) -> None:
      
        epoch = trainer.current_epoch
        logs=trainer.callback_metrics

        exv=logs.get('v_exv')
        loss = logs.get('val_loss')

        print(f"Epoch {epoch} - V Loss: {loss}  - EXV: {exv}")
        
    

In [66]:

class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.x_train = torch.linspace(0, 1, 100)
        self.y_train = self.x_train * 2 + 1
        self.x_val = torch.linspace(0, 1, 20)
        self.y_val = self.x_val * 2 + 1

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(self.x_train, self.y_train)
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
             torch.utils.data.TensorDataset(self.x_val, self.y_val)
        )


In [67]:


model = Model()
data_module = MyDataModule()

#     checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', save_top_k=1, mode='max')
#     early_stopping_callback = EarlyStopping(monitor='val_loss', patience=5)

trainer = p.Trainer(
        callbacks=[LoggingCallback()],
        max_epochs=20
    )

trainer.fit(model, data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type              | Params
--------------------------------------------
0 | t_exv | ExplainedVariance | 0     
1 | v_exv | ExplainedVariance | 0     
2 | model | Linear            | 2     
--------------------------------------------
2         Trainable params
0         Non-trainable params
2         Total params
0.000     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Epoch 0 - V Loss: 1.970310926437378  - EXV: -0.1868652105331421


  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0 - V Loss: 6.080156326293945  - EXV: -0.14163994789123535
Epoch 0 - Training Loss: 11.891194343566895 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 1 - V Loss: 5.866881370544434  - EXV: -0.11386752128601074
Epoch 1 - Training Loss: 11.50711441040039 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 2 - V Loss: 5.664691925048828  - EXV: -0.08776199817657471
Epoch 2 - Training Loss: 11.143440246582031 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 3 - V Loss: 5.469810962677002  - EXV: -0.0626065731048584
Epoch 3 - Training Loss: 10.792717933654785 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 4 - V Loss: 5.281004428863525  - EXV: -0.03815639019012451
Epoch 4 - Training Loss: 10.452550888061523 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 5 - V Loss: 5.097670555114746  - EXV: -0.014310717582702637
Epoch 5 - Training Loss: 10.121782302856445 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 6 - V Loss: 4.919440269470215  - EXV: 0.0089913010597229
Epoch 6 - Training Loss: 9.79971981048584 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 7 - V Loss: 4.7460551261901855  - EXV: 0.031788647174835205
Epoch 7 - Training Loss: 9.485885620117188 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 8 - V Loss: 4.577317237854004  - EXV: 0.05411481857299805
Epoch 8 - Training Loss: 9.179922103881836 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 9 - V Loss: 4.4130635261535645  - EXV: 0.07598280906677246
Epoch 9 - Training Loss: 8.881532669067383 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 10 - V Loss: 4.25315523147583  - EXV: 0.09741079807281494
Epoch 10 - Training Loss: 8.590476036071777 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 11 - V Loss: 4.097468376159668  - EXV: 0.11842334270477295
Epoch 11 - Training Loss: 8.306530952453613 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 12 - V Loss: 3.9458911418914795  - EXV: 0.13902491331100464
Epoch 12 - Training Loss: 8.029509544372559 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 13 - V Loss: 3.7983238697052  - EXV: 0.1592336893081665
Epoch 13 - Training Loss: 7.759232521057129 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 14 - V Loss: 3.654670000076294  - EXV: 0.1790522336959839
Epoch 14 - Training Loss: 7.495540142059326 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 15 - V Loss: 3.514843702316284  - EXV: 0.19849538803100586
Epoch 15 - Training Loss: 7.238282203674316 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 16 - V Loss: 3.3787612915039062  - EXV: 0.2175716757774353
Epoch 16 - Training Loss: 6.9873199462890625 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 17 - V Loss: 3.2463455200195312  - EXV: 0.23628753423690796
Epoch 17 - Training Loss: 6.742518901824951 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

Epoch 18 - V Loss: 3.1175215244293213  - EXV: 0.2546519637107849
Epoch 18 - Training Loss: 6.503756046295166 - EXV: 1.0


Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19 - V Loss: 2.992220640182495  - EXV: 0.27266818284988403
Epoch 19 - Training Loss: 6.270915985107422 - EXV: 1.0


In [70]:
# trainer.test()

In [68]:
# add noram dataset
# add test