In [None]:
"""
In this notebook we make predictions on the severity of a harmful algal bloom based on Landsat and Sentinel satellite imagery

Inputs:
    Satellite Image (preferably sentinel, but if not available, landsat)
    Temperature Data (hrrr)
    
Outputs:
    Severyty of Harmful Algal Bloom on a scale of 1-5 (1 being the least severe, 5 being the most severe)

"""

## Imports and Setup

In [45]:
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.models as models
import torchvision.transforms as transforms

from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

In [70]:
torch.manual_seed(42)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

DATA_DIR = Path.cwd() / 'data'
MODEL_DIR = Path.cwd() / 'models'
RESULTS_DIR = Path.cwd() / 'results'
IMAGE_DIR = DATA_DIR / 'images'
# IMAGE_DIR = DATA_DIR / 'sentinel-images'

NUM_CLASSES = 5

'''
transform_image = transforms.Compose(
    [
        transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop((480, 480)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
'''

transform_image = transforms.Compose(
    [
        transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

# https://towardsai.net/p/data-science/how-when-and-why-should-you-normalize-standardize-rescale-your-data-3f083def38ff
temp_scalar = MinMaxScaler(feature_range=(0, 1))

## Create Model

In [71]:
class CNN(nn.Module):
    """EfficientNetV2 Model
    
    Args:
        num_classes (int): Number of classes
        dropout (float): Dropout rate
    """
    def __init__(self, metadata_embed = None, dropout=0.2):
        super(CNN, self).__init__()
        # EfficientNetV2 base
        # self.pretrained_cnn = models.efficientnet_v2_s(weights='DEFAULT')
        # cnn_out = self.pretrained_cnn.classifier[1].in_features
        # self.pretrained_cnn = nn.Sequential(*list(self.pretrained_cnn.children())[:-1])
        
        # ResNet50 base
        self.pretrained_cnn = models.resnet50(weights='DEFAULT')
        cnn_out = self.pretrained_cnn.fc.in_features
        self.pretrained_cnn = nn.Sequential(*list(self.pretrained_cnn.children())[:-1])
        
        fc_in = cnn_out
        
        if metadata_embed is not None:
            self.lin = nn.Linear(1, metadata_embed)
            fc_in += metadata_embed
        
        fc_embed = fc_in//2
        fc_out = fc_embed//2
            
        # Add fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(fc_in, fc_embed),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fc_embed, fc_out),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        
        # Final classifier layer
        self.classifier = nn.Linear(fc_out, NUM_CLASSES)
        
        
    def forward(self, img, met=None):
        """EfficientNetV2 forward pass
        
        Args:
            img (torch.Tensor): Input image
            met (torch.Tensor): Input metadata features
            
        Returns:
            preds (torch.Tensor): Predicted labels
        """
        img_features = self.pretrained_cnn(img)                             # (batch_size, cnn_out, 1, 1)
        
        img_features = img_features.view(img_features.size(0), -1)          # (batch_size, cnn_out)
        
        
        if met is not None:
            met_features = self.lin(met.unsqueeze(dim=1))
            features = torch.cat((img_features, met_features), dim=1)   # (batch_size, cnn_out+metadata_features)
        else:
            features = img_features
            
        
        features = self.fc(features)
                
        logits = self.classifier(features)
        
        return logits

In [48]:
# def RMSELoss(y_features, y_targets):
#     """RMSE Loss
    
#     Args:
#         y_features (torch.Tensor): Model features
#         y_targets (torch.Tensor): Target labels
        
#     Returns:
#         loss (torch.Tensor): Loss value
#     """
#     y_preds = torch.argmax(F.softmax(y_features, dim=1), dim=1)
    
#     loss = torch.sqrt(F.mse_loss(y_preds, y_targets))
    
#     return loss

## Create Dataloader

In [60]:
class SatelliteDataset(Dataset):
    def __init__(self, df, image_transform=None):
        self.df = df
        self.uids = self.df['uid'].to_numpy()
        # self.temps = self.df['temperature'].to_numpy(dtype=np.float32)
        self.targets = self.df['severity'].to_numpy()

        self.image_transform = image_transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        image_pth = IMAGE_DIR / f"{self.uids[index]}.png"
        image = Image.open(image_pth)
        
        if self.image_transform is not None:
            image = self.image_transform(image)
                    
        return self.uids[index], image, self.targets[index] # returns (uid, image, target/severity)

In [50]:
def get_loader(df, transform, batch_size=32, shuffle=True, num_workers=0, pin_memory=True):
    
    dataset = SatelliteDataset(df, transform)
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    
    return loader, dataset

## Data Preprocessing

In [51]:
# Load data from csv files
metadata = pd.read_csv(DATA_DIR / 'c20p-metadata.csv')
metadata.date = pd.to_datetime(metadata.date)

train_labels = pd.read_csv(DATA_DIR / 'train_labels.csv')

submission_format = pd.read_csv(DATA_DIR / 'submission_format.csv')

temperature = pd.read_csv(DATA_DIR / 'temperature.csv')

In [52]:
scaled_temp = temperature.copy()
scaled_temp.temperature = temp_scalar.fit_transform(scaled_temp.temperature.values.reshape(-1, 1))

In [53]:
metadata_temp = metadata.merge(
    scaled_temp,
    how="inner",
    left_on='uid',
    right_on='uid',
    validate='1:1'
)

In [54]:
# Get entire train set from metadata file
train_full = train_labels.merge(
    metadata,
    how='inner',
    left_on='uid',
    right_on='uid',
    validate='1:1',
)
train_full = train_full[['uid', 'severity']]

In [55]:
# Split into train and validation sets to better evaluate the model
train, validate = train_test_split(train_full, test_size=0.2, random_state=42)

In [56]:
# Get test set from metadata file
test_full = metadata_temp[metadata_temp['split'] == 'test']
test = submission_format.merge(
    test_full,
    how='inner',
    left_on='uid',
    right_on='uid',
    validate='1:1',
)
test = test[['uid', 'region', 'severity']]

In [57]:
len(train), len(validate), len(test)

(11207, 2802, 4258)

## Train Model

In [58]:
# Hyperparameters
learning_rate = 1e-5
dropout = 0.3

batch_size = 64
num_workers = 0

num_epochs = 10
patience = 5

In [61]:
# Get dataloaders
train_loader, train_dataset = get_loader(train, transform_image, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader, val_dataset = get_loader(validate, transform_image, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader, test_dataset = get_loader(test, transform_image, batch_size=1, shuffle=False, num_workers=num_workers)

In [62]:
# Create model object
model = CNN(dropout=dropout).to(device)

criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
def save_checkpoint(model_to_save, filename='checkpoint.pt'):
    torch.save(model_to_save.state_dict(), MODEL_DIR / filename)

In [85]:
best_loss = None
counter = 0

for epoch in range(num_epochs):
    ## TRAINING LOOP
    train_count = 0
    train_loss = 0
    avg_train_loss = 0
    pbar_train = tqdm(train_loader, total=len(train_loader), leave=False)
    model.train()
    for idx, (uids, imgs, targets) in enumerate(pbar_train):
        # Move data to device
        imgs = imgs.to(device)
        # temperatures = temperatures.to(device)
        targets = targets.to(device)
        
        # Forward pass   
        outputs = model(
            imgs, 
            # temperatures,
        )
        loss = criterion(outputs, targets)
        #rmse = RMSELoss(outputs, targets)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update progress bar
        train_count += 1
        train_loss += loss.item()
        avg_train_loss = train_loss/train_count
        desc = (
            f"Epoch {epoch+1}/{num_epochs}"
            f" - Train Loss: {avg_train_loss:.4f}"
        )
        pbar_train.set_description(desc=desc)
    
    
    ## VALIDATION LOOP   
    val_count = 0
    val_loss = 0
    avg_val_loss = 0
    pbar_val = tqdm(val_loader, total=len(val_loader), leave=True)
    model.eval()
    with torch.no_grad():
        for idx, (uid, imgs, targets) in enumerate(pbar_val):
            # Move data to device
            imgs = imgs.to(device)
            # temperatures = temperatures.to(device)
            targets = targets.to(device)
            
            # Forward pass
            outputs = model(
                imgs,
                # temperatures,
            )
            loss = criterion(outputs, targets)
            
            # Update progress bar
            val_count += 1
            val_loss += loss.item()
            avg_val_loss = val_loss/val_count
            desc = (
                f"Epoch {epoch+1}/{num_epochs}"
                f" - Train Loss: {avg_train_loss:.4f}"
                f" - Val Loss: {avg_val_loss:.4f}"
            )
            pbar_val.set_description(desc=desc)
            
    
    ## CHECKPOINTING AND EARLY STOPPING
    if best_loss is None:   # i.e. first epoch
        best_loss = avg_val_loss
        save_checkpoint(model, filename=f'checkpoint.pt')
        
    elif avg_val_loss > best_loss:  # i.e. loss increased
        counter += 1
        if counter >= patience:
            print('Validation loss has not decreased. Stopping training.')
            break
        
    else:   # avg_val_loss < best_loss i.e. loss decreased
        best_loss = avg_val_loss
        save_checkpoint(model, filename=f'checkpoint.pt')
        counter = 0
    
    
            

Epoch 1/10 - Train Loss: 1.0623 - Val Loss: 0.9865: 100%|██████████| 273/273 [00:15<00:00, 18.09it/s]
Epoch 2/10 - Train Loss: 0.9684 - Val Loss: 1.0514: 100%|██████████| 273/273 [00:13<00:00, 19.65it/s]
Epoch 3/10 - Train Loss: 0.9195 - Val Loss: 1.3018: 100%|██████████| 273/273 [00:14<00:00, 18.98it/s]
Epoch 4/10 - Train Loss: 0.8870 - Val Loss: 0.9277: 100%|██████████| 273/273 [00:13<00:00, 19.84it/s]
Epoch 5/10 - Train Loss: 0.8470 - Val Loss: 0.9401: 100%|██████████| 273/273 [00:13<00:00, 19.76it/s]
Epoch 6/10 - Train Loss: 0.7915 - Val Loss: 1.0354: 100%|██████████| 273/273 [00:13<00:00, 19.75it/s]
Epoch 7/10 - Train Loss: 0.7289:  14%|█▍        | 155/1090 [00:21<02:09,  7.24it/s]

## Generate Predictions

In [63]:
model = CNN().to(device) # metadata_embed=256

model_pth = MODEL_DIR / 'checkpoint.pt'
model.load_state_dict(torch.load(model_pth))


<All keys matched successfully>

In [65]:
submission = submission_format.copy()

results = {}

model.eval()
pbar = tqdm(test_loader, total=len(test_loader), leave=True)
for idx, (uid, img,  _) in enumerate(pbar):
    uid = uid[0]
    # temperature = temperature.to(device)
    img = img.to(device)
    
    output = model(
        img,
        # temperature,
    )
    
    prediction = torch.argmax(F.softmax(output, dim=1), dim=1)
    
    submission.loc[submission.uid == uid, 'severity'] = prediction.item()

100%|██████████| 4258/4258 [01:18<00:00, 53.91it/s]


In [66]:
submission

Unnamed: 0,uid,region,severity
0,aabn,west,1
1,aair,west,1
2,aajw,northeast,1
3,aalr,midwest,1
4,aalw,west,4
...,...,...,...
6505,zzpn,northeast,1
6506,zzrv,west,4
6507,zzsx,south,1
6508,zzvv,west,4


In [None]:
submission.severity.value_counts().sort_index()

1    4747
2      35
3     473
4    1255
Name: severity, dtype: int64

In [67]:
submission.severity.value_counts().sort_index()

1    4544
2     104
3     486
4    1376
Name: severity, dtype: int64

In [68]:
submission.to_csv(RESULTS_DIR / '5_notemp-cnn.csv', index=False)