In [56]:
import os
from dataclasses import dataclass
from logging import getLogger as get_logger
from typing import Dict, List, Optional, Type
import logging
from simple_parsing.helpers.serialization.serializable import Serializable
from simple_parsing.helpers import field
import wandb
from pytorch_lightning.loggers import LightningLoggerBase
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything

from target_prop.config import Config
from target_prop.datasets.dataset_config import DatasetConfig
from target_prop.models import Model, DTP, VanillaDTP, TargetProp, ParallelDTP, BaselineModel
from target_prop.models.model import Model
from target_prop.networks import  Network, ResNet18, ResNet34, SimpleVGG, LeNet, ViT
from target_prop.networks.network import Network
from target_prop.scheduler_config import CosineAnnealingLRConfig, StepLRConfig
from target_prop.utils.hydra_utils import get_outer_class

class Options(Serializable):
    """ All the options required for a run. This dataclass acts as a schema for the Hydra configs.
    
    For more info, see https://hydra.cc/docs/tutorials/structured_config/schema/
    """

    # Configuration for the dataset + transforms.
    dataset: DatasetConfig  # = field(default_factory=DatasetConfig)

    # The model used.
    model: Model.HParams  # = field(default_factory=Model.HParams)
    # The network to be used.
    network: Network.HParams  # = field(default_factory=SimpleVGG.HParams)

    # Keyword arguments for the Trainer constructor.
    trainer: Dict = field(default_factory=dict)  # type: ignore

    # Configs for the callbacks.
    callbacks: Dict = field(default_factory=dict)  # type: ignore

    # Config(s) for the logger(s).
    logger: Dict = field(default_factory=dict)  # type: ignore

    # Wether to run in debug mode or not.
    debug: bool = False

    verbose: bool = False

    # Random seed.
    seed: Optional[int] = None

    # Name for the experiment.
    name: str = ""

logger = get_logger(__name__)
raw_options = OmegaConf.create({'dataset': {'dataset': 'cifar10', 'data_dir': '/home/ono/Dev/scalingDTP/data', 'num_workers': 12, 'shuffle': True, 'normalize': True, 'image_crop_size': 32, 'image_crop_pad': 4, 'val_split': 0.1, 'use_legacy_std': False}, 'model': {'lr_scheduler': {'interval': 'epoch', 'frequency': 1, 'T_max': 85, 'eta_min': 1e-05}, 'batch_size': 128, 'use_scheduler': True, 'feedback_training_iterations': [20, 25, 30, 35, 15], 'max_epochs': 90, 'b_optim': {'type': 'sgd', 'lr': [0.0001, 0.00035, 0.001, 0.002, 0.08], 'weight_decay': None, 'momentum': 0.9}, 'noise': [0.4, 0.4, 0.2, 0.2, 0.08], 'f_optim': {'type': 'sgd', 'lr': [0.04], 'weight_decay': 0.0001, 'momentum': 0.9}, 'beta': 0.7, 'feedback_samples_per_iteration': 1, 'early_stopping_patience': 0, 'init_symetric_weights': False, 'plot_every': 1000}, 'network': {'activation': 'elu', 'batch_size': 128, 'channels': [128, 128, 256, 256, 512], 'bias': True}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'gpus': -1, 'strategy': 'dp', 'min_epochs': 1, 'max_epochs': 90, 'resume_from_checkpoint': None}, 'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/accuracy', 'mode': 'max', 'save_top_k': 1, 'save_last': True, 'verbose': False, 'dirpath': 'checkpoints/', 'filename': 'epoch_{epoch:03d}', 'auto_insert_metric_name': False}, 'early_stopping': {'_target_': 'pytorch_lightning.callbacks.EarlyStopping', 'monitor': 'val/accuracy', 'mode': 'max', 'patience': 100, 'min_delta': 0}, 'model_summary': {'_target_': 'pytorch_lightning.callbacks.RichModelSummary', 'max_depth': 1}, 'rich_progress_bar': {'_target_': 'pytorch_lightning.callbacks.RichProgressBar'}}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'scalingDTP', 'name': '${name}', 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'debug': False, 'verbose': False, 'seed': 4248715256, 'name': ''}
                  )
options: Options
trainer: Trainer = field(init=False, to_dict=False)
model: Model = field(init=False, to_dict=False)
network: Network = field(init=False, to_dict=False)
datamodule: VisionDataModule = field(init=False, to_dict=False)

callbacks: List[Callback] = field(init=False, default_factory=list, to_dict=False)
loggers: List[LightningLoggerBase] = field(init=False, default_factory=list, to_dict=False)

trainer =Trainer = field(init=False, to_dict=False)
model: Model = field(init=False, to_dict=False)
network: Network = field(init=False, to_dict=False)
datamodule: VisionDataModule = field(init=False, to_dict=False)

callbacks: List[Callback] = field(init=False, default_factory=list, to_dict=False)
loggers: List[LightningLoggerBase] = field(init=False, default_factory=list, to_dict=False)
options = OmegaConf.to_object(raw_options)
actual_callbacks: Dict[str, Callback] = {}
# Create the callbacks
assert isinstance(options['callbacks'], dict)
for name, callback in options['callbacks'].items():
    if isinstance(callback, dict):
        callback = hydra.utils.instantiate(callback)
    elif not isinstance(callback, Callback):
        raise ValueError(f"Invalid callback value {callback}")
    actual_callbacks[name] = callback
    callbacks = list(actual_callbacks.values())
# Create the loggers, if any.
assert isinstance(options['logger'], dict)
actual_loggers: Dict[str, LightningLoggerBase] = {}
for name, lightning_logger in options['logger'].items():
    if isinstance(lightning_logger, dict):
        lightning_logger = hydra.utils.instantiate(lightning_logger)
    elif not isinstance(lightning_logger, LightningLoggerBase):
        raise ValueError(f"Invalid logger value {lightning_logger}")
    actual_loggers[name] = lightning_logger
logger = list(actual_loggers.values())
assert isinstance(options['trainer'], dict)
if options["debug"]:
    logger.info(f"Setting the max_epochs to 1, since the 'debug' flag was passed.")
    self.options['trainer']["max_epochs"] = 1
trainer = hydra.utils.instantiate(
    options['trainer'], callbacks=callbacks, logger=logger,
)
from target_prop.datasets.dataset_config import DatasetConfig
options = OmegaConf.create(raw_options)
dataset =DatasetConfig(options.dataset)




Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [32]:
options.dataset

{'dataset': 'cifar10', 'data_dir': '/home/ono/Dev/scalingDTP/data', 'num_workers': 12, 'shuffle': True, 'normalize': True, 'image_crop_size': 32, 'image_crop_pad': 4, 'val_split': 0.1, 'use_legacy_std': False}

In [60]:
datamodule = dataset.make_datamodule(batch_size=options.model.batch_size)


In [37]:
datamodule.dims

(3, 32, 32)

In [33]:
# datamodule = VisionDataModule(datamodule)

In [47]:
vd = VisionDataModule

In [58]:
model

Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x7fbd603ff520>,default_factory=<dataclasses._MISSING_TYPE object at 0x7fbd603ff520>,init=False,repr=True,hash=None,compare=True,metadata=mappingproxy({'to_dict': False, 'cmd': True, 'positional': False}),_field_type=None)

In [64]:
network = ViT(in_channels = datamodule.dims[0],n_classes=datamodule.num_classes,hparams=options['network'])
options = OmegaConf.to_object(raw_options)
dict(OmegaConf.to_object(OmegaConf.create(options['model'])))
model = DTP(network=network,datamodule=datamodule,hparams = DTP.hparams,network_hparams=options['network'],config=Config(seed = options['seed'],debug=options['debug']))  



TypeError: __init__() missing 1 required positional argument: 'hparams'

AttributeError: type object 'HParams' has no attribute 'b'

In [None]:
# @dataclass
# class Options(Serializable):
#     """ All the options required for a run. This dataclass acts as a schema for the Hydra configs.
    
#     For more info, see https://hydra.cc/docs/tutorials/structured_config/schema/
#     """

#     # Configuration for the dataset + transforms.
#     dataset: DatasetConfig  # = field(default_factory=DatasetConfig)

#     # The model used.
#     model: Model.HParams  # = field(default_factory=Model.HParams)
#     # The network to be used.
#     network: Network.HParams  # = field(default_factory=SimpleVGG.HParams)

#     # Keyword arguments for the Trainer constructor.
#     trainer: Dict = field(default_factory=dict)  # type: ignore

#     # Configs for the callbacks.
#     callbacks: Dict = field(default_factory=dict)  # type: ignore

#     # Config(s) for the logger(s).
#     logger: Dict = field(default_factory=dict)  # type: ignore

#     # Wether to run in debug mode or not.
#     debug: bool = False

#     verbose: bool = False

#     # Random seed.
#     seed: Optional[int] = None

#     # Name for the experiment.
#     name: str = ""


# cs = ConfigStore.instance()
# cs.store(name="base_options", node=Options)

# cs.store(group="model", name="model", node=Model.HParams())
# cs.store(group="model", name="dtp", node=DTP.HParams())
# cs.store(group="model", name="parallel_dtp", node=ParallelDTP.HParams())
# cs.store(group="model", name="vanilla_dtp", node=VanillaDTP.HParams())
# cs.store(group="model", name="target_prop", node=TargetProp.HParams())
# cs.store(group="model", name="backprop", node=BaselineModel.HParams())

# cs.store(group="network", name="simple_vgg", node=SimpleVGG.HParams())
# cs.store(group="network", name="lenet", node=LeNet.HParams())
# cs.store(group="network", name="resnet18", node=ResNet18.HParams())
# cs.store(group="network", name="resnet34", node=ResNet34.HParams())
# cs.store(group="network", name="vit", node=ViT.HParams())
# cs.store(group="lr_scheduler", name="step", node=StepLRConfig)
# cs.store(group="lr_scheduler", name="cosine", node=CosineAnnealingLRConfig)



In [None]:
# options = Options(VisionDataModule,DTP,ViT)