### Baseline Training

Here we'll use the pretrained resnet model without an augmentation with limited data.

This will help to the planning of the further stages of the model development.


###

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import torch.optim as optim
import math
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from pathlib import Path
import json
from datetime import datetime
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
import os
from dotenv import load_dotenv
from torchvision.models import resnet18


from src import utils
from src.data.datasets import CloudHoleDataset
from src.models.models import CNN_Model
from src.training.train import (
    train_model,
    test
)

load_dotenv()


cmap = plt.get_cmap('Greys_r')

labels_path = f"{os.getenv('processed_data_path')}/labels_revised.csv"
gold_data_path = f"{os.getenv('gold_data_path')}/seviri/"
model_save_path = f"{os.getenv('model_save_path')}/pretrained_baseline"

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

LOGS_PATH = 'logs/v1.0_baseline'
BATCH_SIZE = 32
LEARNING_RATE = 0.001
EPOCHS = 50
NUM_WORKERS = 4
PATIENCE = 10
MIN_DELTA = 0.001

In [12]:
train_dataset = CloudHoleDataset(
    gold_zarr_path=f"{gold_data_path}/v1.0_baseline/train/data.zarr",
    augment=False,
    pretrained=True,
    model='resnet18'
)

validation_dataset = CloudHoleDataset(
    gold_zarr_path=f"{gold_data_path}/v1.0_baseline/validation/data.zarr",
    augment=False,
    pretrained=True,
    model='resnet18'
)

test_dataset = CloudHoleDataset(
    gold_zarr_path=f"{gold_data_path}/v1.0_baseline/test/data.zarr",
    augment=False,
    pretrained=True,
    model='resnet18'
)

✓ Loaded GoldCloudHoleDataset
  Path: /home/plato/dl_cloudhole/dl_cloudhole/data/processed/gold/seviri//v1.0_baseline/train/data.zarr
  Samples: 165
  Augmentation: OFF
  Normalization: mean=13.0594, std=7.281753233658452
✓ Loaded GoldCloudHoleDataset
  Path: /home/plato/dl_cloudhole/dl_cloudhole/data/processed/gold/seviri//v1.0_baseline/validation/data.zarr
  Samples: 131
  Augmentation: OFF
✓ Loaded GoldCloudHoleDataset
  Path: /home/plato/dl_cloudhole/dl_cloudhole/data/processed/gold/seviri//v1.0_baseline/test/data.zarr
  Samples: 98
  Augmentation: OFF


In [None]:

print('\n' + '='*80)
print('CLOUD HOLE DETECTION - TRAINING')
print('='*80 + '\n')

# Create directories
Path(model_save_path).mkdir(parents=True, exist_ok=True)
Path(LOGS_PATH).mkdir(parents=True, exist_ok=True)


# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    validation_dataset,
    batch_size=BATCH_SIZE * 2,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE * 2,
    shuffle=False,
    num_workers=NUM_WORKERS
)


model = resnet18(pretrained=True)

# FREEZE ALL LAYERS
print('\nFreezing backbone layers...')
for param in model.parameters():
    param.requires_grad = False

# UNFREEZE ONLY THE CLASSIFICATION HEAD

# Replace the final fully connected layer
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # Binary classification

# The new fc layer is unfrozen by default

print(f'✓ Backbone frozen')
print(f'✓ Classification head (fc layer) unfrozen and ready to train\n')


# VERIFY WHICH LAYERS ARE TRAINABLE

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())

print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
print(f'Frozen parameters: {total_params - trainable_params:,}')
print(f'Trainable ratio: {trainable_params/total_params*100:.2f}%\n')

model = model.to(DEVICE)

# OPTIMIZER - ONLY UPDATE TRAINABLE PARAMETERS
# Option 1: Only pass trainable parameters (more efficient)

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LEARNING_RATE
)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()


CLOUD HOLE DETECTION - TRAINING






Freezing backbone layers...
✓ Backbone frozen
✓ Classification head (fc layer) unfrozen and ready to train

Total parameters: 11,177,538
Trainable parameters: 1,026
Frozen parameters: 11,176,512
Trainable ratio: 0.01%



In [None]:
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=DEVICE,
    epochs=EPOCHS,
    patience=PATIENCE,
    min_delta=MIN_DELTA,
    save_dir=model_save_path
)


In [None]:
print('\nLoading best model for testing...')

checkpoint = torch.load(Path(model_save_path) / 'best_model.pth', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f'✓ Loaded best model from epoch {checkpoint["epoch"]+1}')

test_metrics = test(
    model=model,
    test_loader=test_loader,
    criterion=criterion,
    device=DEVICE,
    save_path=Path(LOGS_PATH) / 'test_results.json'
)

print('\n' + '='*80)
print('TRAINING COMPLETE')
print('='*80)
print(f'Best model: {model_save_path}/best_model.pth')
print(f'Test results: {LOGS_PATH}/test_results.json')
print('='*80 + '\n')




Loading best model for testing...
✓ Loaded best model from epoch 20

TESTING MODEL


Testing: 100%|██████████| 2/2 [00:21<00:00, 10.64s/it, loss=0.4553]



TEST RESULTS
Loss:        0.4553
Accuracy:    0.7551
Precision:   0.8246
Recall:      0.7705
Specificity: 0.7297
F1 Score:    0.7966
ROC-AUC:     0.8653

Confusion Matrix:
                Predicted
                 0      1
Actual  0       27    10
        1       14    47

✓ Test results saved to: logs/v1.0_baseline/test_results.json

TRAINING COMPLETE
Best model: /home/plato/dl_cloudhole/dl_cloudhole/models/pretrained_baseline/best_model.pth
Test results: logs/v1.0_baseline/test_results.json

