In [None]:
from torch.optim import Adam, AdamW
import torch
from torchmetrics import MetricCollection, MeanAbsoluteError, MeanSquaredError, ExplainedVariance
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
from torch import Tensor
import math

In [None]:
def log_cosh_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    def _log_cosh(x: torch.Tensor) -> torch.Tensor:
        return x + nn.functional.softplus(-2. * x) - math.log(2.0)
    return torch.mean(_log_cosh(y_pred - y_true))

class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_pred: Tensor, y_true: Tensor) -> torch.Tensor:
        return log_cosh_loss(y_pred, y_true)

In [None]:
class UAV_vit(pl.LightningModule):

    def __init__(self, backbone, learning_rate=1e-6, loss_threshold=0.5, weight_decay=1e-1, batch_size: int = 16, no_grad_layers_n: int = 6, dropout: float = 0.0, attention_dropout: float = 0.0):
        super().__init__()
        
        self.backbone = backbone
        self.learning_rate = learning_rate
        self.loss_threshold = loss_threshold
        self.weight_decay = weight_decay
        # self.criterion = nn.HuberLoss(delta=loss_threshold)
        # self.criterion = nn.SmoothL1Loss(beta=self.loss_threshold)
        self.criterion = LogCoshLoss()
        self.optimizer = AdamW(backbone.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        self.no_grad_layers_n = int(no_grad_layers_n)
        self.batch_size = batch_size
        self.dropout = dropout
        self.attention_dropout = attention_dropout

        self.test_output = []
        self.test_loss = []
        self.test_targets_mean = []
        # Set dropout
        self.apply(lambda m: self.set_dropouts(m))

        # Get the number of input features of the last layer of the backbone
        num_input_filters = backbone.heads[0].in_features
        num_output_values = 1

        # Replace the head of the model
        self.backbone.heads[0] = nn.Linear(in_features=num_input_filters, out_features=num_output_values).float()
        
        metric_collection = MetricCollection([
            MeanSquaredError(),
            MeanAbsoluteError(),
            ExplainedVariance()
        ])
        self.val_metrics = metric_collection.clone(prefix="val_")
        self.test_metrics = metric_collection.clone(prefix="test_")

        if(self.no_grad_layers_n > 0):
            for i, param in enumerate(self.backbone.encoder.parameters()):
                if i < self.no_grad_layers_n:
                    param.requires_grad = False

    def set_dropouts(self, m):
        if isinstance(m, nn.Dropout):
            m.p = self.dropout
        elif isinstance(m, nn.MultiheadAttention):
            m.dropout = self.attention_dropout   

    def forward(self, x):
        x = self.backbone(x)
        return x

    def configure_optimizers(self):
        return self.optimizer
    
    def get_batch_data(self, batch):
        images, labels = batch
        labels = labels.unsqueeze(1)
        outputs = self.forward(images)
        loss = self.criterion(outputs, labels)
        return outputs, labels, loss
    
    # Training

    def training_step(self, batch, batch_idx):
        _, _, loss = self.get_batch_data(batch)
        self.log("train_loss", loss, on_epoch=True, on_step=False)
        return loss
    
    # Validation

    def validation_step(self, batch, batch_idx):
        outputs, labels, loss = self.get_batch_data(batch)
        step_metrics = self.val_metrics.forward(outputs, labels)
        self.log_dict(step_metrics, on_epoch=True, on_step=False)
        self.log("val_loss", loss, on_epoch=True, on_step=False)

    def on_validation_epoch_end(self):
        self.val_metrics.reset()

    # Testing

    def test_step(self, batch, batch_idx):
        outputs, labels, loss = self.get_batch_data(batch)
        step_metrics = self.test_metrics.forward(outputs, labels)

        outputs = outputs.squeeze().cpu().numpy()
        labels = labels.squeeze().cpu().numpy()

        self.test_output.extend(outputs)
        self.test_loss.append(loss.item())
        self.test_targets_mean.append(np.mean(labels))

        self.log_dict(step_metrics, on_epoch=True, on_step=False)
        self.log("test_loss", loss, on_epoch=True, on_step=False)

    def on_test_epoch_end(self):
        self.test_metrics.reset()

    # Prediction

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        outputs = self(batch)
        return outputs

