## Imports and Setup

In [8]:
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.models as models
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

In [28]:
torch.manual_seed(42)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

DATA_DIR = Path.cwd() / 'data'
MODEL_DIR = Path.cwd() / 'models'
RESULTS_DIR = Path.cwd() / 'results'
IMAGE_DIR = DATA_DIR / 'images'
# IMAGE_DIR = DATA_DIR / 'sentinel-images'
'''
transform = transforms.Compose(
    [
        transforms.Resize((480, 480), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
'''

transform = transforms.Compose(
    [
        transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

## Create Model

In [10]:
class CNN(nn.Module):
    """EfficientNetV2 Model
    
    Args:
        num_classes (int): Number of classes
        dropout (float): Dropout rate
    """
    def __init__(self, num_classes, dropout=0.2):
        super(CNN, self).__init__()
        # EfficientNetV2 base
        # self.pretrained_cnn = models.efficientnet_v2_s(weights='DEFAULT')
        # in_features = self.pretrained_cnn.classifier[1].in_features
        # self.pretrained_cnn = nn.Sequential(*list(self.pretrained_cnn.children())[:-1])
        
        
        # ResNet50 base
        self.pretrained_cnn = models.resnet50(weights='DEFAULT')
        in_features = self.pretrained_cnn.fc.in_features
        self.pretrained_cnn = nn.Sequential(*list(self.pretrained_cnn.children())[:-1])
        
        
        # Add fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(in_features, in_features//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(in_features//2, num_classes)
        )
        
        # Final softmax layer
        #self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        """EfficientNetV2 forward pass
        
        Args:
            x (torch.Tensor): Input image
            
        Returns:
            preds (torch.Tensor): Predicted labels
        """
        features = self.pretrained_cnn(x)                   # (batch_size, 1280, 1, 1)
        features = features.view(features.size(0), -1)      # (batch_size, 1280)
        features = self.fc(features)                        # (batch_size, num_classes)
        
        return features

In [11]:
def RMSELoss(y_features, y_targets):
    """RMSE Loss
    
    Args:
        y_features (torch.Tensor): Model features
        y_targets (torch.Tensor): Target labels
        
    Returns:
        loss (torch.Tensor): Loss value
    """
    y_preds = torch.argmax(F.softmax(y_features, dim=1), dim=1)
    
    loss = torch.sqrt(F.mse_loss(y_preds, y_targets))
    
    return loss

## Create Dataloader

In [12]:
class SatelliteDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.uids = self.df['uid'].tolist()
        self.targets = self.df['severity'].tolist()

        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        image_pth = IMAGE_DIR / f"{self.uids[index]}.png"
        image = Image.open(image_pth)
        
        if self.transform is not None:
            image = self.transform(image)
            
        return self.uids[index], image, self.targets[index] # returns uid, image, target/severity

In [13]:
def get_loader(df, transform, batch_size=32, shuffle=True, num_workers=0, pin_memory=True):
    
    dataset = SatelliteDataset(df, transform)
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    
    return loader, dataset

## Train Model

In [14]:
# Load data from csv files
metadata = pd.read_csv(DATA_DIR / 'saved-metadata.csv')
metadata.date = pd.to_datetime(metadata.date)

train_labels = pd.read_csv(DATA_DIR / 'train_labels.csv')
submission_format = pd.read_csv(DATA_DIR / 'submission_format.csv')

In [15]:
# Hyperparameters
NUM_CLASSES = 5
learning_rate = 3e-4
dropout = 0.2

batch_size = 8
num_workers = 0

num_epochs = 10
patience = 5

In [16]:
# Get entire train set from metadata file
train_full = train_labels.merge(
    metadata,
    how='inner',
    left_on='uid',
    right_on='uid',
    validate='1:1',
)
train_full = train_full[['uid', 'severity']]

In [17]:
# Split into train and validation sets to better evaluate the model
train, validate = train_test_split(train_full, test_size=0.2, random_state=42)

In [18]:
# Get test set from metadata file
test_full = metadata[metadata['split'] == 'test']
test = submission_format.merge(
    test_full,
    how='inner',
    left_on='uid',
    right_on='uid',
    validate='1:1',
)
test = test[['uid', 'region', 'severity']]

In [19]:
# Get dataloaders
train_loader, train_dataset = get_loader(train, transform, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader, val_dataset = get_loader(validate, transform, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader, test_dataset = get_loader(test, transform, batch_size=1, shuffle=False, num_workers=num_workers)

In [20]:
# Create model object
model = CNN(NUM_CLASSES, dropout=dropout).to(device)

criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [21]:
def save_checkpoint(model_to_save, filename='checkpoint.pt'):
    torch.save(model_to_save.state_dict(), MODEL_DIR / filename)

best_loss = None
counter = 0

In [14]:
for epoch in range(num_epochs):
    ## TRAINING LOOP
    train_count = 0
    train_loss = 0
    avg_train_loss = 0
    pbar_train = tqdm(train_loader, total=len(train_loader), leave=False)
    model.train()
    for idx, (uids, imgs, targets) in enumerate(pbar_train):
        # Move data to device
        imgs = imgs.to(device)
        targets = targets.to(device)
        
        # Forward pass   
        outputs = model(imgs)
        loss = criterion(outputs, targets)
        #rmse = RMSELoss(outputs, targets)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update progress bar
        train_count += 1
        train_loss += loss.item()
        avg_train_loss = train_loss/train_count
        desc = (
            f"Epoch {epoch+1}/{num_epochs}"
            f" - Train Loss: {avg_train_loss:.4f}"
        )
        pbar_train.set_description(desc=desc)
    
    
    ## VALIDATION LOOP   
    val_count = 0
    val_loss = 0
    avg_val_loss = 0
    pbar_val = tqdm(val_loader, total=len(val_loader), leave=True)
    model.eval()
    with torch.no_grad():
        for idx, (uid, imgs, targets) in enumerate(pbar_val):
            # Move data to device
            imgs = imgs.to(device)
            targets = targets.to(device)
            
            # Forward pass
            outputs = model(imgs)
            loss = criterion(outputs, targets)
            
            # Update progress bar
            val_count += 1
            val_loss += loss.item()
            avg_val_loss = val_loss/val_count
            desc = (
                f"Epoch {epoch+1}/{num_epochs}"
                f" - Train Loss: {avg_train_loss:.4f}"
                f" - Val Loss: {avg_val_loss:.4f}"
            )
            pbar_val.set_description(desc=desc)
            
    
    ## CHECKPOINTING AND EARLY STOPPING
    if best_loss is None:   # i.e. first epoch
        best_loss = avg_val_loss
        save_checkpoint(model, filename=f'checkpoint.pt')
        
    elif avg_val_loss > best_loss:  # i.e. loss increased
        counter += 1
        if counter >= patience:
            print('Validation loss has not decreased. Stopping training.')
            break
        
    else:   # avg_val_loss < best_loss i.e. loss decreased
        best_loss = avg_val_loss
        save_checkpoint(model, filename=f'checkpoint.pt')
        counter = 0
    
    
            

Epoch 1/5 - Train Loss: 1.0151 - Val Loss: 0.9871: 100%|██████████| 226/226 [00:12<00:00, 17.45it/s]
Epoch 2/5 - Train Loss: 0.9186 - Val Loss: 0.8820: 100%|██████████| 226/226 [00:12<00:00, 17.80it/s]
Epoch 3/5 - Train Loss: 0.8773 - Val Loss: 1.0769: 100%|██████████| 226/226 [00:13<00:00, 16.70it/s]
Epoch 4/5 - Train Loss: 0.8860:  11%|█▏        | 103/904 [00:16<01:59,  6.72it/s]

## Create Predictions

In [30]:
model = CNN(NUM_CLASSES).to(device)

model_pth = MODEL_DIR / 'checkpoint.pt'
model.load_state_dict(torch.load(model_pth))


<All keys matched successfully>

In [31]:
submission = submission_format.copy()

results = {}

model.eval()
pbar = tqdm(test_loader, total=len(test_loader), leave=True)
for idx, (uid, img, _) in enumerate(pbar):
    uid = uid[0]
    img = img.to(device)
    
    output = model(img)
    
    prediction = torch.argmax(F.softmax(output, dim=1), dim=1)
    
    submission.loc[submission.uid == uid, 'severity'] = prediction.item()

100%|██████████| 4275/4275 [01:16<00:00, 55.84it/s]


In [32]:
submission

Unnamed: 0,uid,region,severity
0,aabn,west,4
1,aair,west,1
2,aajw,northeast,1
3,aalr,midwest,1
4,aalw,west,4
...,...,...,...
6505,zzpn,northeast,1
6506,zzrv,west,4
6507,zzsx,south,1
6508,zzvv,west,4


In [33]:
submission.severity.value_counts().sort_index()

1    4640
2       3
3     549
4    1318
Name: severity, dtype: int64

In [26]:
submission.severity.value_counts().sort_index()

1    4158
2     439
3     440
4    1473
Name: severity, dtype: int64

In [29]:
submission.to_csv(RESULTS_DIR / '2_all-cnn.csv', index=False)