In [1]:
import os, sys
PROJECT_ROOT = os.path.abspath(os.path.dirname(sys.path[0]))
sys.path.append(PROJECT_ROOT)

from lsr_tensor import *
from lsr_bcd_regression import *
import torch
import torch.nn.functional as f
from datasets import *
from federated_algos import *
from matplotlib import pyplot as plt
import numpy as np
from federated_tests import *
from medmnist import VesselMNIST3D
import cProfile

In [2]:
# Tests
shape, ranks, separation_rank = (16, 16, 16), (2, 2, 2), 2
x_stdev = 1
y_stdev = 0.05
sample_size = 2000
val_sample_size = int(sample_size * 0.1)

with torch.no_grad():
    true_lsr = LSR_tensor_dot(shape, ranks, separation_rank)
    f.normalize(true_lsr.core_tensor, p=2, dim=0, out=true_lsr.core_tensor)
    true_lsr.core_tensor *= (5 / torch.sqrt(torch.sqrt(torch.prod(torch.tensor(ranks)))))
    
dataset = synthesize_data(true_lsr, sample_size, shape, x_stdev, y_stdev)
val_dataset = synthesize_data(true_lsr, val_sample_size, shape, x_stdev, y_stdev)
client_datasets = federate_dataset(dataset, 5)

In [None]:
# Vessel MNIST 3D
shape, ranks, separation_rank = (28, 28, 28), (3, 3, 3), 2
vessel_dataset = VesselMNIST3D(split="train", download=True)
vessel_client_datasets = federate_dataset(vessel_dataset, 5)
vessel_val_dataset = VesselMNIST3D(split="val", download=True)

In [3]:
print("Stepwise federated algorithm training...")
hypers = {"max_iter": 10, "batch_size": None, "lr": 0.001, "momentum": 0.9, "steps": 10}
loss_fn = f.mse_loss
aggregator_fn = avg_aggregation
lsr_dot_params = (shape, ranks, separation_rank, torch.float32, torch.device('cpu'))
args = (BCD_federated_stepwise, lsr_dot_params, client_datasets, val_dataset,\
        hypers, loss_fn, aggregator_fn, False)

avg_loss, error = run_test(1, 1, *args)

Stepwise federated algorithm training...
Run 0
0 0
0 1


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
print("Stepwise federated algorithm training...")
hypers = {"max_iter": 100, "batch_size": None, "lr": 0.001, "momentum": 0.9, "steps": 10}
loss_fn = logistic_loss
aggregator_fn = avg_aggregation
lsr_dot_params = (shape, ranks, separation_rank, torch.double, torch.device('cpu'))
args = (BCD_federated_stepwise, lsr_dot_params, vessel_client_datasets, vessel_val_dataset,\
        hypers, loss_fn, aggregator_fn, False)

avg_loss, error = run_test(1, 1, *args)

In [None]:
plt.plot(np.arange(len(avg_loss)) + 1, avg_loss, label="Stepwise", color="#23648FFF")
plt.fill_between(np.arange(len(avg_loss)) + 1, avg_loss-error, avg_loss+error, color="#23648FFF", edgecolor="none", alpha=0.5)
plt.legend()
plt.xlabel("Iteration #")
plt.ylabel("Loss")
plt.title("Validation Loss over # of Iterations")
plt.show()

In [None]:
print("Stepwise federated algorithm training...")
hypers = {"max_iter": 10, "batch_size": None, "lr": 0.001, "momentum": 0.9, "steps": 10}
loss_fn = f.mse_loss
aggregator_fn = avg_aggregation
init_lsr_dot = LSR_tensor_dot(shape, ranks, separation_rank)

_, stepwise_loss = BCD_federated_stepwise(init_lsr_dot, client_datasets, val_dataset, hypers, loss_fn, aggregator_fn, verbose=True)

print("\nSplit factors + core federated algorithm training...")
hypers = {"max_iter": 50, "batch_size": None, "lr": 0.001, "momentum": 0.9, "steps": 10}
loss_fn = f.mse_loss
aggregator_fn = avg_aggregation
init_lsr_dot = LSR_tensor_dot(shape, ranks, separation_rank)

_, full_factors_loss = BCD_federated_all_factors(init_lsr_dot, client_datasets, val_dataset, hypers, loss_fn, aggregator_fn, verbose=True, ortho_iteratively=True)

print("\n1 full iteration federated algorithm training...")
hypers = {"max_rounds": 50, "max_iter": 1, "batch_size": None, "lr": 0.001, "momentum": 0.9, "steps": 10}
loss_fn = f.mse_loss
aggregator_fn = avg_aggregation
init_lsr_dot = LSR_tensor_dot(shape, ranks, separation_rank)

_, full_1_iter_loss = BCD_federated_full_iteration(init_lsr_dot, client_datasets, val_dataset, hypers, loss_fn, aggregator_fn, verbose=True)

print("\n5 full iteration federated algorithm training...")
hypers = {"max_rounds": 10, "max_iter": 5, "batch_size": None, "lr": 0.001, "momentum": 0.9, "steps": 10}
loss_fn = f.mse_loss
aggregator_fn = avg_aggregation
init_lsr_dot = LSR_tensor_dot(shape, ranks, separation_rank)

_, full_5_iter_loss = BCD_federated_full_iteration(init_lsr_dot, client_datasets, val_dataset, hypers, loss_fn, aggregator_fn, verbose=True)

In [None]:
init_lsr_dot = LSR_tensor_dot(shape, ranks, separation_rank)
_, unfederated_diag = lsr_bcd_regression(f.mse_loss, dataset, init_lsr_dot, lr=0.001, momentum=0.9,\
                                        step_epochs=10, max_iter=50, batch_size=None, threshold=1e-4, init_zero=False, ortho=True,\
                                        verbose=True, true_param=None, val_dataset=val_dataset)
unfederated_loss = unfederated_diag["val_loss"]

In [None]:
plt.plot(np.arange(len(unfederated_loss)) + 1, unfederated_loss, label="Unfederated")
plt.plot(np.arange(len(stepwise_loss)) + 1, stepwise_loss, label="Stepwise")
plt.plot(np.arange(len(full_factors_loss)) + 1, full_factors_loss, label="Full Factors + Core")
plt.plot(np.arange(len(full_1_iter_loss)) + 1, full_1_iter_loss, label="Full 1 Iteration")
plt.plot((np.arange(len(full_5_iter_loss)) + 1)*5, full_5_iter_loss, label="Full 5 Iterations")
plt.legend()
plt.xlabel("Iteration #")
plt.ylabel("Loss")
plt.title("Validation Loss over # of Iterations")
plt.show()

In [None]:
# Performance testing
print("Stepwise federated algorithm training...")
hypers = {"max_iter": 100, "batch_size": None, "lr": 0.001, "momentum": 0.9, "steps": 10}
loss_fn = f.mse_loss
aggregator_fn = avg_aggregation
lsr_dot_params = (shape, ranks, separation_rank, torch.float32, torch.device('cuda'))
init_lsr_dot = LSR_tensor_dot(*lsr_dot_params)
cProfile.run("BCD_federated_stepwise(init_lsr_dot, client_datasets, val_dataset,\
              hypers, loss_fn, aggregator_fn, False)", sort='tottime')