## Imports and Setup

In [31]:
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.models as models
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

In [17]:
torch.manual_seed(42)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

DATA_DIR = Path.cwd() / 'data'
MODEL_DIR = Path.cwd() / 'models'
#IMAGE_DIR = DATA_DIR / 'images'
IMAGE_DIR = DATA_DIR / 'sentinel-images'

transform = transforms.Compose(
    [
        transforms.Resize((480, 480), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

transform = transforms.Compose(
    [
        transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

## Create Dataloader

In [3]:
class SatelliteDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.uids = self.df['uid'].tolist()
        self.targets = self.df['severity'].tolist()

        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        image_pth = IMAGE_DIR / f"{self.uids[index]}.png"
        image = Image.open(image_pth)
        
        if self.transform is not None:
            image = self.transform(image)
            
        return self.uids[index], image, self.targets[index] # returns uid, image, target/severity

In [4]:
def get_loader(df, transform, batch_size=32, shuffle=True, num_workers=0, pin_memory=True):
    
    dataset = SatelliteDataset(df, transform)
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    
    return loader, dataset

## Create Model

In [5]:
class CNN(nn.Module):
    """EfficientNetV2 Model
    
    Args:
        num_classes (int): Number of classes
        dropout (float): Dropout rate
    """
    def __init__(self, num_classes, dropout=0.2):
        super(CNN, self).__init__()
        # EfficientNetV2 base
        # self.pretrained_cnn = models.efficientnet_v2_s(weights='DEFAULT')
        # in_features = self.pretrained_cnn.classifier[1].in_features
        # self.pretrained_cnn = nn.Sequential(*list(self.pretrained_cnn.children())[:-1])
        
        
        # ResNet50 base
        self.pretrained_cnn = models.resnet50(weights='DEFAULT')
        in_features = self.pretrained_cnn.fc.in_features
        self.pretrained_cnn = nn.Sequential(*list(self.pretrained_cnn.children())[:-1])
        
        
        # Add fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(in_features, in_features//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(in_features//2, num_classes)
        )
        
        # Final softmax layer
        #self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x, targets=None):
        """EfficientNetV2 forward pass
        
        Args:
            x (torch.Tensor): Input image
            targets (torch.Tensor): Target labels
            
        Returns:
            preds (torch.Tensor): Predicted labels
            loss (torch.Tensor): Loss value
        """
        features = self.pretrained_cnn(x)                 # (batch_size, 1280, 1, 1)
        features = features.view(features.size(0), -1)  # (batch_size, 1280)
        features = self.fc(features)                    # (batch_size, num_classes)
        
        #preds = self.softmax(features)                  # (batch_size, num_classes)
        #preds = torch.argmax(preds, dim=1)              # (batch_size)
        
        loss = None
        if targets is not None:
            pass
        
        return features, loss

In [6]:
def RMSELoss(y_features, y_targets):
    """RMSE Loss
    
    Args:
        y_features (torch.Tensor): Model features
        y_targets (torch.Tensor): Target labels
        
    Returns:
        loss (torch.Tensor): Loss value
    """
    y_preds = torch.argmax(F.softmax(y_features, dim=1), dim=1)
    
    loss = torch.sqrt(F.mse_loss(y_preds, y_targets))
    
    return loss

## Train Model

In [7]:
# Load data from csv files
metadata = pd.read_csv(DATA_DIR / 'sentinel-metadata.csv')
metadata.date = pd.to_datetime(metadata.date)

train_labels = pd.read_csv(DATA_DIR / 'train_labels.csv')
submission_format = pd.read_csv(DATA_DIR / 'submission_format.csv')

In [8]:
# Hyperparameters
num_classes = 5
learning_rate = 3e-4
dropout = 0.2

batch_size = 8
num_workers = 0

num_epochs = 5

In [9]:
# Get entire train set from metadata file
train_full = train_labels.merge(
    metadata,
    how='inner',
    left_on='uid',
    right_on='uid',
    validate='1:1',
)
train_full = train_full[['uid', 'severity']]

In [10]:
# Split into train and validation sets to better evaluate the model
train, validate = train_test_split(train_full, test_size=0.2, random_state=42)

In [11]:
# Get test set from metadata file
test_full = metadata[metadata['split'] == 'test']
test = submission_format.merge(
    test_full,
    how='inner',
    left_on='uid',
    right_on='uid',
    validate='1:1',
)
test = test[['uid', 'region', 'severity']]

In [32]:
# Get dataloaders
train_loader, train_dataset = get_loader(train, transform, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader, test_dataset = get_loader(validate, transform, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [33]:
# Create model object
model = CNN(num_classes, dropout=dropout).to(device)

criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [34]:
for epoch in range(num_epochs):
    # Training loop
    train_images = 0
    train_loss = 0
    pbar_train = tqdm(train_loader, total=len(train_loader), leave=False)
    model.train()
    for idx, (uid, imgs, targets) in enumerate(pbar_train):
        # Move data to device
        imgs = imgs.to(device)
        targets = targets.to(device)
        
        # Forward pass   
        outputs, _ = model(imgs)
        loss = criterion(outputs, targets)
        #rmse = RMSELoss(outputs, targets)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update progress bar
        train_images += 1
        train_loss += loss.item()
        desc = (
            f"Epoch [{epoch+1}]/[{num_epochs}]"
            f" - Train Loss: {train_loss/train_images:.4f}"
        )
        pbar_train.set_description(desc=desc)
    
    # Validation loop    
    val_images = 0
    val_loss = 0
    pbar_val = tqdm(val_loader, total=len(val_loader), leave=True)
    model.eval()
    with torch.no_grad():
        for idx, (uid, imgs, targets) in enumerate(pbar_val):
            # Move data to device
            imgs = imgs.to(device)
            targets = targets.to(device)
            
            # Forward pass
            outputs, _ = model(imgs)
            loss = criterion(outputs, targets)
            
            # Update progress bar
            val_images += 1
            val_loss += loss.item()
            desc = (
                f"Epoch [{epoch+1}]/[{num_epochs}]"
                f" - Train Loss: {train_loss/train_images:.4f}"
                f" - Val Loss: {val_loss/val_images:.4f}"
            )
            pbar_val.set_description(desc=desc)
    
    log = (
        f"Epoch [{epoch+1}]/[{num_epochs}]"
        f" - Train Loss: {train_loss/train_images:.4f}"
        f" - Val Loss: {val_loss/val_images:.4f}"
    )
            

Epoch [1]/[5] - Val Loss: 1.0877: 100%|██████████| 226/226 [00:11<00:00, 19.40it/s]  
                                                                                    

KeyboardInterrupt: 

In [None]:
model_pth = MODEL_DIR / 'model.pt'
torch.save(model.state_dict(), model_pth)

## Create Predictions

In [None]:
model = CNN(num_classes).to(device)
model.load_state_dict(torch.load(model_pth))
model.eval()