In [None]:
import torch
from pytorch_lightning import LightningModule
import torchmetrics

In [None]:
class Deep_Miner(LightningModule):

    def __init__(self, num_classes):
        super().__init__()
        
        
        self.num_classes = num_classes        
        in_channels = 2
           

        self.features_dem = torch.nn.Sequential(
            
            torch.nn.Conv2d(in_channels, 6, kernel_size=5),  #in = 2,300,300 ==> out = 6,296,296
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),               #out = 6,148,148

            torch.nn.Conv2d(6, 3, kernel_size=5),            #out = 3,144,144
            torch.nn.ReLU(),
            torch.nn.Dropout2d(0.25),
            torch.nn.MaxPool2d(kernel_size=2),               #out = 3,72,72

            torch.nn.Conv2d(3, 3, kernel_size=3),            #out = 3,70,70
            torch.nn.BatchNorm2d(3),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=5),               #out = 3,14,14
        )

        self.classifier = torch.nn.Sequential(
            torch.nn.Linear((3*14*14), 1024),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.25),
            torch.nn.Linear(1024, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 84),
            torch.nn.ReLU(),
            torch.nn.Linear(84, num_classes),
        )


    def forward(self, dem,coordenadas):
        
        x_DEM = self.features_dem(dem)
        x_DEM = torch.flatten(x_DEM, 1)
        
        x = x_DEM + coordenadas
     
        logits = self.classifier(x)
        return logits
    
    def training_step(self, batch, batch_idx):
        _dem, _coordenadas, _target = batch
        
        logits = self(_dem, _coordenadas)
        
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logits,_target)

        acc = torchmetrics.Accuracy().to("cuda:0")
        res_acc = acc(logits,torch.where(_target.to(torch.int)==1)[1])

        self.log("train_loss", loss)
        self.log("train_acc", res_acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        _dem, _coordenadas, _target = batch
        
        logits = self(_dem, _coordenadas)
        
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logits,_target)

        acc = torchmetrics.Accuracy().to("cuda:0")
        res_acc = acc(logits,torch.where(_target.to(torch.int)==1)[1])

        self.log("valid_loss", loss)
        self.log("valid_acc", res_acc, prog_bar=True)
        

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)