Goal:
- Simplify swin transformer model with comprehensive setup for train, val, test with metrics (loss, r2 for log, r2 for original, mae, etc.)
    1. Do not use log transformation for trait targets (may have high variance in outputs) and continue to use R2 loss
    2. Use log transform for trait targets & adjust R2 loss to be calculated in original scale (not log-scale) OR add a layer of exp10 (10^x) in model for backprop

- [Later] Include hyperparameter tuning framework, add visualizations of first layer

Credits:
- Modified from HDJOJO's original notebook with SWIN Transformer, which was modified from https://www.kaggle.com/code/markwijkhuizen/planttraits2024-eda-training-pub.
- Training only, EDA part not included.
- Image model only, tabular data not used.

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as imageio
import albumentations as A

from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch import nn
from tqdm.notebook import tqdm
from sklearn.preprocessing import StandardScaler

import torch
import timm
# import glob
import torchmetrics
import time
import psutil
import os

tqdm.pandas()

In [3]:
class Config():
    IMAGE_SIZE = 256
#     BACKBONE = 'swin_large_patch4_window12_384.ms_in22k_ft_in1k'
    BACKBONE = 'swinv2_small_window16_256'
    TARGET_COLUMNS = ['X4_mean', 'X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']
    N_TARGETS = len(TARGET_COLUMNS)
    BATCH_SIZE = 128
    LR_MAX = 1e-4
    WEIGHT_DECAY = 0.01
    N_EPOCHS = 10
    TRAIN_MODEL = True
    IS_INTERACTIVE = os.environ['KAGGLE_KERNEL_RUN_TYPE'] == 'Interactive'
    
    # Added variables
    NUM_FOLDS = 5
    VALID_FOLD = 0  # Fold of validation data
        
CONFIG = Config()

In [None]:
# Read in training data
train_df = pd.read_csv('/kaggle/input/planttraits2024/train.csv')
train_df['file_path'] = train_df['id'].apply(lambda s: f'/kaggle/input/planttraits2024/train_images/{s}.jpeg')
train_df['jpeg_bytes'] = train_df['file_path'].progress_apply(lambda fp: open(fp, 'rb').read())
train_df.to_pickle('train.pkl') # serialize object into string form
print("Train df length:", len(train_df))

### Data Filtering

In [None]:
# Sampled training set for faster training
print("Previous length:", len(train_df))
# train_df = train_df.sample(frac=0.3, random_state=42)
# print("Sampled length:", len(train_df))

In [None]:
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=CONFIG.NUM_FOLDS, shuffle=True, random_state=42)

# Create separate bin for each traits
for i, trait in enumerate(CONFIG.TARGET_COLUMNS):
    # Determine the bin edges dynamically based on the distribution of traits
    bin_edges = np.percentile(train_df[trait], np.linspace(0, 100, CONFIG.NUM_FOLDS + 1))
    train_df[f"bin_{i}"] = np.digitize(train_df[trait], bin_edges)

# Concatenate the bins into a final bin
train_df["final_bin"] = (
    train_df[[f"bin_{i}" for i in range(CONFIG.N_TARGETS)]]
    .astype(str)
    .agg("".join, axis=1)
)

# Perform the stratified split using final bin
train_df = train_df.reset_index(drop=True)
for fold, (train_idx, valid_idx) in enumerate(skf.split(train_df, train_df["final_bin"])):
    train_df.loc[valid_idx, "fold"] = fold
    
train_df.head()

In [None]:
train = train_df[train_df["fold"] != CONFIG.VALID_FOLD]
valid = train_df[train_df["fold"] == CONFIG.VALID_FOLD] # Fold 0 is validation
train[CONFIG.TARGET_COLUMNS + ["fold"]].describe()

In [None]:
class PlantDataPreProcess:
    lower_quantile = 0.005
    upper_quantile = 0.995
    log_transform = np.log10

In [None]:
# Filter data
print("Num samples before filtering:", len(train))

for trait in CONFIG.TARGET_COLUMNS:
    lower_bound = train[trait].quantile(PlantDataPreProcess.lower_quantile)
    upper_bound = train[trait].quantile(PlantDataPreProcess.upper_quantile)
    train = train[(train[trait] >= lower_bound) & (train[trait] <= upper_bound)]
    
print("Num samples After filtering:", len(train))
train[CONFIG.TARGET_COLUMNS].describe()

In [None]:
# Log10 transformation for all traits except X4
LOG_FEATURES = ['X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']
y_train = train[CONFIG.TARGET_COLUMNS]

for skewed_trait in LOG_FEATURES:
    y_train.loc[:, skewed_trait] = y_train[skewed_trait].apply(PlantDataPreProcess.log_transform)

y_train.describe()

In [None]:
# Normalize to mean = 0, std dev = 1
from sklearn.preprocessing import StandardScaler

SCALER = StandardScaler()
y_train = SCALER.fit_transform(y_train)

# y_train_df = pd.DataFrame(y_train, columns=CONFIG.TARGET_COLUMNS)
# y_train_df.describe()

### SWIN Transformer Data Load

In [None]:
CONFIG.N_TRAIN_SAMPLES = len(train)
CONFIG.N_STEPS_PER_EPOCH = (CONFIG.N_TRAIN_SAMPLES // CONFIG.BATCH_SIZE)
CONFIG.N_STEPS = CONFIG.N_STEPS_PER_EPOCH * CONFIG.N_EPOCHS + 1

test = pd.read_csv('/kaggle/input/planttraits2024/test.csv')
test['file_path'] = test['id'].apply(lambda s: f'/kaggle/input/planttraits2024/test_images/{s}.jpeg')
test['jpeg_bytes'] = test['file_path'].progress_apply(lambda fp: open(fp, 'rb').read())
test.to_pickle('test.pkl')

print('N_TRAIN_SAMPLES:', len(train), 'N_TEST_SAMPLES:', len(test))

In [None]:
print("Train len:", len(train))
print("y_train len", len(y_train))

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(
                CONFIG.BACKBONE,
                num_classes=CONFIG.N_TARGETS,
                pretrained=True)  # Use pretrained SWIN Transformer model
        
    def forward(self, inputs):
        return self.backbone(inputs)

model = Model()
model = model.to('cuda')
print(model)

In [None]:
# Use model-specific image processing (transforms) from 'swinv2_tiny_window16_256' model

# get model specific transforms (normalization, resize)
# data_config = timm.data.resolve_model_data_config(model)
# transforms = timm.data.create_transform(**data_config, is_training=False)
# print(transforms)

In [None]:
# Where did values come from?
# Values seem to be the normalization used in training SWIN transformer on image net
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

TRAIN_TRANSFORMS = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomSizedCrop(
            [448, 512],
            CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE, w2h_ratio=1.0, p=0.75),
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.25),
        A.ImageCompression(quality_lower=85, quality_upper=100, p=0.25),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

VALID_TRANSFORMS = A.Compose([
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

TEST_TRANSFORMS = A.Compose([
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

class Dataset(Dataset):
    def __init__(self, X_jpeg_bytes, y, transforms=None):
        self.X_jpeg_bytes = X_jpeg_bytes
        self.y = y
        self.transforms = transforms

    def __len__(self):
        return len(self.X_jpeg_bytes)

    def __getitem__(self, index):
        X_sample = self.transforms(
            image=imageio.imread(self.X_jpeg_bytes[index]),
        )['image']
        y_sample = self.y[index]
        
        return X_sample, y_sample

train_dataset = Dataset(
    train['jpeg_bytes'].values,
    y_train,
    TRAIN_TRANSFORMS,
)

train_dataloader = DataLoader(
        train_dataset,
        batch_size=CONFIG.BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        num_workers=psutil.cpu_count(),
)

# Computes validation R2 in the log space (for all traits except X4)
# bc train R2 is computed in log space

valid_y = valid[CONFIG.TARGET_COLUMNS].values

def preprocess_targets(y, scaler, log_features, is_train=True):
    y = pd.DataFrame(y, columns=CONFIG.TARGET_COLUMNS)
    if is_train:
        for skewed_trait in log_features:
            y[skewed_trait] = y[skewed_trait].apply(PlantDataPreProcess.log_transform)
        y = scaler.fit_transform(y)
    else:
        for skewed_trait in log_features:
            y[skewed_trait] = y[skewed_trait].apply(PlantDataPreProcess.log_transform)
        y = scaler.transform(y)
    return y

valid_y = preprocess_targets(valid_y, SCALER, LOG_FEATURES, is_train=False)

valid_dataset = Dataset(
    valid['jpeg_bytes'].values,
    valid_y,
    VALID_TRANSFORMS,
)

valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=CONFIG.BATCH_SIZE,
        shuffle=False,
        num_workers=psutil.cpu_count(),
)

test_dataset = Dataset(
    test['jpeg_bytes'].values,
    test['id'].values,
    TEST_TRANSFORMS,
)

In [None]:
# List available Swin Transformer models in timm library
# list(filter(lambda x : 'swin' in x, timm.list_models()))

In [None]:
def get_lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=CONFIG.LR_MAX,
        total_steps=CONFIG.N_STEPS,
        pct_start=0.1,
        anneal_strategy='cos',
        div_factor=1e1,
        final_div_factor=1e1,
    )

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val):
        self.sum += val.sum()
        self.count += val.numel()
        self.avg = self.sum / self.count

MAE = torchmetrics.regression.MeanAbsoluteError().to('cuda')
R2 = torchmetrics.regression.R2Score(num_outputs=CONFIG.N_TARGETS, multioutput='uniform_average').to('cuda')
LOSS = AverageMeter()

Y_MEAN = torch.tensor(y_train).mean(dim=0).to('cuda')
EPS = torch.tensor([1e-6]).to('cuda')

def r2_loss(y_pred, y_true):
    ss_res = torch.sum((y_true - y_pred)**2, dim=0)
    ss_total = torch.sum((y_true - Y_MEAN)**2, dim=0)
    ss_total = torch.maximum(ss_total, EPS)
    r2 = torch.mean(ss_res / ss_total)
    return r2

# How is this R2 Loss?
LOSS_FN = nn.SmoothL1Loss() # r2_loss

learning_rate = 1e-3
optimizer = torch.optim.AdamW(
    params=model.parameters(),
    #lr=CONFIG.LR_MAX,
    lr=learning_rate,
    weight_decay=CONFIG.WEIGHT_DECAY,
)

# LR_SCHEDULER = get_lr_scheduler(optimizer)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Define validation loop with batch processing
def validate(model, dataloader, scaler, log_features):
    model.eval()
    MAE_valid = torchmetrics.MeanAbsoluteError().to('cuda')
    R2_valid = torchmetrics.R2Score(num_outputs=CONFIG.N_TARGETS, multioutput='uniform_average').to('cuda')
    losses = []
    
    for X_batch, y_true in dataloader:
        X_batch = X_batch.to('cuda')
        y_true = y_true.to('cuda')
        with torch.no_grad():
            y_pred = model(X_batch)
            loss = LOSS_FN(y_pred, y_true)
            losses.append(loss.item())
            MAE_valid.update(y_pred, y_true)
            R2_valid.update(y_pred, y_true)
    
    valid_r2 = R2_valid.compute().item()
    valid_mae = MAE_valid.compute().item()
    valid_loss = np.mean(losses)
    
    return valid_r2, valid_mae, valid_loss

metrics = {
    'epoch': [],
    'loss': [],
    'mae': [],
    'r2': [],
    'lr': [],
    'training_time': [],
    'num_params': count_parameters(model),
    'valid_r2': [],
    'valid_mae': [],
    'valid_loss': []
}

In [None]:
best_valid_r2 = -np.inf 

print("Start Training:")
for epoch in range(CONFIG.N_EPOCHS):
    epoch_start_time = time.time()
    MAE.reset()
    R2.reset()
    LOSS.reset()
    model.train()
    
    epoch_loss = 0
    epoch_mae = 0
    epoch_r2 = 0
        
    for step, (X_batch, y_true) in enumerate(train_dataloader):
        X_batch = X_batch.to('cuda')
        y_true = y_true.to('cuda')
        t_start = time.perf_counter_ns()
        y_pred = model(X_batch)
        loss = LOSS_FN(y_pred, y_true)
        LOSS.update(loss)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        #LR_SCHEDULER.step()
        MAE.update(y_pred, y_true)
        R2.update(y_pred, y_true)
        
        epoch_loss += loss.item()
        epoch_mae += MAE.compute().item()
        epoch_r2 += R2.compute().item()
            
        if not CONFIG.IS_INTERACTIVE and (step+1) == CONFIG.N_STEPS_PER_EPOCH:
            print(
                f'EPOCH {epoch+1:02d}, {step+1:04d}/{CONFIG.N_STEPS_PER_EPOCH} | ' + 
                f'loss: {LOSS.avg:.4f}, mae: {MAE.compute().item():.4f}, r2: {R2.compute().item():.4f}, ' +
                #f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {learning_rate:.2e}',
            )
        elif CONFIG.IS_INTERACTIVE:
            print(
                f'\rEPOCH {epoch+1:02d}, {step+1:04d}/{CONFIG.N_STEPS_PER_EPOCH} | ' + 
                f'loss: {LOSS.avg:.4f}, mae: {MAE.compute().item():.4f}, r2: {R2.compute().item():.4f}, ' +
                #f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {learning_rate:.2e}',
                end='\n' if (step + 1) == CONFIG.N_STEPS_PER_EPOCH else '', flush=True,
            )
            
    epoch_training_time = time.time() - epoch_start_time
    
    # Validate on validation set
    valid_r2, valid_mae, valid_loss = validate(model, valid_dataloader, SCALER, LOG_FEATURES)
    print(
        f'VALIDATION | epoch: {epoch + 1:02d}, '
        f'valid_loss: {valid_loss:.4f}, valid_mae: {valid_mae:.4f}, valid_r2: {valid_r2:.4f}'
    )
        
    # Log metrics for this epoch
    metrics['epoch'].append(epoch + 1)
    metrics['loss'].append(epoch_loss / len(train_dataloader))
    metrics['mae'].append(epoch_mae / len(train_dataloader))
    metrics['r2'].append(epoch_r2 / len(train_dataloader))
    #metrics['lr'].append(LR_SCHEDULER.get_last_lr()[0])
    metrics['lr'].append(learning_rate)
    metrics['training_time'].append(epoch_training_time)
    metrics['valid_r2'].append(valid_r2)
    metrics['valid_mae'].append(valid_mae)
    metrics['valid_loss'].append(valid_loss)

    # Save the model if validation R2 improves
    if valid_r2 > best_valid_r2:
        best_valid_r2 = valid_r2
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Saved Best Model at Epoch {epoch + 1} with R2: {valid_r2:.4f}')

# Save metrics to a file
import json
with open('metrics6.json', 'w') as f:
    json.dump(metrics, f)

In [None]:
import json
import matplotlib.pyplot as plt

# Load the metrics
with open('metrics6.json', 'r') as f:
    metrics = json.load(f)

# Plotting training and validation metrics
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss plot
axes[0, 0].plot(metrics['epoch'], metrics['loss'], label='Train Loss')
axes[0, 0].plot(metrics['epoch'], metrics['valid_loss'], label='Valid Loss')
axes[0, 0].set_title('Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()

# MAE plot
axes[0, 1].plot(metrics['epoch'], metrics['mae'], label='Train MAE')
axes[0, 1].plot(metrics['epoch'], metrics['valid_mae'], label='Valid MAE')
axes[0, 1].set_title('Mean Absolute Error (MAE)')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('MAE')
axes[0, 1].legend()

# R2 plot
axes[1, 0].plot(metrics['epoch'], metrics['r2'], label='Train R2')
axes[1, 0].plot(metrics['epoch'], metrics['valid_r2'], label='Valid R2')
axes[1, 0].set_title('R2 Score')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('R2 Score')
axes[1, 0].legend()

# Learning rate plot
axes[1, 1].plot(metrics['epoch'], metrics['lr'], label='Learning Rate')
axes[1, 1].set_title('Learning Rate')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Validate on validation set
VALID_ROWS = []
model.eval()

for X_sample_valid, valid_id in tqdm(valid_dataset):
    with torch.no_grad():
        y_pred = model(X_sample_valid.unsqueeze(0).to('cuda')).detach().cpu().numpy()
    
    y_pred = SCALER.inverse_transform(y_pred).squeeze()
    row = {'id': valid_id}
    
    for k, v in zip(CONFIG.TARGET_COLUMNS, y_pred):
        if k in LOG_FEATURES:
            row[k] = 10 ** v
        else:
            row[k] = v

    VALID_ROWS.append(row)
    
valid_predict_df = pd.DataFrame(VALID_ROWS)
print(valid_predict_df.head())

In [None]:
# valid_y_true
print(valid[['id'] + CONFIG.TARGET_COLUMNS].head())
valid_y_true = torch.tensor(valid[CONFIG.TARGET_COLUMNS].to_numpy()).to('cuda')

In [None]:
# Evaluate valid scores
valid_y_pred = torch.tensor(valid_predict_df[CONFIG.TARGET_COLUMNS].to_numpy()).to('cuda')

with torch.no_grad():
    # Calculate R2 Loss
    print("Validation R2 Loss (using r2_loss):", r2_loss(valid_y_pred, valid_y_true))

    # Loss function (smooth L1 loss)
    valid_loss = LOSS_FN(valid_y_pred, valid_y_true)
    print("Validation loss (Smooth L1 loss): ", valid_loss)

In [None]:
MAE_valid = torchmetrics.regression.MeanAbsoluteError().to('cuda')
R2_valid = torchmetrics.regression.R2Score(num_outputs=CONFIG.N_TARGETS, multioutput='uniform_average').to('cuda')

print("Torch R2 valid:", R2_valid(valid_y_pred, valid_y_true))
print("Torch MAE valid:", MAE_valid(valid_y_pred, valid_y_true))

In [None]:
# Predict on test set
SUBMISSION_ROWS = []
model.eval()

for X_sample_test, test_id in tqdm(test_dataset):
    with torch.no_grad():
        y_pred = model(X_sample_test.unsqueeze(0).to('cuda')).detach().cpu().numpy()
    
    y_pred = SCALER.inverse_transform(y_pred).squeeze()
    row = {'id': test_id}
    
    for k, v in zip(CONFIG.TARGET_COLUMNS, y_pred):
        if k in LOG_FEATURES:
            row[k.replace('_mean', '')] = 10 ** v
        else:
            row[k.replace('_mean', '')] = v

    SUBMISSION_ROWS.append(row)
    
submission_df = pd.DataFrame(SUBMISSION_ROWS)
print(submission_df.head())
submission_df.to_csv('submission.csv', index=False)
print("Submit!")