In [None]:
%cd crop-type-segmentation/

In [2]:
from pathlib import Path

import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb

from prithvi import PrithviSemanticSegmentation
from datamodule import CropTypeDataModule

In [3]:
# Modify the following configuration as needed
config = {}
config["num_classes"] = 17
config["num_frames"] = 7
config["months"] = [2, 4, 5, 6, 7, 8, 9]
config["img_size"] = 224
config["learning_rate"] = 0.001
config["decoder_num_convs"] = 1
config["in_channels"] = 6
config["channels"] = [0, 1, 2, 3, 4, 5]
config["batch_size"] = 5
config["wandb_project"] = "test"
config["wandb_name"] = "test"
config["max_epochs"] = 1
config["num_workers"] = 4 if torch.cuda.is_available() else 0
config["ckpt_path"] = None
config["train_percent"] = 0.7

### Load the data

In [4]:
data_dir = Path("data")

In [5]:
datamodule = CropTypeDataModule(
    data_dir=data_dir,
    months=config["months"],
    channels=config["channels"],
    train_percent=config["train_percent"],
    img_size=config["img_size"],
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)

### Load and train the model

In [6]:
model = PrithviSemanticSegmentation(
    num_classes=config["num_classes"],
    in_channels=config["in_channels"],
    num_frames=config["num_frames"],
    decoder_num_convs=config["decoder_num_convs"],
    img_size=config["img_size"],
    learning_rate=config["learning_rate"],
)

In [7]:
wandb_logger = WandbLogger(
    name=config["wandb_name"], save_dir="wandb-logs", project=config["wandb_project"]
)

In [None]:
trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=config["max_epochs"],
    log_every_n_steps=1,
    # limit_train_batches=1,
    # limit_test_batches=1,
    fast_dev_run=1,
)

In [None]:
trainer.fit(
    model,
    datamodule=datamodule,
    ckpt_path=config["ckpt_path"],
)
wandb_logger.experiment.config.update(config)
wandb.finish()