In [None]:
from dataset import EcoPerceiverLoaderConfig, EcoPerceiverBatch, EcoPerceiverDataset
import torch
from torch import nn
import numpy as np
import time
from tqdm import tqdm
import sqlite3
from torch.utils.data import DataLoader
from einops import rearrange
from components import EcoSageConfig, ECInputModule, ModisLinearInputModule, AttentionLayer, FourierFeatureMapping, GeoInputModule, IGBPInputModule, PhenocamRGBInputModule
from typing import Tuple, Dict, Optional

In [None]:
sites = ['US-Whs', 'CA-Gro', 'DE-Geb', 'CN-Din']
config = EcoPerceiverLoaderConfig(context_window_length=64, targets=['NEE'])
ds = EcoPerceiverDataset('/data/fluxes/carbonsense_v2', config)
dl = DataLoader(ds, batch_size=16, shuffle=True, collate_fn=ds.collate_fn, num_workers=8)

In [None]:
class EcoSage(nn.Module):
    def __init__(self, config: EcoSageConfig):
        super().__init__()
        self.config = config
        self.windowed_modules = nn.ModuleList([ECInputModule(config)])
        self.auxiliary_modules = nn.ModuleList([ModisLinearInputModule(config), GeoInputModule(config), IGBPInputModule(config), PhenocamRGBInputModule(config)])

        self.input_hidden_dim = 2 * self.config.num_frequencies + self.config.input_embedding_dim
        self.latent_embeddings = nn.Embedding(self.config.context_length, self.config.latent_space_dim)
        layers = []
        for l in self.config.layers:
            if l in ['w', 'c']:
                layers.append(AttentionLayer(self.config.latent_space_dim, self.config.num_heads, self.config.mlp_ratio, kv_hidden_size=self.input_hidden_dim))
            else:
                layers.append(AttentionLayer(self.config.latent_space_dim, self.config.num_heads, self.config.mlp_ratio))
        self.layers = nn.ModuleList(layers)
        self.output_proj = nn.Linear(self.config.latent_space_dim, 1)
        self.apply(self.initialize_weights)
    
    def initialize_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm)):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)
    
    def forward(self, batch):
        windowed_inputs = []
        windowed_masks = []
        for m in self.windowed_modules:
            ip, mask = m(batch)
            windowed_inputs.append(ip)
            windowed_masks.append(mask)
        windowed_input = torch.cat(windowed_inputs, dim=-2)
        windowed_mask = torch.cat(windowed_masks, dim=-1)

        aux_inputs = []
        aux_masks = []
        for m in self.auxiliary_modules:
            ip, mask = m(batch)
            if ip == None or mask == None:
                continue
            aux_inputs.append(ip)
            aux_masks.append(mask)
        aux_input = torch.cat(aux_inputs, dim=-2)
        aux_mask = torch.cat(aux_masks, dim=-1)

        # print(windowed_input.shape)
        # print(aux_input.shape)
        # print()

        B, L, _ = batch.predictor_values.shape
        hidden = self.latent_embeddings.weight.unsqueeze(0).repeat(B,1,1)

        for i, layer_type in enumerate(self.config.layers):
            if layer_type == 'w':
                hidden = rearrange(hidden, 'B L H -> (B L) H').unsqueeze(1)
                hidden, _ = self.layers[i](hidden, windowed_input, mask=windowed_mask)
                hidden = rearrange(hidden.squeeze(), '(B L) H -> B L H', B=B, L=L)
            elif layer_type == 'c':
                hidden, _ = self.layers[i](hidden, aux_input, mask=aux_mask)
            else:
                hidden, _ = self.layers[i](hidden)
        
        output = self.output_proj(hidden[:,-1,:]).squeeze()
        return output

# note: call Robert at 905 706 8876 (CRA)

config = EcoSageConfig()
model = EcoSage(config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
print('Model created successfully')

In [None]:
for batch in tqdm(dl):
    model(batch)

In [None]:
# with sqlite3.connect('/data/fluxes/carbonsense_v2/carbonsense_v2.sql') as conn:
#     res =conn.execute('SELECT DISTINCT(igbp) FROM site_data;').fetchall()
# [r[0] for r in res]