In [None]:
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import mlflow
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
from torchmetrics import MetricCollection, MeanSquaredError, MeanAbsoluteError, R2Score

In [None]:
class UAV_vit(pl.LightningModule):
    def __init__(self, backbone, loss_fn, labels):
        super().__init__()
        
        self.loss_fn = loss_fn
        self.model = backbone

        self.labels = labels

        # Get the number of input features of the last layer of the backbone
        num_input_filters = backbone.heads[0].in_features
        num_output_values = 1

        # Replace the head of the model
        self.model.heads = nn.Linear(in_features=num_input_filters, out_features=num_output_values).float()

        metric_collection = MetricCollection([
            MeanSquaredError(),
            MeanAbsoluteError(),
            R2Score(nan_to_num=True)
        ])
        self.train_metrics = metric_collection.clone(prefix="train_")
        self.val_metrics = metric_collection.clone(prefix="val_")
        self.test_metrics = metric_collection.clone(prefix="test_")

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-4, weight_decay=1e-5)
        return optimizer
    
    # Training

    def training_step(self, batch, batch_idx):
        outputs, labels, loss = self.get_batch_data(batch)
        step_metrics = self.train_metrics.forward(outputs, labels)
        self.log_dict(step_metrics, on_epoch=True, on_step=False)
        self.log("train_loss", loss, on_epoch=True, on_step=False)
        return loss
    
    def on_train_epoch_end(self):
        self.train_metrics.reset()

    # Validation

    def validation_step(self, batch, batch_idx):
        outputs, labels, loss = self.get_batch_data(batch)
        step_metrics = self.val_metrics.forward(outputs, labels)
        self.log_dict(step_metrics, on_epoch=True, on_step=False)
        self.log("val_loss", loss, on_epoch=True, on_step=False)

    def on_validation_epoch_end(self):
        self.val_metrics.reset()

    # Testing

    def test_step(self, batch, batch_idx):
        outputs, labels, loss = self.get_batch_data(batch)
        step_metrics = self.test_metrics.forward(outputs, labels)
        self.log_dict(step_metrics, on_epoch=True, on_step=False)
        self.log("test_loss", loss, on_epoch=True, on_step=False)


    def on_test_epoch_end(self):
        self.test_metrics.reset()

    # Helper functions

    def get_batch_data(self, batch):
        images, labels = batch
        labels = labels.unsqueeze(-1)
        outputs = self.model(images)
        loss = self.loss_fn(outputs, labels)    
        return outputs, labels, loss