In [4]:
from src.models.lightningmodel import LightningClassifierModelWrapper, modelling_choice, count_parameters
from src.preprocessing.data_loader import GeoEye1, compute_dataset_statistics, get_transforms, get_dataloaders
import torch
import pandas as pd
import matplotlib.pyplot as plt

In [5]:
root_dir="ipeo_hurricane_for_students"
trainer, lightning_model = modelling_choice(model_name="resnet18", max_epochs=40, pretrained=True)
#mean, std = compute_dataset_statistics(root_dir, split="train", batch_size=1000)

Seed set to 42
/home/nstaehel/.venv/lib64/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/nstaehel/.venv/lib64/python3.9/site-packages/i ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [6]:
mean = torch.load("src/preprocessing/mean.pt")
std = torch.load("src/preprocessing/std.pt")

In [7]:
train_loader, val_loader, test_loader = get_dataloaders(root_dir, mean=mean, std=std, batch_size=100)
trainer.fit(lightning_model, train_loader, val_loader)
trainable_params, total_params = count_parameters(lightning_model)
print(f"Trainable parameters: {trainable_params}, Total parameters: {total_params}")    

Loaded 19000 images for train split
Loaded 2000 images for validation split
Loaded 2000 images for test split



  | Name  | Type   | Params | Mode  | FLOPs
-------------------------------------------------
0 | model | ResNet | 11.2 M | train | 0    
-------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)
68        Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/nstaehel/.venv/lib64/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 49, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/nstaehel/.venv/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 630, in _fit_impl
    self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
  File "/home/nstaehel/.venv/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1079, in _run
    results = self._run_stage()
  File "/home/nstaehel/.venv/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1121, in _run_stage
    self._run_sanity_check()
  File "/home/nstaehel/.venv/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1150, in _run_sanity_check
    val_loop.run()
  File "/home/nstaehel/.venv/lib64/python3.9/site-packages/pytorch_lightning/loops/utilities.py", line 179, in _decorator
    return loop_run(self, *arg

In [None]:
# Simple plotting for quick visualization of validation metrics
if not lightning_model.val_metrics_df.empty:
    df = lightning_model.val_metrics_df
    
    # Create figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Define metrics to plot
    metrics_config = [
        {'col': 'val_loss', 'color': 'blue', 'title': 'Validation Loss', 'ylabel': 'Loss'},
        {'col': 'val_accuracy', 'color': 'green', 'title': 'Validation Accuracy', 'ylabel': 'Accuracy'},
        {'col': 'val_f1', 'color': 'red', 'title': 'Validation F1 Score', 'ylabel': 'F1 Score'}
    ]
    
    for ax, config in zip(axes, metrics_config):
        if config['col'] in df.columns and df[config['col']].notna().any():
            # Remove NaN values for plotting
            plot_data = df[['epoch', config['col']]].dropna()
            
            if not plot_data.empty:
                ax.plot(plot_data['epoch'], plot_data[config['col']], 
                       color=config['color'], linewidth=2, marker='o')
                ax.set_xlabel('Epoch')
                ax.set_ylabel(config['ylabel'])
                ax.set_title(config['title'])
                ax.grid(True, alpha=0.3)
    
    plt.tight_layout()

    
    # Save the plots if needed
    fig.savefig('logs/figures/validation_metrics.png', dpi=300, bbox_inches='tight')
    print("Plots saved as 'validation_metrics.png'")

In [None]:
# Simple plotting for quick visualization of training metrics
if not lightning_model.train_metrics_df.empty:
    df = lightning_model.train_metrics_df
    
    # Create figure with subplots
    fig, axes = plt.subplots(1, 1, figsize=(15, 4))
    
    # Define metrics to plot
    metrics_config = [
        {'col': 'train_loss', 'color': 'blue', 'title': 'Training Loss', 'ylabel': 'Loss'},
    ]
    
    for ax, config in zip(axes, metrics_config):
        if config['col'] in df.columns and df[config['col']].notna().any():
            # Remove NaN values for plotting
            plot_data = df[['epoch', config['col']]].dropna()
            
            if not plot_data.empty:
                ax.plot(plot_data['epoch'], plot_data[config['col']], 
                       color=config['color'], linewidth=2, marker='o')
                ax.set_xlabel('Epoch')
                ax.set_ylabel(config['ylabel'])
                ax.set_title(config['title'])
                ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save the plots if needed
    fig.savefig('logs/figures/training_metrics.png', dpi=300, bbox_inches='tight')
    print("Plot saved as 'training_metrics.png'")