In [1]:
import os
import sys

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import io
import torch
import subprocess
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import ticker

from glob import glob
from datetime import datetime
from huggingface_hub import snapshot_download

In [2]:
# repo_dir = "satvision-toa"

# if not os.path.exists(repo_dir):
#     subprocess.run(["git", "clone", "https://github.com/nasa-nccs-hpda/satvision-toa"])
# else:
#     subprocess.run(["git", "-C", repo_dir, "pull"])

In [3]:
sys.path.append("../../../satvision-toa")
from satvision_toa.configs.config import _C, _update_config_from_file
from satvision_toa.models.mim import build_mim_model

In [4]:
from satvision_toa.datasets.ocean_color_dataset import OceanColorDataset
from satvision_toa.models.decoders.ocean_color_decoder import (
    OceanColorUNETV2, OceanColorFCNV3,
    OceanColorFCNV2, OceanColorFCNV2point5
)
from satvision_toa.models.decoders.ocean_color_e2e_decoder import (
    OceanColorFCN, OceanColorUNET
)

from satvision_toa.transforms.ocean_color import (
    GlobalMinMaxNorm, PBMinMaxNorm, ScaleAndOffset
)

from satvision_toa.data_utils.utils_ocean_color import (
    load_config, get_dataloaders, train_model
)

In [5]:
full_dir = "/panfs/ccds02/nobackup/people/ajkerr1/SatVision/OceanColor"

## SV Model

In [6]:
train_dir = os.path.join(full_dir, "chips/ft/chips_ft")
test_dir = os.path.join(full_dir, "chips/ft/val_chips_ft")

In [None]:
config = load_config()
model = build_mim_model(config)
model = OceanColorFCNV2point5(
    swin_encoder=model.encoder, freeze_encoder=True
)

=> merge config from /home/ajkerr1/.cache/huggingface/hub/models--nasa-cisto-data-science-group--downstream-satvision-toa-3dclouds/snapshots/1c6d3b4fba1a476956027e56d5dd9708bdfef0ba/mim_pretrain_swinv2_satvision_giant_128_window08_50ep.yaml


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
# ADJUST FOR TRAINING UNET VS SATVISION
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
    train_dir, test_dir, num_inputs=14, batch_size=64
)

num_epochs = 10
save_every = 20
test_every = 1
save_path = "sv_unet"
pdf_path = "pred_pdfs/svtoa"
metrics_filename = "sv_metrics"
sv_train_losses, sv_val_losses = train_model(
    model, train_dataloader, val_dataloader, test_dataloader, num_epochs=num_epochs, 
    save_path=save_path, save_every=save_every, test_every=test_every, pdf_path=pdf_path, 
    metrics_filename=metrics_filename
)

## Plain UNET

In [None]:
train_dir = os.path.join(full_dir, "chips/e2e/chips_6_27")
test_dir = os.path.join(full_dir, "chips/e2e/val_chips_e2e")

In [None]:
model = OceanColorUNET(in_channels=12, out_channels=1)

In [None]:
# ADJUST FOR TRAINING UNET VS SATVISION
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
    train_dir, test_dir, num_inputs=12, batch_size=64
)

num_epochs = 10
save_every = 20
test_every = 1
save_path = "e2e_unet"
pdf_path = "pred_pdfs/e2e"
metrics_filename = "e2e_metrics"
e2e_train_losses, e2e_val_losses = train_model(
    model, train_dataloader, val_dataloader, test_dataloader, num_epochs=num_epochs, 
    save_path=save_path, save_every=save_every, test_every=test_every, pdf_path=pdf_path, 
    metrics_filename=metrics_filename
)

## Viz Epoch-level Metrics
*Metrics other than loss have been divided into epochs 0-20 and 20 onward for visual clarity.*

### SatVision Metrics

#### Train, Val Losses

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(40, 25))
ax1.plot(sv_train_losses)
ax2.plot(sv_val_losses)
ax1.set_xlabel("Epoch", fontsize=25)
ax2.set_xlabel("Epoch", fontsize=25)
ax1.set_ylabel("Loss", fontsize=25)
ax2.set_ylabel("Loss", fontsize=25)
ax1.tick_params(axis='both', which='major', labelsize=15)
ax2.tick_params(axis='both', which='major', labelsize=15)
ax1.set_title("SatVision Train Loss", fontsize=30)
ax2.set_title("SatVision Val Loss", fontsize=30)
ax1.grid()
ax2.grid()
plt.show()

#### R2, RMSE, SSIM, PSNR

In [None]:
sv_metrics_filename = 'sv_metrics_epoch_metrics.csv'
df = pd.read_csv(sv_metrics_filename)
df = df[['epoch', 'r2', 'rmse', 'ssim', 'psnr']]
df1 = df.iloc[0:21]
df2 = df.iloc[21:]

fig, axes = plt.subplots(2, 2, figsize=(10, 6))
axes = np.array(axes).flatten()
metric_columns = ['r2', 'rmse', 'ssim', 'psnr']
for idx, metric in enumerate(metric_columns):
    axes[idx].plot(df['epoch'], df[metric], label=metric)
    axes[idx].set_title(f'{metric.upper()} vs Epoch')
    axes[idx].set_xlabel('Epoch')
    axes[idx].set_ylabel(metric.upper())
    axes[idx].grid(True, alpha=0.3)
    axes[idx].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))

plt.tight_layout()
plt.show()

### E2E Metrics

#### Train, Val Losses

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(40, 25))
ax1.plot(e2e_train_losses)
ax2.plot(e2e_val_losses)
ax1.set_xlabel("Epoch", fontsize=25)
ax2.set_xlabel("Epoch", fontsize=25)
ax1.set_ylabel("Loss", fontsize=25)
ax2.set_ylabel("Loss", fontsize=25)
ax1.tick_params(axis='both', which='major', labelsize=15)
ax2.tick_params(axis='both', which='major', labelsize=15)
ax1.set_title("E2E Training Loss", fontsize=30)
ax2.set_title("E2E Validation Loss", fontsize=30)
ax1.grid()
ax2.grid()
plt.show()

#### R2, RMSE, SSIM, PSNR

In [None]:
e2e_metrics_filename = 'e2e_metrics_epoch_metrics.csv'
df = pd.read_csv(e2e_metrics_filename)
df = df[['epoch', 'r2', 'rmse', 'ssim', 'psnr']]
df1 = df.iloc[0:21]
df2 = df.iloc[21:]

fig, axes = plt.subplots(2, 2, figsize=(10, 6))
axes = np.array(axes).flatten()
metric_columns = ['r2', 'rmse', 'ssim', 'psnr']
for idx, metric in enumerate(metric_columns):
    axes[idx].plot(df['epoch'], df[metric], label=metric)
    axes[idx].set_title(f'E2E {metric.upper()} vs Epoch')
    axes[idx].set_xlabel('Epoch')
    axes[idx].set_ylabel(metric.upper())
    axes[idx].grid(True, alpha=0.3)
    axes[idx].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
    axes[idx].yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))

plt.tight_layout()
plt.show()