In [12]:
row["run_id"]

'76d45e8fdc07489fbd155c3ac5142008'

In [15]:
# trying to use good initialized model and use original
import os
from hydra import compose, initialize
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
import mlflow
import optuna
import numpy as np
from tqdm import tqdm
from omegaconf.omegaconf import OmegaConf
from torch.utils.data import DataLoader

from dataset import MnistDataset
from utils import get_or_create_experiment, parse_experiment_runs_to_optuna_study
torch.autograd.set_detect_anomaly(True)

runs = mlflow.search_runs(experiment_ids=["823019999500194436"], filter_string="metrics.eval_ucc_acc>0.25").sort_values("metrics.train_ucc_acc", ascending=False)
row = runs.iloc[1]
cfg = OmegaConf.load(os.path.join("mlruns", row["experiment_id"], "e7bd69f493604ce99a968daae16c9822", "artifacts/config.yaml"))
cfg.model.alpha = 0.5
init_path = os.path.join("mlruns", row["experiment_id"], "e7bd69f493604ce99a968daae16c9822", "artifacts", "init_model", "data", "model.pth")
model_path = os.path.join("mlruns", row["experiment_id"], "e7bd69f493604ce99a968daae16c9822", "artifacts", "best_model.pth", "data", "model.pth")
best_init_model = torch.load(init_path, weights_only=False)
model = torch.load(model_path, weights_only=False)
model.alpha = 0.5

In [17]:
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

def init_dataloader(args):
    # assert args.dataset in [
    #     "mnist",
    #     "camelyon",
    # ], "Mode should be either mnist or camelyon"
    # if args.dataset == "mnist":
    train_dataset_len = args.train_num_steps * args.batch_size
    train_dataset = MnistDataset(
        mode="train",
        num_instances=args.num_instances,
        num_samples_per_class=args.num_samples_per_class,
        digit_arr=list(range(args.ucc_end-args.ucc_start+1)),
        ucc_start=args.ucc_start,
        ucc_end=args.ucc_end,
        length=train_dataset_len,
    )
    val_dataset_len = args.val_num_steps * args.batch_size
    val_dataset = MnistDataset(
        mode="val",
        num_instances=args.num_instances,
        num_samples_per_class=args.num_samples_per_class,
        digit_arr=list(range(args.ucc_end-args.ucc_start+1)),
        ucc_start=args.ucc_start,
        ucc_end=args.ucc_end,
        length=val_dataset_len,
    )
    # create dataloader
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=False,
    )
    return train_loader, val_loader

def evaluate(model, val_loader, device, ae_mode, clf_mode) -> dict:
    model.eval()
    val_ae_loss_list = []
    val_ucc_loss_list = []
    val_acc_list = []
    rec_criterion = nn.MSELoss()
    with torch.no_grad():
        for batch_samples, batch_labels in val_loader:
            batch_samples = batch_samples.to(device)
            batch_labels = batch_labels.to(device)
            if ae_mode:
                batch_size, num_instances, num_channel, patch_size, _ = batch_samples.shape
                x = batch_samples.view(-1, num_channel,
                                       batch_samples.shape[-2], batch_samples.shape[-1])
                features = model.encoder(x)
                reconstruction = model.decoder(features)
                reconstruction = reconstruction.view(batch_size, num_instances,
                                    1, patch_size, patch_size)
                ae_loss = rec_criterion(batch_samples, reconstruction)
                val_ae_loss_list.append(ae_loss.item())

            if clf_mode:
                ucc_logits, _ = model(batch_samples)
                ucc_val_loss = F.cross_entropy(ucc_logits, batch_labels)
                # acculate accuracy
                _, ucc_predicts = torch.max(ucc_logits, dim=1)
                acc = torch.sum(
                    ucc_predicts == batch_labels).item() / len(batch_labels)
                val_acc_list.append(acc)
                val_ucc_loss_list.append(ucc_val_loss.item())

        if ae_mode and clf_mode:
            return {
                "eval_ae_loss": np.round(np.mean(val_ae_loss_list), 5),
                "eval_ucc_loss": np.round(np.mean(val_ucc_loss_list), 5),
                "eval_ucc_acc": np.round(np.mean(val_acc_list), 5)
            }
        elif ae_mode:
            return {
                "eval_ae_loss": np.round(np.mean(val_ae_loss_list), 5),
            }
        elif clf_mode:
            return {
                "eval_ucc_loss": np.round(np.mean(val_ucc_loss_list), 5),
                "eval_ucc_acc": np.round(np.mean(val_acc_list), 5)
            }

def train(args, model, optimizer, lr_scheduler, train_loader, val_loader, device):
    print("training")
    # mlflow.pytorch.log_model(model, "init_model")
    # output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
    model.train()
    step = 0
    best_eval_acc = 0
    patience = 2
    # ae_steps = 500

    rec_criterion = nn.MSELoss()
    if step == 0:
        mlflow.pytorch.log_model(
            model,
            "best_model.pth"
        )
    for batch_samples, batch_labels in tqdm(train_loader):
        batch_samples = batch_samples.to(device)
        batch_labels = batch_labels.to(device)
        optimizer.zero_grad()

        # ucc_logits, reconstruction = model(batch_samples, batch_labels)
    
    
        ucc_logits, reconstruction= model(batch_samples, batch_labels)
        ce_loss , ae_loss, loss = model.compute_loss(
            inputs=batch_samples,
            labels=batch_labels,
            output=ucc_logits,
            reconstruction=reconstruction,
            return_losses = True
        )

        loss.backward()

        optimizer.step()

        step += 1

        if step % 10 == 0:
            with torch.no_grad():
                metric_dict = {}
                grad_log = {name: torch.mean(param.grad).cpu().item(
                ) for name, param in model.named_parameters() if isinstance(param.grad, torch.Tensor)}
                mlflow.log_metrics(grad_log, step=step)
                metric_dict["train_ae_loss"] = np.round(ae_loss.detach().item(), 5)
                _, pred = torch.max(ucc_logits, dim=1)
                accuracy = torch.sum(pred.flatten() == batch_labels.flatten())/len(batch_labels)
                metric_dict["train_ucc_loss"] = np.round(ce_loss.detach().item(), 5)
                metric_dict["train_ucc_acc"] = np.round(float(accuracy), 5)
                metric_dict["loss"] = np.round(float(loss), 5)
            mlflow.log_metrics(metric_dict, step=step)

        if step % args.save_interval == 0:
            eval_metric_dict = evaluate(
                model,
                val_loader,
                device,
                ae_mode=True,
                clf_mode=True)
            print(f"step: {step}," + ",".join([f"{key}: {value}"for key, value in eval_metric_dict.items()]))
            mlflow.log_metrics(eval_metric_dict, step=step)
            # early stop
            eval_acc = eval_metric_dict["eval_ucc_acc"]
            if eval_acc > best_eval_acc:
                # patience = 2
                best_eval_acc = eval_acc
                mlflow.pytorch.log_model(
                    model,
                    "best_model.pth"
                )
            # else:
            #     patience -= 1

            # if patience <= 0:
            #     break
            if step == 80000:
                break
            model.train()

    print("Training finished!!!")
    return best_eval_acc


mlflow.set_tracking_uri("mlruns")
run_name = "ucc-drn-multi-step-recreation-same-init"
experiment_id = get_or_create_experiment(experiment_name=run_name)
mlflow.set_experiment(experiment_id=experiment_id)
with mlflow.start_run(nested=True, run_name="melodic-conch-890"):

    defaults = {
        # "init_method": {
        #     "type": "categorical",
        #     "range": ["uniform", "normal", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"],
        #     "aliases": [
        #         "model.drn.init_method",
        #     ]
        # },
        # "num_bins": {
        #     "type": "int",
        #     "value": 10,
        #     "range": [5, 100],
        #     "aliases": [
        #         "model.drn.num_bins",
        #         "args.num_bins",
        #         "model.kde_model.num_bins"
        #     ]
        # },
        "lr": {
            "type": "float",
            "value": 0.005,
            "range": [0.008, 0.08],
            "aliases": ["args.learning_rate"]
        },
        "hidden_q": {
            "type": "int",
            "value": 100,
            "range": [4, 100],
            "aliases": ["model.drn.hidden_q"]
        },
        "num_layers": {
            "type": "int",
            "value": 2,
            "range": [1, 10],
            "aliases": ["model.drn.num_layers"]
        },
        "num_nodes": {
            "type": "int",
            "value": 9,
            "range": [1, 10],
            "aliases": ["model.drn.num_nodes"]
        }
    }

    print(cfg)
    mlflow.log_dict(dict(OmegaConf.to_object(cfg)), "config.yaml")

    args = cfg.args
    device = torch.device("cuda" if torch.cuda.is_available() else "mps")

    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.001, amsgrad=True)
    train_loader, val_loader = init_dataloader(args)
    mlflow.pytorch.log_model(model, "init_model")
    best_acc = train(args, model, optimizer, None,
                     train_loader, val_loader, device)

{'args': {'batch_size': 20, 'dataset': 'mnist', 'learning_rate': 0.005, 'model_dir': 'saved_models/', 'model_name': 'mnist_ucc_drn', 'num_bins': 11, 'num_features': 10, 'num_instances': 32, 'num_samples_per_class': 5, 'num_workers': 4, 'save_interval': 1000, 'seed': 22, 'train_num_steps': 100000, 'ucc_end': 4, 'ucc_start': 1, 'val_num_steps': 200}, 'model': {'decoder': {'block1_num_layer': 1, 'block1_output_channel': 64, 'block2_num_layer': 1, 'block2_output_channel': 32, 'block3_num_layer': 1, 'block3_output_channel': 16, 'linear_size': 6272, 'output_channel': 1, 'reshape_size': [7, 7, 128]}, 'drn': {'hidden_q': 100, 'init_method': 'xavier_uniform', 'num_bins': 11, 'num_layers': 2, 'num_nodes': 9, 'output_bins': 4, 'output_nodes': 1}, 'encoder': {'block1_num_layer': 1, 'block1_output_channel': 321, 'block2_num_layer': 1, 'block2_output_channel': 64, 'block3_num_layer': 1, 'block3_output_channel': 128, 'conv_input_channel': 1, 'conv_output_channel': 16, 'flatten_size': 6272, 'num_featu



x_train shape: torch.Size([50000, 1, 28, 28])
50000 train samples
10000 val samples




training




step: 1000,eval_ae_loss: 0.11678,eval_ucc_loss: 0.7763,eval_ucc_acc: 0.99975




step: 2000,eval_ae_loss: 0.11312,eval_ucc_loss: 0.77051,eval_ucc_acc: 1.0


  3%|▎         | 3000/100000 [32:04<542:40:59, 20.14s/it]

step: 3000,eval_ae_loss: 0.11155,eval_ucc_loss: 0.76605,eval_ucc_acc: 0.99975


  4%|▍         | 4000/100000 [42:32<539:21:23, 20.23s/it]

step: 4000,eval_ae_loss: 0.11115,eval_ucc_loss: 0.76702,eval_ucc_acc: 0.9985


  5%|▌         | 5000/100000 [53:01<532:49:44, 20.19s/it]

step: 5000,eval_ae_loss: 0.10646,eval_ucc_loss: 0.7607,eval_ucc_acc: 1.0


  6%|▌         | 6000/100000 [1:03:30<528:20:43, 20.23s/it]

step: 6000,eval_ae_loss: 0.10563,eval_ucc_loss: 0.75862,eval_ucc_acc: 1.0


  7%|▋         | 7000/100000 [1:14:07<524:05:53, 20.29s/it]

step: 7000,eval_ae_loss: 0.10209,eval_ucc_loss: 0.75737,eval_ucc_acc: 1.0


  8%|▊         | 8000/100000 [1:25:01<518:08:16, 20.27s/it]

step: 8000,eval_ae_loss: 0.10112,eval_ucc_loss: 0.75616,eval_ucc_acc: 0.99975


  9%|▉         | 9000/100000 [1:35:38<512:40:13, 20.28s/it]

step: 9000,eval_ae_loss: 0.10044,eval_ucc_loss: 0.75473,eval_ucc_acc: 1.0


 10%|█         | 10000/100000 [1:46:09<507:23:51, 20.30s/it]

step: 10000,eval_ae_loss: 0.10213,eval_ucc_loss: 0.75442,eval_ucc_acc: 0.99975


 11%|█         | 11000/100000 [1:56:45<502:59:54, 20.35s/it]

step: 11000,eval_ae_loss: 0.10019,eval_ucc_loss: 0.75351,eval_ucc_acc: 0.99975


 12%|█▏        | 12000/100000 [2:07:28<495:45:12, 20.28s/it]

step: 12000,eval_ae_loss: 0.10124,eval_ucc_loss: 0.75276,eval_ucc_acc: 1.0


 13%|█▎        | 13000/100000 [2:18:07<493:43:12, 20.43s/it]

step: 13000,eval_ae_loss: 0.10167,eval_ucc_loss: 0.75236,eval_ucc_acc: 1.0


 14%|█▍        | 14000/100000 [2:28:54<476:49:11, 19.96s/it]

step: 14000,eval_ae_loss: 0.09864,eval_ucc_loss: 0.75172,eval_ucc_acc: 1.0


 15%|█▌        | 15000/100000 [2:39:09<468:51:48, 19.86s/it]

step: 15000,eval_ae_loss: 0.09824,eval_ucc_loss: 0.75114,eval_ucc_acc: 1.0


 16%|█▌        | 16000/100000 [2:49:20<466:49:18, 20.01s/it]

step: 16000,eval_ae_loss: 0.09771,eval_ucc_loss: 0.75097,eval_ucc_acc: 1.0


 17%|█▋        | 17000/100000 [2:59:40<463:41:21, 20.11s/it]

step: 17000,eval_ae_loss: 0.09494,eval_ucc_loss: 0.75052,eval_ucc_acc: 1.0


 18%|█▊        | 18000/100000 [3:10:28<494:04:47, 21.69s/it]

step: 18000,eval_ae_loss: 0.10039,eval_ucc_loss: 0.75192,eval_ucc_acc: 0.99975


 19%|█▉        | 19000/100000 [3:21:46<465:59:35, 20.71s/it]

step: 19000,eval_ae_loss: 0.09649,eval_ucc_loss: 0.74983,eval_ucc_acc: 1.0


 20%|██        | 20000/100000 [3:32:21<442:27:29, 19.91s/it]

step: 20000,eval_ae_loss: 0.09501,eval_ucc_loss: 0.74963,eval_ucc_acc: 0.99975


 21%|██        | 21000/100000 [3:42:51<439:58:15, 20.05s/it]

step: 21000,eval_ae_loss: 0.09496,eval_ucc_loss: 0.74976,eval_ucc_acc: 0.99975


 22%|██▏       | 22000/100000 [3:53:23<433:44:27, 20.02s/it]

step: 22000,eval_ae_loss: 0.09535,eval_ucc_loss: 0.74928,eval_ucc_acc: 1.0


 23%|██▎       | 23000/100000 [4:03:54<428:55:15, 20.05s/it]

step: 23000,eval_ae_loss: 0.0955,eval_ucc_loss: 0.74929,eval_ucc_acc: 0.99975


 24%|██▍       | 24000/100000 [4:14:30<424:05:53, 20.09s/it]

step: 24000,eval_ae_loss: 0.09617,eval_ucc_loss: 0.74865,eval_ucc_acc: 1.0


 25%|██▌       | 25000/100000 [4:25:22<417:47:55, 20.05s/it]

step: 25000,eval_ae_loss: 0.09527,eval_ucc_loss: 0.7487,eval_ucc_acc: 1.0


 26%|██▌       | 26000/100000 [4:36:00<411:48:11, 20.03s/it]

step: 26000,eval_ae_loss: 0.09451,eval_ucc_loss: 0.74904,eval_ucc_acc: 0.9995


 27%|██▋       | 27000/100000 [4:46:34<407:12:49, 20.08s/it]

step: 27000,eval_ae_loss: 0.09377,eval_ucc_loss: 0.74847,eval_ucc_acc: 0.99975


 28%|██▊       | 28000/100000 [4:57:09<403:22:10, 20.17s/it]

step: 28000,eval_ae_loss: 0.09337,eval_ucc_loss: 0.74812,eval_ucc_acc: 1.0


 29%|██▉       | 29000/100000 [5:07:49<397:00:30, 20.13s/it]

step: 29000,eval_ae_loss: 0.09375,eval_ucc_loss: 0.74821,eval_ucc_acc: 0.99975


 30%|███       | 30000/100000 [5:18:30<403:45:54, 20.77s/it]

step: 30000,eval_ae_loss: 0.09464,eval_ucc_loss: 0.74783,eval_ucc_acc: 1.0


 31%|███       | 31000/100000 [5:29:24<385:06:11, 20.09s/it]

step: 31000,eval_ae_loss: 0.09405,eval_ucc_loss: 0.74812,eval_ucc_acc: 0.99925


 32%|███▏      | 32000/100000 [5:40:07<379:54:24, 20.11s/it]

step: 32000,eval_ae_loss: 0.09522,eval_ucc_loss: 0.74774,eval_ucc_acc: 1.0


 33%|███▎      | 33000/100000 [5:50:48<375:51:05, 20.20s/it]

step: 33000,eval_ae_loss: 0.09353,eval_ucc_loss: 0.74754,eval_ucc_acc: 1.0


 34%|███▍      | 34000/100000 [6:01:40<388:29:52, 21.19s/it]

step: 34000,eval_ae_loss: 0.09401,eval_ucc_loss: 0.74745,eval_ucc_acc: 1.0


 35%|███▌      | 35000/100000 [6:12:34<363:25:33, 20.13s/it]

step: 35000,eval_ae_loss: 0.09469,eval_ucc_loss: 0.74751,eval_ucc_acc: 0.99975


 36%|███▌      | 36000/100000 [6:23:39<358:22:17, 20.16s/it]

step: 36000,eval_ae_loss: 0.09389,eval_ucc_loss: 0.74733,eval_ucc_acc: 1.0


 37%|███▋      | 37000/100000 [6:34:31<353:58:04, 20.23s/it]

step: 37000,eval_ae_loss: 0.09471,eval_ucc_loss: 0.74711,eval_ucc_acc: 1.0


 38%|███▊      | 38000/100000 [6:45:16<345:43:43, 20.07s/it]

step: 38000,eval_ae_loss: 0.0936,eval_ucc_loss: 0.74723,eval_ucc_acc: 0.99975


 39%|███▉      | 39000/100000 [6:56:05<343:51:01, 20.29s/it]

step: 39000,eval_ae_loss: 0.09368,eval_ucc_loss: 0.74693,eval_ucc_acc: 1.0


 40%|████      | 40000/100000 [7:06:55<335:55:29, 20.16s/it]

step: 40000,eval_ae_loss: 0.09477,eval_ucc_loss: 0.74684,eval_ucc_acc: 1.0


 41%|████      | 41000/100000 [7:17:59<333:13:57, 20.33s/it]

step: 41000,eval_ae_loss: 0.09339,eval_ucc_loss: 0.74686,eval_ucc_acc: 1.0


 42%|████▏     | 42000/100000 [7:30:51<372:33:44, 23.12s/it]

step: 42000,eval_ae_loss: 0.09527,eval_ucc_loss: 0.74668,eval_ucc_acc: 1.0


 43%|████▎     | 43000/100000 [7:43:28<350:32:38, 22.14s/it]

step: 43000,eval_ae_loss: 0.1005,eval_ucc_loss: 0.74698,eval_ucc_acc: 1.0


 44%|████▎     | 43731/100000 [7:52:19<10:07:44,  1.54it/s] 


KeyboardInterrupt: 

In [18]:
drn = best_init_model.ucc_classifier
layers = list(drn.modules())[1:-1]
for index, layer in enumerate(layers):
    print(index, "weights")
    print(layer.W.detach().cpu().numpy())
    print(layer.W.detach().cpu().numpy().max())
    print(layer.W.detach().cpu().numpy().min())
    print(layer.W.detach().cpu().numpy().mean())
    print(layer.W.detach().cpu().numpy().std())

0 weights
[[ 0.12929279 -0.24679643 -0.3951519  -0.50869083 -0.1545282  -0.50259244
   0.34096634  0.53263193  0.4816404  -0.47664326]
 [-0.10936746  0.32023436  0.15579069  0.46153766  0.24118489 -0.05316174
   0.1069783   0.5318013   0.48920268 -0.23420337]
 [ 0.1459015  -0.5443165  -0.02662891  0.4736374   0.05059469  0.44143683
   0.2309137  -0.5014591  -0.1175445  -0.03031945]
 [-0.42587566 -0.17734405 -0.04971427  0.06673527 -0.23131001  0.10457581
   0.01849186  0.43196142  0.21586502 -0.268082  ]
 [ 0.45593995 -0.33081645  0.00265354 -0.1670514   0.1638146   0.5280388
  -0.03166616  0.10945356  0.35259664  0.12462896]
 [-0.5077368  -0.06593943 -0.3657821  -0.00541192 -0.4278472  -0.2834111
  -0.48945418  0.5266476   0.0878678  -0.19442981]
 [-0.19350737  0.36867696 -0.03424609  0.16475844 -0.47798714  0.07769984
  -0.02374071  0.36930096 -0.07794341  0.39444548]
 [-0.24032155  0.30128026  0.07513559 -0.36390638  0.5397268   0.2302916
   0.53798884  0.24895728  0.55028635  0.237