## Import al neccesary libries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import wandb
import traceback
import torch.nn.functional as F
import numpy as np

from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor


from models.data import LandsatDataModule
from models.nn import ResAttentionConvNetCBAM
from models.trainers import BasicTrainer
from models.trainers import LossFactory


## Set all configurations using dictionaries

In [15]:
config = {
    'model_name': 'ResAttentionConvNetCBAM',
    'data_module': {
        'train_file': '/teamspace/studios/this_studio/datasets/uint16_optimized_balanced_train_data.h5',
        'test_file': '/teamspace/studios/this_studio/dataset/test_data.h5',
        'batch_size': 512,
        'dtype': np.uint16,
        'num_workers': 4,
        'seed': 50,
        'split_ratio': (0.8, 0.2),
        'transform': {
            'RandomHorizontalFlip': {'p': 0.5},
            'RandomVerticalFlip': {'p': 0.5},
            'RandomRotation': {'degrees': 90},
        }
    },
    'model': {
        'input_channels': 6,
        'embedding_size': 256,
        'num_classes': 1,
        'dropout_rate': 0.5
    },
    'loss': {
        'BCEWithLogitsLoss': {},
        'FocalLoss': {'alpha': 0.8, 'gamma': 2},
    },
    'optimizer': {
        'type': 'AdamW',
        'lr': 1e-3,
        'weight_decay': 1.0e-5
    },
    'scheduler': {
        'type': 'StepLR',
        'step_size': 10,
        'gamma': 0.1
    }
}

## Create the data module

In [16]:
# Create the HDF5DataModule from the configuration
data_module_config = config['data_module']
data_module = LandsatDataModule.from_config(data_module_config)
print("DataModule created successfully")

DataModule created successfully


## Create the model

In [17]:
# Create the model from the configuration
model_config = config['model']
model = ResAttentionConvNetCBAM.from_config(model_config)
print("Model created successfully")

Model created successfully


## Set the loss function

In [18]:
# Create the loss function from the configuration
loss_config = config['loss']
loss = LossFactory.from_config(loss_config)

## Create the trainer

In [19]:
# Create Lightning module
optimizer_config = config['optimizer']
scheduler_config = config['scheduler']

trainer_module = BasicTrainer(model, loss, optimizer_config, scheduler_config)
print("Lightning module created successfully")

Lightning module created successfully


## Initialize wandb

In [8]:
# Initialize wandb
run_name = f"{config['model_name']}_embed{config['model']['embedding_size']}"
wandb.init(project="INEGI", entity="geo-dl", config=config, name=run_name)

# Setup wandb logger
wandb_logger = WandbLogger(project="INEGI", entity="geo-dl", name=run_name, config=config)

## Create pytorch lightning trainers and callbacks

In [None]:
# Setup model checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='inegi-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    monitor='val_loss',
    mode='min'
)

# Learning rate monitor
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# Setup trainer
trainer = pl.Trainer(
    max_epochs=30,
    logger=wandb_logger,
    log_every_n_steps=5,
    callbacks=[checkpoint_callback, lr_monitor],
    accumulate_grad_batches=1,
    devices=1 if torch.cuda.is_available() else None,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu'
)

## Train the model

In [None]:
# Train model
trainer.fit(trainer_module, data_module)

## Shutdown all

In [None]:
# Close wandb run
wandb.finish()