In [1]:
import mlflow
experiment_name = "ucc-drn-goated-init"
experiment_id = mlflow.get_experiment_by_name(name=experiment_name).experiment_id
mlflow.set_experiment(experiment_id=experiment_id)
run_id = "ae9704aec54d48729f10dcb726d0125f"

In [2]:
import torch
from torch.optim import Adam

model = torch.load(f"mlruns/{experiment_id}/{run_id}/artifacts/best_model/data/model.pth", weights_only=False)
optimizer = torch.load(f"mlruns/{experiment_id}/{run_id}/artifacts/best_optimizer.pth/best_optimizer.pth", weights_only=False)


In [3]:
optimizer.param_groups[-1]["lr"] = 0.0005

In [4]:
# using optimization to find the optimal mean and variance for normal initialization
from copy import deepcopy
from hydra import compose, initialize
import torch
import torch.nn as nn
import torch.nn.functional as F
import mlflow
import optuna
import numpy as np
from tqdm import tqdm
from typing import Tuple
from omegaconf.omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

from model import UCCDRNModel
from dataset import MnistDataset
from utils import get_or_create_experiment, parse_experiment_runs_to_optuna_study
torch.autograd.set_detect_anomaly(True)


    
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

def init_dataloader(args):
    train_dataset_len = args.train_num_steps * args.batch_size
    train_dataset = MnistDataset(
        mode="train",
        num_instances=args.num_instances,
        num_samples_per_class=args.num_samples_per_class,
        digit_arr=list(range(0,10)),
        ucc_start=args.ucc_start,
        ucc_end=args.ucc_end,
        length=train_dataset_len,
    )
    val_dataset_len = args.val_num_steps * args.batch_size
    val_dataset = MnistDataset(
        mode="val",
        num_instances=args.num_instances,
        num_samples_per_class=args.num_samples_per_class,
        digit_arr=list(range(0, 10)),
        ucc_start=args.ucc_start,
        ucc_end=args.ucc_end,
        length=val_dataset_len,
    )
    # create dataloader
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
    )
    return train_loader, val_loader

def evaluate(model, val_loader, device):
    model.eval()
    val_ae_loss_list = []
    val_ucc_loss_list = []
    val_acc_list = []
    with torch.no_grad():
        for batch_samples, batch_labels in val_loader:
            batch_samples = batch_samples.to(device)
            batch_labels = batch_labels.to(device)

            ucc_logits, reconstruction = model(batch_samples, return_reconstruction=True)

            ucc_loss = F.cross_entropy(ucc_logits, batch_labels)
            val_ucc_loss_list.append(ucc_loss.item())

            ae_loss = F.mse_loss(batch_samples, reconstruction)
            val_ae_loss_list.append(ae_loss.item())

            # acculate accuracy
            _, ucc_predicts = torch.max(ucc_logits, dim=1)
            acc = torch.sum(ucc_predicts == batch_labels).item() / len(batch_labels)
            val_acc_list.append(acc)
    return {
                "eval_ae_loss": np.round(np.mean(val_ae_loss_list), 5),
                "eval_ucc_loss": np.round(np.mean(val_ucc_loss_list), 5),
                "eval_ucc_acc": np.round(np.mean(val_acc_list), 5)
            }

def train(args, model, optimizer, lr_scheduler, train_loader, val_loader, device):
    print("training")
    model.train()
    step = 0
    best_eval_acc = 0
    if step == 0:
        mlflow.pytorch.log_model(
            model,
            "best_model"
        )
    epoch = 10
    for e in range(epoch):
        for batch_samples, batch_labels in train_loader:
            batch_samples = batch_samples.to(device)
            batch_labels = batch_labels.to(device)

            optimizer.zero_grad()
            ucc_logits, reconstruction = model(batch_samples, return_reconstruction=True)
            ucc_loss, ae_loss, loss = model.compute_loss(batch_samples, batch_labels, ucc_logits, reconstruction, return_losses=True)

            loss.backward()

            optimizer.step()

            step += 1

            if step % 10 == 0:
                with torch.no_grad():
                    metric_dict = {}
                    grad_log = {name: torch.mean(param.grad).cpu().item(
                    ) for name, param in model.named_parameters() if isinstance(param.grad, torch.Tensor)}
                    mlflow.log_metrics(grad_log, step=step)
                    metric_dict["train_ae_loss"] = np.round(ae_loss.detach().item(), 5)
                    _, pred = torch.max(ucc_logits, dim=1)
                    accuracy = torch.sum(pred.flatten() == batch_labels.flatten())/len(batch_labels)
                    metric_dict["train_ucc_loss"] = np.round(ucc_loss.detach().item(), 5)
                    metric_dict["train_ucc_acc"] = np.round(float(accuracy), 5)
                    metric_dict["loss"] = np.round(float(loss), 5)
                if step %100 ==0:
                    print(metric_dict)
                    
                mlflow.log_metrics(metric_dict, step=step)

            if step % args.save_interval == 0:
                eval_metric_dict = evaluate(
                    model,
                    val_loader,
                    device)
                print(f"step: {step}," + ",".join([f"{key}: {value}"for key, value in eval_metric_dict.items()]))
                mlflow.log_metrics(eval_metric_dict, step=step)
                # early stop
                eval_acc = eval_metric_dict["eval_ucc_acc"]
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    mlflow.log_metric("best_eval_acc", best_eval_acc)
                    mlflow.pytorch.log_model(
                        model,
                        "best_model"
                    )
                    torch.save(optimizer, "best_optimizer.pth")
                    mlflow.log_artifact(
                        "best_optimizer.pth",
                        "best_optimizer.pth"
                    )
                if step == 200000:
                    break
                model.train()

    print("Training finished!!!")
    return best_eval_acc


mlflow.set_tracking_uri("mlruns")
run_name = "ucc-drn-goated-init"
experiment_id = get_or_create_experiment(experiment_name=run_name)
mlflow.set_experiment(experiment_id=experiment_id)
cfg_name = "train_drn"
with initialize(version_base=None, config_path="../configs"):
    cfg = compose(config_name=cfg_name)
# device = torch.device("mps")
# model = torch.load(f"mlruns/189454739472380536/b76e52db991c4b90a51eb9b8da9fc6ab/artifacts/best_model/data/model.pth", map_location=device, weights_only=False)
# optimizer = torch.optim.Adam(params=model.parameters(), lr=0.008)
with mlflow.start_run(nested=True, run_name="lr-0.0005"):
    mlflow.log_dict(dict(OmegaConf.to_object(cfg)), "config.yaml")
    args = cfg.args
    device = torch.device("cuda" if torch.cuda.is_available() else "mps")
    # model = torch.load(f"mlruns/189454739472380536/b76e52db991c4b90a51eb9b8da9fc6ab/artifacts/best_model.pth/data/model.pth")
    # model, optimizer = init_model_and_optimizer(args, cfg, device)
    train_loader, val_loader = init_dataloader(args)
    mlflow.pytorch.log_model(model, "init_model")
    mlflow.log_params({"learning_rate": 5e-4})
    best_acc = train(args, model, optimizer, None,
                    train_loader, val_loader, device)


x_train shape: torch.Size([50000, 1, 28, 28])
50000 train samples
10000 val samples
10000 test samples
x_train shape: torch.Size([50000, 1, 28, 28])
50000 train samples
10000 val samples
10000 test samples




training




{'train_ae_loss': np.float64(0.03864), 'train_ucc_loss': np.float64(0.38644), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.42508)}
{'train_ae_loss': np.float64(0.04037), 'train_ucc_loss': np.float64(0.40504), 'train_ucc_acc': np.float64(0.95), 'loss': np.float64(0.44541)}
{'train_ae_loss': np.float64(0.03784), 'train_ucc_loss': np.float64(0.38035), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.41818)}
{'train_ae_loss': np.float64(0.04071), 'train_ucc_loss': np.float64(0.43261), 'train_ucc_acc': np.float64(0.9), 'loss': np.float64(0.47332)}
{'train_ae_loss': np.float64(0.04176), 'train_ucc_loss': np.float64(0.37945), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.42121)}
{'train_ae_loss': np.float64(0.04), 'train_ucc_loss': np.float64(0.37844), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.41844)}
{'train_ae_loss': np.float64(0.0393), 'train_ucc_loss': np.float64(0.39657), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.43587)}
{'train_ae_loss'



{'train_ae_loss': np.float64(0.04074), 'train_ucc_loss': np.float64(0.3792), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.41995)}
{'train_ae_loss': np.float64(0.03777), 'train_ucc_loss': np.float64(0.40414), 'train_ucc_acc': np.float64(0.95), 'loss': np.float64(0.44192)}
{'train_ae_loss': np.float64(0.04127), 'train_ucc_loss': np.float64(0.41025), 'train_ucc_acc': np.float64(0.95), 'loss': np.float64(0.45152)}
{'train_ae_loss': np.float64(0.04238), 'train_ucc_loss': np.float64(0.402), 'train_ucc_acc': np.float64(0.95), 'loss': np.float64(0.44439)}
{'train_ae_loss': np.float64(0.03865), 'train_ucc_loss': np.float64(0.37951), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.41816)}
{'train_ae_loss': np.float64(0.03634), 'train_ucc_loss': np.float64(0.40345), 'train_ucc_acc': np.float64(0.95), 'loss': np.float64(0.43979)}
{'train_ae_loss': np.float64(0.03688), 'train_ucc_loss': np.float64(0.41263), 'train_ucc_acc': np.float64(0.95), 'loss': np.float64(0.44951)}
{'train_ae_



{'train_ae_loss': np.float64(0.03968), 'train_ucc_loss': np.float64(0.38171), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.42139)}
{'train_ae_loss': np.float64(0.04042), 'train_ucc_loss': np.float64(0.38064), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.42106)}
{'train_ae_loss': np.float64(0.04045), 'train_ucc_loss': np.float64(0.423), 'train_ucc_acc': np.float64(0.95), 'loss': np.float64(0.46346)}
{'train_ae_loss': np.float64(0.03708), 'train_ucc_loss': np.float64(0.39815), 'train_ucc_acc': np.float64(0.95), 'loss': np.float64(0.43523)}
{'train_ae_loss': np.float64(0.04285), 'train_ucc_loss': np.float64(0.37944), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.42229)}
{'train_ae_loss': np.float64(0.0415), 'train_ucc_loss': np.float64(0.38316), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.42467)}
{'train_ae_loss': np.float64(0.04183), 'train_ucc_loss': np.float64(0.3874), 'train_ucc_acc': np.float64(1.0), 'loss': np.float64(0.42924)}
{'train_ae_loss

libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x140526840>
Traceback (most recent call last):
  File "/Users/tanguanyu/UCC-DRN-Pytorch/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/Users/tanguanyu/UCC-DRN-Pytorch/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1568, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/anaconda3/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^

KeyboardInterrupt: 

{'state': {0: {'step': tensor(100000.),
   'exp_avg': tensor([[[[-6.7333e-03, -6.2005e-03, -7.3149e-03],
             [-7.5662e-03, -5.9466e-03, -5.6723e-03],
             [-5.2717e-03, -1.3496e-03,  3.8254e-04]]],
   
   
           [[[-1.0657e-37, -1.0657e-37, -1.0657e-37],
             [-1.0657e-37, -1.0657e-37, -1.0657e-37],
             [ 1.1048e-37, -1.0657e-37, -1.0657e-37]]],
   
   
           [[[ 2.9916e-04,  2.9191e-04,  4.7625e-04],
             [ 3.5952e-04,  4.5829e-04,  4.8339e-04],
             [-4.9764e-04,  4.0587e-06, -1.4013e-03]]],
   
   
           [[[-4.9944e-04,  1.0160e-04,  1.1345e-04],
             [-4.6275e-04,  1.4977e-04, -1.1290e-04],
             [-2.9028e-04,  2.3929e-04,  1.9374e-04]]],
   
   
           [[[ 1.9362e-04,  3.4409e-05, -6.7650e-04],
             [-8.2952e-04,  1.4290e-04, -3.7770e-04],
             [-1.2114e-03,  9.9263e-04, -2.5519e-04]]],
   
   
           [[[ 7.5088e-03,  7.2409e-03,  8.4585e-04],
             [ 4.5008e-03,  4.0942e

0.005