## Import al neccesary libries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import wandb
import traceback
import torch.nn.functional as F
import numpy as np

from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor


from models.data import LandsatDataModule
from models.nn import ResAttnConvNet
from models.trainers import BasicTrainer
from models.trainers import LossFactory


## Set all configurations using dictionaries

In [2]:
config = {
    'data_module': {
        'train_file': '/teamspace/studios/this_studio/datasets/uint16_optimized_balanced_train_data.h5',
        'test_file': '/teamspace/studios/this_studio/dataset/test_data.h5',
        'batch_size': 1024,
        'dtype': np.uint16,
        'num_workers': 4,
        'seed': 50,
        'split_ratio': (0.8, 0.2),
        'transform': {
            'RandomHorizontalFlip': {'p': 0.5},
            'RandomVerticalFlip': {'p': 0.5},
            'RandomRotation': {'degrees': 90},
        }
    },
    'model': {
            'input_channels': 6,
            'initial_channels': 32,
            'embedding_size': 512,
            'depth': 3,
            'num_classes': 1,
            'reduction': 16,
            'dropout_rate': 0.3,
    },
    'loss_functions': {
        'BCEWithLogitsLoss': {'reduction': 'mean'},
        'FocalLoss': {'alpha': 0.25, 'gamma': 2}
    },
    'optimizer': {
        'type': 'AdamW',
        'lr': 1e-3,
        'weight_decay': 1.0e-5
    },
    'scheduler': {
        'type': 'StepLR',
        'step_size': 10,
        'gamma': 0.1
    }
}

## Create the data module

In [3]:
# Create the HDF5DataModule from the configuration
data_module_config = config['data_module']
data_module = LandsatDataModule.from_config(data_module_config)
print("DataModule created successfully")

DataModule created successfully


## Create the model

In [4]:
# Create the model from the configuration
model_config = config['model']
model = ResAttnConvNet.from_config(model_config)
print("Model created successfully")

Model created successfully


## Set the loss function

In [5]:
# Create the loss function from the configuration
loss_config = config['loss_functions']
loss = LossFactory.from_config(loss_config)
print("Loss function created successfully")

Loss function created successfully


## Create the trainer

In [6]:
# Create Lightning module
optimizer_config = config['optimizer']
scheduler_config = config['scheduler']

trainer_module = BasicTrainer(model, loss, optimizer_config, scheduler_config)
print("Lightning module created successfully")

Lightning module created successfully


## Initialize wandb

In [7]:
# Initialize wandb
run_name = f"{model.get_class_name()}_embed{config['model']['embedding_size']}"
wandb.init(project="INEGI", entity="geo-dl", config=config, name=run_name)

# Setup wandb logger
wandb_logger = WandbLogger(project="INEGI", entity="geo-dl")

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhucarlos[0m ([33mgeo-dl[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112328666664931, max=1.0…

## Create pytorch lightning trainers and callbacks

In [8]:
# Setup model checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='inegi-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    monitor='val_loss',
    mode='min'
)

# Learning rate monitor
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# Setup trainer
trainer = pl.Trainer(
    max_epochs=30,
    logger=wandb_logger,
    log_every_n_steps=5,
    callbacks=[checkpoint_callback, lr_monitor],
    accumulate_grad_batches=1,
    devices=1 if torch.cuda.is_available() else None,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu'
)
print("Trainer created successfully")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Trainer created successfully


## Train the model

In [9]:
# Train model
try:
    trainer.fit(trainer_module, data_module)
except Exception as e:
    print(f"An error occurred during training: {e}")
    # Print the full traceback, including the line number
    traceback.print_exc()
finally:
    # Close wandb run
    wandb.finish()

/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /teamspace/studios/this_studio/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type            | Params | Mode 
------------------------------------------------------
0 | model     | ResAttnConvNet  | 5.1 M  | train
1 | accuracy  | BinaryAccuracy  | 0      | train
2 | precision | BinaryPrecision | 0      | train
3 | recall    | BinaryRecall    | 0      | train
4 | f1        | BinaryF1Score   | 0      | train
5 | aucroc    | BinaryAUROC     | 0      | train
------------------------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

## Shutdown all

In [None]:
# Close wandb run
wandb.finish()