## Import al neccesary libries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import wandb
import traceback
import torch.nn.functional as F
import numpy as np

from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor


from models.data import LandsatDataModule
from models.nn import ResAttentionConvNetCBAM
from models.trainers import BasicTrainer
from models.trainers import LossFactory
from models.trainers import CombinedLoss
from models.trainers import FeatureAwareTrainer

## Set all configurations using dictionaries

In [None]:
config = {
    'data_module': {
        'train_file': '/teamspace/studios/this_studio/datasets/uint16_optimized_balanced_train_data.h5',
        'test_file': '/teamspace/studios/this_studio/dataset/test_data.h5',
        'batch_size': 1024,
        'dtype': np.uint16,
        'num_workers': 4,
        'seed': 50,
        'split_ratio': (0.8, 0.2),
        #'transform': {
        #    'RandomHorizontalFlip': {'p': 0.5},
        #    'RandomVerticalFlip': {'p': 0.5},
        #}
    },
    'model': {
            'input_channels': 6,
            'initial_channels': 16,
            'embedding_size': 128,
            'depth': 2,
            'num_classes': 1,
            'reduction': 16,
            'dropout_rate': 0.5,
    },
    'loss_functions': {
        'center': {
            'params': {'num_classes': 2, 'feat_dim': 128, 'lambda_c': 0.03},
            'weight': 1.0
        },
        'focal': {
            'params': {'alpha': 0.25, 'gamma': 2.0},
            'weight': 1.0
        }
    },
    'optimizer': {
        'type': 'AdamW',
        'lr': 1e-3,
        'weight_decay': 1.0e-5
    },
    'scheduler': {
        'type': 'StepLR',
        'step_size': 15,
        'gamma': 0.1
    }
}

## Create the data module

In [None]:
# Create the HDF5DataModule from the configuration
data_module_config = config['data_module']
data_module = LandsatDataModule.from_config(data_module_config)
print("DataModule created successfully")

## Create the model

In [None]:
# Create the model from the configuration
model_config = config['model']

# Crear el modelo desde la configuraci√≥n
model = ResAttentionConvNetCBAM.from_config(model_config)
print("Model created successfully")

## Set the loss function

In [None]:
# Create the loss function from the configuration
loss_config = config['loss_functions']
loss = CombinedLoss.from_config(loss_config)
print("Loss function created successfully")

## Create the trainer

In [None]:
# Create Lightning module
optimizer_config = config['optimizer']
scheduler_config = config['scheduler']

trainer_module = FeatureAwareTrainer(model, loss, optimizer_config, scheduler_config)
print("Lightning module created successfully")

## Initialize wandb

In [None]:
# Initialize wandb
run_name = f"{model.get_class_name()}_embed{config['model']['embedding_size']}"
wandb.init(project="INEGI", entity="geo-dl", config=config, name=run_name)

# Setup wandb logger
wandb_logger = WandbLogger(project="INEGI", entity="geo-dl")

## Create pytorch lightning trainers and callbacks

In [None]:
# Setup model checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='inegi-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    monitor='val_loss',
    mode='min'
)

# Learning rate monitor
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# Setup trainer
trainer = pl.Trainer(
    max_epochs=20,
    logger=wandb_logger,
    log_every_n_steps=5,
    callbacks=[checkpoint_callback, lr_monitor],
    accumulate_grad_batches=1,
    devices=1 if torch.cuda.is_available() else None,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu'
)
print("Trainer created successfully")

## Train the model

In [None]:
# Train model
try:
    trainer.fit(trainer_module, data_module)
except Exception as e:
    print(f"An error occurred during training: {e}")
    # Print the full traceback, including the line number
    traceback.print_exc()
finally:
    # Close wandb run
    wandb.finish()

## Shutdown all

In [None]:
# Close wandb run
wandb.finish()