In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('..')) # modify to point to plant_id
if module_path not in sys.path:
    sys.path.append(module_path)
print(sys.path)

In [None]:
import plant_id
from plant_id import callbacks as cb
from plant_id import lit_models

import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
import torch
import wandb

In [None]:
from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args

In [None]:
import plant_id.metadata.inat as metadata
import yaml
config_file = "training_config.yml"
with open(config_file, "rb") as file:
        config = yaml.load(file, Loader=yaml.Loader)

In [None]:
# generate run name
WANDB_RUN_NAME = f"{config['PRETRAINED_STEM']} wd={config['WEIGHT_DECAY']}"
if config['REDUCE_LR_ON_PLATEAU']:
    WANDB_RUN_NAME += f" ROP={config['ROP_COEFF']}"
if config['LAYERWISE_LR_DECAY']:
    WANDB_RUN_NAME += f" LLRD={config['LLRD_COEFF']}"
if not config['HEAD_DWSCONV']:
    WANDB_RUN_NAME += ' head_dws=False'
if config['TRAIN_SET'] == 'full':
    WANDB_RUN_NAME += ' ds=full'

print(WANDB_RUN_NAME)

In [None]:
## process args
if config['AUTO_LR_FIND']:
    lr = 'autotune'
else:
    lr = config['LR']

if config['AUTO_SCALE_BATCH_SIZE']==None:
    auto_scale_batch_size = 'False'
else:
    auto_scale_batch_size = config['AUTO_SCALE_BATCH_SIZE']

if config['LOADED_MODEL']==None:
    loaded_model = 'None'
else:
    loaded_model = config['LOADED_MODEL']

FC_DIM = metadata.NUM_PLANT_CLASSES
## wandb metadata
training_config = dict (
        dataset_id = "iNat-2021",
        infra = "BLM MBA",
        pretrained_stem = config['PRETRAINED_STEM'],
        resolution = config['RESOLUTION'],
        learning_rate = lr,
        batch_size = config['BATCH_SIZE'], # need better logging: if autoscaled, this isn't the actual bs
        auto_scale_batch_size = auto_scale_batch_size,
        weight_decay = config['WEIGHT_DECAY'],
        fc_dropout = config['FC_DROPOUT'],
        fc_dim = FC_DIM,
        precision = config['PRECISION'],
        limit_train_batches = config['LIMIT_TRAIN_BATCHES'],
        dataset_type = config['TRAIN_SET'],
        loaded_model = loaded_model,
        reduce_lr_on_plateau = config['REDUCE_LR_ON_PLATEAU'],
        use_swa = config['STOCHASTIC_WEIGHT_AVERAGING'],
        early_stopping = config['EARLY_STOPPING'],
        loss = config['LOSS'],
        layerwise_lr_decay = config['LAYERWISE_LR_DECAY'],
        layerwise_lr_decay_coeff = config['LLRD_COEFF'],
        head_dwsconv = config['HEAD_DWSCONV'],
        rop_coeff = config['ROP_COEFF'],
        rop_threshold = config['ROP_THRESHOLD'],
        rop_threshold_mode = config['ROP_THRESHOLD_MODE'],
    )  

In [None]:
pretrained_stem = training_config['pretrained_stem']
print("Pretrained base model: ", pretrained_stem)

## different models have different names for stem and blocks
if 'lambda' in pretrained_stem:
    mode = 'lambda'
elif 'resnet' in pretrained_stem:
    mode = 'resnet'
elif 'efficientnet' in pretrained_stem:
    mode = 'efficientnet'
elif 'convnext' in pretrained_stem:
    mode = 'convnext'
elif 'resnext' in pretrained_stem:
    mode = 'resnext'

In [None]:
%load_ext autoreload

from pathlib import Path
import pytorch_lightning as pl
import plant_id.models as models
import plant_id.data as data
#import plant_id.util

%autoreload 2

In [None]:
%env PYTORCH_ENABLE_MPS_FALLBACK=1
!echo $PYTORCH_ENABLE_MPS_FALLBACK

In [None]:
## set log dir and format filename
#log_dir = Path("training") / "logs"
log_dir = 'logs'
logger = pl.loggers.TensorBoardLogger(log_dir)
experiment_dir = logger.log_dir

goldstar_metric = "validation/acc"
filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}"
if goldstar_metric == "validation/acc":
    filename_format += "-validation.acc={validation/acc:.3f}"

## callbacks
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=2,
    filename=filename_format,
    monitor=goldstar_metric,
    mode="max",
    auto_insert_metric_name=False,
    dirpath=experiment_dir,
)
lr_callback = pl.callbacks.LearningRateMonitor()
callbacks = [checkpoint_callback, lr_callback, ]
if config['EARLY_STOPPING']:
    callbacks.append(pl.callbacks.EarlyStopping(
    monitor="validation/acc",
    mode="max",
    patience=1,
    ))
if config['STOCHASTIC_WEIGHT_AVERAGING']:
    callbacks.append(pl.callbacks.StochasticWeightAveraging(swa_lrs=1e-4,
                                                            swa_epoch_start=0.5,
                                                            annealing_epochs=5,
    ))

## set up lit model and datamodule
DATA_CONFIG = {"input_dims" : (3, training_config['resolution'], training_config['resolution'])}
MODEL_CONFIG = {"pretrained_stem" : pretrained_stem, "fc_dim": FC_DIM,
                "fc_dropout" : training_config['fc_dropout'], "mode": mode}    

model = models.FinetuningCNN(data_config=DATA_CONFIG, model_config=MODEL_CONFIG)
lit_model = lit_models.LitFinetuningCNN(model)
datamodule = data.iNatDataModule()

if config['LOADED_MODEL'] != 'None':
    loaded_lit_model = torch.load(config['LOADED_MODEL'])
    #TODO: replace with load_state_dict, this currently doesn't work
    lit_model.state_dict = loaded_lit_model.state_dict # need this step for wandb to log
    del(loaded_lit_model)
    
# WATCH OUT HARDCODED HERE
use_wandb = False

## optionally, watch model with wandb    
if use_wandb: #args.wandb
    wandb.init(
        project=config['WANDB_PROJECT_NAME'],
        notes="testing wandb integration",
        name=WANDB_RUN_NAME,
        tags=["test"],
        config=training_config,
    )
    logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train")
    logger.watch(lit_model)
    experiment_dir = logger.experiment.dir

In [None]:
#%env PYTORCH_ENABLE_MPS_FALLBACK=1

## training code
trainer = pl.Trainer(max_epochs=config['NUM_EPOCHS'],
                     devices=-1, # can only address one GPU on mps?
                     accelerator='mps', #heyo
                     callbacks=callbacks, logger=logger,
                     auto_scale_batch_size=None,
                     auto_lr_find=False,
#                     precision=config['PRECISION'],
                     limit_train_batches=config['LIMIT_TRAIN_BATCHES'],
                     )
trainer.tune(lit_model, datamodule=datamodule)
trainer.fit(lit_model, datamodule=datamodule)

#### config['BATCH_SIZE']

In [None]:
import timm

timm.list_models(pretrained=True)