In [1]:
import os
import math

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, RichModelSummary, RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
from torchvision.ops import generalized_box_iou
import wandb

In [4]:
from config import config
from src.data_module import ClimateNetDataModule
from src.utils import Generic
from src.models.manager import get_model

In [5]:
data_dir = "../autodl-nas/ClimateNet/"
files = Generic.list_files(data_dir)
files[:10]

['../autodl-nas/ClimateNet/data-2002-03-31-01-1_4.nc',
 '../autodl-nas/ClimateNet/data-1997-08-29-01-1_0.nc',
 '../autodl-nas/ClimateNet/data-2001-10-29-01-1_2.nc',
 '../autodl-nas/ClimateNet/data-2003-08-04-01-1_1.nc',
 '../autodl-nas/ClimateNet/data-2011-08-08-01-1_1.nc',
 '../autodl-nas/ClimateNet/data-2006-10-15-01-1_4.nc',
 '../autodl-nas/ClimateNet/data-2001-10-29-01-1_3.nc',
 '../autodl-nas/ClimateNet/data-2010-07-05-01-1_0.nc',
 '../autodl-nas/ClimateNet/data-1999-09-03-01-1_0.nc',
 '../autodl-nas/ClimateNet/data-2010-09-09-01-1_0.nc']

In [6]:
feature_list = config["feature_list"]
feature_list

['TMQ', 'U850', 'V850', 'PRECT']

In [7]:
data_module = ClimateNetDataModule(files, feature_list, 16, num_workers=4, shuffle=True)

In [8]:
ap = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(3, 3, kernel_size=1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU(),
        )
tensor = torch.rand(32, 3, 256, 256)
ap(tensor).shape

torch.Size([32, 3, 1, 1])

# DeepLabV3+ with Attention

In [9]:
model_name = "attention"
model = get_model("attention")

wandb_logger = WandbLogger(project="ClimateNet")

callbacks = [
    #EarlyStopping("val_mean_iou", mode="max", patience=10),
    #ModelCheckpoint(monitor="val_mean_iou", mode="max"),
    RichModelSummary(),
    RichProgressBar(),
]

trainer = Trainer(accelerator="gpu", auto_lr_find=True, callbacks=callbacks, devices=-1, logger=wandb_logger, max_epochs=50, log_every_n_steps=1, reload_dataloaders_every_n_epochs=1, precision=16)
#trainer.tune(model, datamodule=data_module)
trainer.fit(model, datamodule=data_module)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mzhf231298[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669445919493833, max=1.0…

Using 16bit native Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

`Trainer.fit` stopped: `max_epochs=50` reached.


In [10]:
wandb.finish()

VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▅▄▄▄▄▄▃▄▃▃▃▃▂▂▃▂▂▃▂▂▂▂▁▂▂▁▂▂▁▁▂▁▂▁▂▂▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_ar_iou,▁▁▄▆▆▇▅██▅█▄▃▇▆▅▆▆▆▆▆▇▇▄▇▆▆▆▆▆▆▇▆▇▇▆▆▇▆▇
val_bg_iou,▃▆▆▆▇▇█▇▅▁▆▇▇▆▇▇▇▄▇▆█▆▇█▆▅▆▁▆▇▆▅▅▅▂▆▄▅▆▅
val_loss,▄▂▂▁▂▁▁▁▁▂▁▂▂▁▂▁▂▄▂▃▂▂▃▁▃▃▄▇▆▂▅▄█▅▅▅▅█▄▄
val_mean_iou,▁▁▅▇▆█▅██▆█▆▆█▇▇▇▇▇▇▇▇▇▆▇▇▇▆▆▆▆▇▆▆▇▆▇▆▇▇
val_tc_iou,▁▂▆▇▆█▆██▆██████▇▇█▇▇▇▇█▇▇▇▆▆▆▇▇▇▆▇▇▇▆▇▇

0,1
epoch,49.0
train_loss,0.06865
trainer/global_step,1149.0
val_ar_iou,0.35115
val_bg_iou,0.9448
val_loss,0.20247
val_mean_iou,0.5312
val_tc_iou,0.29765
