In [None]:
import pandas as pd
import os
import numpy as np
import yaml
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from ecoperceiver import EcoPerceiverModel, EcoPerceiverConfig, EcoPerceiverDataset, ep_collate
from tqdm import tqdm

In [None]:
run = 'default'
seed_checkpoints = [
    ('seed_0', 6),
    ('seed_10', 9),
    ('seed_20', 10),
    ('seed_30', 6),
    ('seed_40', 6),
    ('seed_50', 13),
    ('seed_60', 8),
    ('seed_70', 7),
    ('seed_80', 8),
    ('seed_90', 12),
]

DATA_DIR = Path('data') / 'carbonsense'
RUN_DIR = Path('runs') / run
CONFIG_PATH = RUN_DIR / 'config.yml'

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

TEST_SITES = config['data']['test_sites']

In [None]:
dataset_test = EcoPerceiverDataset(
    DATA_DIR, TEST_SITES,
    context_length=config['model']['context_length'],
    targets=config['data']['target_columns']
    )

data_loader_test = DataLoader(
    dataset_test,
    batch_size=128,
    num_workers=config['data']['num_workers'], pin_memory=config['data']['pin_memory'],
    collate_fn=ep_collate)

In [None]:
config['model']['spectral_data_channels'] = dataset_test.num_channels()
config['model']['tabular_inputs'] = dataset_test.columns()
device = torch.device('cuda')
model = EcoPerceiverModel(EcoPerceiverConfig(**config['model']))

datatype = torch.float32
cuda_major = torch.cuda.get_device_properties(device).major
if cuda_major >= 8:
    datatype = torch.bfloat16

In [None]:
inference_df = dataset_test.get_target_dataframe()
inference_df.set_index(['SITE_ID', 'timestamp'], inplace=True, drop=True)

In [None]:
inference_df = inference_df.sort_index()

In [None]:
for seed, checkpoint in seed_checkpoints:
    checkpoint_path = RUN_DIR / seed / f'checkpoint-{checkpoint}.pth'
    results_path = RUN_DIR / seed / f'results-{checkpoint}.csv'
    # if os.path.exists(results_path):
    #     print(f'Already have results for {seed}-{checkpoint}, skipping...')
    #     continue
    
    weights = torch.load(checkpoint_path)
    model.load_state_dict(weights['model'])
    model.to(device)
    model.eval()
    print(f'Running results for {seed}...')
    for batch in tqdm(data_loader_test):
        with torch.cuda.amp.autocast(dtype=datatype):
            op = model(batch)
            outputs = op['logits'].cpu().tolist()
            # Update inference df
            idx = pd.MultiIndex.from_tuples(zip(batch['site_ids'], batch['timestamps']), names=['SITE_ID', 'timestamp'])

            inference_df.update(pd.DataFrame(outputs, columns=['Inferred'], index=idx))
    inference_df.to_csv(results_path)
    inference_df['Inferred'] = np.nan