In [1]:
# !pip install torchmetrics

In [2]:
import subprocess
import sys
import pkg_resources
from collections import defaultdict

def analyze_packages():
    print("=== COMPREHENSIVE PACKAGE ANALYSIS ===\n")
    
    # Get pip list
    result = subprocess.run([sys.executable, '-m', 'pip', 'list', '--format=json'], 
                           capture_output=True, text=True)
    
    pip_packages = {}
    if result.returncode == 0:
        import json
        pip_data = json.loads(result.stdout)
        pip_packages = {pkg['name'].lower(): pkg['version'] for pkg in pip_data}
    
    # Get pkg_resources info
    pkg_resources_packages = {}
    location_map = defaultdict(list)
    
    for dist in pkg_resources.working_set:
        pkg_name = dist.project_name.lower()
        pkg_resources_packages[pkg_name] = {
            'version': dist.version,
            'location': dist.location
        }
        location_map[dist.location].append(f"{dist.project_name} ({dist.version})")
    
    # Show packages by location
    print("PACKAGES BY LOCATION:")
    for location in sys.path:
        if location in location_map:
            print(f"\n{location}:")
            for pkg in sorted(location_map[location]):
                print(f"  {pkg}")
    
    # Show discrepancies
    print("\n=== PACKAGE STATUS ===")
    all_packages = set(pip_packages.keys()) | set(pkg_resources_packages.keys())
    
    for pkg in sorted(all_packages):
        pip_version = pip_packages.get(pkg, "NOT FOUND")
        pkg_resources_info = pkg_resources_packages.get(pkg, {})
        pkg_resources_version = pkg_resources_info.get('version', "NOT FOUND")
        location = pkg_resources_info.get('location', "UNKNOWN")
        
        status = "✓" if pip_version != "NOT FOUND" and pkg_resources_version != "NOT FOUND" else "✗"
        print(f"{status} {pkg:25} pip:{pip_version:15} pkg_resources:{pkg_resources_version:15}")
        
        if pkg_resources_version != "NOT FOUND":
            print(f"    Location: {location}")

# analyze_packages()

  import pkg_resources


In [3]:
import os
import sys

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import io
import torch
import subprocess
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from glob import glob
from datetime import datetime
from huggingface_hub import snapshot_download

In [4]:
# print(f"Python executable: {sys.executable}")
# print(f"Python version: {sys.version}")
# print(f"\nPython path (sys.path):")
# for i, path in enumerate(sys.path):
#     print(f"{i:2d}: {path}")

In [5]:
# repo_dir = "satvision-toa"

# if not os.path.exists(repo_dir):
#     subprocess.run(["git", "clone", "https://github.com/nasa-nccs-hpda/satvision-toa"])
# else:
#     subprocess.run(["git", "-C", repo_dir, "pull"])

In [6]:
sys.path.append("../../../satvision-toa/")
# sys.path.append("satvision-toa")
from satvision_toa.configs.config import _C, _update_config_from_file
from satvision_toa.models.mim import build_mim_model

In [7]:
from satvision_toa.datasets.ocean_color_dataset import OceanColorDataset
from satvision_toa.models.decoders.ocean_color_decoder import (
    OceanColorUNETV2, OceanColorFCNV3,
    OceanColorFCNV2, OceanColorFCNV2point5
)
from satvision_toa.models.decoders.ocean_color_e2e_decoder import (
    OceanColorFCN, OceanColorUNET
)

from satvision_toa.transforms.ocean_color import (
    GlobalMinMaxNorm, PBMinMaxNorm, ScaleAndOffset
)

from satvision_toa.data_utils.utils_ocean_color import (
    load_config, get_dataloaders, train_model
)

In [8]:
full_dir = "/panfs/ccds02/nobackup/people/ajkerr1/SatVision/OceanColor"

## SV Model

In [9]:
train_dir = os.path.join(full_dir, "chips/ft/chips_ft")
test_dir = os.path.join(full_dir, "chips/ft/val_chips_ft")

In [10]:
config = load_config()
model = build_mim_model(config)
model = OceanColorFCNV2point5(
    swin_encoder=model.encoder, freeze_encoder=True
)

=> merge config from /home/ajkerr1/.cache/huggingface/hub/models--nasa-cisto-data-science-group--downstream-satvision-toa-3dclouds/snapshots/1c6d3b4fba1a476956027e56d5dd9708bdfef0ba/mim_pretrain_swinv2_satvision_giant_128_window08_50ep.yaml


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Freezing encoder


In [11]:
# ADJUST FOR TRAINING UNET VS SATVISION
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
    train_dir, test_dir, num_inputs=14, batch_size=64
)

num_epochs = 100
save_every = 20
test_every = 10
save_path = "sv_unet"
pdf_path = "pred_pdfs/svtoa"
metrics_filename = "sv_metrics"
train_losses, val_losses = train_model(
    model, train_dataloader, val_dataloader, test_dataloader, num_epochs=num_epochs, 
    save_path=save_path, save_every=save_every, test_every=test_every, pdf_path=pdf_path, 
    metrics_filename=metrics_filename
)

train & val size: (144, 37)
Decoder weights initialized with Kaiming/Xavier initialization
Starting training for 100 epochs on cuda
Model parameters: 640,583,393


Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]

Epoch 1/100 [Train]:   0%|          | 0/6 [00:07<?, ?it/s, Loss=0.339812, Avg=0.339812, LR=1.00e-04][A
Epoch 1/100 [Train]:  17%|█▋        | 1/6 [00:07<00:38,  7.66s/it, Loss=0.339812, Avg=0.339812, LR=1.00e-04][A
Epoch 1/100 [Train]:  17%|█▋        | 1/6 [00:13<00:38,  7.66s/it, Loss=0.357237, Avg=0.348525, LR=1.00e-04][A
Epoch 1/100 [Train]:  33%|███▎      | 2/6 [00:13<00:27,  6.77s/it, Loss=0.357237, Avg=0.348525, LR=1.00e-04][A
Epoch 1/100 [Train]:  33%|███▎      | 2/6 [00:19<00:27,  6.77s/it, Loss=0.332436, Avg=0.343162, LR=1.00e-04][A
Epoch 1/100 [Train]:  50%|█████     | 3/6 [00:19<00:19,  6.50s/it, Loss=0.332436, Avg=0.343162, LR=1.00e-04][A
Epoch 1/100 [Train]:  50%|█████     | 3/6 [00:26<00:19,  6.50s/it, Loss=0.301666, Avg=0.332788, LR=1.00e-04][A
Epoch 1/100 [Train]:  67%|██████▋   | 4/

Starting testing...



Testing:   0%|          | 0/1 [00:00<?, ?it/s][A
Testing: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s][A
Training Progress:   1%|          | 1/100 [00:38<1:03:55, 38.74s/it, Train_Loss=0.3336, Val_Loss=0.4838, Val_MAE=0.4717, Time=38.7s]

Plots saved to pred_pdfs/svtoa/preds_day_2025_07_28_time_11_05_1ep.pdf



  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]

Training Progress:   1%|          | 1/100 [00:45<1:14:34, 45.20s/it, Train_Loss=0.3336, Val_Loss=0.4838, Val_MAE=0.4717, Time=38.7s]


KeyboardInterrupt: 

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(40, 25))
ax1.plot(train_losses, label="Train Loss")
ax2.plot(val_losses, label="Val Loss")
ax1.set_xlabel("Epoch", fontsize=25)
ax2.set_xlabel("Epoch", fontsize=25)
ax1.set_ylabel("Loss", fontsize=25)
ax2.set_ylabel("Loss", fontsize=25)
ax1.tick_params(axis='both', which='major', labelsize=15)
ax2.tick_params(axis='both', which='major', labelsize=15)
ax1.set_title("Training Loss", fontsize=30)
ax2.set_title("Validation Loss", fontsize=30)
ax1.grid()
ax2.grid()
plt.show()

## Plain UNET

In [None]:
train_dir = os.path.join(full_dir, "chips/e2e/chips_6_27")
test_dir = os.path.join(full_dir, "chips/e2e/val_chips_e2e")

In [None]:
model = OceanColorUNET(in_channels=12, out_channels=1)

In [None]:
# ADJUST FOR TRAINING UNET VS SATVISION
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
    train_dir, test_dir, num_inputs=12, batch_size=64
)

num_epochs = 2
save_every = 20
test_every = 10
save_path = "e2e_unet"
pdf_path = "pred_pdfs/e2e"
metrics_filename = "e2e_metrics"
train_losses, val_losses = train_model(
    model, train_dataloader, val_dataloader, test_dataloader, num_epochs=num_epochs, 
    save_path=save_path, save_every=save_every, test_every=test_every, pdf_path=pdf_path, 
    metrics_filename=metrics_filename
)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(40, 25))
ax1.plot(train_losses, label="Train Loss")
ax2.plot(val_losses, label="Val Loss")
ax1.set_xlabel("Epoch", fontsize=25)
ax2.set_xlabel("Epoch", fontsize=25)
ax1.set_ylabel("Loss", fontsize=25)
ax2.set_ylabel("Loss", fontsize=25)
ax1.tick_params(axis='both', which='major', labelsize=15)
ax2.tick_params(axis='both', which='major', labelsize=15)
ax1.set_title("Training Loss", fontsize=30)
ax2.set_title("Validation Loss", fontsize=30)
ax1.grid()
ax2.grid()
plt.show()

In [None]:
df = pd.read_csv('e2e_metrics_epoch_metrics_average.csv')
df = df[['epoch', 'r2', 'rmse', 'ssim', 'psnr']]
df.plot(x='epoch')

In [None]:
df = pd.read_csv('sv_metrics_epoch_metrics_average.csv')
df = df[['epoch', 'r2', 'rmse', 'ssim', 'psnr']]
df1 = df.iloc[0:21]
df2 = df.iloc[21:]

fig, axes = plt.subplots(2, 2, figsize=(10, 6))
axes = np.array(axes).flatten()
metric_columns = ['r2', 'rmse', 'ssim', 'psnr']
for idx, metric in enumerate(metric_columns):
    axes[idx].plot(df['epoch'], df[metric], label=metric)
    axes[idx].set_title(f'{metric.upper()} vs Epoch')
    axes[idx].set_xlabel('Epoch')
    axes[idx].set_ylabel(metric.upper())
    axes[idx].grid(True, alpha=0.3)
    axes[idx].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))

plt.tight_layout()
plt.show()

In [None]:
df = pd.read_csv('e2e_metrics_epoch_metrics_average.csv')
df = df[['epoch', 'r2', 'rmse', 'ssim', 'psnr']]
df1 = df.iloc[0:21]
df2 = df.iloc[21:]

fig, axes = plt.subplots(2, 2, figsize=(10, 6))
axes = np.array(axes).flatten()
metric_columns = ['r2', 'rmse', 'ssim', 'psnr']
for idx, metric in enumerate(metric_columns):
    axes[idx].plot(df['epoch'], df[metric], label=metric)
    axes[idx].set_title(f'{metric.upper()} vs Epoch')
    axes[idx].set_xlabel('Epoch')
    axes[idx].set_ylabel(metric.upper())
    axes[idx].grid(True, alpha=0.3)
    axes[idx].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))

plt.tight_layout()
plt.show()