In [None]:
import torch
from pytorch_lightning import LightningModule
import torchmetrics
import torchvision

In [22]:
class resnet_18(torch.nn.Module):
    def __init__(self, freeze=True):
        super().__init__()
        
        #transformar la entrada de 2 canales a 3 canales debido a que la resnet18 necesita 3 canales de entrada
        self.in_conv = torch.nn.Conv2d(2,3,3,padding=1)

        #descargar el modelo entrenado
        self.resnet = torchvision.models.resnet18(pretrained=True)
        
        #eliminamos la ultima capa
        self.resnet=torch.nn.Sequential(*list(self.resnet.children())[:-1])
        
        if freeze:
            self.freeze()
        
                
    def freeze(self):
        #congela los parametros de la red preentrenada
        for param in self.resnet.parameters():
            param.requires_grad=False
            
    def unfreeze(self):
        #descongela los parametros de la red preentrenada
        for param in self.resnet.parameters():
            param.requires_grad=True
            
    def forward(self,x):
        x = self.in_conv(x)
        x = self.resnet(x)
        x = x.view(x.shape[0],-1)
                
        return x

class resnet_mod(LightningModule):

    def __init__(self, num_classes):
        super().__init__()
        
        
        self.num_classes = num_classes        

        self.features_dem = resnet_18()

        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(512, 1024),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.25),
            torch.nn.Linear(1024, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 84),
            torch.nn.ReLU(),
            torch.nn.Linear(84, num_classes),
        )

        

    def forward(self, dem,coordenadas):
        
        x_DEM = self.features_dem(dem)
        #x_DEM = torch.flatten(x_DEM, 1)        
        x = x_DEM + coordenadas
     
        logits = self.classifier(x)
        return logits

    def training_step(self, batch, batch_idx):
        _dem, _coordenadas, _target = batch
        
        logits = self(_dem, _coordenadas)
        
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logits,_target)

        acc = torchmetrics.Accuracy().to("cuda:0")
        res_acc = acc(logits,torch.where(_target.to(torch.int)==1)[1])

        self.log("train_loss", loss)
        self.log("train_acc", res_acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        _dem, _coordenadas, _target = batch
        
        logits = self(_dem, _coordenadas)
        
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logits,_target)

        acc = torchmetrics.Accuracy().to("cuda:0")
        res_acc = acc(logits,torch.where(_target.to(torch.int)==1)[1])

        self.log("valid_loss", loss)
        self.log("valid_acc", res_acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)
