In [1]:
import torch
import wandb
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from torchgeo.trainers import AutoregressionTask

from src.ndvi_datamodule import NDVIDataModule

In [2]:
# Modify the following configuration as needed
config = {}
config["learning_rate"] = 0.001
config["batch_size"] = 256
config["wandb_project"] = "test"
config["wandb_name"] = "test"
config["max_epochs"] = 1
config["num_workers"] = 4 if torch.cuda.is_available() else 0
config["ckpt_path"] = None
config["val_split_pct"] = 0.25
config["test_split_pct"] = 0.25
config["num_past_steps"] = 10
config["num_future_steps"] = 3
config["input_size"] = 1
config["hidden_size"] = 32
config["num_layers"] = 1
config["teacher_force_prob"] = None

In [3]:
data_dir = "data"

In [4]:
datamodule = NDVIDataModule(
    data_dir=data_dir,
    batch_size=config["batch_size"],
    val_split_pct=config["val_split_pct"],
    test_split_pct=config["test_split_pct"],
    num_workers=config["num_workers"],
    num_past_steps=config["num_past_steps"],
    num_future_steps=config["num_future_steps"],
)

In [5]:
model = AutoregressionTask(
    model="seq2seq",
    input_size=config["input_size"],
    loss="mse",
    lr=config["learning_rate"],
    hidden_size=config["hidden_size"],
    output_sequence_len=config["num_future_steps"],
    num_layers=config["num_layers"],
    teacher_force_prob=config["teacher_force_prob"],
)

In [6]:
wandb_logger = WandbLogger(
    name=config["wandb_name"], save_dir="wandb-logs", project=config["wandb_project"]
)

In [None]:
trainer = Trainer(
    logger=wandb_logger,
    max_epochs=config["max_epochs"],
    log_every_n_steps=1,
    # num_sanity_val_steps=0,
    # limit_train_batches=1,
    # limit_val_batches=1,
    # limit_test_batches=1,
    # fast_dev_run=1,
)

In [None]:
trainer.fit(
    model,
    datamodule=datamodule,
    ckpt_path=config["ckpt_path"],
)
wandb_logger.experiment.config.update(config)
wandb.finish()