In [18]:
%load_ext autoreload
%autoreload 2

import os
import sys
import torch
import yaml
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Gestion du dossier de travail
if os.getcwd().endswith('notebooks'):
    os.chdir('..')
sys.path.append(os.getcwd())

In [13]:
from src.preprocessing import ArealData
from src.models import EuroSATCNN

In [14]:
with open("config.yaml", "r") as f:
        config = yaml.safe_load(f)

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Training on {device} ---")

# DataLoaders
train_dataset = ArealData(
    csv_file=config['data']['train_dir'],
    root_dir=config['data']['root_dir'],
    n_channels=config['model']['in_channels']
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=config['model']['batch_size'], 
    shuffle=True,
    num_workers=2
)

validation_dataset = ArealData(
    csv_file=config['data']['validation_dir'],
    root_dir=config['data']['root_dir'],
    n_channels=config['model']['in_channels']
)

val_loader = DataLoader(
    validation_dataset, 
    batch_size=config['model']['batch_size'], 
    shuffle=True,
    num_workers=2
)

test_dataset = ArealData(
    csv_file=config['data']['test_dir'],
    root_dir=config['data']['root_dir'],
    n_channels=config['model']['in_channels']
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=config['model']['batch_size'], 
    shuffle=True,
    num_workers=2
)

--- Training on cpu ---


In [21]:
model = EuroSATCNN(
    in_channels=config['model']['in_channels'], 
    n_classes=config['model']['n_classes']
).to(device)
    
optimizer = torch.optim.Adam(model.parameters(), lr=config['model']['learning_rate'])
criterion = torch.nn.CrossEntropyLoss()


In [22]:
def evaluate( loader, desc="Evaluation"):
        model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in tqdm(loader, desc=desc, leave=False):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        return total_loss / len(loader), 100 * correct / total

In [None]:
model.train()
history = {'train_loss': [], 'val_acc': [], 'val_loss': []}
for epoch in range(config['model']['epochs']):
    running_loss = 0.0
    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{config['model']['epochs']}]", unit="batch")
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        current_loss = loss.item()
        running_loss += current_loss

        loop.set_postfix(loss=current_loss)

    avg_epoch_loss = running_loss / len(train_loader)
    # Validation
    val_loss, val_acc = evaluate(val_loader, desc="Validation")
    
    # Log History
    history['train_loss'].append(avg_epoch_loss)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Epoch {epoch+1} Summary: Train Loss: {avg_epoch_loss:.4f} | Val Acc: {val_acc:.2f}%")

Epoch [1/20]:   0%|          | 0/604 [00:08<?, ?batch/s]


KeyboardInterrupt: 

In [None]:
os.makedirs(os.path.dirname(config['model']['model_path']), exist_ok=True)
torch.save(model.state_dict(), config['model']['model_path'])
print("--- Training Complete & Model Saved ---")