In [1]:
import os
import torch
from torch.optim import AdamW
from vit_foundry.perceiver.model import Perceiver, PerceiverConfig
from vit_foundry.perceiver.dataset import FluxDataLoader, FluxTimeSeriesValidationDataLoader
import numpy as np
from tqdm import tqdm
import math
torch.multiprocessing.set_sharing_strategy('file_system')

from torch.utils.tensorboard import SummaryWriter

In [2]:
train_site_ratio = 0.8
batch_size = 256
context_length = 32
target_columns = ['NEE_VUT_REF']
num_epochs = 10
learning_rate = 5e-6
run_name = 'runs/full_run_time_series'

In [3]:
DIRTY_SITES = ['DE-Akm', 'CH-Aws', 'AR-Vir', 'US-GBT', 'ZM-Mon', 'DE-Lnf', 'US-Wi3', 'BE-Bra',
               'SD-Dem', 'CA-Obs', 'IT-Ro2', 'US-GLE', 'RU-Cok', 'SE-Svb', 'US-WCr', 'IT-Ren',
               'US-ICs', 'SJ-Adv', 'IT-Cpz', 'US-Myb', 'US-Los', 'US-Syv', 'CZ-BK2']

In [4]:
DATA_DIR = os.path.join('data', 'processed', 'v3')
SITES = [s for s in os.listdir(DATA_DIR) if s not in DIRTY_SITES][:10] # for prototyping

num_train_sites = int(len(SITES) * 0.8)
TRAIN_SITES = SITES[:num_train_sites]
VAL_SITES = SITES[num_train_sites:]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
print(TRAIN_SITES)
print(VAL_SITES)

['US-HB3', 'CA-Gro', 'AR-TF1', 'US-Mo1', 'US-NR1', 'BE-Vie', 'US-UM3', 'CG-Tch']
['IT-MBo', 'AU-Dry']


In [6]:
train_dl = FluxDataLoader(
    DATA_DIR,
    TRAIN_SITES,
    context_length=context_length,
    target_columns=target_columns,
    time_series=True,
    shuffle=True,
    #num_workers=0,
    batch_size=batch_size
)

val_dl = FluxTimeSeriesValidationDataLoader(
    DATA_DIR,
    VAL_SITES,
    context_length=context_length,
    batch_size=batch_size,
    target_columns=target_columns
)

# val_dl = FluxDataLoader(
#     DATA_DIR,
#     VAL_SITES,
#     context_length=context_length,
#     target_columns=target_columns,
#     time_series=True,
#     shuffle=False,
#     num_workers=16,
#     batch_size=batch_size
# )

In [7]:

# reload(sys.modules['vit_foundry.perceiver'])
# from vit_foundry.perceiver import Perceiver, PerceiverConfig, FluxDataLoader

config = PerceiverConfig(
    tabular_inputs=tuple(train_dl.dataset.columns()),
    obs_dropout=0.3,
    input_embedding_dim=32,
    latent_hidden_dim=128,
    spectral_data_channels=train_dl.dataset.num_channels(),
    spectral_data_resolution=(8,8),
    context_length=context_length,
    layers='cscscscscsss'
)

model = Perceiver(config)
model.to(device)
optim = AdamW(model.parameters(), lr=learning_rate)


# o, m, s, p = next(iter(train_dl))
# op = model(o, m, s, p)
# print(((op['logits'] - p.to(device)) ** 2).nanmean())

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

1911833

In [9]:
#from time import time
def train_one_epoch(model, dataloader, optimizer, writer, epoch):
    model.train()
    total_loss = 0.0
    #times = [time()]
    for idx, batch in enumerate(tqdm(dataloader)):
        #times.append(time())
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            op = model(*batch)
        #times.append(time())
        loss = op['loss']
        if not math.isfinite(loss.item()):
            print(batch[-1].squeeze().tolist())
            print(op['logits'].squeeze().tolist())
            print(batch[0].isnan().sum())
            print(batch[2].sum())
        assert math.isfinite(loss.item()), f'Loss is {loss}, stopping training.'
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        writer.add_scalar('train_loss', loss.item(), epoch * len(dataloader) + idx)

        #times.append(time())
    return total_loss / len(dataloader)#, times

def val_one_epoch(model, dataloader, writer, epoch):
    model.eval()
    total_loss = 0.0
    for idx, batch in enumerate(tqdm(dataloader)):
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            op = model(*batch)
        loss = op['loss']
        assert math.isfinite(loss.item()), f'Loss is {loss}, stopping training.'
        total_loss += loss.item()
        dataloader.update_inferred_values(op['logits'].tolist())
        writer.add_scalar('val_loss', loss.item(), epoch * len(dataloader) + idx)
    return total_loss / len(dataloader)

In [10]:
writer = SummaryWriter(run_name)
train_losses = []
val_losses = []
for i in range(num_epochs):
    val_losses.append(val_one_epoch(model, val_dl, writer, i))
    writer.add_scalar("val_batch_loss", val_losses[-1], i)
    train_losses.append(train_one_epoch(model, train_dl, optim, writer, i))
    writer.add_scalar("train_batch_loss", train_losses[-1], i)
    writer.flush()

  0%|          | 0/8753 [00:00<?, ?it/s]

  3%|▎         | 276/8753 [00:12<06:15, 22.58it/s]


KeyboardInterrupt: 