In [3]:
import kagglehub, os

# Download latest version
path = kagglehub.dataset_download("apollo2506/eurosat-dataset")

print("Path to dataset files:", path)

Resuming download from 49283072 bytes (2145373495 bytes left)...
Resuming download from https://www.kaggle.com/api/v1/datasets/download/apollo2506/eurosat-dataset?dataset_version_number=6 (49283072/2194656567) bytes left.


100%|██████████| 2.04G/2.04G [19:21<00:00, 1.85MB/s]

Extracting files...





Path to dataset files: C:\Users\erikd\.cache\kagglehub\datasets\apollo2506\eurosat-dataset\versions\6


In [8]:
!move {path} ./data/

Zugriff verweigert


In [2]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch
from PIL import Image
import os
import numpy as np
import torchvision.transforms as transforms




class EuroSATSample:
    img: torch.Tensor
    label: str                          

class EuroSATDataset(Dataset):
    """EuroSAT Dataset for satellite image classification.
    
    Args:
        root_dir (str): Directory with all the images.
        split (str): Split of the dataset to use ('train', 'val', 'test').
                     The split should correspond to a CSV file named '{split}.csv' in the root directory.
                     The CSV file should have two columns: 'Filename' and 'Label'.
        transform (callable, optional): Optional transform to be applied on a sample.
    """
    def __init__(self, root_dir, split, transform=None):
        self.root_dir = root_dir
        self.transform = transform if transform is not None else transforms.ToTensor()
        self.split = split

        split_file = os.path.join(root_dir, f"{split}.csv")
        if not os.path.exists(split_file):
            raise FileNotFoundError(f"Split file {split_file} does not exist.")
        df = pd.read_csv(split_file)

        self.image_paths = df['Filename'].tolist()
        self.labels = df['Label'].tolist()  
        self.len = len(self.image_paths)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        sample = EuroSATSample()

        sample.label = self.labels[idx]

        image_path = os.path.join(self.root_dir, self.image_paths[idx])
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image file {image_path} does not exist.")
        
        image = Image.open(image_path)

        if self.transform:
            image = self.transform(image)
        
        sample.img = image
        return sample

In [2]:
import torch
from torch import nn
from src.datasets import EuroSATDataset
from torch.utils.data import DataLoader
from torchvision import transforms

In [2]:
dataset_path = './data/EuroSAT'

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

dataset = EuroSATDataset(root_dir=dataset_path, split='train', transform=transform)

dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [5]:
from src.models.backbones.resnet import ResNet
from src.models.heads import FFN

class EuroSATModel(nn.Module):
    def __init__(self, backbone, head):
        super(EuroSATModel, self).__init__()
        self.backbone = backbone
        self.head = head
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Global average pooling
        self.flatten = nn.Flatten()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.backbone(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.head(x)
        return x
    
    def loss(self, input, target):
        logits = self.forward(input)
        loss = self.loss_fn(logits, target)
        return loss
    
    def forward_train(self, x, target):
        img = x['image']
        logits = self.forward(img)
        loss = self.loss_fn(logits, target)
        return logits, loss
    
    def forward_test(self, x):
        img = x['image']
        logits = self.forward(img)
        return logits
    
    def predict(self, x):
        logits = self.forward_test(x)
        return torch.argmax(logits, dim=1)
    
model = EuroSATModel(
    backbone=ResNet(idims=3, odims=64, arch=(2, 2, 2, 2), base_dims=32), 
    head=FFN(idims=64, odims=10, hidden_dims=64, dropout=0.5, nlayers=6))

In [6]:
from src.runner.utils import progress_bar

class Runner(nn.Module):
    def __init__(self, model, loading_cfg, data_cfg, optim_cfg, device='cpu'):
        super(Runner, self).__init__()
        self.model = model
        self.loading_cfg = loading_cfg
        self.optim = torch.optim.Adam(model.parameters(), **optim_cfg)
        self.device = device

        self.train_data = EuroSATDataset(**data_cfg, split='train')
        self.val_data = EuroSATDataset(**data_cfg, split='validation')
        self.test_data = EuroSATDataset(**data_cfg, split='test')

    def run(self, mode='train', val_interval=10, log_interval=10, epochs = 100, current_epoch=0):
        if mode == 'train':
            self.model.train()
            train_loader = DataLoader(self.train_data, batch_size=self.loading_cfg['batch_size'], shuffle=True)
            test_loader = DataLoader(self.val_data, batch_size=self.loading_cfg['batch_size'], shuffle=False)

            for epoch in range(current_epoch, epochs + 1):
                loss = self.train(self.model, train_loader, self.optim, self.device)
                if epoch % val_interval == 0:
                    acc, val_loss = self.test(self.model, test_loader, self.device) 
                if log_interval and epoch % log_interval == 0:
                    progress_bar(epoch=epoch, total_epochs=epochs, acc=acc, loss=loss, val_loss=val_loss)
            
        elif mode == 'validation':
            self.model.eval()
            dataloader = DataLoader(self.val_data, batch_size=1, shuffle=False)
            acc, val_loss = self.test(self.model, dataloader, self.device)
            print(f"Validation Accuracy: {acc}, Validation Loss: {val_loss}")
        elif mode == 'test':
            self.model.eval()
            dataloader = DataLoader(self.test_data, batch_size=self.loading_cfg['batch_size'], shuffle=False)
            
        else:
            raise ValueError("Mode must be 'train', 'validation', or 'test'.")

        return dataloader
    def train(self, model, dataloader, optim, device):
        model.train()
        total_loss = 0.0
        for batch in dataloader:
            optim.zero_grad()
            inputs, targets = batch['image'].to(device), batch['label'].to(device)
            loss = model.loss(inputs, targets)
            loss.backward()
            optim.step()
            total_loss += loss.item()
        return total_loss / len(dataloader)
    
    def test(self, model, dataloader, device):
        model.eval()
        total_loss = 0.0
        correct = 0
        with torch.no_grad():
            for batch in dataloader:
                inputs, targets = batch['image'].to(device), batch['label'].to(device)
                logits = model.forward_test(batch)
                loss = model.loss_fn(logits, targets)
                total_loss += loss.item()
                preds = torch.argmax(logits, dim=1)
                correct += (preds == targets).sum().item()
        
        accuracy = correct / len(dataloader.dataset)
        return accuracy, total_loss / len(dataloader)

In [8]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

loading_cfg = {
    'batch_size': 1,
    'num_workers': 4,
}

data_cfg = {
    'root_dir': "data/EuroSAT",
    'transform': transform,
}

optim_cfg = {
    'lr': 0.001,
    'weight_decay': 1e-4,
}

runner = Runner(model=model, loading_cfg=loading_cfg, data_cfg=data_cfg, optim_cfg=optim_cfg, device='cpu')

In [10]:
runner.run(mode='train', val_interval=1, log_interval=1, epochs=10, current_epoch=0)

KeyboardInterrupt: 