In [29]:
import os
import ast
import json
import hydra
from hydra import compose, initialize
import torch
import torch.nn as nn
import torch.nn.functional as F

import mlflow
import optuna
import numpy as np
from tqdm import tqdm
from typing import List, Tuple
from omegaconf.omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torch.optim import Adam, AdamW

from model import DRNOnlyModel, UCCModel
from dataset import MnistEncodedDataset
from omegaconf import DictConfig
from utils import get_or_create_experiment, parse_experiment_runs_to_optuna_study

experiment_id = get_or_create_experiment("ucc-drn-distil")


In [None]:
runs = mlflow.search_runs(experiment_ids=[experiment_id], output_format="list")

In [3]:
runs[0].to_dictionary()

{'info': {'artifact_uri': 'file:///D:/UCC-DRN-Pytorch/mnist/mlruns/380446538050732248/313d1bb3d6f446a2ab697d5f41c4c580/artifacts',
  'end_time': 1741185524741,
  'experiment_id': '380446538050732248',
  'lifecycle_stage': 'active',
  'run_id': '313d1bb3d6f446a2ab697d5f41c4c580',
  'run_name': 'brawny-bass-550',
  'run_uuid': '313d1bb3d6f446a2ab697d5f41c4c580',
  'start_time': 1741184857968,
  'status': 'FAILED',
  'user_id': 'guanyu'},
 'data': {'metrics': {'eval_ucc_acc': 0.66125,
   'eval_ucc_loss': 0.8382442593574524,
   'total_loss': 0.8076022267341614,
   'train_distilation_loss': 0.05767541006207466,
   'train_ucc_acc': 0.6500000357627869,
   'train_ucc_loss': 1.0575778484344482},
  'params': {'hidden_q': '16',
   'lr': '0.09165019656675866',
   'num_bins': '17',
   'num_layers': '1',
   'num_nodes': '6'},
  'tags': {'mlflow.log-model.history': '[{"run_id": "313d1bb3d6f446a2ab697d5f41c4c580", "artifact_path": "best_model.pth", "utc_time_created": "2025-03-05 14:28:57.248725", "mo

In [4]:
for run in runs:
    dictionary = run.to_dictionary()
    run_name = dictionary["info"]["run_name"]
    params = dictionary["data"]["params"]
    if run_name == "fortunate-mule-124":
        break

In [22]:
import json

with open("params.json", "r") as file:
    params_json = json.loads(file.read())

In [23]:
for p,v in params_json.items():
    print(p)
    print(v)

num_bins
{'type': 'int', 'range': [5, 30], 'aliases': ['model.drn.num_bins', 'args.num_bins', 'model.kde_model.num_bins']}
lr
{'type': 'float', 'range': [0.0001, 0.1], 'aliases': ['args.learning_rate']}
hidden_q
{'type': 'int', 'range': [4, 100], 'aliases': ['model.drn.hidden_q']}
num_layers
{'type': 'int', 'range': [1, 10], 'aliases': ['model.drn.num_layers']}
num_nodes
{'type': 'int', 'range': [1, 10], 'aliases': ['model.drn.num_nodes']}
lr_multiplier
{'type': 'float', 'range': [1, 2], 'aliases': ['args.lr_multiplier']}


In [24]:
params = {p: (float(value) if p=="lr"else int(value)) for p, value in params.items() }

In [25]:
from hydra import compose, initialize
from omegaconf import OmegaConf


with initialize(version_base=None, config_path="../configs"):
        cfg = compose(config_name="train_drn")

for p, values in params_json.items():
    if p == "lr_multiplier":
        break
    for a in values["aliases"]:
        exec(f"cfg.{a} = params['{p}']")

In [33]:
def load_parent_model(device):
    model_path = "outputs\\2024-03-01\\14-44-08"
    ucc_cfg = OmegaConf.load(os.path.join(model_path, ".hydra\\config.yaml"))
    model = UCCModel(ucc_cfg)
    state_dict = torch.load(os.path.join(model_path, "mnist_ucc_best.pth"), weights_only=False)["model_state_dict"]
    model.load_state_dict(state_dict)
    parent_model = model.to(device)
    parent_model.eval()
    return parent_model

In [1]:
def init_model_and_optimizer(args, model_cfg, device):
    model = DRNOnlyModel(model_cfg).to(device)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    optimizer = AdamW(model.parameters(), lr=args.learning_rate)
    return model, optimizer


def init_dataloader(args):
    train_dataset_len = args.train_num_steps * args.batch_size
    train_dataset = MnistEncodedDataset(
        mode="train",
        num_instances=args.num_instances,
        num_samples_per_class=args.num_samples_per_class,
        digit_arr=list(range(args.ucc_end-args.ucc_start+1)),
        ucc_start=args.ucc_start,
        ucc_end=args.ucc_end,
        length=train_dataset_len,
    )
    val_dataset_len = args.val_num_steps * args.batch_size
    val_dataset = MnistEncodedDataset(
        mode="val",
        num_instances=args.num_instances,
        num_samples_per_class=args.num_samples_per_class,
        digit_arr=list(range(args.ucc_end-args.ucc_start+1)),
        ucc_start=args.ucc_start,
        ucc_end=args.ucc_end,
        length=val_dataset_len,
    )
    # create DataLoader
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
    )
    return train_loader, val_loader


def evaluate(model, parent_model, val_loader, device) -> Tuple[np.float32, np.float32]:
    T = 2
    model.eval()
    val_loss_list = []
    val_acc_list = []
    with torch.no_grad():
        for batch_samples, batch_labels in val_loader:
            batch_samples = batch_samples.to(device)
            batch_labels = batch_labels.to(device)
            ucc_logits = model(batch_samples)
            # parent_logits = parent_model.ucc_classifier(parent_model.kde(
            #     batch_samples, parent_model.num_nodes, parent_model.sigma))
            ucc_loss = F.cross_entropy(ucc_logits, batch_labels)

            # soft_targets = nn.functional.softmax(parent_logits / T, dim=-1)
            # soft_prob = nn.functional.log_softmax(ucc_logits / T, dim=-1)
            # soft_targets_loss = torch.sum(
            #     soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
            # loss = 0.75*ucc_loss + 0.25*soft_targets_loss
            # acculate accuracy
            _, ucc_predicts = torch.max(ucc_logits, dim=1)
            acc = torch.sum(
                ucc_predicts == batch_labels).item() / len(batch_labels)
            val_acc_list.append(acc)
            val_loss_list.append(ucc_loss.item())
    return np.mean(val_loss_list), np.mean(val_acc_list)


def train(args, model, parent_model, optimizer, lr_scheduler, train_loader, val_loader, device):
    # distillation temperature
    T = 2
    # output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
    parent_model.eval()
    model.train()
    step = 0
    best_eval_acc = 0
    patience = 100
    train_logs = []
    for batch_samples, batch_labels in tqdm(train_loader):
        batch_samples = batch_samples.to(device)
        batch_labels = batch_labels.to(device)
        optimizer.zero_grad()

        # if model.alpha==1:
        #     ucc_logits = model(batch_samples, batch_labels)
        #     loss:torch.Tensor = model.compute_loss(
        #         labels=batch_labels,
        #         output=ucc_logits
        #     )
        # else:
        # original loss
        ucc_logits = model(batch_samples)
        ucc_loss = model.compute_loss(
            outputs=ucc_logits,
            labels=batch_labels,
        )
        # distillation loss
        with torch.no_grad():
            parent_logits = parent_model.ucc_classifier(parent_model.kde(
                batch_samples, parent_model.num_nodes, parent_model.sigma))

        # soft_targets = nn.functional.softmax(parent_logits / T, dim=-1)
        # soft_prob = nn.functional.log_softmax(ucc_logits / T, dim=-1)
        # soft_targets_loss = torch.sum(
        #     soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)
        # loss = 0.75*ucc_loss + 0.25*soft_targets_loss

        ucc_loss.backward()
        optimizer.step()

        step += 1

        if step % 10 == 0:
            with torch.no_grad():
                _, pred = torch.max(ucc_logits, dim=1)
                accuracy = torch.sum(
                    pred.flatten() == batch_labels.flatten())/len(batch_labels)
            train_logs.append({
                "train_ucc_loss": ucc_loss.detach().item(),
                # "train_distilation_loss": soft_targets_loss.detach().item(),
                # "total_loss": loss.detach().item(),
                "train_ucc_acc": float(accuracy)})

        if step % args.save_interval == 0:
            eval_loss, eval_acc = evaluate(
                model, parent_model, val_loader, device)
            print(
                f"step: {step}, eval loss: {eval_loss}, eval acc: {eval_acc}")
            # early stop
            if eval_acc > best_eval_acc:
                patience = 10
                best_eval_acc = eval_acc
                # save model
                # save_path = os.path.join(output_dir, f"{args.model_name}_best.pth")
                # put eval loss and acc in model state dict
                # save_dict = {
                #     "model_state_dict": model.state_dict(),
                #     "optimizer_state_dict": optimizer.state_dict(),
                #     "eval_loss": eval_loss,
                #     "eval_acc": eval_acc,
                #     "step": step,
                # }
                # maybe save optimizer state dict as well
                # torch.save(save_dict, save_path)
            else:
                patience -= 1
            if patience <= 0:
                break
            model.train()
    print("Training finished!!!")
    return {
        "drn_model": model,
        "best_acc": best_eval_acc
    }

NameError: name 'Tuple' is not defined

In [35]:
args = cfg.args
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
model, optimizer = init_model_and_optimizer(args, cfg, device)
parent_model = load_parent_model(device)
train_loader, val_loader = init_dataloader(args)
output = train(args, model, parent_model, optimizer, None, train_loader, val_loader, device)

(55, 5)
(55, 55)
(4, 55)
x_train shape: torch.Size([50000, 1, 28, 28])
50000 train samples
10000 val samples
x_train shape: torch.Size([50000, 10])
50000 train samples
10000 val samples
x_train shape: torch.Size([50000, 1, 28, 28])
50000 train samples
10000 val samples
x_train shape: torch.Size([50000, 10])
50000 train samples
10000 val samples


  1%|          | 1001/100000 [02:35<71:04:59,  2.58s/it] 

step: 1000, eval loss: 0.8058490535616875, eval acc: 0.9890000000000001


  2%|▏         | 2001/100000 [04:59<67:21:12,  2.47s/it]

step: 2000, eval loss: 0.7606072467565537, eval acc: 0.9980000000000001


  3%|▎         | 3001/100000 [07:25<66:29:30,  2.47s/it]

step: 3000, eval loss: 0.7552320855855942, eval acc: 0.9990000000000001


  4%|▍         | 4001/100000 [09:48<66:05:58,  2.48s/it]

step: 4000, eval loss: 0.7542347142100334, eval acc: 0.998


  5%|▌         | 5001/100000 [12:12<62:58:32,  2.39s/it]

step: 5000, eval loss: 0.7536104601621628, eval acc: 0.9975


  6%|▌         | 6001/100000 [14:36<63:30:12,  2.43s/it]

step: 6000, eval loss: 0.7513154190778732, eval acc: 0.9982499999999999


  7%|▋         | 7001/100000 [17:00<62:51:49,  2.43s/it]

step: 7000, eval loss: 0.7496176418662072, eval acc: 0.9990000000000001


  8%|▊         | 8001/100000 [19:24<62:22:36,  2.44s/it]

step: 8000, eval loss: 0.7519190725684166, eval acc: 0.99675


  9%|▉         | 9001/100000 [21:49<61:03:56,  2.42s/it]

step: 9000, eval loss: 0.7498063942790032, eval acc: 0.9990000000000001


 10%|█         | 10001/100000 [24:12<60:20:12,  2.41s/it]

step: 10000, eval loss: 0.749206371307373, eval acc: 0.9997499999999999


 11%|█         | 11001/100000 [26:36<62:19:50,  2.52s/it]

step: 11000, eval loss: 0.7504173627495766, eval acc: 0.998


 12%|█▏        | 12001/100000 [29:00<60:01:43,  2.46s/it]

step: 12000, eval loss: 0.7498601135611535, eval acc: 0.9984999999999999


 13%|█▎        | 13001/100000 [31:24<59:13:31,  2.45s/it]

step: 13000, eval loss: 0.7549134087562561, eval acc: 0.99625


 14%|█▍        | 14001/100000 [33:53<59:24:58,  2.49s/it]

step: 14000, eval loss: 0.748899539411068, eval acc: 0.99925


 15%|█▌        | 15001/100000 [36:20<58:07:06,  2.46s/it]

step: 15000, eval loss: 0.7490520638227463, eval acc: 0.9990000000000001


 16%|█▌        | 16001/100000 [38:45<57:58:09,  2.48s/it]

step: 16000, eval loss: 0.7486720624566078, eval acc: 0.9990000000000001


 17%|█▋        | 17001/100000 [41:10<56:57:16,  2.47s/it]

step: 17000, eval loss: 0.7518499410152435, eval acc: 0.9975


 18%|█▊        | 18001/100000 [43:34<67:30:04,  2.96s/it]

step: 18000, eval loss: 0.7515738433599473, eval acc: 0.9980000000000001


 18%|█▊        | 18144/100000 [43:55<3:18:09,  6.88it/s] 


KeyboardInterrupt: 