In [1]:
import mlflow
import torch

name = "track-gradients"
mlflow.set_tracking_uri("/Users/tanguanyu/UCC-DRN-Pytorch/mnist/mlruns")
experiment_id = "511211072178994368"
runs = mlflow.search_runs(experiment_names=["track-gradients"], output_format="list")

In [1]:

import torch
import sys
sys.path.append("../mnist")

In [2]:
from model import DRNOnlyModel
from hydra import compose, initialize

In [7]:
initial_params = {
        "hidden_q": 10,
        "num_bins": 11,
        "lr": 0.0498,
        "num_layers": 1,
        "num_nodes": 10
    }

defaults = {
                    "num_bins": {
                        "type": "int",
                        "range": [5,100],
                        "aliases": [
                            "model.drn.num_bins",
                            "args.num_bins",
                            "model.kde_model.num_bins"
                        ]
                    },
                    "lr": {
                        "type": "float",
                        "range": [0.0001, 0.1],
                        "aliases": ["args.learning_rate"]
                    },
                    "hidden_q": {
                        "type": "int",
                        "range": [4, 100],
                        "aliases": ["model.drn.hidden_q"]
                    },
                    "num_layers": {
                        "type": "int",
                        "range": [1, 10],
                        "aliases": ["model.drn.num_layers"]
                    },
                    "num_nodes": {
                        "type": "int",
                        "range": [1, 10],
                        "aliases": ["model.drn.num_nodes"]
                    }
                }
                
with initialize(version_base=None, config_path="../configs"):
    cfg = compose(config_name="train_drn")
for key, value in defaults.items():
    v = initial_params[key]
    for a in value["aliases"]:
        exec(f"cfg.{a} = {v}")
model = DRNOnlyModel(cfg)

(10, 11)
(4, 10)


In [17]:
model.drn

Sequential(
  (0): DRN(in_features=10, in_bins=11, out_features=10, out_bins=10)
  (1): DRN(in_features=10, in_bins=10, out_features=1, out_bins=4)
  (2): Flatten(start_dim=1, end_dim=-1)
)

In [32]:
model

DRNOnlyModel(
  (drn): Sequential(
    (0): DRN(in_features=10, in_bins=11, out_features=10, out_bins=10)
    (1): DRN(in_features=10, in_bins=10, out_features=1, out_bins=4)
    (2): Flatten(start_dim=1, end_dim=-1)
  )
)

In [3]:
import os

os.chdir("../mnist")

In [4]:
os.getcwd()

'/Users/tanguanyu/UCC-DRN-Pytorch/mnist'

In [1]:
from collections import defaultdict

import mlflow
import torch
import numpy as np

experiment_name = "ucc-drn-goated-init"
mlflow.set_tracking_uri("/Users/tanguanyu/UCC-DRN-Pytorch/mnist/mlruns")
runs = mlflow.search_runs(experiment_names=[experiment_name,], output_format="list")
experiment = mlflow.set_experiment(experiment_name=experiment_name)
experiment_id = experiment.experiment_id
run_ids = [
    run.info.run_id for run in runs
]
run_names = [
    run.info.run_name for run in runs
]
for id, name in zip(run_ids, run_names):
    prefix = f"mlruns/{experiment_id}/{id}"
    try:
        with open(f"{prefix}/metrics/eval_ucc_acc") as file:
            eval_ucc = file.readlines()
        final_eval_acc = float(eval_ucc[-1].split(" ")[1])
        if final_eval_acc>0.9:
            

            # init_model = torch.load(f"{prefix}/artifacts/init_model/data/model.pth", weights_only=False)
            trained_model = torch.load(f"{prefix}/artifacts/best_model/data/model.pth", weights_only=False)
            drn_model = trained_model.ucc_classifier
            print(id)
            print(name)
            print(final_eval_acc)
            # print("".join(eval_ucc))
            param_dict = defaultdict(list)
            n_list = ["W", "bq", "ba", "lama", "lamq"]
            for name, param in drn_model.named_parameters():
                for n in n_list:
                    if n in name:
                        param_dict[f"{n}_max"].append(param.max().detach().cpu().numpy())
                        param_dict[f"{n}_min"].append(param.min().detach().cpu().numpy())
                        param_dict[f"{n}_mean"].append(param.mean().detach().cpu().numpy())
            for p, value in param_dict.items():
                print(f"{p}: {np.mean(value)}")
    except Exception as e:
        print(e)

98d0cca708cc4f5ba40314aa134af4cd
48ebf24973744164844891aa7584ff54_resume
0.973
W_max: 17.22736930847168
W_min: -8.929351806640625
W_mean: 1.7172229290008545
ba_max: 4.961871147155762
ba_min: 1.6309356689453125
ba_mean: 3.4531285762786865
bq_max: 2.393209218978882
bq_min: -0.4197807312011719
bq_mean: 0.7888431549072266
lama_max: 0.8599966764450073
lama_min: 0.07577153295278549
lama_mean: 0.38784635066986084
lamq_max: 0.3350653052330017
lamq_min: -1.408508539199829
lamq_mean: -0.6045825481414795
48ebf24973744164844891aa7584ff54
4d2f6b34d5f14599a6569c6f43b44037_resume
0.9735
W_max: 15.587149620056152
W_min: -8.233423233032227
W_mean: 1.4915902614593506
ba_max: 4.322299480438232
ba_min: 1.219214916229248
ba_mean: 2.915407419204712
bq_max: 2.2859647274017334
bq_min: -0.44844210147857666
bq_mean: 0.7488437294960022
lama_max: 0.8612028360366821
lama_min: 0.0786328986287117
lama_mean: 0.3851850926876068
lamq_max: 0.3438963294029236
lamq_min: -1.3729379177093506
lamq_mean: -0.5688416361808777
1

In [17]:
run_names

['peaceful-kite-300',
 'orderly-elk-390',
 'clumsy-flea-932',
 'bald-goose-275',
 'c7c6e956dc2c4b93b6079c9311bbbe79_resume_2',
 'f47d86fa34ab4526b75c4d8525baa366_resume_2',
 'wise-smelt-451',
 'stylish-bee-110',
 'nebulous-shad-766',
 'upset-loon-97',
 'ccefe815ccf9454b84e9bd69ae0d9ca2_resume_2',
 'ccefe815ccf9454b84e9bd69ae0d9ca2_resume',
 'ccefe815ccf9454b84e9bd69ae0d9ca2_resume',
 'ccefe815ccf9454b84e9bd69ae0d9ca2_resume',
 '439ed7c180e645bab036bb4c00bcafcd_resume',
 'overjoyed-dolphin-1',
 'luxuriant-donkey-158',
 'treasured-tern-816',
 '4d2f6b34d5f14599a6569c6f43b44037_resume',
 'luminous-perch-450',
 'lr-0.0005',
 'lr-0.0005',
 'lr-0.0005',
 'lr-0.0005',
 'enchanting-rat-536',
 'hilarious-swan-333',
 'selective-goose-567',
 'marvelous-boar-372',
 'bittersweet-cat-414',
 'capricious-fox-694',
 'ae9704aec54d48729f10dcb726d0125f_resume',
 '79f73b9140ff4a0ea5b8b6a668cc85d1_resume',
 '79f73b9140ff4a0ea5b8b6a668cc85d1_resume',
 'bittersweet-pig-324',
 'wise-rook-554',
 'placid-lynx-650

In [2]:
run_id = "48ebf24973744164844891aa7584ff54"
trained_model = torch.load(f"mlruns/{experiment_id}/{run_id}/artifacts/best_model/data/model.pth", weights_only=False)


In [None]:
from dataset import MnistDataset
from torch.utils.data import DataLoader

eval_dataset = MnistDataset(
    num_instances= 32,
    num_samples_per_class = 5,
    digit_arr = list(range(0,10)),
    ucc_start= 1,
    ucc_end= 4,
    mode= "test",
)
eval_dataloader = DataLoader(eval_dataset, 32)


x_train shape: torch.Size([50000, 1, 28, 28])
50000 train samples
10000 val samples
tensor(0.3105)
tensor(0.1325)
10000 test samples


In [7]:
device = torch.device("mps")
for batch_samples, batch_labels in eval_dataloader:
    print(batch_samples.shape)
    batch_samples = batch_samples.to(device)
    output = trained_model(batch_samples)
    print(torch.max(output, dim=1))
    print(batch_labels)
    break

torch.Size([32, 32, 1, 28, 28])
torch.return_types.max(
values=tensor([1.0000, 0.9899, 0.9494, 0.9989, 0.9999, 0.9876, 0.9553, 0.9988, 1.0000,
        0.9919, 0.9987, 0.9964, 0.9991, 0.9896, 0.6986, 0.9981, 1.0000, 0.9911,
        0.9813, 0.9988, 1.0000, 0.9915, 0.9570, 0.9988, 1.0000, 0.9874, 0.9723,
        0.9938, 0.8732, 0.9916, 0.9988, 0.9984], device='mps:0',
       grad_fn=<MaxBackward0>),
indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 3, 3, 0, 1, 3, 3, 0, 1, 2, 3, 0, 1, 2, 3,
        0, 1, 2, 3, 1, 1, 3, 3], device='mps:0'))
tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3,
        0, 1, 2, 3, 0, 1, 2, 3])


In [9]:
from PIL import Image

mean = 0.1325
std = 0.3105
for image in batch_samples[10]:
    print(image.shape)
    a = Image.fromarray((image.to("cpu").numpy().squeeze()*std+mean)*255)
    a.show()

torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
