In [1]:
import os
import ast
import json
import hydra
from hydra import compose, initialize
import torch
import torch.nn.functional as F
import mlflow
import optuna
import numpy as np
from tqdm import tqdm
from typing import List, Tuple
from omegaconf.omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

from model import UCCDRNModel
from dataset import MnistDataset
from omegaconf import DictConfig
from utils import get_or_create_experiment, parse_experiment_runs_to_optuna_study
from optimizers import SGLD
torch.autograd.set_detect_anomaly(True)


# set random seed
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True


def init_model_and_optimizer(args, model_cfg, device):
    model = UCCDRNModel(model_cfg).to(device)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    param_group_names = [name for name, _ in model.named_parameters()]
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=True)
    # for name, param_group in zip(param_group_names, optimizer.param_groups):
    #     if "encoder" in name:
    #         param_group['lr'] = args.learning_rate*args.lr_multiplier
    #     print(f'    {name}: {param_group["lr"]}')
    return model, optimizer


def init_dataloader(args):
    # assert args.dataset in [
    #     "mnist",
    #     "camelyon",
    # ], "Mode should be either mnist or camelyon"
    # if args.dataset == "mnist":
    train_dataset_len = args.train_num_steps * args.batch_size
    train_dataset = MnistDataset(
        mode="train",
        num_instances=args.num_instances,
        num_samples_per_class=args.num_samples_per_class,
        digit_arr=list(range(args.ucc_end-args.ucc_start+1)),
        ucc_start=args.ucc_start,
        ucc_end=args.ucc_end,
        length=train_dataset_len,
    )
    val_dataset_len = args.val_num_steps * args.batch_size
    val_dataset = MnistDataset(
        mode="val",
        num_instances=args.num_instances,
        num_samples_per_class=args.num_samples_per_class,
        digit_arr=list(range(args.ucc_end-args.ucc_start+1)),
        ucc_start=args.ucc_start,
        ucc_end=args.ucc_end,
        length=val_dataset_len,
    )
    # else:
    #     train_dataset_len = args.train_num_steps * args.batch_size
    #     train_dataset = CamelyonUCCDataset(
    #         mode="train",
    #         num_instances=args.num_instances,
    #         data_augment=args.data_augment,
    #         patch_size=args.patch_size,
    #         dataset_len=train_dataset_len,
    #     )
    #     val_dataset_len = args.val_num_steps * args.batch_size
    #     val_dataset = CamelyonUCCDataset(
    #         mode="val",
    #         num_instances=args.num_instances,
    #         data_augment=args.data_augment,
    #         patch_size=args.patch_size,
    #         dataset_len=val_dataset_len,
    #     )
    # create dataloader
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
    )
    return train_loader, val_loader


def evaluate(model, val_loader, device) -> Tuple[np.float32, np.float32]:
    model.eval()
    val_loss_list = []
    val_acc_list = []
    with torch.no_grad():
        for batch_samples, batch_labels in val_loader:
            batch_samples = batch_samples.to(device)
            batch_labels = batch_labels.to(device)
            if model.alpha == 1:
                ucc_logits = model(batch_samples)
                ucc_val_loss = F.cross_entropy(ucc_logits, batch_labels)
                # acculate accuracy
                _, ucc_predicts = torch.max(ucc_logits, dim=1)
                acc = torch.sum(
                    ucc_predicts == batch_labels).item() / len(batch_labels)
                val_acc_list.append(acc)
                val_loss_list.append(ucc_val_loss.item())
            else:
                ucc_logits, _ = model(batch_samples)
                ucc_val_loss = F.cross_entropy(ucc_logits, batch_labels)
                # acculate accuracy
                _, ucc_predicts = torch.max(ucc_logits, dim=1)
                acc = torch.sum(
                    ucc_predicts == batch_labels).item() / len(batch_labels)
                val_acc_list.append(acc)
                val_loss_list.append(ucc_val_loss.item())
    return np.mean(val_loss_list), np.mean(val_acc_list)


def train(args, model, optimizer, lr_scheduler, train_loader, val_loader, device):
    param_group_names = [name for name, _ in model.named_parameters()]
    # output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
    model.train()
    step = 0
    best_eval_acc = 0
    patience = 10
    for batch_samples, batch_labels in tqdm(train_loader):
        batch_samples = batch_samples.to(device)
        batch_labels = batch_labels.to(device)
        optimizer.zero_grad()

        if model.alpha == 1:
            ucc_logits = model(batch_samples, batch_labels)
            loss: torch.Tensor = model.compute_loss(
                labels=batch_labels,
                output=ucc_logits
            )
        else:
            ucc_logits, reconstruction = model(batch_samples, batch_labels)
            loss = model.compute_loss(
                inputs=batch_samples,
                labels=batch_labels,
                output=ucc_logits,
                reconstruction=reconstruction
            )

        loss.backward()

        optimizer.step()
        step += 1

        if step % 10 == 0:
            # lr_scheduler.step()
            with torch.no_grad():
                _, pred = torch.max(ucc_logits, dim=1)
                accuracy = torch.sum(
                    pred.flatten() == batch_labels.flatten())/len(batch_labels)
            mlflow.log_metrics({"train_ucc_loss": loss.detach(
            ).item(), "train_ucc_acc": float(accuracy)}, step=step)

        if step % args.save_interval == 0:
            eval_loss, eval_acc = evaluate(model, val_loader, device)
            print(
                f"step: {step}, eval loss: {eval_loss}, eval acc: {eval_acc}")
            mlflow.log_metrics(
                {"eval_ucc_loss": eval_loss.item(), "eval_ucc_acc": float(eval_acc)}, step=step)
            # early stop
            if eval_acc > best_eval_acc:
                patience = 2
                best_eval_acc = eval_acc
                # save model
                # save_path = os.path.join(output_dir, f"{args.model_name}_best.pth")
                # put eval loss and acc in model state dict
                # save_dict = {
                #     "model_state_dict": model.state_dict(),
                #     "optimizer_state_dict": optimizer.state_dict(),
                #     "eval_loss": eval_loss,
                #     "eval_acc": eval_acc,
                #     "step": step,
                # }
                mlflow.pytorch.log_model(
                    model,
                    "best_model.pth"
                )
                # torch.save(save_dict, save_path)
            else:
                patience -= 1

            if patience <= 0:
                break
            if step==10000:
                break
            model.train()

    print("Training finished!!!")
    return best_eval_acc


# @hydra.main(version_base=None, config_path="../configs", config_name="train_drn.yaml")
# def main(cfg: DictConfig) -> None:
#     # with mlflow.start_run():

#     device = torch.device("cuda" if torch.cuda.is_available() else "mps")
#     print("device:", device)
#     args = cfg.args
#     print("args: ", args)
#     print("model: ", cfg.model)
#     # set random seed
#     set_random_seed(args.seed)
#     # set model save dir
#     if not os.path.exists(args.model_dir):
#         os.makedirs(args.model_dir)
#     # init model and optimizer
#     model, optimizer = init_model_and_optimizer(args, cfg, device)
#     train_loader, val_loader = init_dataloader(args)
#     train(args, model, optimizer, train_loader, val_loader, device)
#
def objective(trial: optuna.Trial):
    with mlflow.start_run(nested=True):
        # cfg = OmegaConf.load("../configs/train_drn.yaml")
        with initialize(version_base=None, config_path="../configs"):
            cfg = compose(config_name="train_drn")
        # with open("params.json", "r") as file:
        #     params_config = json.loads(file.read())

        defaults = {
            "num_bins": {
                "type": "int",
                "value": 10,
                "range": [5,100],
                "aliases": [
                    "model.drn.num_bins",
                    "args.num_bins",
                    "model.kde_model.num_bins"
                ]
            },
            "lr": {
                "type": "float",
                # "value": 0.085,
                "range": [0.008, 0.08],
                "aliases": ["args.learning_rate"]
            },
            "hidden_q": {
                "type": "int",
                "value": 93,
                "range": [4, 100],
                "aliases": ["model.drn.hidden_q"]
            },
            "num_layers": {
                "type": "int",
                "value": 2,
                "range": [1, 10],
                "aliases": ["model.drn.num_layers"]
            },
            "num_nodes": {
                "type": "int",
                "value": 9,
                "range": [1, 10],
                "aliases": ["model.drn.num_nodes"]
            }
        }
        for key, value in defaults.items():
            if "value" in value:
                v = value["value"]
            else:
                if value["type"]=="int":
                    v = trial.suggest_int(key, value["range"][0], value["range"][1])
                else:
                    v = trial.suggest_float(key, value["range"][0], value["range"][1])
            for a in value["aliases"]:
                exec(f"cfg.{a} = {v}")

        print(cfg)
        mlflow.log_dict(dict(OmegaConf.to_object(cfg)), "config.yaml")

        args = cfg.args
        device = torch.device("cuda" if torch.cuda.is_available() else "mps")
        model, optimizer = init_model_and_optimizer(args, cfg, device)
        train_loader, val_loader = init_dataloader(args)
        print(cfg)
        best_acc = train(args, model, optimizer, None,
                            train_loader, val_loader, device)




mlflow.set_tracking_uri("mlruns")
run_name = "ucc-drn-bin-lr-2"
experiment_id = get_or_create_experiment(experiment_name=run_name)
mlflow.set_experiment(experiment_id=experiment_id)

study = parse_experiment_runs_to_optuna_study(
    experiment_name=run_name,
    study_name=run_name,
    cfg_name="train_drn",
    params_file="params-bin.json"
)
study.optimize(func=objective, n_trials=100, show_progress_bar=True)


  from .autonotebook import tqdm as notebook_tqdm
[I 2025-04-24 02:53:17,623] A new study created in memory with name: ucc-drn-bin-lr-2
  0%|          | 0/100 [00:00<?, ?it/s]

{'args': {'dataset': 'mnist', 'model_dir': 'saved_models/', 'model_name': 'mnist_ucc_drn', 'num_instances': 32, 'ucc_start': 1, 'ucc_end': 4, 'batch_size': 20, 'num_samples_per_class': 5, 'num_workers': 4, 'learning_rate': 0.04388685685981109, 'num_bins': 10, 'num_features': 10, 'train_num_steps': 100000, 'val_num_steps': 200, 'save_interval': 1000, 'seed': 22}, 'model': {'num_channels': 1, 'input_shape': [28, 28, 1], 'kde_model': {'num_bins': 10, 'sigma': 0.1}, 'encoder': {'conv_input_channel': 1, 'conv_output_channel': 16, 'block1_output_channel': 321, 'block1_num_layer': 1, 'block2_output_channel': 64, 'block2_num_layer': 1, 'block3_output_channel': 128, 'block3_num_layer': 1, 'flatten_size': 6272, 'num_features': 10}, 'decoder': 'None', 'drn': {'num_bins': 10, 'hidden_q': 93, 'num_layers': 2, 'num_nodes': 9, 'output_bins': 4}, 'ucc_classifier': 'None', 'loss': {'alpha': 1}}, 'experiments': {'num_channels': 1, 'input_shape': [28, 28, 1], 'kde_model': {'num_bins': 11, 'sigma': 0.1}, 




ConfigAttributeError: Key 'init_method' is not in struct
    full_key: model.drn.init_method
    object_type=dict