In [1]:
import torch
from torch import nn
from torchviz import make_dot

In [2]:
# some magic so that the notebook will reload external python modules;
# see https://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [3]:
import argparse
import json
import logging
import random
from collections import defaultdict, OrderedDict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data
from tqdm import trange
import wandb
import os

In [4]:
from pfedhn_pc.models import CNNHyperPC, CNNTargetPC, LocalLayer
from pfedhn_pc.node import BaseNodesForLocal
from utils import get_device, set_logger, set_seed, str2bool

In [5]:
def eval_model(nodes, num_nodes, hnet, net, criteria, device, split):
    curr_results = evaluate(nodes, num_nodes, hnet, net, criteria, device, split=split)
    total_correct = sum([val['correct'] for val in curr_results.values()])
    total_samples = sum([val['total'] for val in curr_results.values()])
    avg_loss = np.mean([val['loss'] for val in curr_results.values()])
    avg_acc = total_correct / total_samples

    all_acc = [val['correct'] / val['total'] for val in curr_results.values()]

    return curr_results, avg_loss, avg_acc, all_acc

@torch.no_grad()
def evaluate(nodes: BaseNodesForLocal, num_nodes, hnet, net, criteria, device, split='test'):
    hnet.eval()
    results = defaultdict(lambda: defaultdict(list))

    for node_id in range(num_nodes):  # iterating over nodes

        running_loss, running_correct, running_samples = 0., 0., 0.
        if split == 'test':
            curr_data = nodes.test_loaders[node_id]
        elif split == 'val':
            curr_data = nodes.val_loaders[node_id]
        else:
            curr_data = nodes.train_loaders[node_id]

        weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
        net.load_state_dict(weights)

        for batch_count, batch in enumerate(curr_data):
            img, label = tuple(t.to(device) for t in batch)
            net_out = net(img)
            pred = nodes.local_layers[node_id](net_out)
            running_loss += criteria(pred, label).item()
            running_correct += pred.argmax(1).eq(label).sum().item()
            running_samples += len(label)

        results[node_id]['loss'] = running_loss / (batch_count + 1)
        results[node_id]['correct'] = running_correct
        results[node_id]['total'] = running_samples

    return results



In [6]:
def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
          steps: int, inner_steps: int, optim: str, lr: float, inner_lr: float,
          embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
          n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
          ) -> None:

    ###############################
    # init nodes, hnet, local net #
    ###############################

    nodes = BaseNodesForLocal(
        data_name=data_name,
        data_path=data_path,
        n_nodes=num_nodes,
        base_layer=LocalLayer,
        layer_config={'n_input': 84, 'n_output': 10 if data_name == 'cifar10' else 100},
        base_optimizer=torch.optim.SGD, optimizer_config=dict(lr=inner_lr, momentum=.9, weight_decay=inner_wd),
        device=device,
        batch_size=bs,
        classes_per_node=classes_per_node,
    )

    embed_dim = embed_dim
    if embed_dim == -1:
        logging.info("auto embedding size")
        embed_dim = int(1 + num_nodes / 4)

    hnet = CNNHyperPC(
        num_nodes, embed_dim, hidden_dim=hyper_hid, n_hidden=n_hidden,
        n_kernels=n_kernels
    )
    net = CNNTargetPC(n_kernels=n_kernels)

    hnet = hnet.to(device)
    net = net.to(device)

    ##################
    # init optimizer #
    ##################
    embed_lr = embed_lr if embed_lr is not None else lr
    optimizers = {
        'sgd': torch.optim.SGD(
            [
                {'params': [p for n, p in hnet.named_parameters() if 'embed' not in n]},
                {'params': [p for n, p in hnet.named_parameters() if 'embed' in n], 'lr': embed_lr}
            ], lr=lr, momentum=0.9, weight_decay=wd
        ),
        'adam': torch.optim.Adam(params=hnet.parameters(), lr=lr)
    }
    optimizer = optimizers[optim]
    criteria = torch.nn.CrossEntropyLoss()

    ################
    # init metrics #
    ################
    last_eval = -1
    best_step = -1
    best_acc = -1
    test_best_based_on_step, test_best_min_based_on_step = -1, -1
    test_best_max_based_on_step, test_best_std_based_on_step = -1, -1
    step_iter = trange(steps)

    results = defaultdict(list)
    for step in step_iter:
        hnet.train()

        # select client at random
        node_id = random.choice(range(num_nodes))

        # produce & load local network weights
        weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
        net.load_state_dict(weights)

        # init inner optimizer
        inner_optim = torch.optim.SGD(
            net.parameters(), lr=inner_lr, momentum=.9, weight_decay=inner_wd
        )

        # storing theta_i for later calculating delta theta
        inner_state = OrderedDict({k: tensor.data for k, tensor in weights.items()})

        # NOTE: evaluation on sent model
        with torch.no_grad():
            net.eval()
            batch = next(iter(nodes.test_loaders[node_id]))
            img, label = tuple(t.to(device) for t in batch)
            net_out = net(img)
            pred = nodes.local_layers[node_id](net_out)
            prvs_loss = criteria(pred, label)
            prvs_acc = pred.argmax(1).eq(label).sum().item() / len(label)
            net.train()

        # inner updates -> obtaining theta_tilda
        for i in range(inner_steps):
            net.train()
            inner_optim.zero_grad()
            optimizer.zero_grad()
            nodes.local_optimizers[node_id].zero_grad()

            batch = next(iter(nodes.train_loaders[node_id]))
            img, label = tuple(t.to(device) for t in batch)

            net_out = net(img)
            pred = nodes.local_layers[node_id](net_out)

            loss = criteria(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
            inner_optim.step()
            nodes.local_optimizers[node_id].step()

        optimizer.zero_grad()

        final_state = net.state_dict()

        # Calculating MSE Loss for the predicted HyperNetwork weights
        hn_loss = 0.0
        for key in weights.keys():
            weight_loss = nn.MSELoss()(inner_state[key], final_state[key])
            hn_loss += weight_loss

        # calculating delta theta
        delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in weights.keys()})

        # calculating phi gradient
        hnet_grads = torch.autograd.grad(
            list(weights.values()), hnet.parameters(), grad_outputs=list(delta_theta.values())
        )

        # update hnet weights
        for p, g in zip(hnet.parameters(), hnet_grads):
            p.grad = g

        torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
        optimizer.step()

        step_iter.set_description(
            f"Step: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}"
        )

        if step % eval_every == 0:
            last_eval = step
            step_results, avg_loss, avg_acc, all_acc = eval_model(
                nodes, num_nodes, hnet, net, criteria, device, split="test"
            )
            logging.info(f"\nStep: {step+1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f}, HN Loss: {hn_loss}")

            results['test_avg_loss'].append(avg_loss)
            results['test_avg_acc'].append(avg_acc)

            _, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="val")
            if best_acc < val_avg_acc:
                best_acc = val_avg_acc
                best_step = step
                test_best_based_on_step = avg_acc
                test_best_min_based_on_step = np.min(all_acc)
                test_best_max_based_on_step = np.max(all_acc)
                test_best_std_based_on_step = np.std(all_acc)

            results['val_avg_loss'].append(val_avg_loss)
            results['val_avg_acc'].append(val_avg_acc)
            results['best_step'].append(best_step)
            results['best_val_acc'].append(best_acc)
            results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
            results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
            results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
            results['test_best_std_based_on_step'].append(test_best_std_based_on_step)

        
        # Wandb logging
        wandb_dict = defaultdict(int)
        wandb_dict["step"] = step
        wandb_dict["hn_loss"] = hn_loss.detach().item()
        for key, value_list in results.items():
            wandb_dict[key] = value_list[-1]
        # wandb.log(wandb_dict)

    if step != last_eval:
        _, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="val")
        step_results, avg_loss, avg_acc, all_acc = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="test")
        logging.info(f"\nStep: {step + 1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f}")

        results['test_avg_loss'].append(avg_loss)
        results['test_avg_acc'].append(avg_acc)

        if best_acc < val_avg_acc:
            best_acc = val_avg_acc
            best_step = step
            test_best_based_on_step = avg_acc
            test_best_min_based_on_step = np.min(all_acc)
            test_best_max_based_on_step = np.max(all_acc)
            test_best_std_based_on_step = np.std(all_acc)

        results['val_avg_loss'].append(val_avg_loss)
        results['val_avg_acc'].append(val_avg_acc)
        results['best_step'].append(best_step)
        results['best_val_acc'].append(best_acc)
        results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
        results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
        results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
        results['test_best_std_based_on_step'].append(test_best_std_based_on_step)

    save_path = Path(save_path)
    save_path.mkdir(parents=True, exist_ok=True)
    with open(str(save_path / "results.json"), "w") as file:
        json.dump(results, file, indent=4)

In [7]:
def rel_error(x, y):
    diff = (x - y).flatten()
    idx = torch.argmax(diff.abs())
    sign = 1 if diff[idx] >= 0 else -1
    return sign * torch.max(torch.abs(x - y)/(torch.maximum(torch.abs(x), torch.abs(y))))

In [8]:
a = 5 * torch.ones(5)
b = 2 * torch.ones(5)
print(rel_error(b, a))

tensor(-0.6000)


In [23]:
def train_mse(
    data_name: str,
    data_path: str,
    classes_per_node: int,
    num_nodes: int,
    steps: int,
    inner_steps: int,
    optim: str,
    lr: float,
    inner_lr: float,
    embed_lr: float,
    wd: float,
    inner_wd: float,
    embed_dim: int,
    hyper_hid: int,
    n_hidden: int,
    n_kernels: int,
    bs: int,
    device,
    eval_every: int,
    save_path: Path,
) -> None:
    ###############################
    # init nodes, hnet, local net #
    ###############################

    nodes = BaseNodesForLocal(
        data_name=data_name,
        data_path=data_path,
        n_nodes=num_nodes,
        base_layer=LocalLayer,
        layer_config={"n_input": 84, "n_output": 10 if data_name == "cifar10" else 100},
        base_optimizer=torch.optim.SGD,
        optimizer_config=dict(lr=inner_lr, momentum=0.9, weight_decay=inner_wd),
        device=device,
        batch_size=bs,
        classes_per_node=classes_per_node,
    )

    embed_dim = embed_dim
    if embed_dim == -1:
        logging.info("auto embedding size")
        embed_dim = int(1 + num_nodes / 4)

    hnet = CNNHyperPC(
        num_nodes,
        embed_dim,
        hidden_dim=hyper_hid,
        n_hidden=n_hidden,
        n_kernels=n_kernels,
    )
    net = CNNTargetPC(n_kernels=n_kernels)

    hnet = hnet.to(device)
    net = net.to(device)

    ##################
    # init optimizer #
    ##################
    embed_lr = embed_lr if embed_lr is not None else lr
    optimizers = {
        "sgd": torch.optim.SGD(
            [
                {"params": [p for n, p in hnet.named_parameters() if "embed" not in n]},
                {
                    "params": [p for n, p in hnet.named_parameters() if "embed" in n],
                    "lr": embed_lr,
                },
            ],
            lr=lr,
            momentum=0.9,
            weight_decay=wd,
        ),
        "adam": torch.optim.Adam(params=hnet.parameters(), lr=lr),
    }
    optimizer = optimizers[optim]
    criteria = torch.nn.CrossEntropyLoss()

    ################
    # init metrics #
    ################
    last_eval = -1
    best_step = -1
    best_acc = -1
    test_best_based_on_step, test_best_min_based_on_step = -1, -1
    test_best_max_based_on_step, test_best_std_based_on_step = -1, -1
    # steps = 1
    step_iter = trange(steps)

    results = defaultdict(list)
    for step in step_iter:
        hnet.train()

        # select client at random
        node_id = random.choice(range(num_nodes))

        # produce & load local network weights
        weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
        net.load_state_dict(weights)

        # init inner optimizer
        inner_optim = torch.optim.SGD(
            net.parameters(), lr=inner_lr, momentum=0.9, weight_decay=inner_wd
        )

        # storing theta_i for later calculating delta theta
        inner_state = OrderedDict({k: tensor for k, tensor in weights.items()})

        # NOTE: evaluation on sent model
        with torch.no_grad():
            net.eval()
            batch = next(iter(nodes.test_loaders[node_id]))
            img, label = tuple(t.to(device) for t in batch)
            net_out = net(img)
            pred = nodes.local_layers[node_id](net_out)
            prvs_loss = criteria(pred, label)
            prvs_acc = pred.argmax(1).eq(label).sum().item() / len(label)
            net.train()

        # inner updates -> obtaining theta_tilda
        for i in range(inner_steps):
            net.train()
            inner_optim.zero_grad()
            optimizer.zero_grad()
            nodes.local_optimizers[node_id].zero_grad()

            batch = next(iter(nodes.train_loaders[node_id]))
            img, label = tuple(t.to(device) for t in batch)

            net_out = net(img)
            pred = nodes.local_layers[node_id](net_out)

            loss = criteria(pred, label)
            loss.backward()


            torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
            inner_optim.step()
            nodes.local_optimizers[node_id].step()

        optimizer.zero_grad()

        final_state = net.state_dict()

        # Calculating MSE Loss for the predicted HyperNetwork weights
        hn_loss = 0.0
        for key in weights.keys():
            weight_loss = nn.MSELoss(reduction='sum')(inner_state[key], final_state[key])
            hn_loss += weight_loss

        # hn_loss /= 2
        # print(f"Dividing by 2")
        hn_loss.backward()

        grad_clones = defaultdict(str)
        # print("\nHello")
        for name, param in hnet.named_parameters():
            # print(f"{name}: {param.grad.shape}")
            grad_clones[name] = param.clone()
        
        # optimizer.zero_grad()

        # # calculating delta theta
        # delta_theta = OrderedDict(
        #     {k: inner_state[k] - final_state[k] for k in weights.keys()}
        # )

        # # calculating phi gradient
        # hnet_grads = torch.autograd.grad(
        #     list(weights.values()),
        #     hnet.parameters(),
        #     grad_outputs=list(delta_theta.values()),
        # )

        # # update hnet weights
        # for p, g in zip(hnet.parameters(), hnet_grads):
        #     p.grad = g

       ########## 
       # Printing the difference
       ########## 
        # print("\nDifferences:")
        # for name, params in hnet.named_parameters():
        #     new_grad = params.grad
        #     old_grad = grad_clones[name]
        #     err = rel_error(new_grad, old_grad)
        #     print(f"{name}: {err}")

        torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
        optimizer.step()

        step_iter.set_description(
            f"Step: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}"
        )

        if step % eval_every == 0:
            last_eval = step
            step_results, avg_loss, avg_acc, all_acc = eval_model(
                nodes, num_nodes, hnet, net, criteria, device, split="test"
            )

            logging.info(f"\n\nStep: {step+1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f} | HN Loss: {hn_loss}")

            results["test_avg_loss"].append(avg_loss)
            results["test_avg_acc"].append(avg_acc)

            _, val_avg_loss, val_avg_acc, _ = eval_model(
                nodes, num_nodes, hnet, net, criteria, device, split="val"
            )
            if best_acc < val_avg_acc:
                best_acc = val_avg_acc
                best_step = step
                test_best_based_on_step = avg_acc
                test_best_min_based_on_step = np.min(all_acc)
                test_best_max_based_on_step = np.max(all_acc)
                test_best_std_based_on_step = np.std(all_acc)

            results["val_avg_loss"].append(val_avg_loss)
            results["val_avg_acc"].append(val_avg_acc)
            results["best_step"].append(best_step)
            results["best_val_acc"].append(best_acc)
            results["best_test_acc_based_on_val_beststep"].append(
                test_best_based_on_step
            )
            results["test_best_min_based_on_step"].append(test_best_min_based_on_step)
            results["test_best_max_based_on_step"].append(test_best_max_based_on_step)
            results["test_best_std_based_on_step"].append(test_best_std_based_on_step)

            # weights_dict = defaultdict(int)
            # for name, param in hnet.named_parameters():
            #     weights_dict[name] = param.detach().norm().item()
            #     print(f"{name}: {weights_dict[name]}")
        
        # Wandb logging
        wandb_dict = defaultdict(int)
        wandb_dict["step"] = step
        wandb_dict["hn_loss"] = hn_loss.detach().item()
        for key, value_list in results.items():
            wandb_dict[key] = value_list[-1]
        # wandb.log(wandb_dict)

    if step != last_eval:
        _, val_avg_loss, val_avg_acc, _ = eval_model(
            nodes, num_nodes, hnet, net, criteria, device, split="val"
        )
        step_results, avg_loss, avg_acc, all_acc = eval_model(
            nodes, num_nodes, hnet, net, criteria, device, split="test"
        )
        logging.info(
            f"\nStep: {step + 1}, AVG Loss: {avg_loss:.4f},  AVG Acc: {avg_acc:.4f}"
        )

        results["test_avg_loss"].append(avg_loss)
        results["test_avg_acc"].append(avg_acc)

        if best_acc < val_avg_acc:
            best_acc = val_avg_acc
            best_step = step
            test_best_based_on_step = avg_acc
            test_best_min_based_on_step = np.min(all_acc)
            test_best_max_based_on_step = np.max(all_acc)
            test_best_std_based_on_step = np.std(all_acc)

        results["val_avg_loss"].append(val_avg_loss)
        results["val_avg_acc"].append(val_avg_acc)
        results["best_step"].append(best_step)
        results["best_val_acc"].append(best_acc)
        results["best_test_acc_based_on_val_beststep"].append(test_best_based_on_step)
        results["test_best_min_based_on_step"].append(test_best_min_based_on_step)
        results["test_best_max_based_on_step"].append(test_best_max_based_on_step)
        results["test_best_std_based_on_step"].append(test_best_std_based_on_step)

    save_path = Path(save_path)
    save_path.mkdir(parents=True, exist_ok=True)
    with open(str(save_path / "results.json"), "w") as file:
        json.dump(results, file, indent=4)


In [18]:
parser = argparse.ArgumentParser(
    description="Federated Hypernetwork with local layers experiment"
)

#############################
#       Dataset Args        #
#############################
parser.add_argument(
    "--data-name", type=str, default="cifar10", choices=['cifar10', 'cifar100'], help="data name"
)
parser.add_argument("--data-path", type=str, default='data', help='data path')
parser.add_argument("--num-nodes", type=int, default=50)

##################################
#       Optimization args        #
##################################
parser.add_argument("--num-steps", type=int, default=5000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--inner-steps", type=int, default=50, help="number of inner steps")
parser.add_argument("--optim", type=str, default='sgd', choices=['adam', 'sgd'], help="learning rate")

################################
#       Model Prop args        #
################################
parser.add_argument("--n-hidden", type=int, default=3, help="num. hidden layers")
parser.add_argument("--inner-lr", type=float, default=5e-3, help="learning rate for inner optimizer")
parser.add_argument("--lr", type=float, default=5e-2, help="learning rate")
parser.add_argument("--wd", type=float, default=1e-3, help="weight decay")
parser.add_argument("--inner-wd", type=float, default=5e-5, help="inner weight decay")
parser.add_argument("--embed-dim", type=int, default=-1, help="embedding dim")
parser.add_argument("--embed-lr", type=float, default=None, help="embedding learning rate")
parser.add_argument("--hyper-hid", type=int, default=100, help="hypernet hidden dim")
parser.add_argument("--spec-norm", type=str2bool, default=False, help="hypernet hidden dim")
parser.add_argument("--nkernels", type=int, default=16, help="number of kernels for cnn model")

#############################
#       General args        #
#############################
parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
parser.add_argument("--eval-every", type=int, default=30, help="eval every X selected epochs")
parser.add_argument("--save-path", type=str, default="pfedhn_pc_cifar_res", help="dir path for output file")
parser.add_argument("--seed", type=int, default=42, help="seed value")

args = parser.parse_args({})
assert args.gpu <= torch.cuda.device_count(), f"--gpu flag should be in range [0,{torch.cuda.device_count() - 1}]"
args.gpu = 1

set_logger()
set_seed(args.seed)

device = get_device(gpus=args.gpu)

if args.data_name == 'cifar10':
    args.classes_per_node = 2
else:
    args.classes_per_node = 10

## Manual Gradients

In [11]:
train(
    data_name=args.data_name,
    data_path=args.data_path,
    classes_per_node=args.classes_per_node,
    num_nodes=args.num_nodes,
    steps=args.num_steps,
    inner_steps=args.inner_steps,
    optim=args.optim,
    lr=args.lr,
    inner_lr=args.inner_lr,
    embed_lr=args.embed_lr,
    wd=args.wd,
    inner_wd=args.inner_wd,
    embed_dim=args.embed_dim,
    hyper_hid=args.hyper_hid,
    n_hidden=args.n_hidden,
    n_kernels=args.nkernels,
    bs=args.batch_size,
    device=device,
    eval_every=args.eval_every,
    save_path=args.save_path,
)

KeyboardInterrupt: 

## MSE Loss

In [22]:
train_mse(
    data_name=args.data_name,
    data_path=args.data_path,
    classes_per_node=args.classes_per_node,
    num_nodes=args.num_nodes,
    steps=args.num_steps,
    inner_steps=args.inner_steps,
    optim=args.optim,
    lr=args.lr,
    inner_lr=args.inner_lr,
    embed_lr=args.embed_lr,
    wd=args.wd,
    inner_wd=args.inner_wd,
    embed_dim=args.embed_dim,
    hyper_hid=args.hyper_hid,
    n_hidden=args.n_hidden,
    n_kernels=args.nkernels,
    bs=args.batch_size,
    device=device,
    eval_every=args.eval_every,
    save_path=args.save_path,
)

Files already downloaded and verified
Files already downloaded and verified


2024-01-27 18:07:34,898 - root - INFO - auto embedding size
Step: 1, Node ID: 26, Loss: 2.5047,  Acc: 0.0156:   0%|          | 0/5000 [00:01<?, ?it/s]2024-01-27 18:07:39,940 - root - INFO - 

Step: 1, AVG Loss: 2.3861,  AVG Acc: 0.0944 | HN Loss: 1.178234577178955
Step: 31, Node ID: 22, Loss: 2.5234,  Acc: 0.0000:   1%|          | 30/5000 [00:53<2:02:16,  1.48s/it]2024-01-27 18:08:31,350 - root - INFO - 

Step: 31, AVG Loss: 1.5649,  AVG Acc: 0.3725 | HN Loss: 0.5831138491630554
Step: 61, Node ID: 46, Loss: 4.2321,  Acc: 0.0000:   1%|          | 60/5000 [01:43<1:58:51,  1.44s/it]2024-01-27 18:09:22,144 - root - INFO - 

Step: 61, AVG Loss: 1.2509,  AVG Acc: 0.5719 | HN Loss: 0.8777075409889221
Step: 62, Node ID: 40, Loss: 0.5787,  Acc: 0.7500:   1%|          | 62/5000 [01:53<2:30:08,  1.82s/it]


KeyboardInterrupt: 