diff --git a/.github/workflows/setup.yml b/.github/workflows/setup.yml index cdbbc11..5379770 100644 --- a/.github/workflows/setup.yml +++ b/.github/workflows/setup.yml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.11"] + python-version: ["3.12"] steps: - uses: actions/checkout@v4 @@ -30,24 +30,14 @@ jobs: - name: Run tests run: | python -m unittest - - name: Install benchmark dependencies - run: | - pip install -r requirements_benchmark.txt - pip install -e . - - name: Run algorithms - run: | - python experiments/run_folktables.py alg=sslalm n_runs=2 run_maxtime=2 - python experiments/run_folktables.py alg=ghost n_runs=2 run_maxtime=2 - python experiments/run_folktables.py alg=alm n_runs=2 run_maxtime=2 - python experiments/run_folktables.py alg=sgd n_runs=2 run_maxtime=2 - + run-on-windows: name: Setup on windows runs-on: windows-latest strategy: fail-fast: false matrix: - python-version: ["3.11"] + python-version: ["3.12"] steps: - uses: actions/checkout@v4 @@ -63,16 +53,6 @@ jobs: - name: Run tests run: | python -m unittest - - name: Install benchmark dependencies - run: | - pip install -r requirements_benchmark.txt - pip install -e . - - name: Run algorithms - run: | - python experiments/run_folktables.py alg=sslalm n_runs=2 run_maxtime=2 - python experiments/run_folktables.py alg=ghost n_runs=2 run_maxtime=2 - python experiments/run_folktables.py alg=alm n_runs=2 run_maxtime=2 - python experiments/run_folktables.py alg=sgd n_runs=2 run_maxtime=2 run-on-macos: name: Setup on macos @@ -80,7 +60,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.11"] + python-version: ["3.12"] steps: - uses: actions/checkout@v4 @@ -95,15 +75,4 @@ jobs: pip install -e . - name: Run tests run: | - python -m unittest - # - name: Install benchmark dependencies - # run: | - # pip install -U --force-reinstall certifi - # pip install -r requirements_benchmark.txt - # pip install -e . - # - name: Run algorithms - # run: | - # python experiments/run_folktables.py alg=sslalm n_runs=2 run_maxtime=2 - # python experiments/run_folktables.py alg=ghost n_runs=2 run_maxtime=2 - # python experiments/run_folktables.py alg=alm n_runs=2 run_maxtime=2 - # python experiments/run_folktables.py alg=sgd n_runs=2 run_maxtime=2 + python -m unittest \ No newline at end of file diff --git a/.gitignore b/.gitignore index a6272e4..8f0bd26 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ requirements_rci.txt benchmark/results benchmark/cache benchmark/data +data/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/PDEs/AL-PINNs b/PDEs/AL-PINNs new file mode 160000 index 0000000..df4378c --- /dev/null +++ b/PDEs/AL-PINNs @@ -0,0 +1 @@ +Subproject commit df4378cb7688d9a6e7ee8d97850e7a92fbbe0192 diff --git a/PDEs/Helmholtz/Helmholtz.py b/PDEs/Helmholtz/Helmholtz.py new file mode 100644 index 0000000..8f66813 --- /dev/null +++ b/PDEs/Helmholtz/Helmholtz.py @@ -0,0 +1,270 @@ +from tqdm import tqdm +import pickle as pkl +import numpy as np +import copy +import sys +import argparse +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +from humancompatible.train.dual_optim import ALM, MoreauEnvelope, PBM + +from networks import set_model, u_Net_shallow_wide, u_Net_shallow_wide_resnet, u_Net_deep_narrow, u_Net_deep_narrow_resnet + +# Equation parameter +k, a1, a2 = 1, 1, 4 + +def q(data) : + x, y = data[:,0].view(-1,1), data[:,1].view(-1,1) + return -((a1*np.pi)**2)*torch.sin(a1*np.pi*x)*torch.sin(a2*np.pi*y) \ + -((a2*np.pi)**2)*torch.sin(a1*np.pi*x)*torch.sin(a2*np.pi*y) \ + +(k**2)*torch.sin(a1*np.pi*x)*torch.sin(a2*np.pi*y) + +def analytic(data) : + x, y = data[:,0].view(-1,1), data[:,1].view(-1,1) + return torch.sin(a1*np.pi*x)*torch.sin(a2*np.pi*y) + +def calculate_derivative(y, x) : + return torch.autograd.grad(y, x, create_graph=True,\ + grad_outputs=torch.ones(y.size()).to(device))[0] + + +def calculate_all_partial(u, x) : + del_u = calculate_derivative(u, x) + u_x, u_y = del_u[:,0], del_u[:,1] + u_xx = calculate_derivative(u_x, x)[:,0] + u_yy = calculate_derivative(u_y, x)[:,1] + return u_xx.view(-1,1), u_yy.view(-1,1) + + +def train(u_model, beta, trainloader, bdry_data, val_test, optimizer, loss_f, dual_opt=None) : + loss_list, loss_list1, loss_list2, val_list, test_list = [], [], [], [], [] + X_bdry, u_bdry = bdry_data + X_val, y_val, X_test, y_test = val_test + + for i, (data,) in enumerate(trainloader) : + u_model.train() + optimizer.zero_grad() + X_v = Variable(data, requires_grad=True).to(device) + output = u_model(X_v) + output_bdry = u_model(X_bdry) + + u_xx, u_yy = calculate_all_partial(output, X_v) + loss1 = loss_f(u_xx + u_yy + (k**2)*output - q(X_v), torch.zeros_like(output)) + constraint = loss_f(output_bdry, torch.zeros_like(output_bdry)) + + # adam optimizer + if dual_opt is None : + loss = loss1 + beta*constraint + loss.backward() + optimizer.step() + optimizer.zero_grad() + + elif dual_opt is not None: + threshold = 0.01 + constraint = constraint - threshold + + # compute the lagrangian value + lagrangian = dual_opt.forward_update(loss1, constraint.unsqueeze(0)) + lagrangian.backward() + optimizer.step() + optimizer.zero_grad() + + + u_model.eval() + val_err = torch.linalg.norm((u_model(X_val) - y_val),2).item() / torch.linalg.norm(y_val,2).item() + test_err = torch.linalg.norm((u_model(X_test) - y_test),2).item() / torch.linalg.norm(y_test,2).item() + + loss_list.append((loss1+constraint).item()) + loss_list1.append(loss1.item()) + loss_list2.append(constraint.item()) + val_list.append(val_err) + test_list.append(test_err) + + + + +def main_function(model_name, beta, lr, EPOCH, device) : + + # Dataset Creation + xmin, xmax = -1,1 + ymin, ymax = -1,1 + Nx, Ny = 51, 51 + X_train = torch.FloatTensor(np.mgrid[xmin:xmax:51j, ymin:ymax:51j].reshape(2, -1).T).to(device) + + # Boundary Conditions + X_bdry = X_train[(X_train[:,0]==xmin) + (X_train[:,0]==xmax) + (X_train[:,1]==ymin) + (X_train[:,1]==ymax)] + u_bdry = torch.zeros_like(X_bdry[:,0]).to(device).view(-1,1) + + X_test, y_test, X_val, y_val= torch.load('./PDEs/Helmholtz/Helmholtz_test', map_location=device) + + # take 1000 samples from the validation set + idx = np.random.choice(X_val.shape[0], 1000, replace=False) + X_val = X_val[idx] + y_val = y_val[idx] + + print(X_train.shape) + print(X_val.shape) + exit() + + # Make dataloader + data_train = TensorDataset(X_train) + train_loader = DataLoader(data_train, batch_size=10000, shuffle=False) + + # train + torch.manual_seed(0) + total_loss, test_errs, val_errs, constraints = [], [], [], [] + u_model = set_model(model_name, device) + optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=lr) + best_model = copy.deepcopy(u_model) + + # unconstrained ADAM + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, val_err, test_err = train(u_model, beta, trainloader=train_loader,\ + bdry_data=[X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss()) + + val_errs.append(val_err) + test_errs.append(test_err) + total_loss.append(loss) + constraints.append(loss2) + + #Print Log + if t%100 == 0 : + print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ + (t, EPOCH, loss, loss1, loss2, val_err, test_err)) + + if np.argmin(val_errs) == t : + best_model = copy.deepcopy(u_model) + + + # SPBM + torch.manual_seed(0) + total_loss_spbm, test_errs_spbm, val_errs_spbm, constraints_spbm = [], [], [], [] + u_model = set_model(model_name, device) + + # Define data and optimizers + optimizer = MoreauEnvelope(torch.optim.Adam([{'params': u_model.parameters()}], lr=lr), mu=2.0) + + dual = PBM( + m=1, + # penalty_update='dimin', + # penalty_update='dimin_adapt', + penalty_update='const', + pbf = 'quadratic_logarithmic', + gamma=0.1, + init_duals=0.01, + init_penalties=1., + penalty_range=(0.5, 1.), + penalty_mult=0.99, + dual_range=(0.01, 100.), + delta=1.0, + device=device + ) + + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, val_err, test_err = train(u_model, beta, trainloader=train_loader,\ + bdry_data=[X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss(), + dual_opt=dual) + + val_errs_spbm.append(val_err) + test_errs_spbm.append(test_err) + total_loss_spbm.append(loss) + constraints_spbm.append(loss2) + + #Print Log + if t%100 == 0 : + print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ + (t, EPOCH, loss, loss1, loss2, val_err, test_err)) + + if np.argmin(val_errs_spbm) == t : + best_model = copy.deepcopy(u_model) + + # ALM + torch.manual_seed(0) + total_loss_alm, test_errs_alm, val_errs_alm, constraints_alm = [], [], [], [] + u_model = set_model(model_name, device) + + # Define data and optimizers + optimizer = MoreauEnvelope(torch.optim.Adam([{'params': u_model.parameters()}], lr=lr), mu=2.0) + + dual = ALM( + m=1, + lr=0.1, + momentum=0.5, + device=device + ) + + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, val_err, test_err = train(u_model, beta, trainloader=train_loader,\ + bdry_data=[X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss(), + dual_opt=dual) + + val_errs_alm.append(val_err) + test_errs_alm.append(test_err) + total_loss_alm.append(loss) + constraints_alm.append(loss2) + + #Print Log + if t%100 == 0 : + print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ + (t, EPOCH, loss, loss1, loss2, val_err, test_err)) + + if np.argmin(val_errs_spbm) == t : + best_model = copy.deepcopy(u_model) + + + # plot the resultsimport matplotlib.pyplot as plt + import matplotlib.pyplot as plt + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) # wider figure + + axes[0].plot(total_loss, label='Adam') + axes[0].plot(total_loss_spbm, label='SPBM') + axes[0].plot(total_loss_alm, label='ALM') + axes[0].set_xlabel('Epoch') + axes[0].set_ylabel('Train Total Loss') + axes[0].legend() + + axes[1].plot(test_errs, label='Adam') + axes[1].plot(test_errs_spbm, label='SPBM') + axes[1].plot(test_errs_alm, label='ALM') + axes[1].set_xlabel('Epoch') + axes[1].set_ylabel('Test Error') + axes[1].legend() + + axes[2].plot(constraints, label='Adam') + axes[2].plot(constraints_spbm, label='SPBM') + axes[2].plot(constraints_alm, label='ALM') + axes[2].set_xlabel('Epoch') + axes[2].set_ylabel('Boundary Constraint Violation') + axes[2].legend() + + plt.tight_layout(pad=2.0) # extra padding between subplots + plt.savefig('./PDEs/Helmholtz/results.png', bbox_inches='tight') + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='deep_narrow', help='Specify the model. Choose one of [deep_narrow, shallow_wide, deep_narrow_resent, shallow_wide_resnet].') + parser.add_argument('--beta', default=1, type=float, help='Penalty parameter beta') + parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate') + parser.add_argument('--EPOCH', default=4000, type=int, help='Number of training EPOCH') + parser.add_argument('--ordinal', default=0, type=int, help='Specify the cuda device ordinal.') + args = parser.parse_args() + + device = torch.device("cuda:{}".format(args.ordinal) if torch.cuda.is_available() else "cpu") + + main_function(args.model, args.beta, args.lr, args.EPOCH, device) + \ No newline at end of file diff --git a/PDEs/Helmholtz/Helmholtz_AL-PINNs.py b/PDEs/Helmholtz/Helmholtz_AL-PINNs.py new file mode 100644 index 0000000..41c3e15 --- /dev/null +++ b/PDEs/Helmholtz/Helmholtz_AL-PINNs.py @@ -0,0 +1,143 @@ +from tqdm import tqdm +import pickle as pkl +import numpy as np +import copy +import argparse +import sys +sys.path.append("..") + +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +from networks import set_model, u_Net_shallow_wide, u_Net_shallow_wide_resnet, u_Net_deep_narrow, u_Net_deep_narrow_resnet + +# Equation parameter +k, a1, a2 = 1, 1, 4 + +def q(data) : + x, y = data[:,0].view(-1,1), data[:,1].view(-1,1) + return -((a1*np.pi)**2)*torch.sin(a1*np.pi*x)*torch.sin(a2*np.pi*y) \ + -((a2*np.pi)**2)*torch.sin(a1*np.pi*x)*torch.sin(a2*np.pi*y) \ + +(k**2)*torch.sin(a1*np.pi*x)*torch.sin(a2*np.pi*y) + +def analytic(data) : + x, y = data[:,0].view(-1,1), data[:,1].view(-1,1) + return torch.sin(a1*np.pi*x)*torch.sin(a2*np.pi*y) + +def calculate_derivative(y, x) : + return torch.autograd.grad(y, x, create_graph=True,\ + grad_outputs=torch.ones(y.size()).to(device))[0] + + +def calculate_all_partial(u, x) : + del_u = calculate_derivative(u, x) + u_x, u_y = del_u[:,0], del_u[:,1] + u_xx = calculate_derivative(u_x, x)[:,0] + u_yy = calculate_derivative(u_y, x)[:,1] + return u_xx.view(-1,1), u_yy.view(-1,1) + + + +def train(u_model, beta, lbd, trainloader, bdry_data, val_test, optimizer, loss_f) : + loss_list, loss_list1, loss_list2, val_list, test_list = [], [], [], [], [] + X_bdry, u_bdry = bdry_data + X_val, y_val, X_test, y_test = val_test + + for i, (data,) in enumerate(trainloader) : + u_model.train() + optimizer.zero_grad() + X_v = Variable(data, requires_grad=True).to(device) + output = u_model(X_v) + output_bdry = u_model(X_bdry) + + u_xx, u_yy = calculate_all_partial(output, X_v) + loss1 = loss_f(u_xx + u_yy + (k**2)*output - q(X_v), torch.zeros_like(output)) + loss2 = loss_f(output_bdry, torch.zeros_like(output_bdry)) + + loss = loss1 + beta*loss2 + (lbd*output_bdry.view(-1)).mean() + loss.backward() + lbd.grad *= -1 + + optimizer.step() + + u_model.eval() + val_err = torch.linalg.norm((u_model(X_val) - y_val),2).item() / torch.linalg.norm(y_val,2).item() + test_err = torch.linalg.norm((u_model(X_test) - y_test),2).item() / torch.linalg.norm(y_test,2).item() + + loss_list.append((loss1+loss2).item()) + loss_list1.append(loss1.item()) + loss_list2.append(loss2.item()) + val_list.append(val_err) + test_list.append(test_err) + + return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2), np.mean(val_list), np.mean(test_list) + + +def main_function(model_name, beta, lr, lbd_lr, EPOCH, device) : + + # Dataset Creation + xmin, xmax = -1,1 + ymin, ymax = -1,1 + Nx, Ny = 51, 51 + X_train = torch.FloatTensor(np.mgrid[xmin:xmax:51j, ymin:ymax:51j].reshape(2, -1).T).to(device) + + # Boundary Conditions + X_bdry = X_train[(X_train[:,0]==xmin) + (X_train[:,0]==xmax) + (X_train[:,1]==ymin) + (X_train[:,1]==ymax)] + u_bdry = torch.zeros_like(X_bdry[:,0]).to(device).view(-1,1) + + X_test, y_test, X_val, y_val= torch.load('Helmholtz_test', map_location=device) + + # Make dataloader + data_train = TensorDataset(X_train) + train_loader = DataLoader(data_train, batch_size=10000, shuffle=False) + + # train + total_loss, test_errs, val_errs = [], [], [] + u_model = set_model(model_name, device) + lbd = Variable(torch.FloatTensor([0]*X_bdry.size()[0]).to(device), requires_grad=True) + + optimizer=torch.optim.Adam([{'params': u_model.parameters()}, {'params': lbd, 'lr':lbd_lr}], lr=lr) + best_model = copy.deepcopy(u_model) + + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, val_err, test_err = train(u_model, beta, lbd, trainloader=train_loader,\ + bdry_data=[X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss()) + + val_errs.append(val_err) + test_errs.append(test_err) + total_loss.append(loss) + +# # Print Log +# if t%100 == 0 : +# print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ +# (t, EPOCH, loss, loss1, loss2, val_err, test_err)) + + if np.argmin(val_errs) == t : + best_model = copy.deepcopy(u_model) + + return best_model, total_loss, val_errs, test_errs + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='deep_narrow', help='Specify the model. Choose one of [deep_narrow, shallow_wide, deep_narrow_resent, shallow_wide_resnet].') + parser.add_argument('--beta', default=1000, type=float, help='Penalty parameter beta') + parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate') + parser.add_argument('--lbd_lr', default=1, type=float, help='Learning rate for lambda') + parser.add_argument('--EPOCH', default=10000, type=int, help='Number of training EPOCH') + parser.add_argument('--ordinal', default=0, type=int, help='Specify the cuda device ordinal.') + args = parser.parse_args() + + device = torch.device("cuda:{}".format(args.ordinal) if torch.cuda.is_available() else "cpu") + + best_model, total_loss, val_errs, test_errs = main_function(args.model, args.beta, args.lr, args.lbd_lr, args.EPOCH, device) + print('Best Test Error : ', test_errs[np.argmin(val_errs)]) + \ No newline at end of file diff --git a/PDEs/Helmholtz/Helmholtz_test b/PDEs/Helmholtz/Helmholtz_test new file mode 100644 index 0000000..a572d05 Binary files /dev/null and b/PDEs/Helmholtz/Helmholtz_test differ diff --git a/PDEs/Helmholtz/networks.py b/PDEs/Helmholtz/networks.py new file mode 100644 index 0000000..0ccef8b --- /dev/null +++ b/PDEs/Helmholtz/networks.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def set_model(model, device) : + if model =='deep_narrow_resnet' : + u_model = u_Net_deep_narrow_resnet().to(device) + elif model == 'shallow_wide_resnet' : + u_model = u_Net_shallow_wide_resnet().to(device) + elif model == 'deep_narrow' : + u_model = u_Net_deep_narrow().to(device) + elif model == 'shallow_wide' : + u_model = u_Net_shallow_wide().to(device) + return u_model + +class u_Net_shallow_wide(nn.Module): + def __init__(self): + super(u_Net_shallow_wide, self).__init__() + self.fc1 = nn.Linear(2, 256) + self.fc2 = nn.Linear(256, 256) +# self.fc3 = nn.Linear(256, 256) +# self.fc4 = nn.Linear(256, 256) + self.fc5 = nn.Linear(256, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x)) +# x = self.act1(self.fc3(x)) +# x = self.act1(self.fc4(x)) + x = self.fc5(x) + return x + + +class u_Net_shallow_wide_resnet(nn.Module): + def __init__(self): + super(u_Net_shallow_wide_resnet, self).__init__() + self.fc1 = nn.Linear(2, 256) + self.fc2 = nn.Linear(256, 256) +# self.fc3 = nn.Linear(256, 256) +# self.fc4 = nn.Linear(256, 256) + self.fc5 = nn.Linear(256, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x))+x +# x = self.act1(self.fc3(x))+x +# x = self.act1(self.fc4(x))+x + x = self.fc5(x) + return x + +class u_Net_deep_narrow(nn.Module): + def __init__(self): + super(u_Net_deep_narrow, self).__init__() + self.fc1 = nn.Linear(2, 64) + self.fc2 = nn.Linear(64, 64) + self.fc3 = nn.Linear(64, 64) + self.fc4 = nn.Linear(64, 64) + self.fc5 = nn.Linear(64, 64) + self.fc6 = nn.Linear(64, 64) + self.fc7 = nn.Linear(64, 64) + self.fc8 = nn.Linear(64, 64) + self.fc9 = nn.Linear(64, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x)) + x = self.act1(self.fc3(x)) + x = self.act1(self.fc4(x)) + x = self.act1(self.fc5(x)) + x = self.act1(self.fc6(x)) + x = self.act1(self.fc7(x)) + x = self.act1(self.fc8(x)) + x = self.fc9(x) + return x + +class u_Net_deep_narrow_resnet(nn.Module): + def __init__(self): + super(u_Net_deep_narrow_resnet, self).__init__() + self.fc1 = nn.Linear(2, 64) + self.fc2 = nn.Linear(64, 64) + self.fc3 = nn.Linear(64, 64) + self.fc4 = nn.Linear(64, 64) + self.fc5 = nn.Linear(64, 64) + self.fc6 = nn.Linear(64, 64) + self.fc7 = nn.Linear(64, 64) + self.fc8 = nn.Linear(64, 64) + self.fc9 = nn.Linear(64, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x))+x + x = self.act1(self.fc3(x))+x + x = self.act1(self.fc4(x))+x + x = self.act1(self.fc5(x))+x + x = self.act1(self.fc6(x))+x + x = self.act1(self.fc7(x))+x + x = self.act1(self.fc8(x))+x + x = self.fc9(x) + return x \ No newline at end of file diff --git a/PDEs/Helmholtz/results.png b/PDEs/Helmholtz/results.png new file mode 100644 index 0000000..f18e26c Binary files /dev/null and b/PDEs/Helmholtz/results.png differ diff --git a/PDEs/Klein-Gordon/Klein-Gordon_AL-PINNs.py b/PDEs/Klein-Gordon/Klein-Gordon_AL-PINNs.py new file mode 100644 index 0000000..9fabcf5 --- /dev/null +++ b/PDEs/Klein-Gordon/Klein-Gordon_AL-PINNs.py @@ -0,0 +1,175 @@ +from tqdm import tqdm +import pickle as pkl +import numpy as np +import copy +import argparse +import sys +sys.path.append("..") + +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +from networks import set_model, u_Net_shallow_wide, u_Net_shallow_wide_resnet, u_Net_deep_narrow, u_Net_deep_narrow_resnet + +# Equation parameter + +k=3 +alpha, delta, gamma = -1, 0, 1 + +def analytic(bdry) : + t, x = bdry[:,0].view(-1,1), bdry[:,1].view(-1,1) + return x*torch.cos(5*np.pi*t) + ((x*t)**3) + +def u_tt(data) : + t, x = data[:,0].view(-1,1), data[:,1].view(-1,1) + return -((5*np.pi)**2)*x*torch.cos(5*np.pi*t) + 6*(x**3)*t + +def u_xx(data) : + t, x = data[:,0].view(-1,1), data[:,1].view(-1,1) + return 6*x*(t**3) + +def u3(data) : + t, x = data[:,0].view(-1,1), data[:,1].view(-1,1) + return (x*torch.cos(5*np.pi*t) + ((x*t)**3))**3 + +def u(data) : + t, x = data[:,0].view(-1,1), data[:,1].view(-1,1) + return x*torch.cos(5*np.pi*t) + ((x*t)**3) + +def f(data) : + return u_tt(data) + alpha*u_xx(data) + delta*u(data) + gamma*u3(data) + +def calculate_derivative(y, x) : + return torch.autograd.grad(y, x, create_graph=True,\ + grad_outputs=torch.ones(y.size()).to(device))[0] + +def calculate_all_partial(u, x) : + del_u = calculate_derivative(u, x) + u_t, u_x = del_u[:,0], del_u[:,1] + u_tt = calculate_derivative(u_t, x)[:,0] + u_xx = calculate_derivative(u_x, x)[:,1] + return u_tt.view(-1,1), u_xx.view(-1,1) + +def train(u_model, lbd1, lbd2, lbd3, beta, trainloader, ini_bdry_data, val_test, optimizer, loss_f) : + loss_list, loss_list1, loss_list2, loss_list3, loss_list4, val_list, test_list = [], [], [], [], [], [], [] + X_ini, u_ini, u_ini_t, X_bdry, u_bdry = ini_bdry_data + X_val, y_val, X_test, y_test = val_test + + for i, (data,) in enumerate(trainloader) : + u_model.train() + optimizer.zero_grad() + X_v = Variable(data, requires_grad=True).to(device) + output = u_model(X_v) + output_ini = u_model(X_ini) + output_ini_t = calculate_derivative(output_ini, X_ini)[:,0].view(-1,1) + output_bdry = u_model(X_bdry) + + u_tt, u_xx = calculate_all_partial(output, X_v) + loss1 = loss_f(u_tt + alpha*u_xx + delta*output + gamma*(output**k) - f(X_v), torch.zeros_like(output)) + loss2 = loss_f(output_ini, u_ini) + loss3 = loss_f(output_ini_t, u_ini_t) + loss4 = loss_f(output_bdry, u_bdry) + + loss = loss1 + beta*loss2 + (lbd1*(output_ini-u_ini).view(-1)).mean() +\ + beta*loss3 + (lbd2*output_ini_t.view(-1)).mean() +\ + beta*loss4 + (lbd3*(output_bdry-u_bdry).view(-1)).mean() + + loss.backward() + + lbd1.grad *= -1 + lbd2.grad *= -1 + lbd3.grad *= -1 + optimizer.step() + + u_model.eval() + val_err = torch.linalg.norm((u_model(X_val) - y_val),2).item() / torch.linalg.norm(y_val,2).item() + test_err = torch.linalg.norm((u_model(X_test) - y_test),2).item() / torch.linalg.norm(y_test,2).item() + + loss_list.append((loss1+loss2+loss3+loss4).item()) + loss_list1.append(loss1.item()) + loss_list2.append(loss2.item()) + loss_list3.append(loss3.item()) + loss_list4.append(loss4.item()) + val_list.append(val_err) + test_list.append(test_err) + + return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2),\ + np.mean(loss_list3), np.mean(loss_list4), np.mean(val_list), np.mean(test_list) + + +def main_function(model_name, beta, lr, lbd_lr, EPOCH, device) : + print(model_name, beta, lr, lbd_lr, EPOCH, device) + # Dataset Creation + tmin, tmax = 0,1 + xmin, xmax = 0,1 + Nt, Nx = 51, 51 + X_train = torch.FloatTensor(np.mgrid[tmin:tmax:51j, xmin:xmax:51j].reshape(2, -1).T).to(device) + + # Initial Conditions + X_ini = Variable(X_train[X_train[:,0]==tmin].to(device), requires_grad=True) + u_ini = X_ini.detach()[:,1].view(-1,1) + u_ini_t = torch.zeros_like(u_ini) + + # Boundary Conditions + X_bdry = X_train[(X_train[:,1]==xmin) + (X_train[:,1]==xmax)] + u_bdry = analytic(X_bdry) + + # Validation & Test Set + X_test, y_test, X_val, y_val= torch.load('Klein-Gordon_test', map_location=device) + + # Make dataloader + data_train = TensorDataset(X_train) + train_loader = DataLoader(data_train, batch_size=10000, shuffle=False) + + # train + total_loss, test_errs, val_errs = [], [], [] + u_model = set_model(model_name, device) + lbd1 = Variable(torch.FloatTensor([0]*X_ini.size()[0]).to(device), requires_grad=True)# + lbd2 = Variable(torch.FloatTensor([0]*X_ini.size()[0]).to(device), requires_grad=True)# + lbd3 = Variable(torch.FloatTensor([0]*X_bdry.size()[0]).to(device), requires_grad=True)# + + optimizer=torch.optim.Adam([{'params': u_model.parameters()}, {'params': lbd1, 'lr':lbd_lr}, \ + {'params': lbd2, 'lr':lbd_lr}, {'params': lbd3, 'lr':lbd_lr}], lr=lr) + best_model = copy.deepcopy(u_model) + + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, loss3, loss4, val_err, test_err = train(u_model, lbd1, lbd2, lbd3, beta, trainloader=train_loader,\ + ini_bdry_data=[X_ini, u_ini, u_ini_t, X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss()) + + val_errs.append(val_err) + test_errs.append(test_err) + total_loss.append(loss) + +# # Print Log +# if t%100 == 0 : +# print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ +# (t, EPOCH, loss, loss1, loss2+loss3+loss4, val_err, test_err)) + + if np.argmin(val_errs) == t : + best_model = copy.deepcopy(u_model) + + return best_model, total_loss, val_errs, test_errs + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='deep_narrow', help='Specify the model. Choose one of [deep_narrow, shallow_wide, deep_narrow_resent, shallow_wide_resnet].') + parser.add_argument('--beta', default=500, type=float, help='Penalty parameter beta') + parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate') + parser.add_argument('--lbd_lr', default=1e-1, type=float, help='Learning rate for lambda') + parser.add_argument('--EPOCH', default=10000, type=int, help='Number of training EPOCH') + parser.add_argument('--ordinal', default=0, type=int, help='Specify the cuda device ordinal.') + args = parser.parse_args() + + device = torch.device("cuda:{}".format(args.ordinal) if torch.cuda.is_available() else "cpu") + + best_model, total_loss, val_errs, test_errs = main_function(args.model, args.beta, args.lr, args.lbd_lr, args.EPOCH, device) + print('Best Test Error : ', test_errs[np.argmin(val_errs)]) \ No newline at end of file diff --git a/PDEs/Klein-Gordon/Klein-Gordon_test b/PDEs/Klein-Gordon/Klein-Gordon_test new file mode 100644 index 0000000..ba06527 Binary files /dev/null and b/PDEs/Klein-Gordon/Klein-Gordon_test differ diff --git a/PDEs/Klein-Gordon/Klein_Gordon.py b/PDEs/Klein-Gordon/Klein_Gordon.py new file mode 100644 index 0000000..3453977 --- /dev/null +++ b/PDEs/Klein-Gordon/Klein_Gordon.py @@ -0,0 +1,286 @@ +from tqdm import tqdm +import pickle as pkl +import numpy as np +import copy +import argparse +import sys +sys.path.append("..") + +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +from humancompatible.train.dual_optim import ALM, MoreauEnvelope, PBM + +from networks import set_model, u_Net_shallow_wide, u_Net_shallow_wide_resnet, u_Net_deep_narrow, u_Net_deep_narrow_resnet + +# Equation parameter + +k=3 +alpha, delta, gamma = -1, 0, 1 + +def analytic(bdry) : + t, x = bdry[:,0].view(-1,1), bdry[:,1].view(-1,1) + return x*torch.cos(5*np.pi*t) + ((x*t)**3) + +def u_tt(data) : + t, x = data[:,0].view(-1,1), data[:,1].view(-1,1) + return -((5*np.pi)**2)*x*torch.cos(5*np.pi*t) + 6*(x**3)*t + +def u_xx(data) : + t, x = data[:,0].view(-1,1), data[:,1].view(-1,1) + return 6*x*(t**3) + +def u3(data) : + t, x = data[:,0].view(-1,1), data[:,1].view(-1,1) + return (x*torch.cos(5*np.pi*t) + ((x*t)**3))**3 + +def u(data) : + t, x = data[:,0].view(-1,1), data[:,1].view(-1,1) + return x*torch.cos(5*np.pi*t) + ((x*t)**3) + +def f(data) : + return u_tt(data) + alpha*u_xx(data) + delta*u(data) + gamma*u3(data) + +def calculate_derivative(y, x) : + return torch.autograd.grad(y, x, create_graph=True,\ + grad_outputs=torch.ones(y.size()).to(device))[0] + +def calculate_all_partial(u, x) : + del_u = calculate_derivative(u, x) + u_t, u_x = del_u[:,0], del_u[:,1] + u_tt = calculate_derivative(u_t, x)[:,0] + u_xx = calculate_derivative(u_x, x)[:,1] + return u_tt.view(-1,1), u_xx.view(-1,1) + + +def train(u_model, beta, trainloader, ini_bdry_data, val_test, optimizer, loss_f, dual_opt=None) : + loss_list, loss_list1, loss_list2, loss_list3, loss_list4, val_list, test_list = [], [], [], [], [], [], [] + X_ini, u_ini, u_ini_t, X_bdry, u_bdry = ini_bdry_data + X_val, y_val, X_test, y_test = val_test + + for i, (data,) in enumerate(trainloader) : + u_model.train() + optimizer.zero_grad() + X_v = Variable(data, requires_grad=True).to(device) + output = u_model(X_v) + output_ini = u_model(X_ini) + output_ini_t = calculate_derivative(output_ini, X_ini)[:,0].view(-1,1) + output_bdry = u_model(X_bdry) + + u_tt, u_xx = calculate_all_partial(output, X_v) + loss1 = loss_f(u_tt + alpha*u_xx + delta*output + gamma*(output**k) - f(X_v), torch.zeros_like(output)) + loss2 = loss_f(output_ini, u_ini) + loss3 = loss_f(output_ini_t, u_ini_t) + loss4 = loss_f(output_bdry, u_bdry) + + if dual_opt is None: + loss = loss1 + beta*loss2 + beta*loss3 + beta*loss4 + loss.backward() + optimizer.step() + elif dual_opt is not None: + threshold = 0.1 + constraints = torch.stack([loss2, loss3, loss4], dim=0) + constraints = constraints - threshold + + # compute the lagrangian value + lagrangian = dual_opt.forward_update(loss1, constraints) + lagrangian.backward() + optimizer.step() + optimizer.zero_grad() + + u_model.eval() + val_err = torch.linalg.norm((u_model(X_val) - y_val),2).item() / torch.linalg.norm(y_val,2).item() + test_err = torch.linalg.norm((u_model(X_test) - y_test),2).item() / torch.linalg.norm(y_test,2).item() + + loss_list.append((loss1+loss2+loss3+loss4).item()) + loss_list1.append(loss1.item()) + loss_list2.append(loss2.item()) + loss_list3.append(loss3.item()) + loss_list4.append(loss4.item()) + val_list.append(val_err) + test_list.append(test_err) + + return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2),\ + np.mean(loss_list3), np.mean(loss_list4), np.mean(val_list), np.mean(test_list) + + +def main_function(model_name, beta, lr, EPOCH, device) : + + # Dataset Creation + tmin, tmax = 0,1 + xmin, xmax = 0,1 + Nt, Nx = 51, 51 + X_train = torch.FloatTensor(np.mgrid[tmin:tmax:51j, xmin:xmax:51j].reshape(2, -1).T).to(device) + + # Initial Conditions + X_ini = Variable(X_train[X_train[:,0]==tmin].to(device), requires_grad=True) + u_ini = X_ini.detach()[:,1].view(-1,1) + u_ini_t = torch.zeros_like(u_ini) + + # Boundary Conditions + X_bdry = X_train[(X_train[:,1]==xmin) + (X_train[:,1]==xmax)] + u_bdry = analytic(X_bdry) + + # Validation & Test Set + X_test, y_test, X_val, y_val= torch.load('./PDEs/Klein-Gordon/Klein-Gordon_test', map_location=device) + + # take 1000 samples from the validation set + idx = np.random.choice(X_val.shape[0], 1000, replace=False) + X_val = X_val[idx] + y_val = y_val[idx] + + # Make dataloader + data_train = TensorDataset(X_train) + train_loader = DataLoader(data_train, batch_size=10000, shuffle=False) + + # train + torch.manual_seed(0) + total_loss, test_errs, val_errs, constraints = [], [], [], [] + u_model = set_model(model_name, device) + optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=lr) + best_model = copy.deepcopy(u_model) + + # for t in tqdm(range(0, EPOCH)) : + + # loss, loss1, loss2, loss3, loss4, val_err, test_err = train(u_model, beta, trainloader=train_loader,\ + # ini_bdry_data=[X_ini, u_ini, u_ini_t, X_bdry, u_bdry],\ + # val_test = [X_val, y_val, X_test, y_test],\ + # optimizer=optimizer, loss_f=nn.MSELoss()) + + # val_errs.append(val_err) + # test_errs.append(test_err) + # total_loss.append(loss) + # constraints.append([loss2, loss3, loss4]) # append both costraint + + # # Print Log + # if t%100 == 0 : + # print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ + # (t, EPOCH, loss, loss1, loss2+loss3+loss4, val_err, test_err)) + + + # SPBM + torch.manual_seed(0) + total_loss_spbm, test_errs_spbm, val_errs_spbm, constraints_spbm = [], [], [], [] + u_model = set_model(model_name, device) + + # Define data and optimizers + optimizer = MoreauEnvelope(torch.optim.Adam([{'params': u_model.parameters()}], lr=0.0005), mu=0.1) + + dual = PBM( + m=3, + # penalty_update='dimin', + # penalty_update='dimin_adapt', + penalty_update='const', + pbf = 'quadratic_logarithmic', + gamma=0.1, + init_duals=0.1, + init_penalties=1., + penalty_range=(0.5, 1.), + penalty_mult=0.99, + dual_range=(0.1, 100.), + delta=1.0, + device=device + ) + + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, loss3, loss4, val_err, test_err = train(u_model, beta, trainloader=train_loader,\ + ini_bdry_data=[X_ini, u_ini, u_ini_t, X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss(), dual_opt=dual) + + val_errs_spbm.append(val_err) + test_errs_spbm.append(test_err) + total_loss_spbm.append(loss) + constraints_spbm.append([loss2, loss3, loss4]) # append both costraint + + # Print Log + if t%100 == 0 : + print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ + (t, EPOCH, loss, loss1, loss2+loss3+loss4, val_err, test_err)) + + # ALM + torch.manual_seed(0) + total_loss_alm, test_errs_alm, val_errs_alm, constraints_alm = [], [], [], [] + u_model = set_model(model_name, device) + + # Define data and optimizers + optimizer = MoreauEnvelope(torch.optim.Adam([{'params': u_model.parameters()}], lr=0.005), mu=2.0) + + dual = ALM( + m=3, + lr=0.1, + momentum=0.5, + device=device + ) + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, loss3, loss4, val_err, test_err = train(u_model, beta, trainloader=train_loader,\ + ini_bdry_data=[X_ini, u_ini, u_ini_t, X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss(), dual_opt=dual) + + val_errs_alm.append(val_err) + test_errs_alm.append(test_err) + total_loss_alm.append(loss) + constraints_alm.append([loss2, loss3, loss4]) # append both costraint + + # Print Log + if t%100 == 0 : + print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ + (t, EPOCH, loss, loss1, loss2+loss3+loss4, val_err, test_err)) + + # plot the resultsimport matplotlib.pyplot as plt + import matplotlib.pyplot as plt + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) # wider figure + + axes[0].plot(total_loss, label='Adam') + axes[0].plot(total_loss_spbm, label='SPBM') + axes[0].plot(total_loss_alm, label='ALM') + axes[0].set_xlabel('Epoch') + axes[0].set_ylabel('Train Total Loss') + axes[0].legend() + + axes[1].plot(test_errs, label='Adam') + axes[1].plot(test_errs_spbm, label='SPBM') + axes[1].plot(test_errs_alm, label='ALM') + axes[1].set_xlabel('Epoch') + axes[1].set_ylabel('Test Error') + axes[1].legend() + + # plot both constraints + methods shuold have the same color but dashed vs solid + axes[2].plot([c[0] for c in constraints], label='Adam - Initial Condition - Zero Order', linestyle='--', color='blue') + axes[2].plot([c[1] for c in constraints], label='Adam - Initial Condition - First Order', linestyle='-', color='blue') + axes[2].plot([c[2] for c in constraints], label='Adam - Boundary Condition', linestyle=':', color='blue') + axes[2].plot([c[0] for c in constraints_spbm], label='SPBM - Initial Condition - Zero Order', linestyle='--', color='orange') + axes[2].plot([c[1] for c in constraints_spbm], label='SPBM - Initial Condition - First Order', linestyle='-', color='orange') + axes[2].plot([c[2] for c in constraints_spbm], label='SPBM - Boundary Condition', linestyle=':', color='orange') + axes[2].plot([c[0] for c in constraints_alm], label='ALM - Initial Condition - Zero Order', linestyle='--', color='green') + axes[2].plot([c[1] for c in constraints_alm], label='ALM - Initial Condition - First Order', linestyle='-', color='green') + axes[2].plot([c[2] for c in constraints_alm], label='ALM - Boundary Condition', linestyle=':', color='green') + axes[2].set_xlabel('Epoch') + axes[2].set_ylabel('Constraint Violation') + axes[2].legend() + + plt.tight_layout(pad=2.0) # extra padding between subplots + plt.savefig('./PDEs/Klein-Gordon/results.png', bbox_inches='tight') + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='deep_narrow', help='Specify the model. Choose one of [deep_narrow, shallow_wide, deep_narrow_resent, shallow_wide_resnet].') + parser.add_argument('--beta', default=1, type=float, help='Penalty parameter beta') + parser.add_argument('--lr', default=1e-1, type=float, help='Learning rate') + parser.add_argument('--EPOCH', default=2000, type=int, help='Number of training EPOCH') + parser.add_argument('--ordinal', default=0, type=int, help='Specify the cuda device ordinal.') + args = parser.parse_args() + + device = 'cuda' + main_function(args.model, args.beta, args.lr, args.EPOCH, device) + + \ No newline at end of file diff --git a/PDEs/Klein-Gordon/networks.py b/PDEs/Klein-Gordon/networks.py new file mode 100644 index 0000000..0ccef8b --- /dev/null +++ b/PDEs/Klein-Gordon/networks.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def set_model(model, device) : + if model =='deep_narrow_resnet' : + u_model = u_Net_deep_narrow_resnet().to(device) + elif model == 'shallow_wide_resnet' : + u_model = u_Net_shallow_wide_resnet().to(device) + elif model == 'deep_narrow' : + u_model = u_Net_deep_narrow().to(device) + elif model == 'shallow_wide' : + u_model = u_Net_shallow_wide().to(device) + return u_model + +class u_Net_shallow_wide(nn.Module): + def __init__(self): + super(u_Net_shallow_wide, self).__init__() + self.fc1 = nn.Linear(2, 256) + self.fc2 = nn.Linear(256, 256) +# self.fc3 = nn.Linear(256, 256) +# self.fc4 = nn.Linear(256, 256) + self.fc5 = nn.Linear(256, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x)) +# x = self.act1(self.fc3(x)) +# x = self.act1(self.fc4(x)) + x = self.fc5(x) + return x + + +class u_Net_shallow_wide_resnet(nn.Module): + def __init__(self): + super(u_Net_shallow_wide_resnet, self).__init__() + self.fc1 = nn.Linear(2, 256) + self.fc2 = nn.Linear(256, 256) +# self.fc3 = nn.Linear(256, 256) +# self.fc4 = nn.Linear(256, 256) + self.fc5 = nn.Linear(256, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x))+x +# x = self.act1(self.fc3(x))+x +# x = self.act1(self.fc4(x))+x + x = self.fc5(x) + return x + +class u_Net_deep_narrow(nn.Module): + def __init__(self): + super(u_Net_deep_narrow, self).__init__() + self.fc1 = nn.Linear(2, 64) + self.fc2 = nn.Linear(64, 64) + self.fc3 = nn.Linear(64, 64) + self.fc4 = nn.Linear(64, 64) + self.fc5 = nn.Linear(64, 64) + self.fc6 = nn.Linear(64, 64) + self.fc7 = nn.Linear(64, 64) + self.fc8 = nn.Linear(64, 64) + self.fc9 = nn.Linear(64, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x)) + x = self.act1(self.fc3(x)) + x = self.act1(self.fc4(x)) + x = self.act1(self.fc5(x)) + x = self.act1(self.fc6(x)) + x = self.act1(self.fc7(x)) + x = self.act1(self.fc8(x)) + x = self.fc9(x) + return x + +class u_Net_deep_narrow_resnet(nn.Module): + def __init__(self): + super(u_Net_deep_narrow_resnet, self).__init__() + self.fc1 = nn.Linear(2, 64) + self.fc2 = nn.Linear(64, 64) + self.fc3 = nn.Linear(64, 64) + self.fc4 = nn.Linear(64, 64) + self.fc5 = nn.Linear(64, 64) + self.fc6 = nn.Linear(64, 64) + self.fc7 = nn.Linear(64, 64) + self.fc8 = nn.Linear(64, 64) + self.fc9 = nn.Linear(64, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x))+x + x = self.act1(self.fc3(x))+x + x = self.act1(self.fc4(x))+x + x = self.act1(self.fc5(x))+x + x = self.act1(self.fc6(x))+x + x = self.act1(self.fc7(x))+x + x = self.act1(self.fc8(x))+x + x = self.fc9(x) + return x \ No newline at end of file diff --git a/PDEs/Klein-Gordon/results.png b/PDEs/Klein-Gordon/results.png new file mode 100644 index 0000000..cd0f283 Binary files /dev/null and b/PDEs/Klein-Gordon/results.png differ diff --git a/PDEs/Viscous_Burgers/Burgers_test b/PDEs/Viscous_Burgers/Burgers_test new file mode 100644 index 0000000..68aeba6 Binary files /dev/null and b/PDEs/Viscous_Burgers/Burgers_test differ diff --git a/PDEs/Viscous_Burgers/Viscous_Burgers.py b/PDEs/Viscous_Burgers/Viscous_Burgers.py new file mode 100644 index 0000000..4c97b23 --- /dev/null +++ b/PDEs/Viscous_Burgers/Viscous_Burgers.py @@ -0,0 +1,275 @@ +from tqdm import tqdm +import pickle as pkl +import numpy as np +import copy +import argparse +import sys +sys.path.append("..") + +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +from humancompatible.train.dual_optim import ALM, MoreauEnvelope, PBM + +from networks import set_model, u_Net_shallow_wide, u_Net_shallow_wide_resnet, u_Net_deep_narrow, u_Net_deep_narrow_resnet + +# Equation parameter + +nu = 0.01/np.pi + +def calculate_derivative(y, x) : + return torch.autograd.grad(y, x, create_graph=True,\ + grad_outputs=torch.ones(y.size()).to(device))[0] + +def calculate_all_partial(u, x) : + del_u = calculate_derivative(u, x) + u_t, u_x = del_u[:,0], del_u[:,1] + u_xx = calculate_derivative(u_x, x)[:,1] + return u_t.view(-1,1), u_x.view(-1,1), u_xx.view(-1,1) + + +def train(u_model, beta, trainloader, ini_bdry_data, val_test, optimizer, loss_f, dual_opt=None) : + loss_list, loss_list1, loss_list2, loss_list3, val_list, test_list = [], [], [], [], [], [] + X_ini, u_ini, X_bdry, u_bdry = ini_bdry_data + X_val, y_val, X_test, y_test = val_test + + for i, (data,) in enumerate(trainloader) : + u_model.train() + optimizer.zero_grad() + X_v = Variable(data, requires_grad=True).to(device) + output = u_model(X_v) + output_ini = u_model(X_ini) + output_bdry = u_model(X_bdry) + + u_t, u_x, u_xx = calculate_all_partial(output, X_v) + loss1 = loss_f(u_t + output*u_x - nu*u_xx, torch.zeros_like(u_t)) + loss2 = loss_f(output_ini-u_ini, torch.zeros_like(output_ini)) + loss3 = loss_f(output_bdry, torch.zeros_like(output_bdry)) + + # unconstrained Adam + if dual_opt is None : + loss = loss1 + beta*loss2 + beta*loss3 + loss.backward() + optimizer.step() + optimizer.zero_grad() + + elif dual_opt is not None: + threshold = 1e-4 + constraints = torch.stack([loss2, loss3], dim=0) + constraints = constraints - threshold + + # compute the lagrangian value + lagrangian = dual_opt.forward_update(loss1, constraints) + lagrangian.backward() + optimizer.step() + optimizer.zero_grad() + + u_model.eval() + val_err = torch.linalg.norm((u_model(X_val) - y_val),2).item() / torch.linalg.norm(y_val,2).item() + test_err = torch.linalg.norm((u_model(X_test) - y_test),2).item() / torch.linalg.norm(y_test,2).item() + + loss_list.append((loss1+loss2+loss3).item()) + loss_list1.append(loss1.item()) + loss_list2.append(loss2.item()) + loss_list3.append(loss3.item()) + val_list.append(val_err) + test_list.append(test_err) + + return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2),\ + np.mean(loss_list3), np.mean(val_list), np.mean(test_list) + + +def main_function(model_name, beta, lr, EPOCH, device) : + + # Dataset Creation + tmin, tmax = 0, 1 + xmin, xmax = -1,1 + Ns, Nx = 51, 51 + X_train = torch.FloatTensor(np.mgrid[tmin:tmax:51j, xmin:xmax:51j].reshape(2, -1).T).to(device) + + # Initial Conditions + X_ini = X_train[X_train[:,0]==tmin] + u_ini = -torch.sin(np.pi*X_ini[:,1].view(-1,1)) + + # Boundary Conditions + X_bdry = X_train[(X_train[:,1]==xmin) + (X_train[:,1]==xmax)] + u_bdry = torch.zeros_like(X_bdry[:,0]).to(device).view(-1,1) + + # Validation & Test Set + X_test, y_test, X_val, y_val= torch.load('./PDEs/Viscous_Burgers/Burgers_test', map_location=device) + + # take 1000 samples from the validation set + idx = np.random.choice(X_val.shape[0], 1000, replace=False) + X_val = X_val[idx] + y_val = y_val[idx] + + # Make dataloader + data_train = TensorDataset(X_train) + train_loader = DataLoader(data_train, batch_size=10000, shuffle=False) + + # train + torch.manual_seed(0) + total_loss, test_errs, val_errs, constraints = [], [], [], [] + u_model = set_model(model_name, device) + optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=lr) + best_model = copy.deepcopy(u_model) + + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, loss3, val_err, test_err = train(u_model, beta, trainloader=train_loader,\ + ini_bdry_data=[X_ini, u_ini, X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss(), + dual_opt=None) + + # val_errs.append(torch.linalg.norm((u_model(X_val) - y_val),2).item() / torch.linalg.norm(y_val,2).item()) + # test_errs.append(torch.linalg.norm((u_model(X_test) - y_test),2).item() / torch.linalg.norm(y_test,2).item()) + # total_loss.append(loss) + + val_errs.append(val_err) + test_errs.append(test_err) + total_loss.append(loss) + constraints.append([loss2, loss3]) # append both costraint + + # Print Log + if t%100 == 0 : + print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ + (t, EPOCH, loss, loss1, loss2+loss3, val_err, test_err)) + + if np.argmin(val_errs) == t : + best_model = copy.deepcopy(u_model) + + + + # SPBM + torch.manual_seed(0) + total_loss_spbm, test_errs_spbm, val_errs_spbm, constraints_spbm = [], [], [], [] + u_model = set_model(model_name, device) + + # Define data and optimizers + optimizer = MoreauEnvelope(torch.optim.Adam([{'params': u_model.parameters()}], lr=0.005), mu=2.0) + + dual = PBM( + m=2, + # penalty_update='dimin', + # penalty_update='dimin_adapt', + penalty_update='const', + pbf = 'quadratic_logarithmic', + gamma=0.9, + init_duals=0.01, + init_penalties=1., + penalty_range=(0.5, 1.), + penalty_mult=0.99, + dual_range=(0.01, 100.), + delta=1.0, + device=device + ) + + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, loss3, val_err, test_err = train(u_model, beta, trainloader=train_loader,\ + ini_bdry_data=[X_ini, u_ini, X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss(), + dual_opt=dual) + + val_errs_spbm.append(val_err) + test_errs_spbm.append(test_err) + total_loss_spbm.append(loss) + constraints_spbm.append([loss2, loss3]) + + #Print Log + if t%100 == 0 : + print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ + (t, EPOCH, loss, loss1, loss2+loss3, val_err, test_err)) + + if np.argmin(val_errs_spbm) == t : + best_model = copy.deepcopy(u_model) + + # ALM + torch.manual_seed(0) + total_loss_alm, test_errs_alm, val_errs_alm, constraints_alm = [], [], [], [] + u_model = set_model(model_name, device) + + # Define data and optimizers + optimizer = MoreauEnvelope(torch.optim.Adam([{'params': u_model.parameters()}], lr=0.005), mu=2.0) + + dual = ALM( + m=2, + lr=0.1, + momentum=0.5, + device=device + ) + + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, loss3, val_err, test_err = train(u_model, beta, trainloader=train_loader,\ + ini_bdry_data=[X_ini, u_ini, X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss(), + dual_opt=dual) + + val_errs_alm.append(val_err) + test_errs_alm.append(test_err) + total_loss_alm.append(loss) + constraints_alm.append([loss2, loss3]) + + #Print Log + if t%100 == 0 : + print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ + (t, EPOCH, loss, loss1, loss2+loss3, val_err, test_err)) + + if np.argmin(val_errs_alm) == t : + best_model = copy.deepcopy(u_model) + + + # plot the resultsimport matplotlib.pyplot as plt + import matplotlib.pyplot as plt + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) # wider figure + + axes[0].plot(total_loss, label='Adam') + axes[0].plot(total_loss_spbm, label='SPBM') + axes[0].plot(total_loss_alm, label='ALM') + axes[0].set_xlabel('Epoch') + axes[0].set_ylabel('Train Total Loss') + axes[0].legend() + + axes[1].plot(test_errs, label='Adam') + axes[1].plot(test_errs_spbm, label='SPBM') + axes[1].plot(test_errs_alm, label='ALM') + axes[1].set_xlabel('Epoch') + axes[1].set_ylabel('Test Error') + axes[1].legend() + + # plot both constraints + methods shuold have the same color but dashed vs solid + axes[2].plot([c[0] for c in constraints], label='Adam - Initial Condition', linestyle='--', color='blue') + axes[2].plot([c[1] for c in constraints], label='Adam - Boundary Condition', linestyle='-', color='blue') + axes[2].plot([c[0] for c in constraints_spbm], label='SPBM - Initial Condition', linestyle='--', color='orange') + axes[2].plot([c[1] for c in constraints_spbm], label='SPBM - Boundary Condition', linestyle='-', color='orange') + axes[2].plot([c[0] for c in constraints_alm], label='ALM - Initial Condition', linestyle='--', color='green') + axes[2].plot([c[1] for c in constraints_alm], label='ALM - Boundary Condition', linestyle='-', color='green') + axes[2].set_xlabel('Epoch') + axes[2].set_ylabel('Constraint Violation') + axes[2].legend() + + plt.tight_layout(pad=2.0) # extra padding between subplots + plt.savefig('./PDEs/Viscous_Burgers/results.png', bbox_inches='tight') + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='deep_narrow', help='Specify the model. Choose one of [deep_narrow, shallow_wide, deep_narrow_resent, shallow_wide_resnet].') + parser.add_argument('--beta', default=1, type=float, help='Penalty parameter beta') + parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate') + parser.add_argument('--EPOCH', default=1000, type=int, help='Number of training EPOCH') + parser.add_argument('--ordinal', default=0, type=int, help='Specify the cuda device ordinal.') + args = parser.parse_args() + + device = torch.device("cuda:{}".format(args.ordinal) if torch.cuda.is_available() else "cpu") + + main_function(args.model, args.beta, args.lr, args.EPOCH, device) + \ No newline at end of file diff --git a/PDEs/Viscous_Burgers/Viscous_Burgers_AL-PINNs.py b/PDEs/Viscous_Burgers/Viscous_Burgers_AL-PINNs.py new file mode 100644 index 0000000..17488ba --- /dev/null +++ b/PDEs/Viscous_Burgers/Viscous_Burgers_AL-PINNs.py @@ -0,0 +1,146 @@ +from tqdm import tqdm +import pickle as pkl +import numpy as np +import copy +import argparse +import sys +sys.path.append("..") + +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +from networks import set_model, u_Net_shallow_wide, u_Net_shallow_wide_resnet, u_Net_deep_narrow, u_Net_deep_narrow_resnet + +# Equation parameter + +nu = 0.01/np.pi + +def calculate_derivative(y, x) : + return torch.autograd.grad(y, x, create_graph=True,\ + grad_outputs=torch.ones(y.size()).to(device))[0] + +def calculate_all_partial(u, x) : + del_u = calculate_derivative(u, x) + u_t, u_x = del_u[:,0], del_u[:,1] + u_xx = calculate_derivative(u_x, x)[:,1] + return u_t.view(-1,1), u_x.view(-1,1), u_xx.view(-1,1) + + +def train(u_model, lbd1, lbd2, beta, trainloader, ini_bdry_data, val_test, optimizer, loss_f) : + loss_list, loss_list1, loss_list2, loss_list3, val_list, test_list = [], [], [], [], [], [] + X_ini, u_ini, X_bdry, u_bdry = ini_bdry_data + X_val, y_val, X_test, y_test = val_test + + for i, (data,) in enumerate(trainloader) : + u_model.train() + optimizer.zero_grad() + X_v = Variable(data, requires_grad=True).to(device) + output = u_model(X_v) + output_ini = u_model(X_ini) + output_bdry = u_model(X_bdry) + + u_t, u_x, u_xx = calculate_all_partial(output, X_v) + loss1 = loss_f(u_t + output*u_x - nu*u_xx, torch.zeros_like(u_t)) + loss2 = loss_f(output_ini-u_ini, torch.zeros_like(output_ini)) + loss3 = loss_f(output_bdry, torch.zeros_like(output_bdry)) + + loss = loss1 +\ + beta*loss2 + (lbd1*(output_ini-u_ini).view(-1)).mean() +\ + beta*loss3 + (lbd2*output_bdry.view(-1)).mean() + + loss.backward() + lbd1.grad *= -1 + lbd2.grad *= -1 + + optimizer.step() + + u_model.eval() + val_err = torch.linalg.norm((u_model(X_val) - y_val),2).item() / torch.linalg.norm(y_val,2).item() + test_err = torch.linalg.norm((u_model(X_test) - y_test),2).item() / torch.linalg.norm(y_test,2).item() + + loss_list.append((loss1+loss2).item()) + loss_list1.append(loss1.item()) + loss_list2.append(loss2.item()) + loss_list3.append(loss3.item()) + val_list.append(val_err) + test_list.append(test_err) + + return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2),\ + np.mean(loss_list3), np.mean(val_list), np.mean(test_list) + + +def main_function(model_name, beta, lr, lbd_lr, EPOCH, device) : + + # Dataset Creation + tmin, tmax = 0, 1 + xmin, xmax = -1,1 + Ns, Nx = 51, 51 + X_train = torch.FloatTensor(np.mgrid[tmin:tmax:51j, xmin:xmax:51j].reshape(2, -1).T).to(device) + + # Initial Conditions + X_ini = X_train[X_train[:,0]==tmin] + u_ini = -torch.sin(np.pi*X_ini[:,1].view(-1,1)) + + # Boundary Conditions + X_bdry = X_train[(X_train[:,1]==xmin) + (X_train[:,1]==xmax)] + u_bdry = torch.zeros_like(X_bdry[:,0]).to(device).view(-1,1) + + # Validation & Test Set + X_test, y_test, X_val, y_val= torch.load('Burgers_test', map_location=device) + + # Make dataloader + data_train = TensorDataset(X_train) + train_loader = DataLoader(data_train, batch_size=10000, shuffle=False) + + # train + total_loss, test_errs, val_errs = [], [], [] + u_model = set_model(model_name, device) + lbd1 = Variable(torch.FloatTensor([0]*X_ini.size()[0]).to(device), requires_grad=True)# + lbd2 = Variable(torch.FloatTensor([0]*X_bdry.size()[0]).to(device), requires_grad=True)# + + optimizer=torch.optim.Adam([{'params': u_model.parameters()}, {'params': lbd1, 'lr':lbd_lr}, \ + {'params': lbd2, 'lr':lbd_lr}], lr=lr) + best_model = copy.deepcopy(u_model) + + for t in tqdm(range(0, EPOCH)) : + + loss, loss1, loss2, loss3, val_err, test_err = train(u_model, lbd1, lbd2, beta, trainloader=train_loader,\ + ini_bdry_data=[X_ini, u_ini, X_bdry, u_bdry],\ + val_test = [X_val, y_val, X_test, y_test],\ + optimizer=optimizer, loss_f=nn.MSELoss()) + + val_errs.append(torch.linalg.norm((u_model(X_val) - y_val),2).item() / torch.linalg.norm(y_val,2).item()) + test_errs.append(torch.linalg.norm((u_model(X_test) - y_test),2).item() / torch.linalg.norm(y_test,2).item()) + total_loss.append(loss) + +# # Print Log +# if t%100 == 0 : +# print("%s/%s | loss: %06.6f | loss_f: %06.6f | loss_u: %06.6f | val error : %06.6f | test error : %06.6f " % \ +# (t, EPOCH, loss, loss1, loss2+loss3, val_err, test_err)) + + if np.argmin(val_errs) == t : + best_model = copy.deepcopy(u_model) + + return best_model, total_loss, val_errs, test_errs + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='deep_narrow', help='Specify the model. Choose one of [deep_narrow, shallow_wide, deep_narrow_resent, shallow_wide_resnet].') + parser.add_argument('--beta', default=1, type=float, help='Penalty parameter beta') + parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate') + parser.add_argument('--lbd_lr', default=1e-4, type=float, help='Learning rate for lambda') + parser.add_argument('--EPOCH', default=10000, type=int, help='Number of training EPOCH') + parser.add_argument('--ordinal', default=0, type=int, help='Specify the cuda device ordinal.') + args = parser.parse_args() + + device = torch.device("cuda:{}".format(args.ordinal) if torch.cuda.is_available() else "cpu") + + best_model, total_loss, val_errs, test_errs = main_function(args.model, args.beta, args.lr, args.lbd_lr, args.EPOCH, device) + print('Best Test Error : ', test_errs[np.argmin(val_errs)]) + \ No newline at end of file diff --git a/PDEs/Viscous_Burgers/networks.py b/PDEs/Viscous_Burgers/networks.py new file mode 100644 index 0000000..0ccef8b --- /dev/null +++ b/PDEs/Viscous_Burgers/networks.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def set_model(model, device) : + if model =='deep_narrow_resnet' : + u_model = u_Net_deep_narrow_resnet().to(device) + elif model == 'shallow_wide_resnet' : + u_model = u_Net_shallow_wide_resnet().to(device) + elif model == 'deep_narrow' : + u_model = u_Net_deep_narrow().to(device) + elif model == 'shallow_wide' : + u_model = u_Net_shallow_wide().to(device) + return u_model + +class u_Net_shallow_wide(nn.Module): + def __init__(self): + super(u_Net_shallow_wide, self).__init__() + self.fc1 = nn.Linear(2, 256) + self.fc2 = nn.Linear(256, 256) +# self.fc3 = nn.Linear(256, 256) +# self.fc4 = nn.Linear(256, 256) + self.fc5 = nn.Linear(256, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x)) +# x = self.act1(self.fc3(x)) +# x = self.act1(self.fc4(x)) + x = self.fc5(x) + return x + + +class u_Net_shallow_wide_resnet(nn.Module): + def __init__(self): + super(u_Net_shallow_wide_resnet, self).__init__() + self.fc1 = nn.Linear(2, 256) + self.fc2 = nn.Linear(256, 256) +# self.fc3 = nn.Linear(256, 256) +# self.fc4 = nn.Linear(256, 256) + self.fc5 = nn.Linear(256, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x))+x +# x = self.act1(self.fc3(x))+x +# x = self.act1(self.fc4(x))+x + x = self.fc5(x) + return x + +class u_Net_deep_narrow(nn.Module): + def __init__(self): + super(u_Net_deep_narrow, self).__init__() + self.fc1 = nn.Linear(2, 64) + self.fc2 = nn.Linear(64, 64) + self.fc3 = nn.Linear(64, 64) + self.fc4 = nn.Linear(64, 64) + self.fc5 = nn.Linear(64, 64) + self.fc6 = nn.Linear(64, 64) + self.fc7 = nn.Linear(64, 64) + self.fc8 = nn.Linear(64, 64) + self.fc9 = nn.Linear(64, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x)) + x = self.act1(self.fc3(x)) + x = self.act1(self.fc4(x)) + x = self.act1(self.fc5(x)) + x = self.act1(self.fc6(x)) + x = self.act1(self.fc7(x)) + x = self.act1(self.fc8(x)) + x = self.fc9(x) + return x + +class u_Net_deep_narrow_resnet(nn.Module): + def __init__(self): + super(u_Net_deep_narrow_resnet, self).__init__() + self.fc1 = nn.Linear(2, 64) + self.fc2 = nn.Linear(64, 64) + self.fc3 = nn.Linear(64, 64) + self.fc4 = nn.Linear(64, 64) + self.fc5 = nn.Linear(64, 64) + self.fc6 = nn.Linear(64, 64) + self.fc7 = nn.Linear(64, 64) + self.fc8 = nn.Linear(64, 64) + self.fc9 = nn.Linear(64, 1) + self.act1 = nn.Tanh() + def forward(self, x): + x = self.act1(self.fc1(x)) + x = self.act1(self.fc2(x))+x + x = self.act1(self.fc3(x))+x + x = self.act1(self.fc4(x))+x + x = self.act1(self.fc5(x))+x + x = self.act1(self.fc6(x))+x + x = self.act1(self.fc7(x))+x + x = self.act1(self.fc8(x))+x + x = self.fc9(x) + return x \ No newline at end of file diff --git a/PDEs/Viscous_Burgers/results.png b/PDEs/Viscous_Burgers/results.png new file mode 100644 index 0000000..9a49386 Binary files /dev/null and b/PDEs/Viscous_Burgers/results.png differ diff --git a/benchmark/_data_sources.py b/benchmark/_data_sources.py index dcfef0f..117d83d 100644 --- a/benchmark/_data_sources.py +++ b/benchmark/_data_sources.py @@ -6,8 +6,6 @@ from sklearn.model_selection import train_test_split from humancompatible.train.fairness.utils import BalancedBatchSampler from itertools import product -import torchvision -from torchvision import transforms import itertools @@ -39,7 +37,7 @@ def comb_cat_dummies(df): -def load_data_norm(batch_size=64): +def load_data_norm(batch_size=64, device='cpu'): # load folktables data data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person") @@ -74,19 +72,19 @@ def load_data_norm(batch_size=64): X_test = scaler.transform(X_test) # make into a pytorch dataset, remove the sensitive attribute - features_train = torch.tensor(X_train) - labels_train = torch.tensor(y_train) - sens_train = torch.tensor(groups_train) + features_train = torch.tensor(X_train).to(device) + labels_train = torch.tensor(y_train).to(device) + sens_train = torch.tensor(groups_train).to(device) # make into a pytorch dataset, remove the sensitive attribute - features_val = torch.tensor(X_val) - labels_val = torch.tensor(y_val) - sens_val = torch.tensor(groups_val) + features_val = torch.tensor(X_val).to(device) + labels_val = torch.tensor(y_val).to(device) + sens_val = torch.tensor(groups_val).to(device) # make into a pytorch dataset, remove the sensitive attribute - features_test = torch.tensor(X_test) - labels_test = torch.tensor(y_test) - sens_test = torch.tensor(groups_test) + features_test = torch.tensor(X_test).to(device) + labels_test = torch.tensor(y_test).to(device) + sens_test = torch.tensor(groups_test).to(device) # set the same seed for fair comparisons torch.manual_seed(0) @@ -104,7 +102,7 @@ def load_data_norm(batch_size=64): return (dataloader_train, dataloader_val, dataloader_test), (features_train, sens_train, labels_train), (features_val, sens_val, labels_val), (features_test, sens_test, labels_test) -def load_data_FT(batch_size, device, sens_attrs, states, group_size_threshold = 0, sens_groups = None, extend_groups = False): +def load_data_FT(batch_size, device, sens_attrs, states=['VA'], group_size_threshold = 0, sens_groups = None, extend_groups = False, dtype=torch.float32): # load folktables data data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person") ACSProblem = BasicProblem( @@ -125,6 +123,10 @@ def load_data_FT(batch_size, device, sens_attrs, states, group_size_threshold = acs_data, categories=categories, dummies=True ) + if 'MAR' in sens_attrs: + df_sens['MAR_2'] = df_sens['MAR_2'] + df_sens['MAR_4'] + df_sens['MAR_5'] + df_sens.drop(['MAR_4', 'MAR_5'], inplace=True, axis='columns') + # df_sens_onehot = comb_cat_dummies(df_sens) if sens_groups else df_sens df_sens_onehot = comb_cat_dummies(df_sens) if len(sens_attrs) > 1 else df_sens @@ -170,21 +172,21 @@ def load_data_FT(batch_size, device, sens_attrs, states, group_size_threshold = print(f"{df_sens_onehot.columns[idx]}, : {(groups[:, idx] == 1).sum()}") # make into a pytorch dataset, remove the sensitive attribute - features_train = torch.tensor(X_train).to(torch.float32) - labels_train = torch.tensor(y_train).to(torch.float32) - sens_train = torch.tensor(groups_train).to(torch.float32) + features_train = torch.tensor(X_train).to(dtype) + labels_train = torch.tensor(y_train).to(dtype) + sens_train = torch.tensor(groups_train).to(dtype) dataset_train = torch.utils.data.TensorDataset(features_train, labels_train) # make into a pytorch dataset, remove the sensitive attribute - features_val = torch.tensor(X_val).to(torch.float32) - labels_val = torch.tensor(y_val).to(torch.float32) - sens_val = torch.tensor(groups_val).to(torch.float32) + features_val = torch.tensor(X_val).to(dtype) + labels_val = torch.tensor(y_val).to(dtype) + sens_val = torch.tensor(groups_val).to(dtype) dataset_val = torch.utils.data.TensorDataset(features_val, labels_val) # make into a pytorch dataset, remove the sensitive attribute - features_test = torch.tensor(X_test).to(torch.float32) - labels_test = torch.tensor(y_test).to(torch.float32) - sens_test = torch.tensor(groups_test).to(torch.float32) + features_test = torch.tensor(X_test).to(dtype) + labels_test = torch.tensor(y_test).to(dtype) + sens_test = torch.tensor(groups_test).to(dtype) dataset_test = torch.utils.data.TensorDataset(features_test, labels_test) # get the dataset @@ -211,8 +213,7 @@ def load_data_FT(batch_size, device, sens_attrs, states, group_size_threshold = - -def load_data_FT_prod(batch_size, device='cpu'): +def load_data_FT_prod(batch_size, device='cpu', extend_groups = False): # load folktables data data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person") @@ -286,21 +287,21 @@ def load_data_FT_prod(batch_size, device='cpu'): print(f"{group_dict[idx]}, : {(groups_onehot[:, idx] == 1).sum()}") # make into a pytorch dataset, remove the sensitive attribute - features_train = torch.tensor(X_train).to(torch.float32) - labels_train = torch.tensor(y_train).to(torch.float32) - sens_train = torch.tensor(groups_train).to(torch.float32) + features_train = torch.tensor(X_train).to(torch.float32).to(device) + labels_train = torch.tensor(y_train).to(torch.float32).to(device) + sens_train = torch.tensor(groups_train).to(torch.float32).to(device) dataset_train = torch.utils.data.TensorDataset(features_train, labels_train) # make into a pytorch dataset, remove the sensitive attribute - features_val = torch.tensor(X_val).to(torch.float32) - labels_val = torch.tensor(y_val).to(torch.float32) - sens_val = torch.tensor(groups_val).to(torch.float32) + features_val = torch.tensor(X_val).to(torch.float32).to(device) + labels_val = torch.tensor(y_val).to(torch.float32).to(device) + sens_val = torch.tensor(groups_val).to(torch.float32).to(device) dataset_val = torch.utils.data.TensorDataset(features_val, labels_val) # make into a pytorch dataset, remove the sensitive attribute - features_test = torch.tensor(X_test).to(torch.float32) - labels_test = torch.tensor(y_test).to(torch.float32) - sens_test = torch.tensor(groups_test).to(torch.float32) + features_test = torch.tensor(X_test).to(torch.float32).to(device) + labels_test = torch.tensor(y_test).to(torch.float32).to(device) + sens_test = torch.tensor(groups_test).to(torch.float32).to(device) dataset_test = torch.utils.data.TensorDataset(features_test, labels_test) # set the same seed for fair comparisons @@ -313,13 +314,13 @@ def load_data_FT_prod(batch_size, device='cpu'): # create a balanced sampling - needed for an unbiased gradient sampler = BalancedBatchSampler( - group_onehot=sens_train, batch_size=batch_size, drop_last=True + group_onehot=sens_train, batch_size=batch_size, drop_last=True, extend_groups=list(range(sens_train.shape[1])) if extend_groups else None ) sampler_val = BalancedBatchSampler( - group_onehot=sens_val, batch_size=batch_size, drop_last=True + group_onehot=sens_val, batch_size=batch_size, drop_last=True, extend_groups=list(range(sens_train.shape[1])) if extend_groups else None ) sampler_test = BalancedBatchSampler( - group_onehot=sens_test, batch_size=batch_size, drop_last=True + group_onehot=sens_test, batch_size=batch_size, drop_last=True, extend_groups=list(range(sens_train.shape[1])) if extend_groups else None ) # create a dataloader from the sampler @@ -331,7 +332,7 @@ def load_data_FT_prod(batch_size, device='cpu'): -def load_data_FT_vec(batch_size, attr = "SEX", device='cpu'): +def load_data_FT_vec(batch_size, device='cpu', attr = "SEX", extend_groups = False): # load folktables data data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person") @@ -369,22 +370,22 @@ def load_data_FT_vec(batch_size, attr = "SEX", device='cpu'): X_test = scaler.transform(X_test) # make into a pytorch dataset, remove the sensitive attribute - features_train = torch.tensor(X_train) - labels_train = torch.tensor(y_train) - sens_train = torch.tensor(groups_train) + features_train = torch.tensor(X_train).to(device) + labels_train = torch.tensor(y_train).to(device) + sens_train = torch.tensor(groups_train).to(device) dataset_train = torch.utils.data.TensorDataset(features_train, labels_train) # make into a pytorch dataset, remove the sensitive attribute - features_val = torch.tensor(X_val) - labels_val = torch.tensor(y_val) - sens_val = torch.tensor(groups_val) + features_val = torch.tensor(X_val).to(device) + labels_val = torch.tensor(y_val).to(device) + sens_val = torch.tensor(groups_val).to(device) dataset_val = torch.utils.data.TensorDataset(features_val, labels_val) # make into a pytorch dataset, remove the sensitive attribute - features_test = torch.tensor(X_test) - labels_test = torch.tensor(y_test) - sens_test = torch.tensor(groups_test) - dataset_test = torch.utils.data.TensorDataset(features_test, labels_test) + features_test = torch.tensor(X_test).to(device) + labels_test = torch.tensor(y_test).to(device) + sens_test = torch.tensor(groups_test).to(device) + dataset_test = torch.utils.data.TensorDataset(features_test, sens_test, labels_test) # set the same seed for fair comparisons torch.manual_seed(0) @@ -396,13 +397,13 @@ def load_data_FT_vec(batch_size, attr = "SEX", device='cpu'): # create a balanced sampling - needed for an unbiased gradient sampler = BalancedBatchSampler( - group_onehot=sens_train, batch_size=batch_size, drop_last=True + group_onehot=sens_train, batch_size=batch_size, drop_last=True, extend_groups=list(range(sens_train.shape[1])) if extend_groups else None ) sampler_val = BalancedBatchSampler( - group_onehot=sens_val, batch_size=batch_size, drop_last=True + group_onehot=sens_val, batch_size=batch_size, drop_last=True, extend_groups=list(range(sens_train.shape[1])) if extend_groups else None ) sampler_test = BalancedBatchSampler( - group_onehot=sens_test, batch_size=batch_size, drop_last=True + group_onehot=sens_test, batch_size=batch_size, drop_last=True, extend_groups=list(range(sens_train.shape[1])) if extend_groups else None ) # create a dataloader from the sampler @@ -413,45 +414,45 @@ def load_data_FT_vec(batch_size, attr = "SEX", device='cpu'): return (dataloader_train, dataloader_val, dataloader_test), (features_train, sens_train, labels_train), (features_val, sens_val, labels_val), (features_test, sens_test, labels_test) -def load_data_DUTCH(batch_size): +def load_data_DUTCH(batch_size, device='cpu', extend_groups = False): # Get the data with a validation split X_train, X_val, X_test, y_train, y_val, y_test, groups_train, groups_val, groups_test, group_names_dict = get_data_dutch( test_size=0.4, seed_n=42, drop_small_groups=True, print_stats=True ) # Convert training data to PyTorch tensors - features_train = torch.tensor(X_train).to(torch.float32) - labels_train = torch.tensor(y_train.to_numpy()).reshape((-1, 1)).to(torch.float32) - sens_train = torch.tensor(groups_train).to(torch.float32) + features_train = torch.tensor(X_train).to(torch.float32).to(device) + labels_train = torch.tensor(y_train.to_numpy()).reshape((-1, 1)).to(torch.float32).to(device) + sens_train = torch.tensor(groups_train).to(torch.float32).to(device) dataset_train = torch.utils.data.TensorDataset(features_train, sens_train, labels_train) # Convert validation data to PyTorch tensors - features_val = torch.tensor(X_val).to(torch.float32) - labels_val = torch.tensor(y_val).reshape((-1, 1)).to(torch.float32) - sens_val = torch.tensor(groups_val).to(torch.float32) + features_val = torch.tensor(X_val).to(torch.float32).to(device) + labels_val = torch.tensor(y_val).reshape((-1, 1)).to(torch.float32).to(device) + sens_val = torch.tensor(groups_val).to(torch.float32).to(device) dataset_val = torch.utils.data.TensorDataset(features_val, sens_val, labels_val) # Convert test data to PyTorch tensors - features_test = torch.tensor(X_test).to(torch.float32) - labels_test = torch.tensor(y_test.to_numpy()).reshape((-1, 1)).to(torch.float32) - sens_test = torch.tensor(groups_test).to(torch.float32) + features_test = torch.tensor(X_test).to(torch.float32).to(device) + labels_test = torch.tensor(y_test.to_numpy()).reshape((-1, 1)).to(torch.float32).to(device) + sens_test = torch.tensor(groups_test).to(torch.float32).to(device) dataset_test = torch.utils.data.TensorDataset(features_test, sens_test, labels_test) # Create balanced samplers sampler_train = BalancedBatchSampler( - group_onehot=sens_train, batch_size=batch_size, drop_last=True#, extend_groups=list(range(sens_train.shape[1])) + group_onehot=sens_train, batch_size=batch_size, drop_last=True, extend_groups=list(range(sens_train.shape[1])) if extend_groups else None ) sampler_val = BalancedBatchSampler( - group_onehot=sens_val, batch_size=252*4, drop_last=True#, extend_groups=list(range(sens_train.shape[1])) + group_onehot=sens_val, batch_size=252*4, drop_last=True, extend_groups=list(range(sens_train.shape[1])) if extend_groups else None ) sampler_test = BalancedBatchSampler( - group_onehot=sens_test, batch_size=252*4, drop_last=True#, extend_groups=list(range(sens_train.shape[1])) + group_onehot=sens_test, batch_size=252*4, drop_last=True, extend_groups=list(range(sens_train.shape[1])) if extend_groups else None ) # Create dataloaders - dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_sampler=sampler_train, num_workers=8) - dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_sampler=sampler_val, num_workers=8) - dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_sampler=sampler_test, num_workers=8) + dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_sampler=sampler_train) + dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_sampler=sampler_val) + dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_sampler=sampler_test) return (dataloader_train, dataloader_val, dataloader_test), (features_train, sens_train, labels_train), (features_val, sens_val, labels_val), (features_test, sens_test, labels_test) @@ -549,7 +550,10 @@ def get_data_dutch(test_size=0.2, seed_n = 42, drop_small_groups=True, print_sta -def load_data_cifar10(balanced=False): +def load_data_cifar10(balanced=False, device='cpu'): + + import torchvision + from torchvision import transforms # load the data transform = transforms.Compose( @@ -566,6 +570,7 @@ def load_data_cifar10(balanced=False): # make some parameters global global classes + print(device) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') @@ -579,60 +584,69 @@ def load_data_cifar10(balanced=False): print(class_ind) # load all data and create a balanced sampler - X = torch.stack([item[0] for item in trainset]) - targets = torch.tensor([item[1] for item in trainset]) + X = torch.stack([item[0] for item in trainset]).to(device) + targets = torch.tensor([item[1] for item in trainset]).to(device) # create onehot vectors - groups_onehot = torch.eye(10)[targets] + groups_onehot = torch.eye(10)[targets].to(device) # create a train dataset dataset_train = torch.utils.data.TensorDataset(X, groups_onehot, targets) - # create the balanced dataloader - sampler = BalancedBatchSampler( - group_onehot=groups_onehot, batch_size=batch_size, drop_last=True - ) if balanced: - trainloader = torch.utils.data.DataLoader(dataset_train, batch_sampler=sampler, num_workers=10) + sampler = BalancedBatchSampler(group_onehot=groups_onehot, batch_size=batch_size, drop_last=True) + trainloader = torch.utils.data.DataLoader(dataset_train, batch_sampler=sampler) else: - trainloader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, num_workers=10) + trainloader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size) # load all data and create a balanced sampler - X_test = torch.stack([item[0] for item in testset]) - targets_test = torch.tensor([item[1] for item in testset]) + X_test = torch.stack([item[0] for item in testset]).to(device) + targets_test = torch.tensor([item[1] for item in testset]).to(device) # create onehot vectors - groups_onehot_test = torch.eye(10)[targets_test] + groups_onehot_test = torch.eye(10)[targets_test].to(device) # split test / val X_test, X_val, targets_test, targets_val, groups_onehot_test, groups_onehot_val = \ train_test_split(X_test, targets_test, groups_onehot_test, test_size=0.5, random_state=42) - # create a train dataset - dataset_test = torch.utils.data.TensorDataset(X_test, groups_onehot_test, targets_test) + dataset_val = torch.utils.data.TensorDataset(X_val, groups_onehot_val, targets_val) - # create the balanced dataloader - sampler = BalancedBatchSampler( - group_onehot=groups_onehot_test, batch_size=batch_size, drop_last=True - ) + global valloader + if balanced: + # create the balanced dataloader + sampler = BalancedBatchSampler( + group_onehot=groups_onehot_val, batch_size=batch_size, drop_last=True + ) + valloader = torch.utils.data.DataLoader(dataset_val, batch_sampler=sampler) + else: + valloader = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size) - global testloader + dataset_test = torch.utils.data.TensorDataset(X_test, groups_onehot_test, targets_test) + + global testloader if balanced: - testloader = torch.utils.data.DataLoader(dataset_test, batch_sampler=sampler, num_workers=10) - else: - testloader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, num_workers=10) + # create the balanced dataloader + sampler = BalancedBatchSampler( + group_onehot=groups_onehot_test, batch_size=batch_size, drop_last=True + ) + testloader = torch.utils.data.DataLoader(dataset_test, batch_sampler=sampler) + else: + testloader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size) # clean the memory of redundant variables del X, targets, groups_onehot del X_test, targets_test, groups_onehot_test + return trainloader, valloader, testloader, classes, class_ind - return trainloader, testloader, classes, class_ind +def load_data_cifar100(balanced=False, device='cpu'): -def load_data_cifar100(balanced=False): + import torchvision + from torchvision import transforms # load the data transform = transforms.Compose( @@ -672,50 +686,63 @@ def load_data_cifar100(balanced=False): print(class_ind) # load all data and create a balanced sampler - X = torch.stack([item[0] for item in trainset]) - targets = torch.tensor([item[1] for item in trainset]) + X = torch.stack([item[0] for item in trainset]).to(device) + targets = torch.tensor([item[1] for item in trainset]).to(device) # create onehot vectors - groups_onehot = torch.eye(100)[targets] + groups_onehot = torch.eye(100)[targets].to(device) # create a train dataset dataset_train = torch.utils.data.TensorDataset(X, groups_onehot, targets) - # create the balanced dataloader - sampler = BalancedBatchSampler( - group_onehot=groups_onehot, batch_size=batch_size, drop_last=True - ) if balanced: - trainloader = torch.utils.data.DataLoader(dataset_train, batch_sampler=sampler, num_workers=1) + # create the balanced dataloader + sampler = BalancedBatchSampler( + group_onehot=groups_onehot, batch_size=batch_size, drop_last=True + ) + trainloader = torch.utils.data.DataLoader(dataset_train, batch_sampler=sampler) else: - trainloader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, num_workers=1) + trainloader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size) # load all data and create a balanced sampler - X_test = torch.stack([item[0] for item in testset]) - targets_test = torch.tensor([item[1] for item in testset]) + X_test = torch.stack([item[0] for item in testset]).to(device) + targets_test = torch.tensor([item[1] for item in testset]).to(device) # create onehot vectors - groups_onehot_test = torch.eye(100)[targets_test] + groups_onehot_test = torch.eye(100)[targets_test].to(device) # split test / val X_test, X_val, targets_test, targets_val, groups_onehot_test, groups_onehot_val = \ train_test_split(X_test, targets_test, groups_onehot_test, test_size=0.5, random_state=42) + # create a train dataset + dataset_val = torch.utils.data.TensorDataset(X_val, groups_onehot_val, targets_val) + + global valloader + if balanced: + # create the balanced dataloader + sampler = BalancedBatchSampler( + group_onehot=groups_onehot_val, batch_size=batch_size, drop_last=True + ) + valloader = torch.utils.data.DataLoader(dataset_val, batch_sampler=sampler) + else: + valloader = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size) + # create a train dataset dataset_test = torch.utils.data.TensorDataset(X_test, groups_onehot_test, targets_test) - # create the balanced dataloader - sampler = BalancedBatchSampler( - group_onehot=groups_onehot_test, batch_size=batch_size, drop_last=True - ) global testloader if balanced: - testloader = torch.utils.data.DataLoader(dataset_test, batch_sampler=sampler, num_workers=1) + # create the balanced dataloader + sampler = BalancedBatchSampler( + group_onehot=groups_onehot_test, batch_size=batch_size, drop_last=True + ) + testloader = torch.utils.data.DataLoader(dataset_test, batch_sampler=sampler) else: - testloader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, num_workers=1) + testloader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size) # clean the memory of redundant variables del X, targets, groups_onehot del X_test, targets_test, groups_onehot_test - return trainloader, testloader, classes, class_ind \ No newline at end of file + return trainloader, valloader, testloader, classes, class_ind \ No newline at end of file diff --git a/benchmark/conf/data/x.yaml b/benchmark/conf/data/x.yaml new file mode 100644 index 0000000..59a03ce --- /dev/null +++ b/benchmark/conf/data/x.yaml @@ -0,0 +1 @@ +name: income \ No newline at end of file diff --git a/benchmark/conf/hydra/launcher/configured_submitit_slurm.yaml b/benchmark/conf/hydra/launcher/configured_submitit_slurm.yaml new file mode 100644 index 0000000..e6a024b --- /dev/null +++ b/benchmark/conf/hydra/launcher/configured_submitit_slurm.yaml @@ -0,0 +1,15 @@ +defaults: +# - override hydra/launcher: submitit_slurm + - submitit_slurm + +#hydra: +# launcher: +#mem_gb: 32 # configuration used with slurm launcher +#tasks_per_node: 60 +mem_per_cpu: 16G +cpus_per_task: 4 +#mem_gb: 32 +partition: amd +timeout_min: 1440 +#array_parallelism: 50 + diff --git a/benchmark/conf/task/cifar100_loss.yaml b/benchmark/conf/task/cifar100_loss.yaml new file mode 100644 index 0000000..684da55 --- /dev/null +++ b/benchmark/conf/task/cifar100_loss.yaml @@ -0,0 +1,42 @@ +task: cifar100 +batch_size: 400 + +# TODO: rework to list (constraints: [(name: ...), (name: ...), ...]) +constraint: + name: LossPairwise + bound: 0.1 + loss: CrossEntropyLoss + +algorithms: none + +# pbm_params: +# primal__lr: 0.005 +# dual__penalty_mult: 0.9999 +# dual__penalty_update: dimin #aimd +# dual__penalty_range: [1e-1, 10.] +# dual__pbf: quadratic_logarithmic +# dual__gamma: 0.1 +# dual_delta: 1.0 +# moreau__mu: 2.0 + +# alm_params: +# primal__lr: 0.005 +# dual__lr: 0.001 +# dual__penalty: 0.0 +# dual__momentum: 0. +# moreau__mu: 2.0 + +# ssg_params: +# primal__lr: 0.01 +# dual__lr: 0.001 +# moreau__mu: 0.0 + +# adam_params: +# lr: 0.01 +# # penalty: 0.5 + +hydra: + run: + dir: . + job: + chdir: false diff --git a/benchmark/conf/task/cifar10_acc_pair.yaml b/benchmark/conf/task/cifar10_acc_pair.yaml deleted file mode 100644 index c5e481c..0000000 --- a/benchmark/conf/task/cifar10_acc_pair.yaml +++ /dev/null @@ -1,41 +0,0 @@ -task: cifar10 -batch_size: 120 - -# TODO: rework to list (constraints: [(name: ...), (name: ...), ...]) -constraint: - name: FairretPairwise - bound: 0.1 - statistic: Accuracy - kwargs: - uses_labels: true - -pbm_params: - primal__lr: 0.05 - dual__penalty_mult: 0.1 - dual__penalty_update: dimin_adapt #aimd - dual__penalty_range: [1e-1, 100] - dual__pbf: quadratic_logarithmic - dual__gamma: 0.1 - moreau__mu: 2.0 - -alm_params: - primal__lr: 0.001 - dual__lr: 0.05 - dual__penalty: 0.0 - dual__momentum: 0.0 - moreau__mu: 0.0 - -ssg_params: - primal__lr: 0.01 - dual__lr: 0.001 - moreau__mu: 0.0 - -adam_params: - lr: 0.01 - penalty: 0.5 - -hydra: - run: - dir: . - job: - chdir: false diff --git a/benchmark/conf/task/cifar10_loss.yaml b/benchmark/conf/task/cifar10_loss.yaml index a9fe984..61ff317 100644 --- a/benchmark/conf/task/cifar10_loss.yaml +++ b/benchmark/conf/task/cifar10_loss.yaml @@ -7,31 +7,54 @@ constraint: bound: 0.1 loss: CrossEntropyLoss -pbm_params: - primal__lr: 0.01 - dual__penalty_mult: 0.1 - dual__penalty_update: dimin_adapt #aimd - dual__penalty_range: [1e-1, 100] - dual__pbf: quadratic_logarithmic - dual__gamma: 0.1 - dual_delta: 0.5 - moreau__mu: 2.0 - -alm_params: - primal__lr: 0.001 - dual__lr: 0.05 - dual__penalty: 0.0 - dual__momentum: 0.0 - moreau__mu: 0.0 +# pbm_params: +# primal__lr: 0.005 +# dual__penalty_mult: 0.1 +# dual__penalty_update: dimin_adapt #aimd +# dual__penalty_range: [1e-1, 100.] +# dual__pbf: quadratic_logarithmic +# dual__gamma: 0.5 +# dual_delta: 0.1 +# moreau__mu: 2.0 + +# pbm_params: +# primal__lr: 0.0005 +# dual__penalty_mult: 0.9 +# dual__penalty_update: dimin_adapt #aimd +# dual__penalty_range: [1e-1, 1.] +# dual__pbf: quadratic_logarithmic +# dual__gamma: 0.9 +# dual_delta: 1.0 +# moreau__mu: 2.0 + +# pbm_params: +# primal__lr: 0.001 +# dual__penalty_mult: 0.99 +# dual__penalty_update: const #aimd +# dual__penalty_range: [1e-1, 1.] +# dual__pbf: quadratic_reciprocal +# dual__gamma: 0.9 +# dual_delta: 1.0 +# moreau__mu: 1.0 +# +primal__weight_decay: 0.01 + +# alm_params: +# primal__lr: 0.001 +# dual__lr: 0.001 +# dual__penalty: 0.0 +# moreau__mu: 1.0 +# +primal__weight_decay: 0.01 ssg_params: primal__lr: 0.01 dual__lr: 0.001 - moreau__mu: 0.0 + moreau__mu: 1.0 + +primal__weight_decay: 0.01 -adam_params: - lr: 0.01 - penalty: 0.5 +# adam_params: +# lr: 0.001 +# +primal__weight_decay: 0.01 +# # penalty: 0.5 hydra: run: diff --git a/benchmark/conf/task/dutch_positive_rate_pair.yaml b/benchmark/conf/task/dutch_positive_rate_pair.yaml index 8aaa667..8abdb1a 100644 --- a/benchmark/conf/task/dutch_positive_rate_pair.yaml +++ b/benchmark/conf/task/dutch_positive_rate_pair.yaml @@ -1,5 +1,7 @@ task: equalized_odds_pairwise +seed: 0 + batch_size: 72 # TODO: rework to list (constraints: [(name: ...), (name: ...), ...]) @@ -11,32 +13,32 @@ constraint: uses_labels: false pbm_params: - primal__lr: 0.01 - dual__penalty_mult: 0.1 - dual__penalty_update: dimin_adapt #aimd + primal__lr: 0.005 + primal__weight_decay: 0.01 + dual__penalty_mult: 0.8 + dual__penalty_update: dimin_adapt dual__penalty_range: [1e-1, 1.] - # dual__init_duals: 1e-4 dual__pbf: quadratic_logarithmic - dual__gamma: 0.9 + dual__gamma: 0.5 dual__delta: 1.0 - moreau__mu: 2.0 - + moreau__mu: 1.0 alm_params: - primal__lr: 0.005 - dual__lr: 0.01 + primal__lr: 0.001 + primal__weight_decay: 0.01 + dual__lr: 0.005 dual__penalty: 0.0 - dual__momentum: 0.5 - moreau__mu: 2.0 + moreau__mu: 1.0 ssg_params: - primal__lr: 0.001 - dual__lr: 0.001 - moreau__mu: 2.0 + primal__lr: 0.005 + primal__weight_decay: 0.01 + dual__lr: 0.005 + moreau__mu: 1.0 adam_params: - lr: 0.01 - penalty: 0.5 + lr: 0.0001 + weight_decay: 0.01 hydra: run: diff --git a/benchmark/conf/task/folktables_positive_rate_pair.yaml b/benchmark/conf/task/folktables_positive_rate_pair.yaml index 2ac3d48..9ce8daf 100644 --- a/benchmark/conf/task/folktables_positive_rate_pair.yaml +++ b/benchmark/conf/task/folktables_positive_rate_pair.yaml @@ -1,5 +1,5 @@ task: equalized_odds_pairwise -batch_size: 40 +batch_size: 30 # TODO: rework to list (constraints: [(name: ...), (name: ...), ...]) constraint: @@ -10,30 +10,33 @@ constraint: uses_labels: false pbm_params: - primal__lr: 0.01 - dual__penalty_mult: 0.5 + primal__lr: 0.0005 + primal__weight_decay: 0.01 + dual__penalty_mult: 0.999 dual__penalty_update: dimin_adapt #aimd dual__penalty_range: [1e-1, 1.] dual__pbf: quadratic_logarithmic - dual__gamma: 0.5 - dual__delta: 0.9 - moreau__mu: 2.0 + dual__gamma: 0.9 + dual__delta: 1.0 + moreau__mu: 1.0 + alm_params: - primal__lr: 0.05 - dual__lr: 0.05 + primal__lr: 0.0001 + primal__weight_decay: 0.01 + dual__lr: 0.005 dual__penalty: 0.0 - dual__momentum: 0.1 - moreau__mu: 2.0 + moreau__mu: 1.0 ssg_params: - primal__lr: 0.01 - dual__lr: 0.001 - moreau__mu: 0.0 + primal__lr: 0.001 + primal__weight_decay: 0.01 + dual__lr: 0.0001 + moreau__mu: 1.0 adam_params: - lr: 0.01 - # penalty: 0.5 + primal__lr: 0.0001 + primal__weight_decay: 0.01 hydra: run: diff --git a/benchmark/conf/task/folktables_positive_rate_vec.yaml b/benchmark/conf/task/folktables_positive_rate_vec.yaml index de6ec59..dd20e43 100644 --- a/benchmark/conf/task/folktables_positive_rate_vec.yaml +++ b/benchmark/conf/task/folktables_positive_rate_vec.yaml @@ -8,32 +8,37 @@ constraint: bound: 0.2 loss: NormLoss statistic: PositiveRate - constraint_kwargs: + kwargs: uses_labels: false -pbm_params: - primal__lr: 0.001 - dual__p_mult: 0.9 - dual__penalty_update: dimin_adapt - dual__pbf: quadratic_logarithmic - dual__gamma: 0.1 - moreau__mu: 2.0 - -alm_params: - primal__lr: 0.001 - dual__lr: 0.001 - dual__penalty: 0. - dual__momentum: 0. - moreau__mu: 0.5 - -ssg_params: - primal__lr: 0.05 - dual__lr: 0.01 - moreau__mu: 0.0 - -adam_params: - lr: 0.001 - penalty: 0.5 +# pbm_params: +# primal__lr: 0.001 +# +primal__weight_decay: 0.01 +# dual__penalty_mult: 0.999 +# dual__penalty_update: dimin_adapt #aimd +# dual__penalty_range: [1e-1, 1] +# dual__pbf: quadratic_logarithmic +# dual__gamma: 0.1 +# moreau__mu: 1.0 + + +# alm_params: +# primal__lr: 0.0001 +# +primal__weight_decay: 0.01 +# dual__lr: 0.001 +# dual__penalty: 1.0 +# moreau__mu: 1.0 + +# ssg_params: +# primal__lr: 0.001 +# +primal__weight_decay: 0.01 +# dual__lr: 0.0005 +# moreau__mu: 1.0 + +# adam_params: +# lr: 0.001 +# +primal__weight_decay: 0.01 +# penalty: 0.5 hydra: run: diff --git a/benchmark/conf/task/tinyimage.yaml b/benchmark/conf/task/tinyimage.yaml new file mode 100644 index 0000000..89e5f0b --- /dev/null +++ b/benchmark/conf/task/tinyimage.yaml @@ -0,0 +1,42 @@ +task: tinyimage +batch_size: 400 + +# TODO: rework to list (constraints: [(name: ...), (name: ...), ...]) +constraint: + name: LossPairwise + bound: 0.1 + loss: CrossEntropyLoss + +algorithms: none + +# pbm_params: +# primal__lr: 0.005 +# dual__penalty_mult: 0.9999 +# dual__penalty_update: dimin #aimd +# dual__penalty_range: [1e-1, 10.] +# dual__pbf: quadratic_logarithmic +# dual__gamma: 0.1 +# dual_delta: 1.0 +# moreau__mu: 2.0 + +# alm_params: +# primal__lr: 0.005 +# dual__lr: 0.001 +# dual__penalty: 0.0 +# dual__momentum: 0. +# moreau__mu: 2.0 + +# ssg_params: +# primal__lr: 0.01 +# dual__lr: 0.001 +# moreau__mu: 0.0 + +# adam_params: +# lr: 0.01 +# # penalty: 0.5 + +hydra: + run: + dir: . + job: + chdir: false diff --git a/benchmark/conf/task/weight.yaml b/benchmark/conf/task/weight.yaml index 2860d28..079e0b4 100644 --- a/benchmark/conf/task/weight.yaml +++ b/benchmark/conf/task/weight.yaml @@ -3,12 +3,12 @@ batch_size: 80 pbm_params: - primal__lr: 0.005 + primal__lr: 0.001 dual__penalty_mult: 0.1 dual__penalty_update: dimin_adapt #aimd - dual__penalty_range: [1e-1, 10] + dual__penalty_range: [1e-1, 1] dual__pbf: quadratic_logarithmic - dual__gamma: 0.0 + dual__gamma: 0.95 moreau__mu: 2.0 alm_params: diff --git a/benchmark/constraint_meta.py b/benchmark/constraint_meta.py index 3d3699a..8684c1a 100644 --- a/benchmark/constraint_meta.py +++ b/benchmark/constraint_meta.py @@ -23,7 +23,7 @@ class ConstraintMetadata: class FairretPairwise(ConstraintMetadata): """Wrapper class for a pairwise fairness constraint based on a given statistic (e.g., positive rate, false positive rate, etc.). The constraint is computed as the difference between the statistic for each pair of groups.""" - def __init__(self, statistic: Callable, uses_labels: bool, abs_diff: bool = False, as_logits: bool = False): + def __init__(self, statistic: Callable, uses_labels: bool, abs_diff: bool = False, as_logits: bool = True): """Initializes the FairretPairwise constraint. Args: statistic (Callable): An initialized fairret.statistic object. @@ -41,7 +41,7 @@ def __init__(self, statistic: Callable, uses_labels: bool, abs_diff: bool = Fals self.uses_labels = uses_labels def compute_constraints(self, model, batch_out, batch_sens, batch_labels): - if not self.as_logits: + if self.as_logits: batch_out = torch.sigmoid(batch_out) if self.uses_labels: stat_pergroup = self.statistic(batch_out, batch_sens, batch_labels) @@ -58,7 +58,7 @@ def compute_constraints(self, model, batch_out, batch_sens, batch_labels): class FairretMean(ConstraintMetadata): """Wrapper class for a pairwise fairness constraint based on a given statistic (e.g., positive rate, false positive rate, etc.). The constraint is computed as the difference between the statistic for each pair of groups.""" - def __init__(self, statistic: Callable, uses_labels: bool, as_logits: bool = False): + def __init__(self, statistic: Callable, uses_labels: bool, as_logits: bool = True): """Initializes the FairretPairwise constraint. Args: statistic (Callable): An initialized fairret.statistic object. @@ -73,7 +73,7 @@ def __init__(self, statistic: Callable, uses_labels: bool, as_logits: bool = Fal self.uses_labels = uses_labels def compute_constraints(self, model, batch_out, batch_sens, batch_labels): - if not self.as_logits: + if self.as_logits: batch_out = torch.sigmoid(batch_out) if self.uses_labels: stat_pergroup = self.statistic(batch_out, batch_sens, batch_labels) @@ -91,15 +91,15 @@ def compute_constraints(self, model, batch_out, batch_sens, batch_labels): class FairretAgg(ConstraintMetadata): """Wrapper class for a vector fairness constraint based on a given statistic (e.g., positive rate, false positive rate, etc.). The constraint is computed as the difference between the statistic for each group and the mean statistic across all groups.""" - def __init__(self, loss: Callable, uses_labels: bool, as_logits: bool = False): + def __init__(self, loss: Callable, uses_labels: bool, as_logits: bool = True): super().__init__( fn=self.compute_constraints, m_fn=lambda n_groups: 1 ) self.loss = loss self.as_logits = as_logits - if self.as_logits: - raise ValueError("`as_logits=True`is not supported for the fairret loss constraint, since the loss should already be computed on the logits.") + if not self.as_logits: + raise ValueError("`as_logits=False`is not supported for the fairret loss constraint, since the loss should already be computed on the logits.") self.uses_labels = uses_labels def compute_constraints(self, model, batch_out, batch_sens, batch_labels): @@ -138,7 +138,9 @@ def compute_constraints(self, model, batch_out, batch_sens, batch_labels, loss = loss = self.loss(batch_out, batch_labels) per_group_losses = _get_normalized_per_group_losses(loss, batch_sens).squeeze() - constraints = ((per_group_losses.unsqueeze(1) - per_group_losses.unsqueeze(0))) + # print(per_group_losses) + constraints = ((per_group_losses.unsqueeze(1) - per_group_losses.unsqueeze(0))) + # print(constraints) mask = ~torch.eye(batch_sens.shape[-1], dtype=torch.bool) constraints = constraints[mask] return constraints diff --git a/benchmark/demo_balls.py b/benchmark/demo_balls.py index 42d196f..91a2c5f 100644 --- a/benchmark/demo_balls.py +++ b/benchmark/demo_balls.py @@ -1,3 +1,5 @@ +from sched import scheduler + from matplotlib.patches import Circle import torch from humancompatible.train.optim.PBM import PBM @@ -5,20 +7,22 @@ import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable +from humancompatible.train.dual_optim import ALM, MoreauEnvelope, PBM torch.manual_seed(1) np.random.seed(1) -def plot_balls_trajectory(trajectories, names): - """ - trajectory: array-like of shape (N, 2), where each row is [x, y] - """ +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Circle +from mpl_toolkits.axes_grid1 import make_axes_locatable +def plot_balls_trajectory(trajectories, names): fig, ax = plt.subplots(figsize=(24, 16)) # Feasible regions: unit balls ball_centers = [(-2, 0), (2, 0)] radius = np.sqrt(0.99) - labels = [r"$\mathbb{E}[g_1(x,y,\xi)] \leq 0 $", r"$\mathbb{E}[g_2(x,y,\xi)] \leq 0$"] + labels = [r"$\mathbb{E}[g(x,\xi)] \leq 0$", r"$\mathbb{E}[g(x,\xi)] \leq 0$"] # Heatmap for x^2 + y^2 x = np.linspace(-4, 4, 100) @@ -26,23 +30,22 @@ def plot_balls_trajectory(trajectories, names): X, Y = np.meshgrid(x, y) Z = X**2 + Y**2 - + # Draw balls for center, label in zip(ball_centers, labels): ball = Circle( center, radius=radius, facecolor="lightgray", edgecolor="black", - linewidth=1.5, + linewidth=6, # Thickest practical line alpha=0.6, zorder=1, ) ax.add_patch(ball) - # Label inside the ball ax.text( center[0], center[1], label, - fontsize=18, + fontsize=40, # Very large font fontweight='bold', color='black', ha='center', @@ -50,16 +53,16 @@ def plot_balls_trajectory(trajectories, names): zorder=5 ) + # Plot trajectories for i, traj in enumerate(trajectories): traj = np.asarray(traj) - - # Trajectory ax.plot( traj[:, 0], traj[:, 1], - linewidth=2.0, + linewidth=6, # Thickest practical line zorder=3, - alpha=1.0 + alpha=1.0, + color= 'c' if i == 0 else ('orange' if i == 1 else ('tab:green' if i == 2 else 'red')) ) # Emphasize x_0 and x_n @@ -68,66 +71,77 @@ def plot_balls_trajectory(trajectories, names): ax.scatter( x0[0], x0[1], - s=80, + s=400, # Very large marker marker="o", facecolor="white", edgecolor="black", - linewidth=2, + linewidth=4, # Thick edge zorder=4, ) ax.scatter( xn[0], xn[1], - s=60, + s=300, # Very large marker marker="s", facecolor="black", edgecolor="black", zorder=4, ) + + # Labels for x_0 and x_n x_n = [ r"$x_{\rho=0}^n$", r"$x_{\rho=1}^n$", r"$x_{\rho=2.5}^n$", r"$x_{SPBM}^n$", ] - # Labels for x_0 and x_n ax.annotate( r"$x^0$", xy=(x0[0], x0[1]), - xytext=(6, 8), + xytext=(15, 15), # Adjusted offset textcoords="offset points", - fontsize=22, + fontsize=40, # Very large font zorder=5, ) + ax.annotate( x_n[i], xy=(xn[0], xn[1]), - # xytext=(8 if i==1 or i == 3 else -36, -12), - xytext=(8 if i==1 or i == 3 else -45, -15), + xytext=(-25, -50) if i == 0 else ((10, 40) if i == 1 else ((30, -70) if i == 2 else (-80, 150))), textcoords="offset points", - fontsize=22, + fontsize=55, # Very large font zorder=5, + color= 'c' if i == 0 else ('orange' if i == 1 else ('tab:green' if i == 2 else 'red')) ) # Formatting ax.set_aspect("equal", adjustable="box") - ax.set_xlabel(r"$x$", fontsize=12) - ax.set_ylabel(r"$y$", fontsize=12) - ax.grid(True, linestyle=":", linewidth=0.8, alpha=0.7) + ax.set_xlabel(r"$x_1$", fontsize=40) # Very large font + ax.set_ylabel(r"$x_2$", fontsize=40) # Very large font + ax.grid(True, linestyle=":", linewidth=2, alpha=0.7) # Thicker grid ax.set_xlim(-3.2, 3.2) ax.set_ylim(-1.8, 1.8) - - - contour = ax.contourf(X, Y, Z, levels=100, cmap='viridis', alpha=0.5, zorder=0) + ax.tick_params(axis='both', which='major', labelsize=32) # Large tick fonts + + # Contour/heatmap + contour = ax.contourf( + X, Y, Z, + levels=100, + cmap='viridis', + alpha=0.6, # More opaque + zorder=0 + ) divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="3%", pad=0.08) - + cax = divider.append_axes("right", size="5%", pad=0.1) # Wider colorbar cbar = fig.colorbar(contour, cax=cax) - cbar.set_label(r"$x^2 + y^2$", fontsize=14) + cbar.set_label(r"$||x||^2$", fontsize=40) # Very large font + cbar.ax.tick_params(labelsize=32) # Large colorbar tick fonts + # Save with highest practical DPI fig.savefig( "./demo_balls_pbm.pdf", bbox_inches="tight", - pad_inches=0.05 + pad_inches=0.1, + dpi=100 # Highest practical DPI ) @@ -214,7 +228,24 @@ def run_sgd(rho: float): xy[0] = 0 xy[1] = 1 -pbm = PBM([xy], m=1, lr=0.01, dual_bounds=(1e-3, 1e3), penalty_update_m='CONST', epoch_len=2, mu=0, opt_method="Adam") +# Define data and optimizers +optimizer = MoreauEnvelope(torch.optim.SGD([xy], lr=0.01), mu=0.0) +scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=lambda step: 0.99 ** step +) + +dual = PBM( + m=1, + penalty_update='const', + pbf = 'quadratic_logarithmic', + gamma=0.0, + init_duals=0.1, + init_penalties=1., + penalty_range=(0.5, 1.), + penalty_mult=0.99, + dual_range=(0.1, 10.) +) iters = 200 @@ -231,20 +262,21 @@ def run_sgd(rho: float): r = np.random.uniform() minibatch = samples[ - 0 if r > 0.5 else 1 + 1 if r > 0.5 else 0 ] c = balls(xy, minibatch) - - pbm.dual_step(0, c) - dual_log.append(pbm._dual_vars.detach().numpy().copy().item()) - obj = parabola(xy) + + # compute the lagrangian value + lagrangian = dual.forward_update(obj, c) + lagrangian.backward() + optimizer.step() + optimizer.zero_grad() - pbm.step(obj) - for gr in pbm.param_groups: - gr['lr'] *= 0.99 + scheduler.step() + dual_log.append(dual.duals.detach().numpy().copy().item()) con_log.append(c.detach().numpy().copy().item()) diff --git a/benchmark/pbm_dimin_config.yaml b/benchmark/pbm_dimin_config.yaml new file mode 100644 index 0000000..2ebc98a --- /dev/null +++ b/benchmark/pbm_dimin_config.yaml @@ -0,0 +1,326 @@ +,algorithm,primal__lr,dual_penalty_mult,dual__penalty_update,dual__pbf,dual__gamma,dual__delta,moreau__mu +82817d6b,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.999,1.0,0.1 +5b3ca681,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.9,1.0,1.0 +06e21ef2,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.5,1.0,0.1 +f0705958,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.999,1.0,2.0 +bf0b84d8,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.999,1.0,0.5 +6a03c4e2,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.9,1.0,0.5 +18392377,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.99,1.0,2.0 +ba76ac2f,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.5,1.0,0.1 +ed2e80d8,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.9,1.0,0.5 +a48c3aa7,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.99,1.0,1.0 +14154af8,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.99,1.0,2.0 +e398f5ee,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.1,1.0,0.1 +4824b1b3,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.99,1.0,0.1 +9038fce2,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.99,1.0,0.1 +9f9dc643,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.99,1.0,0 +e7b9066e,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.99,1.0,0.5 +ca0e357c,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.99,1.0,0.1 +faa0f027,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.999,1.0,1.0 +af6dfb12,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.999,1.0,0.5 +0b8e5134,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.99,1.0,1.0 +169a2a3a,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.999,1.0,1.0 +607bc2c0,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.999,1.0,1.0 +3fb16663,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.5,1.0,2.0 +acad47e9,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.5,1.0,0.5 +fa46c18c,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.9,1.0,1.0 +a0301919,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.5,1.0,2.0 +3c2f8f32,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.99,1.0,0.5 +19ed5a31,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.999,1.0,0.5 +45572647,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.5,1.0,0.1 +b48ee2b5,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.9,1.0,0.1 +ce1ca314,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.1,1.0,2.0 +7065c95d,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.9,1.0,0 +4869bd42,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.99,1.0,2.0 +b5b38e92,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.9,1.0,2.0 +8a3b4bf9,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.9,1.0,0.5 +20add89e,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.999,1.0,0.1 +a3d20128,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.999,1.0,0.5 +b1c37cdc,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.9,1.0,0.1 +d64ec5d3,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.999,1.0,0 +6fd312df,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.9,1.0,2.0 +f851f103,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.99,1.0,1.0 +57e973be,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.99,1.0,0.5 +e1c61153,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.999,1.0,0 +73e684a8,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.1,1.0,1.0 +889f7dde,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.5,1.0,2.0 +3efdb5e4,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.999,1.0,0.1 +c3ac7ba2,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.9,1.0,1.0 +46c3e884,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.999,1.0,2.0 +0f2af858,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.5,1.0,0.1 +ea168269,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.99,1.0,0.1 +7eae0db6,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.9,1.0,0 +19e6d0a5,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.9,1.0,1.0 +3b3506d0,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.9,1.0,0.5 +3fd3ea94,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.5,1.0,1.0 +f287fd9b,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.5,1.0,0.5 +2b62adaa,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.9,1.0,0 +3abca38a,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.5,1.0,0 +1e1797c8,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.99,1.0,1.0 +edebf933,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.5,1.0,0.5 +98c2f8c2,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.9,1.0,0 +fdb941ea,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.999,1.0,0 +0e63af8d,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.999,1.0,1.0 +3687be50,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.999,1.0,2.0 +e96487ad,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.9,1.0,1.0 +370576db,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.99,1.0,1.0 +9461418d,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.99,1.0,0 +6b6b0317,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.999,1.0,1.0 +54d76a73,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.99,1.0,1.0 +ddc0f892,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.999,1.0,2.0 +2e917ebb,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.999,1.0,1.0 +f333ba22,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.99,1.0,2.0 +14bb0bc6,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.5,1.0,2.0 +5dd0db45,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.99,1.0,2.0 +e0330170,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.999,1.0,0 +9b4daa70,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.99,1.0,1.0 +d5b75b65,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.999,1.0,2.0 +a5aaba50,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.999,1.0,0.1 +16ba70ee,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.5,1.0,0.5 +3e436265,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.99,1.0,2.0 +0c3f61d5,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.9,1.0,0 +b8bc00fb,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.5,1.0,0 +d69b749c,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.99,1.0,0 +c75cd148,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.999,1.0,0 +72be9778,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.5,1.0,1.0 +89d684bc,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.99,1.0,2.0 +cb5b2b52,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.999,1.0,0.1 +fc2bc3bb,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.5,1.0,1.0 +e03398a2,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.5,1.0,2.0 +40bd7779,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.999,1.0,2.0 +9e4362da,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.5,1.0,1.0 +8d8e608a,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.999,1.0,0.1 +fdb84094,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.999,1.0,1.0 +c1a597bd,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.1,1.0,0 +624be33c,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.5,1.0,0.1 +a3db4351,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.999,1.0,2.0 +0d3f6c28,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.9,1.0,0.5 +855ec898,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.99,1.0,0.1 +8a0b633e,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.99,1.0,0 +38b5830a,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.99,1.0,0.5 +3ccf8e55,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.9,1.0,0.5 +fdb913dd,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.99,1.0,2.0 +b475ac20,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.9,1.0,2.0 +15f32408,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.999,1.0,0 +64be9e02,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.99,1.0,0.1 +6de0c47f,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.99,1.0,0.5 +bb8a77af,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.99,1.0,1.0 +36448ef8,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.5,1.0,0 +1449a077,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.5,1.0,1.0 +0ba19590,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.9,1.0,1.0 +bcf7e71c,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.99,1.0,1.0 +b4f80b28,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.9,1.0,0.1 +13280627,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.999,1.0,0.5 +25d228eb,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.5,1.0,1.0 +dd1b93c5,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.99,1.0,0 +7e436f8a,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.9,1.0,2.0 +324cbfb9,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.5,1.0,0.1 +32932ce5,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.5,1.0,0.5 +3462334c,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.9,1.0,1.0 +135e9dab,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.99,1.0,0.1 +0c63fc1c,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.99,1.0,0.5 +d8cdf0d3,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.9,1.0,0 +4e1488f5,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.9,1.0,2.0 +11a8696b,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.5,1.0,0.1 +10719b39,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.999,1.0,2.0 +8228cd81,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.99,1.0,0.5 +0250c55d,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.999,1.0,0.1 +368b33b6,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.999,1.0,1.0 +bc2edf56,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.9,1.0,2.0 +89272e8c,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.9,1.0,0 +a255debe,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.999,1.0,2.0 +0800ddab,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.5,1.0,0 +e21dbabc,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.5,1.0,0.5 +94bd21b1,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.99,1.0,2.0 +7722b98e,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.99,1.0,1.0 +46d3327d,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.999,1.0,0.1 +71c70895,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.9,1.0,0 +200c6eed,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.5,1.0,0.5 +6f21bbda,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.9,1.0,0 +cb9d7ad0,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.999,1.0,0.5 +f756f501,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.9,1.0,0.1 +70ec5ee8,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.9,1.0,1.0 +4d211aa5,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.9,1.0,0.1 +7f68bab1,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.999,1.0,1.0 +beb3596f,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.5,1.0,0.1 +88816ce4,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.5,1.0,0.5 +491befdf,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.9,1.0,2.0 +74b8cdd6,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.5,1.0,0.1 +ed0caa13,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.9,1.0,0 +f1c79150,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.999,1.0,0.5 +653ac4e7,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.99,1.0,0.1 +d0ca8a35,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.999,1.0,0.5 +dd4b8935,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.5,1.0,1.0 +2fa4a18a,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.9,1.0,1.0 +6a47634d,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.5,1.0,0.1 +fe21f25d,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.9,1.0,0.5 +b8c81c1a,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.999,1.0,0.1 +4b673f5e,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.99,1.0,0.1 +0884fa28,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.999,1.0,2.0 +21f4f2b0,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.999,1.0,2.0 +da6992cb,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.99,1.0,0 +179c02ae,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.99,1.0,2.0 +20797283,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.999,1.0,1.0 +4102a36c,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.9,1.0,0.5 +97411649,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.9,1.0,0.5 +6b7b3146,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.9,1.0,0.5 +f71fe8c5,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.9,1.0,0.1 +c287c41e,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.5,1.0,0.5 +f9904ae0,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.5,1.0,0.5 +d94561d1,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.9,1.0,1.0 +fbfaaa3d,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.5,1.0,1.0 +293a2367,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.5,1.0,0.1 +d762c188,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.999,1.0,0 +adde6ef3,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.5,1.0,2.0 +b1c91296,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.5,1.0,2.0 +233c3714,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.5,1.0,0 +125c64d0,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.99,1.0,0.5 +9832f4b8,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.99,1.0,0 +c4aab02f,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.999,1.0,1.0 +fc8e2656,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.999,1.0,0.1 +8fe79504,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.999,1.0,0.1 +fb7e4884,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.5,1.0,0 +6caccd49,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.999,1.0,0.5 +b56d7545,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.9,1.0,2.0 +95592499,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.5,1.0,1.0 +ac096871,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.99,1.0,0.1 +218116c0,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.9,1.0,0.1 +b39b977c,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.99,1.0,1.0 +aaaa8cff,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.999,1.0,0.1 +a11b51dc,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.5,1.0,0.1 +de658dc6,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.99,1.0,0 +e9961a84,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.9,1.0,0.1 +3c172fdd,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.99,1.0,0.5 +aa9f53dc,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.9,1.0,2.0 +55dcac6e,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.5,1.0,2.0 +c648c195,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.99,1.0,0.1 +01e79092,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.999,1.0,1.0 +0ee5e8fa,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.9,1.0,0 +6a9b7d53,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.99,1.0,0.1 +904eea7e,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.999,1.0,0 +2667cee4,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.5,1.0,1.0 +10afc92f,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.99,1.0,0 +519fc23b,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.99,1.0,0.1 +111302e0,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.5,1.0,0.5 +d3828877,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.99,1.0,0 +402c5e53,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.9,1.0,0.1 +ba3d948c,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.5,1.0,2.0 +a69d4905,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.9,1.0,1.0 +8f71c416,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.9,1.0,2.0 +fc786671,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.999,1.0,0 +eebcc229,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.999,1.0,2.0 +4293165c,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.999,1.0,1.0 +ec75533e,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.999,1.0,0 +c6a90b0a,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.9,1.0,1.0 +6a7c0e38,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.5,1.0,0.5 +cab86a42,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.99,1.0,0.5 +4c575b5a,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.99,1.0,1.0 +34e2328f,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.9,1.0,0 +ccfcb6de,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.999,1.0,0.5 +cddec2d9,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.99,1.0,1.0 +97c6b7f2,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.99,1.0,0 +4766af48,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.99,1.0,0.5 +54056ced,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.99,1.0,0.5 +81470845,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.5,1.0,2.0 +336f93bf,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.999,1.0,2.0 +61ab3ad4,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.99,1.0,2.0 +82fa31e3,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.9,1.0,2.0 +fff6d2c1,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.999,1.0,1.0 +c385eba3,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.9,1.0,1.0 +70430a24,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.9,1.0,0.1 +e0d9d104,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.999,1.0,0.1 +c7fbf993,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.9,1.0,0.5 +4d251599,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.999,1.0,0.1 +8c35df3a,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.999,1.0,2.0 +3bfdb40c,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.99,1.0,2.0 +040d0348,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.99,1.0,0 +b1ed3fed,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.999,1.0,0.5 +a99d0745,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.9,1.0,0.1 +5f575060,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.5,1.0,2.0 +6a0fdd18,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.5,1.0,0 +2ce1450c,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.5,1.0,1.0 +d6ce6186,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.999,1.0,2.0 +934b2640,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.999,1.0,1.0 +91890c35,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.99,1.0,0.1 +10dd97f1,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.99,1.0,0.1 +ab1d5425,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.999,1.0,0 +7189f63c,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.99,1.0,0 +72b701b2,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.9,1.0,0 +54ae1ff5,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.5,1.0,0 +c119c79d,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.5,1.0,2.0 +40b6422e,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.999,1.0,0.1 +e637e178,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.9,1.0,1.0 +5ae4eb2e,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.5,1.0,0.5 +1cc0dfab,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.99,1.0,1.0 +6791490c,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.9,1.0,0.1 +0b3ceeee,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.999,1.0,0.5 +3d6cd9ca,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.5,1.0,0.5 +8d2abf9e,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.9,1.0,0.1 +29e9ff1a,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.9,1.0,2.0 +f288701b,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.999,1.0,0.1 +a150853a,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.5,1.0,0 +5feff0ea,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.99,1.0,0.1 +e11bceae,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.1,1.0,0.5 +783474f2,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.999,1.0,0.5 +ca4cd6de,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.999,1.0,1.0 +37806fa4,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.999,1.0,0 +e99657ca,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.9,1.0,0.5 +22544125,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.5,1.0,0.1 +e3e59acf,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.9,1.0,0.1 +376fc38c,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.99,1.0,1.0 +3464a2eb,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.5,1.0,2.0 +8e4e9cbc,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.999,1.0,0 +cd8c07fc,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.9,1.0,1.0 +f6d18eb0,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.9,1.0,0.1 +1c7b672e,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.999,1.0,2.0 +f82e154d,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.99,1.0,0.5 +8dd962c4,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.9,1.0,0.5 +88ad06ac,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.5,1.0,0 +46646db7,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.5,1.0,0 +18d3bbe0,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.9,1.0,2.0 +503836ca,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.999,1.0,0 +8ce79fa5,pbm_dimin,0.001,0.9,dimin,quadratic_logarithmic,0.99,1.0,0 +f677c700,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.9,1.0,0 +04f490ea,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.99,1.0,2.0 +a4e16e00,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.5,1.0,0 +8f7319ee,pbm_dimin,0.001,0.999,dimin,quadratic_logarithmic,0.9,1.0,0 +109ff18b,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.5,1.0,1.0 +17140b87,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.9,1.0,2.0 +66332b52,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.5,1.0,0.5 +e5ec895d,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.9,1.0,2.0 +4f6d2b15,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.99,1.0,0.5 +2eec358e,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.99,1.0,2.0 +db4cbc82,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.999,1.0,0.5 +c549303e,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.5,1.0,1.0 +0d9b7e97,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.9,1.0,1.0 +3e1b21ce,pbm_dimin,0.0005,0.99,dimin,quadratic_logarithmic,0.5,1.0,0 +1cbd61d3,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.5,1.0,0 +2f43d516,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.9,1.0,0.5 +ac6f09c7,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.99,1.0,0.5 +9dfe0a70,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.999,1.0,0.5 +cf459dfc,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.999,1.0,0 +aa016d20,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.5,1.0,0.5 +428a9fb1,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.5,1.0,0.1 +be188337,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.9,1.0,0.5 +c70b6d71,pbm_dimin,0.005,1.0,dimin,quadratic_logarithmic,0.5,1.0,0 +e356573a,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.99,1.0,2.0 +8afc6c03,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.5,1.0,1.0 +b2676c1b,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.5,1.0,0 +72a4be4d,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.5,1.0,2.0 +281b4a74,pbm_dimin,0.0001,1.0,dimin,quadratic_logarithmic,0.999,1.0,0.5 +aee2ac81,pbm_dimin,0.0001,0.999,dimin,quadratic_logarithmic,0.999,1.0,0 +222d9d3c,pbm_dimin,0.0005,0.999,dimin,quadratic_logarithmic,0.99,1.0,2.0 +bf08f4ba,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.99,1.0,1.0 +ef1fbb14,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.9,1.0,0.5 +cdfa58f2,pbm_dimin,0.001,1.0,dimin,quadratic_logarithmic,0.9,1.0,0.1 +3bcc9755,pbm_dimin,0.0005,0.9,dimin,quadratic_logarithmic,0.5,1.0,1.0 +8358f760,pbm_dimin,0.005,0.99,dimin,quadratic_logarithmic,0.9,1.0,0 +6f9df33e,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.5,1.0,2.0 +bb6c79de,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.5,1.0,2.0 +39f1edd7,pbm_dimin,0.005,0.999,dimin,quadratic_logarithmic,0.99,1.0,0 +ee2dc57f,pbm_dimin,0.0005,1.0,dimin,quadratic_logarithmic,0.5,1.0,1.0 +a3145abb,pbm_dimin,0.0001,0.99,dimin,quadratic_logarithmic,0.99,1.0,0 +5ff5fdec,pbm_dimin,0.005,0.9,dimin,quadratic_logarithmic,0.5,1.0,0.1 +8349a90a,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.99,1.0,0.5 +3007200c,pbm_dimin,0.001,0.99,dimin,quadratic_logarithmic,0.5,1.0,0.1 +c874ba03,pbm_dimin,0.0001,0.9,dimin,quadratic_logarithmic,0.9,1.0,2.0 diff --git a/benchmark/plotting.py b/benchmark/plotting.py index 86a6c2a..836d8a9 100644 --- a/benchmark/plotting.py +++ b/benchmark/plotting.py @@ -2,34 +2,18 @@ import numpy as np -def plot_losses_and_constraints_stochastic( - train_losses_list, - train_losses_std_list, - train_constraints_list, - train_constraints_std_list, - constraint_thresholds, - test_losses_list=None, - test_losses_std_list=None, - test_constraints_list=None, - test_constraints_std_list=None, - titles=None, - eval_points=1, - std_multiplier=2, - log_constraints=False, - mode="train", # "train" or "train_test" - times=[], # second per epoch - plot_time_instead_epochs=False, - save_path=None, - separate_constraints = False, - abs_constraints=False -): - """ - mode: - "train" -> only training plots - "train_test" -> training + test side by side + +def plot_accuracy_per_epoch(algorithms_data, eval_points=1, algorithms_data_std=None, std_multiplier=2): """ + Plots average per-class accuracy per epoch for each algorithm. -# # --- Color palette (Tableau 10) --- + Parameters: + - algorithms_data: Dict where keys are algorithm names and values are lists of average per-class accuracies per epoch. + - eval_points: Evaluate points for markers (default: 1). + - algorithms_data_std: Optional. Dict with same structure as algorithms_data for standard deviations. + - std_multiplier: Multiplier for standard deviation bands (default: 2). + """ + # --- Color palette (Tableau 10) --- colors = [ "#4E79A7", "#F28E2B", @@ -45,175 +29,274 @@ def plot_losses_and_constraints_stochastic( marker_styles = ["o", "s", "D", "^", "v", "<", ">", "P", "X", "*"] - num_algos = len(train_losses_list) - if titles is None: - titles = [f"Algorithm {i + 1}" for i in range(num_algos)] - - constraint_thresholds = np.atleast_1d(constraint_thresholds) + num_algos = len(algorithms_data) + algo_names = list(algorithms_data.keys()) # --- Layout --- - ncols = 1 if mode == "train" else 2 - nrows = 1 + train_constraints_list[0].shape[0] if separate_constraints else 2 + fig, axes = plt.subplots(num_algos, 1, figsize=(9, 4*num_algos), sharex=True) - join_bottom_plot = not test_constraints_list and mode == "train_test" + if num_algos == 1: + axes = [axes] # Ensure axes is iterable - if join_bottom_plot: - fig = plt.figure(figsize=(9 * ncols, 10)) - axes = [] + for i, (ax, algo_name) in enumerate(zip(axes, algo_names)): + y = np.array(algorithms_data[algo_name]) + K = len(y) + x = np.arange(1, K+1) - ax1 = fig.add_subplot(2, 2, 1) - ax2 = fig.add_subplot(2, 2, 2, sharey = ax1) - ax3 = fig.add_subplot(2, 1, 2) + color = colors[i % len(colors)] + ax.plot(x, y, lw=2.2, color=color, label=algo_name) - axes = [ax1, ax2, ax3] - else: - fig, axes = plt.subplots(2, ncols, figsize=(9 * ncols, 10), sharex="col", sharey="row") - - if ncols == 1: - axes = np.array([[axes[0]], [axes[1]]]) - - # ====================================================== - # Helper plotting functions - # ====================================================== + # Add error band if standard deviation data is provided + if algorithms_data_std is not None: + y_std = np.array(algorithms_data_std[algo_name]) + ax.fill_between( + x, + y - std_multiplier * y_std, + y + std_multiplier * y_std, + color=color, + alpha=0.15, + ) - def plot_loss(ax, losses_list, losses_std_list, title_suffix): - for j, (loss, loss_std) in enumerate(zip(losses_list, losses_std_list)): - x = np.arange(len(loss)) - color = colors[j % len(colors)] - upper = loss + std_multiplier * loss_std - lower = loss - std_multiplier * loss_std - - if plot_time_instead_epochs: - x *= round(times[j]) - - # ax.plot(x, loss, lw=2.2, color=color, label=titles[j] + f"; TPE: {minutes}m:{seconds}s") - ax.plot(x, loss, lw=2.2, color=color, label=titles[j]) - ax.fill_between(x, lower, upper, color=color, alpha=0.15) - - if eval_points is not None: - idx = ( - np.arange(0, len(loss), eval_points) - if isinstance(eval_points, int) - else np.array(eval_points) - ) - idx = idx[idx < len(loss)] - ax.plot( - x[idx], - loss[idx], - marker_styles[j % len(marker_styles)], - color=color, - markersize=6, - alpha=0.8, - ) + if eval_points is not None: + idx = ( + np.arange(0, len(y), eval_points) + if isinstance(eval_points, int) + else np.array(eval_points) + ) + idx = idx[idx < len(y)] + ax.plot( + x[idx], + y[idx], + marker_styles[i % len(marker_styles)], + color=color, + markersize=6, + alpha=0.8, + ) - ax.set_title(f"Loss ({title_suffix})") - ax.set_ylabel("Mean Loss") + ax.set_title(algo_name) + ax.set_xlabel("Epoch") + ax.set_ylabel("Average Per-Class Accuracy") ax.grid(True, linestyle="--", alpha=0.35) ax.legend(fontsize=9) - - - def plot_constraints(ax, constraints_list, constraints_std_list, title_suffix): - for j, (constraints, constraints_std) in enumerate( - zip(constraints_list, constraints_std_list) - ): - color = colors[j % len(colors)] - constraints = np.asarray(constraints) - constraints_std = np.asarray(constraints_std) - - x = np.arange(constraints.shape[1]) - - print(np.array(constraints).shape) - c_max = np.max(constraints, axis=0) - c_max_std = np.std(c_max) - - c_lower = c_max - std_multiplier * c_max_std - c_upper = c_max + std_multiplier * c_max_std - ax.fill_between(x, c_lower, c_upper, color=color, alpha=0.1) - - if plot_time_instead_epochs: - x *= round(times[j]) - - label = titles[j] - ax.plot(x, c_max, lw=1.8, color=color, alpha=0.3, label=label) - - if eval_points is not None: - idx = ( - np.arange(0, len(c_max), eval_points) - if isinstance(eval_points, int) - else np.array(eval_points) - ) - idx = idx[idx < len(c_max)] - ax.plot( - x[idx], - c_max[idx], - marker_styles[j % len(marker_styles)], - color=color, - markersize=5, - alpha=0.3, - ) - - for th in constraint_thresholds: - y = np.log(th) if log_constraints else th - ax.axhline(y, color="red", linestyle="--", lw=1.4, label="Threshold") + plt.tight_layout() + plt.show() - ax.set_title(f"Constraint ({title_suffix})") - ax.set_ylabel("Log Constraint" if log_constraints else "Constraint") - if plot_time_instead_epochs: - ax.set_xlabel("Time (m)") - else: - ax.set_xlabel("Epoch") - ax.grid(True, linestyle="--", alpha=0.35) - ax.legend(fontsize=9) - # ====================================================== - # TRAIN PLOTS - # ====================================================== - plot_loss( - axes[0] if join_bottom_plot else axes[0, 0], - train_losses_list, - train_losses_std_list, - "Train" - ) - plot_constraints( - axes[2] if join_bottom_plot else axes[1, 0], - train_constraints_list, - train_constraints_std_list, - "Train", - ) +# def plot_losses_and_constraints_stochastic( +# train_losses_list, +# train_losses_std_list, +# train_constraints_list, +# train_constraints_std_list, +# constraint_thresholds, +# test_losses_list=None, +# test_losses_std_list=None, +# test_constraints_list=None, +# test_constraints_std_list=None, +# titles=None, +# eval_points=1, +# std_multiplier=2, +# log_constraints=False, +# mode="train", # "train" or "train_test" +# times=[], # second per epoch +# plot_time_instead_epochs=False, +# save_path=None, +# separate_constraints = False, +# abs_constraints=False +# ): +# """ +# mode: +# "train" -> only training plots +# "train_test" -> training + test side by side +# """ - # ====================================================== - # TEST PLOTS - # ====================================================== +# # # --- Color palette (Tableau 10) --- +# colors = [ +# "#4E79A7", +# # "#F28E2B", +# "#E15759", +# "#76B7B2", +# "#59A14F", +# "#EDC948", +# "#B07AA1", +# "#FF9DA7", +# "#9C755F", +# "#BAB0AB", +# ] - if mode == "train_test": - plot_loss( - axes[1] if join_bottom_plot else axes[0, 1], - test_losses_list, - test_losses_std_list, - "Test" - ) - if join_bottom_plot: - print(axes[0].get_yticks()) - print(axes[1].get_yticks()) - if test_constraints_list: - plot_constraints( - axes[1, 1], - test_constraints_list, - test_constraints_std_list, - "Test", - ) +# marker_styles = ["o", "s", "D", "^", "v", "<", ">", "P", "X", "*"] - plt.tight_layout() - if save_path: - plt.savefig(save_path) +# num_algos = len(train_losses_list) +# if titles is None: +# titles = [f"Algorithm {i + 1}" for i in range(num_algos)] +# constraint_thresholds = np.atleast_1d(constraint_thresholds) +# # --- Layout --- +# ncols = 1 if mode == "train" else 2 +# nrows = 1 + train_constraints_list[0].shape[0] if separate_constraints else 2 -import numpy as np -import matplotlib.pyplot as plt +# join_bottom_plot = not test_constraints_list and mode == "train_test" + +# if join_bottom_plot: +# fig = plt.figure(figsize=(9 * ncols, 10)) +# axes = [] + +# ax1 = fig.add_subplot(2, 2, 1) +# ax2 = fig.add_subplot(2, 2, 2, sharey = ax1) +# ax3 = fig.add_subplot(2, 1, 2) + +# axes = [ax1, ax2, ax3] +# else: +# fig, axes = plt.subplots(2, ncols, figsize=(9 * ncols, 10), sharex="col", sharey="row") + +# if ncols == 1: +# axes = np.array([[axes[0]], [axes[1]]]) + +# # ====================================================== +# # Helper plotting functions +# # ====================================================== + +# def plot_loss(ax, losses_list, losses_std_list, title_suffix): +# for j, (loss, loss_std) in enumerate(zip(losses_list, losses_std_list)): +# x = np.arange(len(loss)) +# color = colors[j % len(colors)] +# upper = loss + std_multiplier * loss_std +# lower = loss - std_multiplier * loss_std + +# if plot_time_instead_epochs: +# x *= round(times[j]) + +# # ax.plot(x, loss, lw=2.2, color=color, label=titles[j] + f"; TPE: {minutes}m:{seconds}s") +# ax.plot(x, loss, lw=2.2, color=color, label=titles[j]) +# ax.fill_between(x, lower, upper, color=color, alpha=0.15) + +# if eval_points is not None: +# idx = ( +# np.arange(0, len(loss), eval_points) +# if isinstance(eval_points, int) +# else np.array(eval_points) +# ) +# idx = idx[idx < len(loss)] +# ax.plot( +# x[idx], +# loss[idx], +# marker_styles[j % len(marker_styles)], +# color=color, +# markersize=6, +# alpha=0.8, +# ) + +# ax.set_title(f"Loss ({title_suffix})") +# ax.set_ylabel("Mean Loss") +# ax.grid(True, linestyle="--", alpha=0.35) +# ax.legend(fontsize=9) + + +# def plot_constraints(ax, constraints_list, constraints_std_list, title_suffix): +# for j, (constraints, constraints_std) in enumerate( +# zip(constraints_list, constraints_std_list) +# ): +# color = colors[j % len(colors)] +# constraints = np.asarray(constraints) +# constraints_std = np.asarray(constraints_std) + +# x = np.arange(constraints.shape[1]) + +# print(np.array(constraints).shape) +# c_max = np.max(constraints, axis=0) +# c_max_std = np.std(c_max) + +# c_lower = c_max - std_multiplier * c_max_std +# c_upper = c_max + std_multiplier * c_max_std +# ax.fill_between(x, c_lower, c_upper, color=color, alpha=0.1) + +# if plot_time_instead_epochs: +# x *= round(times[j]) + +# label = titles[j] +# ax.plot(x, c_max, lw=1.8, color=color, alpha=0.3, label=label) + +# if eval_points is not None: +# idx = ( +# np.arange(0, len(c_max), eval_points) +# if isinstance(eval_points, int) +# else np.array(eval_points) +# ) +# idx = idx[idx < len(c_max)] +# ax.plot( +# x[idx], +# c_max[idx], +# marker_styles[j % len(marker_styles)], +# color=color, +# markersize=5, +# alpha=0.3, +# ) + + +# for th in constraint_thresholds: +# y = np.log(th) if log_constraints else th +# ax.axhline(y, color="red", linestyle="--", lw=1.4, label="Threshold") + +# ax.set_title(f"Constraint ({title_suffix})") +# ax.set_ylabel("Log Constraint" if log_constraints else "Constraint") + +# if plot_time_instead_epochs: +# ax.set_xlabel("Time (m)") +# else: +# ax.set_xlabel("Epoch") +# ax.grid(True, linestyle="--", alpha=0.35) +# ax.legend(fontsize=9) + +# # ====================================================== +# # TRAIN PLOTS +# # ====================================================== + +# plot_loss( +# axes[0] if join_bottom_plot else axes[0, 0], +# train_losses_list, +# train_losses_std_list, +# "Train" +# ) +# plot_constraints( +# axes[2] if join_bottom_plot else axes[1, 0], +# train_constraints_list, +# train_constraints_std_list, +# "Train", +# ) + +# # ====================================================== +# # TEST PLOTS +# # ====================================================== + +# if mode == "train_test": +# plot_loss( +# axes[1] if join_bottom_plot else axes[0, 1], +# test_losses_list, +# test_losses_std_list, +# "Test" +# ) +# if join_bottom_plot: +# print(axes[0].get_yticks()) +# print(axes[1].get_yticks()) +# if test_constraints_list: +# plot_constraints( +# axes[1, 1], +# test_constraints_list, +# test_constraints_std_list, +# "Test", +# ) + +# plt.tight_layout() +# if save_path: +# plt.savefig(save_path) + + + +# import numpy as np +# import matplotlib.pyplot as plt def plot_losses_and_constraints_stochastic( @@ -258,7 +341,19 @@ def plot_losses_and_constraints_stochastic( if titles is None: titles = [f"Algorithm {i+1}" for i in range(num_algos)] - colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + # colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + colors = [ + "#4E79A7", + "#F28E2B", + "#E15759", + "#76B7B2", + "#59A14F", + "#EDC948", + "#B07AA1", + "#FF9DA7", + "#9C755F", + "#BAB0AB", + ] markers = ["o", "s", "D", "^", "v", "<", ">", "P", "X", "*"] constraint_thresholds = np.atleast_1d(constraint_thresholds) diff --git a/benchmark/run_benchmark.py b/benchmark/run_benchmark.py index c67b892..662c9cc 100644 --- a/benchmark/run_benchmark.py +++ b/benchmark/run_benchmark.py @@ -7,7 +7,7 @@ import hydra from omegaconf import DictConfig, OmegaConf -from benchmark_utils import create_model +from utils import create_model, create_conv_model, create_resnet from _data_sources import load_data_FT_prod, load_data_FT_vec, load_data_FT, load_data_DUTCH, load_data_norm, load_data_cifar10, load_data_cifar100 # from benchmark_utils import * from plotting import plot_losses_and_constraints_stochastic @@ -18,7 +18,12 @@ def run_benchmark(data_cfg, task, n_runs, n_epochs, constraint_cfg, pbm_params, seed = 0 torch.manual_seed(seed) dataset = data_cfg['name'] - device = 'cpu' + + device = ('cuda:0' if torch.cuda.is_available() else 'cpu') + print(torch.cuda.device_count()) + print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU') + torch.set_default_device(device) + result_dir = "results/" + dataset + '_' + task os.makedirs(result_dir, exist_ok=True) @@ -50,18 +55,27 @@ def run_benchmark(data_cfg, task, n_runs, n_epochs, constraint_cfg, pbm_params, batch_size = cfg.batch_size if task == 'cifar10': criterion = torch.nn.CrossEntropyLoss(reduction = 'none') - dataloader_train, dataloader_val, classes, class_ind = load_data_cifar10(batch_size) + dataloader_train, dataloader_val, dataloader_test, classes, class_ind = load_data_cifar10(device=device) features_train, sens_train, labels_train = next(iter(dataloader_train)) + model_fn = create_conv_model + dataloader_val = dataloader_test + model_kwargs = {} elif task == 'cifar100': criterion = torch.nn.CrossEntropyLoss(reduction = 'none') - dataloader_train, dataloader_val, classes, class_ind = load_data_cifar100(batch_size) + dataloader_train, dataloader_val, dataloader_test, classes, class_ind = load_data_cifar100(device=device) features_train, sens_train, labels_train = next(iter(dataloader_train)) + model_fn = create_resnet + dataloader_val = dataloader_test + model_kwargs = {} else: criterion = torch.nn.functional.binary_cross_entropy_with_logits (dataloader_train, dataloader_val, dataloader_test), (features_train, sens_train, labels_train), (features_val, sens_val, labels_val), (features_test, sens_test, labels_test) = data_source(batch_size) - features_val = features_test - sens_val = sens_test - labels_val = labels_test + if task != 'equalized_odds_vec': + features_val = features_test + sens_val = sens_test + labels_val = labels_test + model_fn = create_model + model_kwargs = {'input_shape': features_train.shape[-1], **model_kwargs} data_val = (features_val, sens_val, labels_val) if task not in ['cifar10', 'cifar100'] else dataloader_val @@ -131,9 +145,10 @@ def run_benchmark(data_cfg, task, n_runs, n_epochs, constraint_cfg, pbm_params, use_slack=False, fuse_loss_constraint=fuse_loss_constraint, reg_penalty=p, - model_gen = create_model, - model_kwargs = {'input_shape': features_train.shape[-1], **model_kwargs}, - criterion=criterion + model_gen = model_fn, + model_kwargs = model_kwargs, + criterion=criterion, + device=device ) models.append(model) @@ -180,9 +195,10 @@ def run_benchmark(data_cfg, task, n_runs, n_epochs, constraint_cfg, pbm_params, constraints_to_eq=False, use_slack=False, fuse_loss_constraint=fuse_loss_constraint, - model_gen = create_model, - model_kwargs = {'input_shape': features_train.shape[-1], **model_kwargs}, - criterion=criterion + model_gen = model_fn, + model_kwargs = model_kwargs, + criterion=criterion, + device=device ) models.append(model) pbm_history_train.append(h_train) @@ -225,9 +241,10 @@ def run_benchmark(data_cfg, task, n_runs, n_epochs, constraint_cfg, pbm_params, constraints_to_eq=True, use_slack=alm_use_slack, fuse_loss_constraint=fuse_loss_constraint, - model_gen = create_model, - model_kwargs = {'input_shape': features_train.shape[-1], **model_kwargs}, - criterion=criterion + model_gen = model_fn, + model_kwargs = model_kwargs, + criterion=criterion, + device=device ) models.append(model) alm_history_train.append(h_train) @@ -272,9 +289,10 @@ def run_benchmark(data_cfg, task, n_runs, n_epochs, constraint_cfg, pbm_params, constraints_to_eq=False, use_slack=False, fuse_loss_constraint=fuse_loss_constraint, - model_gen = create_model, - model_kwargs = {'input_shape': features_train.shape[-1], **model_kwargs}, - criterion=criterion + model_gen = model_fn, + model_kwargs = model_kwargs, + criterion=criterion, + device=device ) models.append(model) ssg_history_train.append(h_train) @@ -300,6 +318,10 @@ def run_benchmark(data_cfg, task, n_runs, n_epochs, constraint_cfg, pbm_params, ###### PLOT ###### + algorithms_run = ['adam', 'ssg', 'alm', 'pbm'] + # algorithms_run = ['adam', 'alm_slack', 'alm_max', 'pbm_dimin'] + # algorithms_run = ['adam', 'alm_max', 'pbm_dimin'] + if len(algorithms_run) == 0: print('\nNo algorithms were run. Skipping plotting.\n') return @@ -307,8 +329,18 @@ def run_benchmark(data_cfg, task, n_runs, n_epochs, constraint_cfg, pbm_params, def read_prepare_data(path: str): train = pd.read_csv(path) c_cols = [c for c in train.columns if c.startswith('c_')] + if "params" in train.columns: + train.drop("params", inplace=True, axis="columns") + if "acc" in train.columns: + train.drop("acc", inplace=True, axis="columns") mean = train.groupby(by='epoch').mean() - std = train.groupby(by=['Unnamed: 0', 'epoch']).mean().groupby(by='epoch').std() + # breakpoint() + # if "Unnamed: 0" not in train.columns: + # train.reset_index(inplace=True) + # train['Unnamed: 0'] = train.index + # std = train.groupby(by=['Unnamed: 0', 'epoch']).mean().groupby(by='epoch').std() + + std = train.groupby(by='epoch').std() loss_mean = mean['loss'].to_numpy() loss_std = std['loss'].to_numpy() cs_mean = mean[c_cols].to_numpy() @@ -327,20 +359,23 @@ def read_prepare_data(path: str): cs_stds_val = [] alg_names = [] alg_display_names = { - 'adam': 'Adam', + 'adam': 'Unconstrained Adam', 'pbm': 'SPBM', - 'alm': 'ALM', - 'ssg': 'SSG' + 'pbm_dimin': 'SPBM', + 'alm': 'SSL-ALM', + 'alm_max': 'SSL-ALM', + 'ssg': 'SSW' } + for alg in algorithms_run: + train_file = f'{result_dir}/runs_{alg}_train.csv' val_file = f'{result_dir}/runs_{alg}_val.csv' if os.path.exists(train_file) and os.path.exists(val_file): _, loss_mean_train, loss_std_train, cs_mean_train, cs_std_train = read_prepare_data(train_file) _, loss_mean_val, loss_std_val, cs_mean_val, cs_std_val = read_prepare_data(val_file) - loss_means_train.append(loss_mean_train) loss_stds_train.append(loss_std_train) cs_means_train.append(cs_mean_train.T) @@ -349,8 +384,9 @@ def read_prepare_data(path: str): loss_stds_val.append(loss_std_val) cs_means_val.append(cs_mean_val.T) cs_stds_val.append(cs_std_val.T) + # alg_names.append(alg) alg_names.append(alg_display_names[alg]) - + print(alg_names) if len(alg_names) > 0: plot_losses_and_constraints_stochastic( loss_means_train, @@ -366,6 +402,7 @@ def read_prepare_data(path: str): mode='train_test' if not task == 'weight_norm' else 'train', # plot_max_constraint=m > 5, save_path=result_dir + '/plot.png', + std_multiplier=1. # combine_algos = True ) diff --git a/benchmark/run_gridsearch.py b/benchmark/run_gridsearch.py index 6b11dd4..229bd7b 100644 --- a/benchmark/run_gridsearch.py +++ b/benchmark/run_gridsearch.py @@ -1,6 +1,4 @@ -# from benchmark_utils import * from utils import * -import benchmark_utils from utils import create_model, create_conv_model from humancompatible.train.fairness.utils import BalancedBatchSampler from itertools import product @@ -51,13 +49,13 @@ def extract_best_params(runs, param_grid, val_c_tolerance, filter='upper'): return min_feasible_val_loss_params, min_feasible_val_loss_idx, min_val_loss, min_val_c, runs -def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): +def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device, seed): seed = 0 torch.manual_seed(seed) dataset = data_cfg['name'] task = task_cfg.task - result_dir = dataset + '_' + task + result_dir = dataset + '_' + task + seed os.makedirs(result_dir, exist_ok=True) @@ -81,13 +79,13 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): if task == 'cifar10': dataloader_train, dataloader_val, classes, class_ind = load_data_cifar10(device=device) features_train, sens_train, labels_train = next(iter(dataloader_train)) - create_model = create_conv_model + create_model_fn = create_conv_model model_kwargs = {} criterion = torch.nn.CrossEntropyLoss(reduction='none') elif task == 'cifar100': dataloader_train, dataloader_val, classes, class_ind = load_data_cifar100(device=device) features_train, sens_train, labels_train = next(iter(dataloader_train)) - create_model = create_conv_model + create_model_fn = create_conv_model model_kwargs = {} criterion = torch.nn.CrossEntropyLoss(reduction='none') else: @@ -95,7 +93,7 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): features_val = features_test sens_val = sens_test labels_val = labels_test - create_model = benchmark_utils.create_model + create_model_fn = create_model model_kwargs = {'input_shape': features_train.shape[1], 'latent_size1': 64, 'latent_size2': 32} criterion = torch.nn.functional.binary_cross_entropy_with_logits @@ -107,11 +105,10 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): "primal__lr": lr, "dual__penalty_mult": dual_lr, "dual__penalty_update": p_update, - "dual__penalty_range": [1e-1, 1.], "dual__pbf": pb_func, - "dual__init_duals": 1e-4, + "dual__penalty_range": p_range, "dual__gamma": dual_gamma, - "dual__delta": dual_delta, + "dual__delta": 1., "moreau__mu": moreau_mu } for ( @@ -119,16 +116,18 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): dual_lr, p_update, pb_func, + p_range, dual_gamma, - dual_delta, + # dual_delta, moreau_mu ) in product ( [0.001, 0.005, 0.01, 0.05], [0., 0.1, 0.2, 0.5], ["dimin_adapt"], ["quadratic_logarithmic"], + [[1e-1, 1.], [1e-1, 100.]], [0., 0.1, 0.2, 0.5], - [0.9, 1.0, 1.1], + # [0.9, 1.0, 1.1], [2.] ) ] @@ -138,18 +137,20 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): "primal__lr": lr, "dual__lr": dual_lr, "dual__penalty": penalty, - "dual__momentum": dual_momentum, + # "dual__momentum": dual_momentum, "moreau__mu": moreau_mu } for ( lr, - dual_lr, penalty, dual_momentum, + dual_lr, + penalty, + # dual_momentum, moreau_mu ) in product ( [0.001, 0.005, 0.01, 0.05], [0.001, 0.005, 0.01, 0.05], [0., 1.], - [0., 0.1, 0.2, 0.5], + # [0., 0.1, 0.2, 0.5], [2.] ) ] @@ -174,7 +175,7 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): alm_max_grid = alm_grid alm_slack_grid = alm_grid - adam_grid = [{"lr": lr} for lr in [0.005, 0.01, 0.05]] + adam_grid = [{"lr": lr} for lr in [0.001, 0.005, 0.01, 0.05]] # Determine constraint function and parameters based on task fuse_loss_constraint = False @@ -215,8 +216,50 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): # Run experiments + if 'adam' in task_cfg.algorithms: + seed = seed + torch.manual_seed(seed) + _, adam_history_train, adam_history_val = run_grid( + m=m, + primal_opt=torch.optim.Adam, + dual_opt=None, + param_grid=adam_grid, + n_epochs=n_epochs, + constraint_fn=constraint_fn, + constraint_bound=constraint_bound, + dataloader=dataloader_train, + data_train=(features_train, sens_train, labels_train), + data_val=data_val, + mode='torch', + verbose=False, + constraints_to_eq=False, + use_slack=False, + fuse_loss_constraint=fuse_loss_constraint, + model_gen=create_model_fn, + model_kwargs=model_kwargs, + device=device, + criterion = criterion) + + best_adam_params = extract_best_params(adam_history_val, adam_grid, None, filter='none') + + print('\n------------\n') + print('adam') + print(best_adam_params[0], best_adam_params[1]) + print(f'loss: {best_adam_params[2]}') + print(f'max c: {best_adam_params[3]}') + print('\n------------\n') + grid_adam = pd.DataFrame(adam_grid) + runs_adam_train = runs_to_df(adam_history_train) + runs_adam_train.to_csv(f'{result_dir}/runs_adam_train.csv') + runs_adam_val = runs_to_df(adam_history_val) + runs_adam_val.to_csv(f'{result_dir}/runs_adam_val.csv') + grid_adam.to_csv(f'{result_dir}/grid_adam.csv') + del adam_history_train, adam_history_val, runs_adam_train, runs_adam_val, grid_adam + ################################################################# if 'pbm' in task_cfg.algorithms: + seed = seed + torch.manual_seed(seed) _, pbm_history_train, pbm_history_val = run_grid( m=m, primal_opt=torch.optim.Adam, @@ -233,7 +276,7 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): constraints_to_eq=False, use_slack=False, fuse_loss_constraint=fuse_loss_constraint, - model_gen=create_model, + model_gen=create_model_fn, model_kwargs=model_kwargs, device=device, criterion=criterion) @@ -255,6 +298,8 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): del pbm_history_train, pbm_history_val, runs_pbm_train, runs_pbm_val, grid_pbm if 'alm_slack' in task_cfg.algorithms: + seed = seed + torch.manual_seed(seed) _, alm_history_train, alm_history_val = run_grid( m=m, primal_opt=torch.optim.Adam, @@ -271,7 +316,7 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): constraints_to_eq=True, use_slack=True, fuse_loss_constraint=fuse_loss_constraint, - model_gen=create_model, + model_gen=create_model_fn, model_kwargs=model_kwargs, device=device, criterion = criterion) @@ -293,6 +338,8 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): del alm_history_train, alm_history_val, runs_alm_train, runs_alm_val, grid_alm if 'alm_max' in task_cfg.algorithms: + seed = seed + torch.manual_seed(seed) _, alm_history_train, alm_history_val = run_grid( m=m, primal_opt=torch.optim.Adam, @@ -309,7 +356,7 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): constraints_to_eq=True, use_slack=False, fuse_loss_constraint=fuse_loss_constraint, - model_gen=create_model, + model_gen=create_model_fn, model_kwargs=model_kwargs, device=device, criterion=criterion) @@ -331,6 +378,8 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): del alm_history_train, alm_history_val, runs_alm_train, runs_alm_val, grid_alm if 'ssg' in task_cfg.algorithms: + seed = seed + torch.manual_seed(seed) _, ssg_history_train, ssg_history_val = run_grid( m=m, primal_opt=torch.optim.Adam, @@ -347,7 +396,7 @@ def main(data_cfg, task_cfg, n_epochs, constraint_cfg, device): constraints_to_eq=False, use_slack=False, fuse_loss_constraint=fuse_loss_constraint, - model_gen=create_model, + model_gen=create_model_fn, model_kwargs=model_kwargs, device=device, criterion=criterion) @@ -380,8 +429,9 @@ def hydra_main(cfg: DictConfig): print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU') torch.set_default_device(device) constraint_cfg = OmegaConf.to_container(task_cfg.constraint, resolve=True) + seed = task_cfg.seed - main(data_cfg, task_cfg, n_epochs, constraint_cfg, device) + main(data_cfg, task_cfg, n_epochs, constraint_cfg, device, seed) diff --git a/benchmark/tinyimagenet_folktables/adam_config.yaml b/benchmark/tinyimagenet_folktables/adam_config.yaml new file mode 100644 index 0000000..5ad545a --- /dev/null +++ b/benchmark/tinyimagenet_folktables/adam_config.yaml @@ -0,0 +1,2 @@ +,algorithm,primal__lr,primal__weight_decay +7c3fe4eff95d,adam,0.0001,0.01 diff --git a/benchmark/tinyimagenet_folktables/alm_max_config.yaml b/benchmark/tinyimagenet_folktables/alm_max_config.yaml new file mode 100644 index 0000000..01853ab --- /dev/null +++ b/benchmark/tinyimagenet_folktables/alm_max_config.yaml @@ -0,0 +1,2 @@ +,algorithm,primal__lr,dual__lr,dual__penalty,moreau__mu,primal__weight_decay +1e87d660012c,alm_max,0.0005,0.0005,0.0,0.1,0.01 diff --git a/benchmark/tinyimagenet_folktables/multirun.yaml b/benchmark/tinyimagenet_folktables/multirun.yaml new file mode 100644 index 0000000..966d71b --- /dev/null +++ b/benchmark/tinyimagenet_folktables/multirun.yaml @@ -0,0 +1,218 @@ +hydra: + run: + dir: multirun/${task.task}_${data.name}/${algorithm.algorithm} + sweep: + dir: multirun/${task.task}_${data.name} + subdir: ./${algorithm.algorithm} + launcher: + submitit_folder: ${hydra.sweep.dir}/.submitit/%j + timeout_min: 60 + cpus_per_task: 4 + gpus_per_node: null + tasks_per_node: 1 + mem_gb: null + nodes: 1 + name: ${hydra.job.name} + stderr_to_stdout: false + _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + partition: h200fast + qos: null + comment: null + constraint: null + exclude: null + gres: gpu=1 + cpus_per_gpu: null + gpus_per_task: null + mem_per_gpu: null + mem_per_cpu: 20G + account: null + signal_delay_s: 120 + max_num_timeout: 0 + additional_parameters: {} + array_parallelism: 256 + setup: null + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: 85 + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: MULTIRUN + searchpath: [] + callbacks: {} + output_subdir: null + overrides: + hydra: + - hydra.mode=MULTIRUN + task: + - algorithm=alm_slack + - algorithm.primal__lr=0.001 + - +algorithm.primal__weight_decay=0.01 + - algorithm.dual__lr=0.0005 + - algorithm.dual__penalty=0. + - algorithm.moreau__mu=1. + - seed=0,1,2 + - task=tinyimagenet + - data=income_sex + - use_test=true + - n_epochs=30 + job: + name: run_single_experiment + chdir: false + override_dirname: +algorithm.primal__weight_decay=0.01,algorithm.dual__lr=0.0005,algorithm.dual__penalty=0.,algorithm.moreau__mu=1.,algorithm.primal__lr=0.001,algorithm=alm_slack,data=income_sex,n_epochs=30,seed=0,1,2,task=tinyimagenet,use_test=true + id: ??? + num: ??? + config_name: experiment + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.3' + cwd: /mnt/personal/kliacand/humancompatible-train/benchmark + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /mnt/personal/kliacand/humancompatible-train/benchmark/conf + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: ??? + choices: + algorithm: alm_slack + data: income_sex + task: tinyimagenet + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: configured_submitit_slurm_h200 + hydra/output: default + verbose: false +task: + task: tinyimagenet + batch_size: 1200 + constraint: + name: LossPairwise + bound: 0.1 + loss: CrossEntropyLoss +data: + name: folktables + kwargs: + sens_attrs: + - SEX + states: + - VA + extend_groups: false +algorithm: + algorithm: alm_slack + primal__lr: 0.001 + dual__lr: 0.0005 + dual__penalty: 0.0 + moreau__mu: 1.0 + primal__weight_decay: 0.01 +n_epochs: 30 +seed: 0 +use_test: true diff --git a/benchmark/tinyimagenet_folktables/pbm_dimin_config.yaml b/benchmark/tinyimagenet_folktables/pbm_dimin_config.yaml new file mode 100644 index 0000000..42d712d --- /dev/null +++ b/benchmark/tinyimagenet_folktables/pbm_dimin_config.yaml @@ -0,0 +1,3 @@ +,algorithm,dual__delta,dual__gamma,dual__pbf,dual__penalty_mult,dual__penalty_update,init_duals,moreau__mu,primal__lr,primal__weight_decay +985d15a255fc,pbm_dimin,1.0,0.9500000000000001,quadratic_reciprocal,1,const,1e-05,1.0,0.0005,0.01 +00301b66238d,pbm_dimin,1.0,0.1,quadratic_reciprocal,0.9999,dimin,,1.0,0.0005,0.01 diff --git a/benchmark/utils.py b/benchmark/utils.py index 0cc85f5..7696093 100644 --- a/benchmark/utils.py +++ b/benchmark/utils.py @@ -5,6 +5,7 @@ import numpy as np from torch import nn from torch.nn import Sequential +from torchvision import models # define the network class ConvNet(nn.Module): @@ -26,6 +27,19 @@ def forward(self, x): x = self.fc3(x) return x +def create_efficientnet(num_classes=200): + """EfficientNet-B0 from scratch (no pretrained weights).""" + model = models.efficientnet_b0(weights=None) + # Replace classifier head for 200-class TinyImageNet + in_features = model.classifier[1].in_features + model.classifier[1] = nn.Linear(in_features, num_classes) + return model + + +def create_resnet(): + import torchvision + return torchvision.models.resnet18(pretrained=False) + def create_conv_model(): return ConvNet() @@ -128,10 +142,7 @@ def run_train( constraint_tol: float = 0., criterion = None, ): - # Use provided criterion or fall back to global - if criterion is None: - criterion = globals()['criterion'] - + print(f"Starting on {device}") model = model_gen(**model_kwargs) primal_params = {k.removeprefix('primal__'): v for k, v in param_set.items() if k.startswith('primal__')} @@ -158,8 +169,8 @@ def run_train( bounds = torch.tensor([constraint_bound]*m).to(device) - torch.sum(model(data_train[0][0].unsqueeze(0))).backward() - model.zero_grad() + # torch.sum(model(data_train[0][0].unsqueeze(0))).backward() + # model.zero_grad() if mode == 'hc': history = train_loop_primal_dual( @@ -249,20 +260,25 @@ def train_loop_sw( epoch_losses = [] epoch_constraints = [] + epoch_accuracies = [] epoch_constraint_time = 0.0 if epoch == 0: model.eval() + train_start_time = time.perf_counter() for batch_features, batch_sens_attrs, batch_labels in train_dataloader: batch_features = batch_features.to(device) batch_sens_attrs = batch_sens_attrs.to(device) batch_labels = batch_labels.to(device) batch_out = model(batch_features) - constraints, constraints_bounded_eq = calc_constraints(model, constraint_fn, constraint_bounds, constraint_options, None, batch_sens_attrs, batch_labels, batch_out, None) loss = loss_fn(batch_out, batch_labels) - train_start_time = time.perf_counter() + constraints, constraints_bounded_eq = calc_constraints(model, constraint_fn, constraint_bounds, constraint_options, None, batch_sens_attrs, batch_labels, batch_out, loss) + batch_acc = calc_perclass_acc(batch_sens_attrs, batch_labels, batch_out) + if loss.dim() > 0: # If loss is not aggregated + loss = loss.mean() epoch_losses.append(loss.detach().cpu().numpy().item()) epoch_constraints.append(constraints.detach().cpu().numpy()) + epoch_accuracies.append(batch_acc) else: # Training phase model.train() @@ -285,6 +301,7 @@ def train_loop_sw( constraints, constraints_bounded_eq = calc_constraints(model, constraint_fn, constraint_bounds, constraint_options, None, batch_sens_attrs, batch_labels, batch_out, None) epoch_constraint_time += time.perf_counter() - constraint_start + batch_acc = calc_perclass_acc(batch_sens_attrs, batch_labels, batch_out) max_constraint = max(constraints_bounded_eq) # constraint step if violated @@ -292,14 +309,19 @@ def train_loop_sw( max_constraint.backward() dual_optimizer.step() loss = loss_fn(batch_out, batch_labels) + if loss.dim() > 0: + loss = loss.mean() else: loss = loss_fn(batch_out, batch_labels) + if loss.dim() > 0: + loss = loss.mean() loss.backward() optimizer.step() epoch_losses.append(loss.detach().cpu().numpy().item()) epoch_constraints.append(constraints.detach().cpu().numpy()) + epoch_accuracies.append(batch_acc) # Stop training timer train_end_time = time.perf_counter() @@ -312,19 +334,21 @@ def train_loop_sw( eval_dict = { "epoch": epoch, "time": train_time, - "loss": np.mean(epoch_losses) + "loss": np.mean(epoch_losses), + "acc": np.mean(epoch_accuracies, axis=0) } | { f"c_{j}": c for j, c in enumerate(np.mean(epoch_constraints, axis=0)) } history_train.append(eval_dict) # Validation phase model.eval() - val_loss, val_constraints = validate_model(model, val_data, loss_fn, constraint_fn, device) + val_loss, val_constraints, val_acc = validate_model(model, val_data, loss_fn, constraint_fn, device) eval_dict = { "epoch": epoch, "time": train_time, - "loss": val_loss + "loss": val_loss, + "acc": val_acc } | { f"c_{j}": c for j, c in enumerate(val_constraints) } @@ -373,20 +397,28 @@ def train_loop_primal_dual( model.train() epoch_losses = [] epoch_constraints = [] + epoch_accuracies = [] epoch_constraint_time = 0.0 if epoch == 0: model.eval() + train_start_time = time.perf_counter() for batch_features, batch_sens_attrs, batch_labels in train_dataloader: batch_features = batch_features.to(device) batch_sens_attrs = batch_sens_attrs.to(device) batch_labels = batch_labels.to(device) batch_out = model(batch_features) - constraints, constraints_bounded_eq = calc_constraints(model, constraint_fn, constraint_bounds, constraint_options, None, batch_sens_attrs, batch_labels, batch_out, None) loss = loss_fn(batch_out, batch_labels) - train_start_time = time.perf_counter() + # print(loss) + constraints, constraints_bounded_eq = calc_constraints(model, constraint_fn, constraint_bounds, constraint_options, None, batch_sens_attrs, batch_labels, batch_out, loss) + batch_acc = calc_perclass_acc(batch_sens_attrs, batch_labels, batch_out) + # print(constraints) + # print(constraints_bounded_eq) + if loss.dim() > 0: # If loss is not aggregated + loss = loss.mean() epoch_losses.append(loss.detach().cpu().numpy().item()) epoch_constraints.append(constraints.detach().cpu().numpy()) + epoch_accuracies.append(batch_acc) model.train() else: # Start training timer @@ -408,10 +440,12 @@ def train_loop_primal_dual( # Compute constraints # if not timing constraints, we just log the time and then subtract it at the end. constraint_start = time.perf_counter() + # print(loss) constraints, constraints_bounded_eq = calc_constraints(model, constraint_fn, constraint_bounds, constraint_options, slack_vars, batch_sens_attrs, batch_labels, batch_out, loss) epoch_constraint_time += time.perf_counter() - constraint_start - if mode != 'sw' and loss.dim() > 0: # If loss is not aggregated + batch_acc = calc_perclass_acc(batch_sens_attrs, batch_labels, batch_out) + if loss.dim() > 0: # If loss is not aggregated loss = loss.mean() lgr = dual_optimizer.forward_update(loss, constraints_bounded_eq) @@ -420,6 +454,7 @@ def train_loop_primal_dual( epoch_losses.append(loss.detach().cpu().numpy().item()) epoch_constraints.append(constraints.detach().cpu().numpy()) + epoch_accuracies.append(batch_acc) # Stop training timer train_end_time = time.perf_counter() @@ -432,23 +467,23 @@ def train_loop_primal_dual( eval_dict = { "epoch": epoch, "time": train_time, - "loss": np.mean(epoch_losses) + "loss": np.mean(epoch_losses), + "acc": np.mean(epoch_accuracies, axis=0) } | { f"c_{j}": c for j, c in enumerate(np.mean(epoch_constraints, axis=0)) } - history_train.append(eval_dict) # Validation phase model.eval() - val_loss, val_constraints = validate_model(model, val_data, loss_fn, constraint_fn, device) + val_loss, val_constraints, val_acc = validate_model(model, val_data, loss_fn, constraint_fn, device) eval_dict = { "epoch": epoch, "time": train_time, - "loss": val_loss + "loss": val_loss, + "acc": val_acc } | { f"c_{j}": c for j, c in enumerate(val_constraints) } - history_val.append(eval_dict) return model, history_train, history_val @@ -488,20 +523,25 @@ def train_loop_adam( model.train() epoch_losses = [] epoch_constraints = [] + epoch_accuracies = [] epoch_constraint_time = 0.0 if epoch == 0: model.eval() + train_start_time = time.perf_counter() for batch_features, batch_sens_attrs, batch_labels in train_dataloader: batch_features = batch_features.to(device) batch_sens_attrs = batch_sens_attrs.to(device) batch_labels = batch_labels.to(device) batch_out = model(batch_features) - constraints, constraints_bounded_eq = calc_constraints(model, constraint_fn, constraint_bounds, constraint_options, None, batch_sens_attrs, batch_labels, batch_out, None) loss = loss_fn(batch_out, batch_labels) - train_start_time = time.perf_counter() + constraints, constraints_bounded_eq = calc_constraints(model, constraint_fn, constraint_bounds, constraint_options, None, batch_sens_attrs, batch_labels, batch_out, loss) + batch_acc = calc_perclass_acc(batch_sens_attrs, batch_labels, batch_out) + if loss.dim() > 0: # If loss is not aggregated + loss = loss.mean() epoch_losses.append(loss.detach().cpu().numpy().item()) epoch_constraints.append(constraints.detach().cpu().numpy()) + epoch_accuracies.append(batch_acc) model.train() else: # Start training timer @@ -526,6 +566,7 @@ def train_loop_adam( if not time_constraint_computation: # time constraint separately to subtract it from total time later epoch_constraint_time += time.perf_counter() - constraint_start + batch_acc = calc_perclass_acc(batch_sens_attrs, batch_labels, batch_out) if loss.dim() > 0: # If loss is not aggregated loss = loss.mean() @@ -540,6 +581,7 @@ def train_loop_adam( epoch_losses.append(loss.detach().cpu().numpy().item()) epoch_constraints.append(constraints.detach().cpu().numpy()) + epoch_accuracies.append(batch_acc) # Stop training timer train_end_time = time.perf_counter() @@ -552,19 +594,21 @@ def train_loop_adam( eval_dict = { "epoch": epoch, "time": train_time, - "loss": np.mean(epoch_losses) + "loss": np.mean(epoch_losses), + "acc": np.mean(epoch_accuracies, axis=0) } | { f"c_{j}": c for j, c in enumerate(np.mean(epoch_constraints, axis=0)) } history_train.append(eval_dict) # Validation phase model.eval() - val_loss, val_constraints = validate_model(model, val_data, loss_fn, constraint_fn, device) + val_loss, val_constraints, val_acc = validate_model(model, val_data, loss_fn, constraint_fn, device) eval_dict = { "epoch": epoch, "time": train_time, - "loss": val_loss + "loss": val_loss, + "acc": val_acc } | { f"c_{j}": c for j, c in enumerate(val_constraints) } @@ -577,6 +621,7 @@ def train_loop_adam( def validate_model(model, val_data, loss_fn, constraint_fn, device): val_losses = [] val_constraints_list = [] + val_accuracies_list = [] with torch.no_grad(): if isinstance(val_data, torch.utils.data.DataLoader): # Use dataloader batches @@ -591,9 +636,11 @@ def validate_model(model, val_data, loss_fn, constraint_fn, device): val_loss = val_loss.mean() val_constraints = constraint_fn(model, val_out, batch_sens_attrs, batch_labels) + val_acc = calc_perclass_acc(batch_sens_attrs, batch_labels, val_out) # val_constraints -= constraint_bounds val_losses.append(val_loss.detach().cpu().numpy().item()) val_constraints_list.append(val_constraints.detach().cpu().numpy()) + val_accuracies_list.append(val_acc) else: # Use full validation dataset val_features, val_sens_attrs, val_labels = val_data @@ -607,12 +654,15 @@ def validate_model(model, val_data, loss_fn, constraint_fn, device): val_loss = val_loss.mean() val_constraints = constraint_fn(model, val_out, val_sens_attrs, val_labels) + val_acc = calc_perclass_acc(val_sens_attrs, val_labels, val_out) val_losses.append(val_loss.detach().cpu().numpy().item()) val_constraints_list.append(val_constraints.detach().cpu().numpy()) + val_accuracies_list.append(val_acc) val_losses = np.mean(val_losses) val_constraints_list = np.mean(val_constraints_list, axis=0) - return val_losses, val_constraints_list + val_accuracies = np.mean(val_accuracies_list, axis=0) + return val_losses, val_constraints_list, val_accuracies def calc_constraints(model, constraint_fn, constraint_bounds, constraint_options, slack_vars, batch_sens, batch_labels, batch_out, loss=None): @@ -629,3 +679,34 @@ def calc_constraints(model, constraint_fn, constraint_bounds, constraint_options constraints_bounded_eq = (constraints - constraint_bounds) return constraints, constraints_bounded_eq + +def calc_perclass_acc(batch_classes_onehot, batch_labels, batch_out): + """ + Efficiently calculate per-class accuracy. + + Args: + batch_classes_onehot: shape (batch_size, n_classes), one-hot encoding of class membership + batch_labels: shape (batch_size,), ground truth labels + batch_out: shape (batch_size, n_classes), logits from model + + Returns: + per_class_acc: shape (n_classes,), accuracy for each class + """ + # Get predictions from logits + predictions = torch.argmax(batch_out, dim=1) + + # Check if predictions match labels + if batch_labels.ndim > predictions.ndim: + batch_labels = batch_labels.squeeze() + correct = (predictions == batch_labels).float() # shape (batch_size,) + + # Number of samples per class + n_samples_per_class = batch_classes_onehot.sum(dim=0) # shape (n_classes,) + + # Sum correct predictions per class using matrix multiplication + correct_per_class = batch_classes_onehot.T @ correct # shape (n_classes,) + + # Calculate per-class accuracy, avoid division by zero + per_class_acc = correct_per_class / (n_samples_per_class + 1e-8) + + return per_class_acc.detach().cpu().numpy() diff --git a/logs/simulator/demo_balls_pbm.pdf b/logs/simulator/demo_balls_pbm.pdf new file mode 100644 index 0000000..d39fe72 Binary files /dev/null and b/logs/simulator/demo_balls_pbm.pdf differ diff --git a/logs/simulator/err_simulator.out b/logs/simulator/err_simulator.out new file mode 100644 index 0000000..f9e9ace --- /dev/null +++ b/logs/simulator/err_simulator.out @@ -0,0 +1,18 @@ +/mnt/appl/software/PyTorch/2.7.1-foss-2025a-CUDA-12.8.0/lib/python3.13/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. + warnings.warn( + train: 0%| | 0/25 [00:00 + train_tinyimagenet() + ~~~~~~~~~~~~~~~~~~^^ + File "/home/bosakad1/humancompatible-train/tiny_image_net.py", line 177, in train_tinyimagenet + train_loss, train_acc, max_constr = run_epoch(model, loaders[loader_name], + ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + criterion, optimizer, device, + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + train=True, dual=dual) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/home/bosakad1/humancompatible-train/tiny_image_net.py", line 239, in run_epoch + lagrangian = dual.forward_update(loss.mean(), constraints.unsqueeze(0)) + File "/home/bosakad1/humancompatible-train/src/humancompatible/train/dual_optim/pbm.py", line 249, in forward_update + cdivp = group_constraints.div(penalties) +RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! diff --git a/logs/simulator/exdbn_simulator.stdout b/logs/simulator/exdbn_simulator.stdout new file mode 100644 index 0000000..6f24714 --- /dev/null +++ b/logs/simulator/exdbn_simulator.stdout @@ -0,0 +1,17 @@ +Dataset has 200 classes. Sample classes: ['n01443537', 'n01629819', 'n01641577', 'n01644900', 'n01698640'] + +Dataset: val | Size: 10000 +Loading cache from ./data/cache_val.pt... + X: torch.Size([10000, 3, 64, 64]), targets: torch.Size([10000]) +tensor(995000) + Loaders created: 'val' and 'val_balanced' +Epoch 1/10 | train loss 5.3459 acc 0.007 | val loss 5.3013 acc 0.005 | max constraint 0.3071 +Epoch 2/10 | train loss 5.2444 acc 0.012 | val loss 5.3078 acc 0.005 | max constraint 0.6606 +Epoch 3/10 | train loss 5.1482 acc 0.016 | val loss 5.2965 acc 0.005 | max constraint 1.6579 +Epoch 4/10 | train loss 4.9798 acc 0.028 | val loss 4.7633 acc 0.043 | max constraint 3.8804 +Epoch 5/10 | train loss 4.7372 acc 0.047 | val loss 4.3450 acc 0.090 | max constraint 9.3368 +Epoch 6/10 | train loss 4.4343 acc 0.079 | val loss 3.9285 acc 0.152 | max constraint 5.1328 +Epoch 7/10 | train loss 4.1320 acc 0.121 | val loss 3.5667 acc 0.219 | max constraint 5.3288 +Epoch 8/10 | train loss 3.7321 acc 0.179 | val loss 2.9391 acc 0.317 | max constraint 5.8890 +Epoch 9/10 | train loss 3.3022 acc 0.248 | val loss 2.3792 acc 0.451 | max constraint 6.3630 +Epoch 10/10 | train loss 2.7599 acc 0.359 | val loss 1.7547 acc 0.586 | max constraint 6.0408 diff --git a/run.batch b/run.batch deleted file mode 100644 index 6426706..0000000 --- a/run.batch +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/sh -#SBATCH --time=4:00:00 -#SBATCH --partition=h200fast -#SBATCH --nodes=1 -#SBATCH --mem=100G -#SBATCH --gres=gpu:1 -#SBATCH --cpus-per-task=1 -#SBATCH --out=./logs/simulator/exdbn_simulator.stdout -#SBATCH --err=./logs/simulator/err_simulator.out -#SBATCH --mail-user=bosakad1@fel.cvut.cz -#SBATCH --mail-type=ALL -#SBATCH --job-name=humancompatible - -ml load Python/3.13.1-GCCcore-14.2.0 -ml load GCCcore/14.2.0 -source env_humancompatible/bin/activate - -python3 train_4_optimizers.py diff --git a/setup.txt b/setup.txt deleted file mode 100644 index d541823..0000000 --- a/setup.txt +++ /dev/null @@ -1,9 +0,0 @@ - -srun -p gpufast --gres=gpu:1 --pty bash -i -srun -p h200fast --gres=gpu:1 --pty bash -i - -ml load Python/3.13.1-GCCcore-14.2.0 -ml load GCCcore/14.2.0 -ml load PyTorch/2.10.0-foss-2025a-CUDA-12.8.0 - -source env_humancompatible/bin/activate \ No newline at end of file diff --git a/src/humancompatible/train/dual_optim/__init__.py b/src/humancompatible/train/dual_optim/__init__.py index 37ecb26..5dafcbd 100644 --- a/src/humancompatible/train/dual_optim/__init__.py +++ b/src/humancompatible/train/dual_optim/__init__.py @@ -1,3 +1,4 @@ from .alm import ALM +from .ialm import iALM from .pbm import PBM -from .moreau import MoreauEnvelope \ No newline at end of file +from .moreau import MoreauEnvelope diff --git a/src/humancompatible/train/dual_optim/alm.py b/src/humancompatible/train/dual_optim/alm.py index dec03c6..9cb4a87 100644 --- a/src/humancompatible/train/dual_optim/alm.py +++ b/src/humancompatible/train/dual_optim/alm.py @@ -1,9 +1,11 @@ import torch from torch.nn import Parameter from torch.optim import Optimizer -from typing import Any, Iterable, Tuple -from torch import clamp_, Tensor, no_grad -from torch.optim.optimizer import _use_grad_for_differentiable +from typing import Any, Tuple +from torch import clamp_, Tensor + +# cite: Stochastic Smoothed Primal-Dual Algorithms for Nonconvex Optimization with Linear Inequality Constraints +# https://arxiv.org/pdf/2504.07607 class ALM(Optimizer): @@ -15,16 +17,17 @@ def __init__( penalty: float = 1.0, *, dual_range: Tuple[float, float] = (0.0, 100.0), - momentum: float = 0., - dampening: float = 0., - device = None + momentum: float = 0.0, + dampening: float = 0.0, + ctol: float = 0., + device=None, ) -> None: """ A wrapper over a PyTorch`Optimizer` that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. :param m: Number of constraints (determines the number of dual variables to create) :type m: int - :param lr: Dual variable update rate + :param lr: Dual variable update rate :type lr: float :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. :type init_duals: float | Tensor @@ -36,48 +39,22 @@ def __init__( :type momentum: float :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. :type dampening: float + :param ctol: Constraint tolerance; allows tiny violations of constraints to account for noise. + :type ctol: float """ if momentum > 0 and dampening == 0: dampening = momentum - - self.dual_range = dual_range - - self.penalty = penalty - duals, defaults = self._init_constraint_group(m, lr, momentum, dampening, init_duals, dual_range, device) - - super().__init__(duals, defaults) - @staticmethod - def _init_constraint_group( - m: int = None, lr: float = None, momentum: float = None, dampening: float = None, init_duals: float | Tensor = None, dual_range: Tuple[float, float] = None, device = None - ): - ## checks ## - if init_duals is None and m is None: - raise ValueError("At least one of`m`,`init_duals` must be set") - - if momentum is not None and (momentum < 0 or momentum > 1): - raise ValueError(f"`momentum`must be within [0,1]; got {momentum}") - - m = m if m is not None else len(init_duals) - - if init_duals is None: # initialize duals if not set or set to scalar - init_duals = torch.zeros(m, requires_grad=False, device=device) + dual_range[0] - elif isinstance(init_duals, float): - init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals - - duals = Parameter(init_duals, requires_grad=False) + self.dual_range = dual_range + self.ctol = ctol - settings_dict = { - "lr": lr, - "momentum": momentum, - "dampening": dampening, - "momentum_buffer": torch.zeros_like(init_duals, requires_grad = False, device=device), - } - settings_dict = {k:v for k,v in settings_dict.items() if v is not None} + self.penalty = penalty + duals, defaults = _init_constraint_group( + m, lr, momentum, dampening, init_duals, dual_range, device + ) - param_group = ([duals], settings_dict) - return param_group + super().__init__(duals, defaults) @property def duals(self) -> Tensor: @@ -88,7 +65,12 @@ def duals(self) -> Tensor: return torch.cat([group["params"][0] for group in self.param_groups]) def add_constraint_group( - self, m: int = None, lr: float = None, momentum: float = None, dampening: float = None, init_duals: Tensor = None + self, + m: int = None, + lr: float = None, + momentum: float = None, + dampening: float = None, + init_duals: Tensor = None, ) -> None: """ Allows to add a group of dual variables with separate initial values and learning rates. @@ -100,15 +82,26 @@ def add_constraint_group( :param init_duals: Initial values for the new dual variables :type init_duals: Tensor """ - duals, settings_dict = self._init_constraint_group(m, lr, momentum, dampening, init_duals, self.dual_range) + duals, settings_dict = _init_constraint_group( + m, lr, momentum, dampening, init_duals, self.dual_range + ) param_group_dict = {"params": duals, **settings_dict} self.add_param_group(param_group_dict) + def _add_penalty_term(self, lagrangian: Tensor, constraints: Tensor) -> None: + """Add penalty term to lagrangian in-place.""" + if self.penalty > 0: + lagrangian.add_( + 0.5 + * self.penalty + * torch.dot(constraints - self.ctol, constraints - self.ctol) + ) + def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: """ Calculates and returns the Augmented Lagrangian. - + :param loss: Loss (objective function) value :type loss: Tensor :param constraints: Tensor of constraint values @@ -118,40 +111,34 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: """ lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i, group in enumerate(self.param_groups): - duals = group["params"][0] - group_constraints = constraints[i * len(duals) : (i + 1) * len(duals)] - lagrangian.add_(duals @ group_constraints) - if self.penalty > 0: - lagrangian.add_( - 0.5 * self.penalty * torch.dot(constraints, constraints) + for i in range(len(self.param_groups)): + duals, group_constraints = _process_constraint_group( + self.param_groups[i], i, constraints, self.ctol, self.dual_range, update_duals=False ) + lagrangian.add_(duals @ group_constraints) + self._add_penalty_term(lagrangian, constraints) return lagrangian def update(self, constraints: Tensor) -> None: - """""" """ Updates the dual variables - + :param constraints: Tensor of constraint values :type constraints: Tensor """ - for i, group in enumerate(self.param_groups): - duals, lr, momentum, dampening, momentum_buffer = group["params"][0], group["lr"], group["momentum"], group["dampening"], group["momentum_buffer"] - group_constraints = constraints[i * len(duals) : (i + 1) * len(duals)] - with torch.no_grad(): - _update_duals(duals, group_constraints, lr, momentum, dampening, momentum_buffer) - clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) - + for i in range(len(self.param_groups)): + _process_constraint_group( + self.param_groups[i], i, constraints, self.ctol, self.dual_range, update_duals=True + ) # evaluate the Lagrangian and update the dual variables def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: """ Combines `forward` and `update`; slightly faster. - + :param loss: Loss (objective function) value :type loss: Tensor :param constraints: Tensor of constraint values @@ -161,30 +148,27 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: """ lagrangian = torch.zeros_like(loss) lagrangian.add_(loss) - for i, group in enumerate(self.param_groups): - duals, lr, momentum, dampening, momentum_buffer = group["params"][0], group["lr"], group["momentum"], group["dampening"], group["momentum_buffer"] - group_constraints = constraints[i * len(duals) : (i + 1) * len(duals)] - with torch.no_grad(): - _update_duals(duals, group_constraints, lr, momentum, dampening, momentum_buffer) - clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) - lagrangian.add_(duals @ group_constraints) - - if self.penalty > 0: - lagrangian.add_( - 0.5 * self.penalty * torch.dot(constraints, constraints) + for i in range(len(self.param_groups)): + duals, group_constraints = _process_constraint_group( + self.param_groups[i], i, constraints, self.ctol, self.dual_range, update_duals=True ) + lagrangian.add_(duals @ group_constraints) + self._add_penalty_term(lagrangian, constraints) return lagrangian def state_dict(self) -> dict[str, Any]: - + state_dict = super().state_dict() state_dict["state"]["penalty"] = self.penalty state_dict["state"]["dual_range"] = self.dual_range # save params themselves in state_dict instead of param ID in default PyTorch - for id_pg, pg in enumerate(state_dict['param_groups']): - pg['params'] = [self.param_groups[id_pg]['params'][param_id] for param_id in pg['params'] ] + for id_pg, pg in enumerate(state_dict["param_groups"]): + pg["params"] = [ + self.param_groups[id_pg]["params"][param_id] + for param_id in pg["params"] + ] return state_dict def load_state_dict(self, state_dict: dict[str, Any]) -> None: @@ -196,9 +180,96 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.param_groups.append(param) -def _update_duals(duals: Tensor, constraints: Tensor, lr: float, momentum: float, dampening: float, buffer: Tensor) -> None: +def _process_constraint_group( + group: dict[str, Any], + group_idx: int, + constraints: Tensor, + ctol: float, + dual_range: Tuple[float, float], + update_duals: bool = False, +) -> Tuple[Tensor, Tensor]: + """ + Process a single constraint group: extract duals/constraints and optionally update duals. + + :param group: The constraint group dictionary + :param group_idx: Index of the constraint group + :param constraints: Full constraints tensor + :param ctol: Constraint tolerance + :param dual_range: Safeguarding range for dual variables + :param update_duals: Whether to update dual variables + :return: Tuple of (duals, group_constraints) + """ + duals = group["params"][0] + group_constraints = ( + constraints[group_idx * len(duals) : (group_idx + 1) * len(duals)] - ctol + ) + + if update_duals: + lr = group.get("lr") + momentum = group.get("momentum", 0.0) + dampening = group.get("dampening", 0.0) + momentum_buffer = group["momentum_buffer"] + + with torch.no_grad(): + _update_duals( + duals, group_constraints, lr, momentum, dampening, momentum_buffer + ) + clamp_(duals, min=dual_range[0], max=dual_range[1]) + + return duals, group_constraints + + +def _init_constraint_group( + m: int = None, + lr: float = None, + momentum: float = None, + dampening: float = None, + init_duals: float | Tensor = None, + dual_range: Tuple[float, float] = None, + device=None, + ): + ## checks ## + if init_duals is None and m is None: + raise ValueError("At least one of`m`,`init_duals` must be set") + + if momentum is not None and (momentum < 0 or momentum > 1): + raise ValueError(f"`momentum`must be within [0,1]; got {momentum}") + + m = m if m is not None else len(init_duals) + + if init_duals is None: # initialize duals if not set or set to scalar + init_duals = ( + torch.zeros(m, requires_grad=False, device=device) + dual_range[0] + ) + elif isinstance(init_duals, float): + init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals + + duals = Parameter(init_duals, requires_grad=False) + + settings_dict = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "momentum_buffer": torch.zeros_like( + init_duals, requires_grad=False, device=device + ), + } + settings_dict = {k: v for k, v in settings_dict.items() if v is not None} + + param_group = ([duals], settings_dict) + return param_group + + +def _update_duals( + duals: Tensor, + constraints: Tensor, + lr: float, + momentum: float, + dampening: float, + buffer: Tensor, +) -> None: if momentum == 0: buffer = constraints else: - buffer.mul_(momentum).add_(constraints, alpha = 1 - dampening) - duals.add_(buffer, alpha = lr) + buffer.mul_(momentum).add_(constraints, alpha=1 - dampening) + duals.add_(buffer, alpha=lr) diff --git a/src/humancompatible/train/dual_optim/barrier.py b/src/humancompatible/train/dual_optim/barrier.py index 6a78d26..512444c 100644 --- a/src/humancompatible/train/dual_optim/barrier.py +++ b/src/humancompatible/train/dual_optim/barrier.py @@ -8,23 +8,22 @@ """ + def exponential_penalty(t): return torch.exp(t) - 1.0 + def modified_log_barrier(t): - return -torch.log(1-t) + return -torch.log(1 - t) + def augmented_lagrangian(t): """ Vectorized version of augmented_lagrangian """ - return torch.where( - t >= -1, - t + 0.5 * torch.square(t), - -0.5 * torch.ones_like(t) - ) + return torch.where(t >= -1, t + 0.5 * torch.square(t), -0.5 * torch.ones_like(t)) def quad_log(t): @@ -36,7 +35,7 @@ def quad_log(t): mask = t >= -0.5 out[mask] = t[mask] + 0.5 * torch.pow(t[mask], 2) - out[~mask] = -0.25 * torch.log(-2 * t[~mask]) - 3/8 + out[~mask] = -0.25 * torch.log(-2 * t[~mask]) - 3 / 8 return out @@ -48,27 +47,26 @@ def quad_recipr(t): out = torch.empty_like(t) - mask = t >= -1/3 + mask = t >= -1 / 3 out[mask] = t[mask] + 0.5 * torch.pow(t[mask], 2) - out[~mask] = (32/27) * (1 / (1 - t[~mask])) - 7/6 + out[~mask] = (32 / 27) * (1 / (1 - t[~mask])) - 7 / 6 + + return out - return out def exponential_penalty_derivative(t): - return torch.exp(t) + return torch.exp(t) + def modified_log_barrier_derivative(t): - return 1 / (1-t) + return 1 / (1 - t) + def aug_lagr_der(t): - return torch.where( - t >= -1, - 1 + t, - torch.zeros_like(t) - ) + return torch.where(t >= -1, 1 + t, torch.zeros_like(t)) def quad_log_der(t): @@ -81,12 +79,13 @@ def quad_log_der(t): return out + def quad_recipr_der(t): out = torch.empty_like(t) - mask = t >= -1/3 + mask = t >= -1 / 3 out[mask] = 1 + t[mask] - out[~mask] = (32/27) * (1 / torch.square(1 - t[~mask])) + out[~mask] = (32 / 27) * (1 / torch.square(1 - t[~mask])) - return out \ No newline at end of file + return out diff --git a/src/humancompatible/train/dual_optim/ialm.py b/src/humancompatible/train/dual_optim/ialm.py new file mode 100644 index 0000000..82d1965 --- /dev/null +++ b/src/humancompatible/train/dual_optim/ialm.py @@ -0,0 +1,309 @@ +import torch +from torch.nn import Parameter +from torch.optim import Optimizer +from typing import Any, Tuple +from torch import clamp_, Tensor + +# cite: Stochastic inexact augmented Lagrangian method for nonconvex expectation constrained optimization +# https://link.springer.com/content/pdf/10.1007/s10589-023-00521-z.pdf + + +class iALM(Optimizer): + def __init__( + self, + m: int = None, + lr: float = 0.01, + init_duals: float | Tensor = None, + penalty: float = 1.0, + *, + dual_range: Tuple[float, float] = (0.0, 100.0), + momentum: float = 0.0, + dampening: float = 0.0, + beta: float = 1.0, + sigma: float = 1.0, + gamma: float = 1.0, + ctol: float = 1e-4, + device=None, + ) -> None: + """ + A wrapper over a PyTorch`Optimizer` that works on the dual maximization tasks according to the Augmented Lagrangian rule. Creates and updates dual variables. + + :param m: Number of constraints (determines the number of dual variables to create) + :type m: int + :param lr: Dual variable update rate + :type lr: float + :param init_duals: Initial values for the new dual variables. Defaults to 0 for all. + :type init_duals: float | Tensor + :param penalty: Augmented Lagrangian penalty parameter. Defaults to`1.` + :type penalty: float + :param dual_range: Safeguarding range for dual variables; they will be`clamp`-ed to this range. + :type dual_range: Tuple[float, float] + :param momentum: Momentum/Smoothing factor for dual variables. Equivalent to SGD momentum. Set to `0` to disable. + :type momentum: float + :param dampening: Dampening for momentum. Equivalent to SGD dampening. Set to `0` to disable. + :type dampening: float + :param beta: Dual variable update rate + :type beta: float + :param sigma: Multiplier for increasing`beta`. + :type sigma: float + :param gamma: Penalty update parameter + :type gamma: float + :param ctol: Constraint tolerance; value that allows tiny violations of constraints to account for noise. + :type ctol: float + """ + + if momentum > 0 and dampening == 0: + dampening = momentum + + self.dual_range = dual_range + + self.beta = beta + self.penalty = penalty + self.gamma = gamma + self.sigma = sigma + self.ctol = ctol + + duals, defaults = self._init_constraint_group( + m, lr, momentum, dampening, init_duals, dual_range, device + ) + + super().__init__(duals, defaults) + + @staticmethod + def _init_constraint_group( + m: int = None, + lr: float = None, + momentum: float = None, + dampening: float = None, + init_duals: float | Tensor = None, + dual_range: Tuple[float, float] = None, + device=None, + ): + ## checks ## + if init_duals is None and m is None: + raise ValueError("At least one of`m`,`init_duals` must be set") + + if momentum is not None and (momentum < 0 or momentum > 1): + raise ValueError(f"`momentum`must be within [0,1]; got {momentum}") + + m = m if m is not None else len(init_duals) + + if init_duals is None: # initialize duals if not set or set to scalar + init_duals = ( + torch.zeros(m, requires_grad=False, device=device) + dual_range[0] + ) + elif isinstance(init_duals, float): + init_duals = torch.zeros(m, requires_grad=False, device=device) + init_duals + + duals = Parameter(init_duals, requires_grad=False) + + settings_dict = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "momentum_buffer": torch.zeros_like( + init_duals, requires_grad=False, device=device + ), + } + settings_dict = {k: v for k, v in settings_dict.items() if v is not None} + + param_group = ([duals], settings_dict) + return param_group + + @property + def duals(self) -> Tensor: + """ + :return: Dual variables, concatenated into a single tensor. + :rtype: Tensor + """ + return torch.cat([group["params"][0] for group in self.param_groups]) + + def add_constraint_group( + self, + m: int = None, + lr: float = None, + momentum: float = None, + dampening: float = None, + init_duals: Tensor = None, + ) -> None: + """ + Allows to add a group of dual variables with separate initial values and learning rates. + + :param m: Size of group (number of dual variables to add) + :type m: int + :param lr: Dual variable update rate + :type lr: float + :param init_duals: Initial values for the new dual variables + :type init_duals: Tensor + """ + duals, settings_dict = self._init_constraint_group( + m, lr, momentum, dampening, init_duals, self.dual_range + ) + param_group_dict = {"params": duals, **settings_dict} + self.add_param_group(param_group_dict) + + def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: + """ + Calculates and returns the Augmented Lagrangian. + + :param loss: Loss (objective function) value + :type loss: Tensor + :param constraints: Tensor of constraint values + :type constraints: Tensor + :return: Lagrangian + :rtype: Tensor + """ + lagrangian = torch.zeros_like(loss) + lagrangian.add_(loss) + for i, group in enumerate(self.param_groups): + duals, lr, momentum, dampening, momentum_buffer = ( + group["params"][0], + group["lr"], + group["momentum"], + group["dampening"], + group["momentum_buffer"], + ) + group_constraints = ( + constraints[i * len(duals) : (i + 1) * len(duals)] - self.ctol + ) + lagrangian.add_(duals @ group_constraints) + + _update_c_buffers(group_constraints, momentum, dampening, momentum_buffer) + + lagrangian.add_(0.5 * self.beta * torch.dot(constraints, constraints)) + + return lagrangian + + def update(self, constraints: Tensor) -> None: + """ + Updates the dual variables + + :param constraints: Tensor of constraint values + :type constraints: Tensor + """ + for i, group in enumerate(self.param_groups): + duals, lr, momentum, dampening, momentum_buffer = ( + group["params"][0], + group["lr"], + group["momentum"], + group["dampening"], + group["momentum_buffer"], + ) + group_constraints = ( + constraints[i * len(duals) : (i + 1) * len(duals)] - self.ctol + ) + with torch.no_grad(): + _update_duals( + duals, + group_constraints, + lr, + self.beta, + self.gamma, + momentum, + dampening, + momentum_buffer, + ) + clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) + + self.beta *= self.sigma + + # evaluate the Lagrangian and update the dual variables + def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: + """ + Combines `forward` and `update`; slightly faster. + + :param loss: Loss (objective function) value + :type loss: Tensor + :param constraints: Tensor of constraint values + :type constraints: Tensor + :return: Lagrangian + :rtype: Tensor + """ + lagrangian = torch.zeros_like(loss) + lagrangian.add_(loss) + for i, group in enumerate(self.param_groups): + duals, lr, momentum, dampening, momentum_buffer = ( + group["params"][0], + group["lr"], + group["momentum"], + group["dampening"], + group["momentum_buffer"], + ) + group_constraints = ( + constraints[i * len(duals) : (i + 1) * len(duals)] - self.ctol + ) + with torch.no_grad(): + _update_c_buffers( + group_constraints, momentum, dampening, momentum_buffer + ) + _update_duals( + duals, + group_constraints, + lr, + self.beta, + self.gamma, + momentum, + dampening, + momentum_buffer, + ) + clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) + + lagrangian.add_(duals @ group_constraints) + + lagrangian.add_( + 0.5 + * self.beta + * torch.dot(constraints - self.ctol, constraints - self.ctol) + ) + + self.beta *= self.sigma + + return lagrangian + + def state_dict(self) -> dict[str, Any]: + + state_dict = super().state_dict() + state_dict["state"]["penalty"] = self.penalty + state_dict["state"]["dual_range"] = self.dual_range + # save params themselves in state_dict instead of param ID in default PyTorch + for id_pg, pg in enumerate(state_dict["param_groups"]): + pg["params"] = [ + self.param_groups[id_pg]["params"][param_id] + for param_id in pg["params"] + ] + return state_dict + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.penalty = state_dict["state"]["penalty"] + self.dual_range = state_dict["state"]["dual_range"] + params = state_dict["param_groups"] + self.param_groups = [] + for param in params: + self.param_groups.append(param) + + +def _update_c_buffers( + constraints: Tensor, + momentum: float, + dampening: float, + buffer: Tensor, +): + if momentum == 0: + buffer = constraints + else: + buffer.mul_(momentum).add_(constraints, alpha=1 - dampening) + + +def _update_duals( + duals: Tensor, + constraints: Tensor, + lr: float, + beta: float, + gamma: float, + momentum: float, + dampening: float, + buffer: Tensor, +) -> None: + + update_mult = min(beta, gamma / (buffer @ buffer)) + duals.add_(buffer, alpha=update_mult) diff --git a/src/humancompatible/train/dual_optim/moreau.py b/src/humancompatible/train/dual_optim/moreau.py index a964857..a6269f7 100644 --- a/src/humancompatible/train/dual_optim/moreau.py +++ b/src/humancompatible/train/dual_optim/moreau.py @@ -1,6 +1,5 @@ import torch from torch.optim import Optimizer -from torch.optim.optimizer import _use_grad_for_differentiable class MoreauEnvelope(Optimizer): @@ -8,7 +7,7 @@ def __init__( self, optimizer: torch.optim.Optimizer, *, - mu: float = 2., + mu: float = 2.0, beta: float = 0.5, ) -> None: """ @@ -26,24 +25,33 @@ def __init__( self.mu, self.beta = mu, beta if mu < 0: - raise ValueError(f"The smoothing parameter`mu`must be non-negative, got {mu}.") + raise ValueError( + f"The smoothing parameter`mu`must be non-negative, got {mu}." + ) else: self.smoothing_buffer = [] for param_group in optimizer.param_groups: - self.smoothing_buffer.append({'params': []}) - for _, param in enumerate(param_group['params']): - self.smoothing_buffer[-1]['params'].append(param.clone().detach()) + self.smoothing_buffer.append({"params": []}) + for _, param in enumerate(param_group["params"]): + self.smoothing_buffer[-1]["params"].append(param.clone().detach()) def step(self) -> None: with torch.no_grad(): # add smoothing term gradient to the gradient w.r.t. primal params, and update smoothing params before optimizer step - for param_group, smoothing_buffer_group in zip(self.optimizer.param_groups, self.smoothing_buffer): - for param, smoothing_buffer in zip(param_group["params"], smoothing_buffer_group['params']): - param.grad.add_(param, alpha=self.mu).add_(smoothing_buffer, alpha=-self.mu) - smoothing_buffer.add_(smoothing_buffer, alpha=-self.beta).add_(param, alpha=self.beta) - - self.optimizer.step() + for param_group, smoothing_buffer_group in zip( + self.optimizer.param_groups, self.smoothing_buffer + ): + for param, smoothing_buffer in zip( + param_group["params"], smoothing_buffer_group["params"] + ): + param.grad.add_(param, alpha=self.mu).add_( + smoothing_buffer, alpha=-self.mu + ) + smoothing_buffer.add_(smoothing_buffer, alpha=-self.beta).add_( + param, alpha=self.beta + ) + self.optimizer.step() def __getattr__(self, name): # Delegate to the wrapped object @@ -54,8 +62,10 @@ def __getattr__(self, name): else: # If it's a method, bind it to self.optimizer if callable(attr): + def method(*args, **kwargs): return attr(self.optimizer, *args, **kwargs) + return method else: return attr @@ -72,7 +82,7 @@ def method(*args, **kwargs): # def load_state_dict(self, state_dict: dict[str, Any]) -> None: # primal_state_dict = state_dict["primal"] # self.primal_optimizer.load_state_dict(primal_state_dict) - + # dual_state_dict = state_dict["dual"] # self.penalty = dual_state_dict["state"]["penalty"] # self.dual_range = dual_state_dict["state"]["dual_range"] @@ -80,4 +90,4 @@ def method(*args, **kwargs): # params = dual_state_dict["param_groups"] # self.param_groups = [] # for param in params: - # self.param_groups.append(param) \ No newline at end of file + # self.param_groups.append(param) diff --git a/src/humancompatible/train/dual_optim/pbm.py b/src/humancompatible/train/dual_optim/pbm.py index 6ee38ec..07c7665 100644 --- a/src/humancompatible/train/dual_optim/pbm.py +++ b/src/humancompatible/train/dual_optim/pbm.py @@ -1,15 +1,10 @@ import torch from torch.nn import Parameter from torch.optim import Optimizer -from typing import Any, Iterable, Tuple, Callable -from torch import clamp_, Tensor, no_grad -from torch.optim.optimizer import _use_grad_for_differentiable -from .barrier import ( - quad_log, - quad_log_der, - quad_recipr, - quad_recipr_der -) +from typing import Any, Tuple, Callable +from torch import clamp_, Tensor +from .barrier import quad_log, quad_log_der, quad_recipr, quad_recipr_der + class PBM(Optimizer): def __init__( @@ -18,15 +13,15 @@ def __init__( penalty_mult: float = 0.1, gamma: float = 0.9, delta: float = 0.9, - penalty_update: str = 'dimin_adapt', + penalty_update: str = "dimin_adapt", *, - pbf: str = 'quadratic_logarithmic', + pbf: str = "quadratic_logarithmic", init_duals: float | Tensor = None, init_penalties: float | Tensor = None, - dual_range: Tuple[float, float] = (0.0001, 100.), - penalty_range: Tuple[float, float] = (0.1, 100.), - device = None, - primal_update_process_length=1, # length of the primal update process - if =1, is the original algorithm + dual_range: Tuple[float, float] = (0.0001, 100.0), + penalty_range: Tuple[float, float] = (0.1, 2.0), + device=None, + primal_update_process_length=1, # length of the primal update process - if =1, is the original algorithm ) -> None: """ A wrapper over a PyTorch`Optimizer` that works on the dual maximization tasks according to the Penalty-Barrier Method rule. Creates and updates dual variables. @@ -38,7 +33,7 @@ def __init__( :param gamma: Multiplier for dual parameter update. Values close to 1 correspond to a high "momentum". :type gamma: float :param delta: Violation/satisfaction parameter for penalty update; values > 1 make the penalties decrease faster on violated constraints and vice versa. - :type delta: float + :type delta: float :param penalty_update: Penalty update strategy; must be one of `dimin`,`dimin_dual`,`dimin_adapt`,`const`. Defaults to`dimin_adapt`. :type penalty_update: str :param pbf: Penalty-Barrier Function to use. Must be one of `quadratic_logarithmic`,`quadratic_reciprocal` @@ -54,7 +49,20 @@ def __init__( self.dual_range = dual_range self.penalty_range = penalty_range - params, defaults = self._init_constraint_group(m, penalty_mult, penalty_update, delta, pbf, init_duals, init_penalties, gamma, dual_range, penalty_range, primal_update_process_length, device) + params, defaults = self._init_constraint_group( + m, + penalty_mult, + penalty_update, + delta, + pbf, + init_duals, + init_penalties, + gamma, + dual_range, + penalty_range, + primal_update_process_length, + device, + ) self.iter = 0 super().__init__(params, defaults) @@ -71,33 +79,45 @@ def _init_constraint_group( dual_range: Tuple[float, float] = None, penalty_range: Tuple[float, float] = None, primal_update_process_length: int = 1, - device = None + device=None, ): if init_duals is None and m is None: raise ValueError("At least one of`size`,`init_duals` must be set") - - if init_duals is None or isinstance(init_duals, (int, float)): # initialize duals if not set or set to scalar - init_duals = torch.zeros(m, requires_grad=False, device=device) + (init_duals if isinstance(init_duals, (int, float)) else dual_range[0]) - if init_penalties is None or isinstance(init_penalties, (int, float)): # initialize penalties if not set or set to scalar - init_penalties = torch.zeros(m, requires_grad=False, device=device) + (init_penalties if isinstance(init_penalties, (int, float)) else penalty_range[1]) + + if init_duals is None or isinstance( + init_duals, (int, float) + ): # initialize duals if not set or set to scalar + init_duals = torch.zeros(m, requires_grad=False, device=device) + ( + init_duals if isinstance(init_duals, (int, float)) else dual_range[0] + ) + if init_penalties is None or isinstance( + init_penalties, (int, float) + ): # initialize penalties if not set or set to scalar + init_penalties = torch.zeros(m, requires_grad=False, device=device) + ( + init_penalties + if isinstance(init_penalties, (int, float)) + else penalty_range[1] + ) duals = Parameter(init_duals, requires_grad=False) penalties = Parameter(init_penalties, requires_grad=False) - + primal_update_process_length = primal_update_process_length - if penalty_update == 'dimin': + if penalty_update == "dimin": penalty_update_f = _update_penalties_dimin - elif penalty_update == 'dimin_dual': + elif penalty_update == "dimin_dual": penalty_update_f = _update_penalties_dimin_dual - elif penalty_update == 'dimin_adapt': + elif penalty_update == "dimin_adapt": penalty_update_f = _update_penalties_adapt - elif penalty_update == 'const': + elif penalty_update == "const": penalty_update_f = _update_penalties_const + elif penalty_update == "aimd": + penalty_update_f = _update_penalties_aimd elif penalty_update is None: penalty_update_f = None else: - raise ValueError(f'Unknown penalty update function: {penalty_update}!') + raise ValueError(f"Unknown penalty update function: {penalty_update}!") settings_dict = { "p_mult": p_mult, @@ -106,9 +126,11 @@ def _init_constraint_group( "pbf": pbf, "dual_momentum": dual_momentum, "primal_update_process_length": primal_update_process_length, - "dual_momentum_buffer": torch.zeros_like(init_duals, requires_grad = False, device=device), + "dual_momentum_buffer": torch.zeros_like( + init_duals, requires_grad=False, device=device + ), } - settings_dict = {k:v for k,v in settings_dict.items() if v is not None} + settings_dict = {k: v for k, v in settings_dict.items() if v is not None} param_group = ([duals, penalties], settings_dict) @@ -123,7 +145,7 @@ def duals(self) -> Tensor: :rtype: Tensor """ return torch.cat([group["params"][0] for group in self.param_groups]) - + @property def penalties(self) -> Tensor: """ @@ -145,7 +167,7 @@ def add_constraint_group( init_penalties: float | Tensor = None, *, momentum: float = None, - primal_update_process_length: int = 1 + primal_update_process_length: int = 1, ) -> None: """ Adds an additional group of dual variables with separate hyperparameters and barrier functions. @@ -169,9 +191,20 @@ def add_constraint_group( :param primal_update_process_length: Length of the primal update process for this group. If 1 (default), uses original algorithm variant. :type primal_update_process_length: int """ - - - params, settings_dict = self._init_constraint_group(m, penalty_mult, penalty_update, delta, pbf, init_duals, init_penalties, momentum, self.dual_range, self.penalty_range, primal_update_process_length) + + params, settings_dict = self._init_constraint_group( + m, + penalty_mult, + penalty_update, + delta, + pbf, + init_duals, + init_penalties, + momentum, + self.dual_range, + self.penalty_range, + primal_update_process_length, + ) param_group_dict = {"params": params, **settings_dict} self.add_param_group(param_group_dict) @@ -186,13 +219,37 @@ def update(self, constraints: Tensor) -> None: """ for i, group in enumerate(self.param_groups): - duals, penalties, p_mult, _update_penalties, delta, pbf, momentum, primal_update_process_length = group["params"][0], group["params"][1], group["p_mult"], group["penalty_update"], group["delta"], group["pbf"], group['dual_momentum'], group["primal_update_process_length"] + ( + duals, + penalties, + p_mult, + _update_penalties, + delta, + pbf, + momentum, + primal_update_process_length, + ) = ( + group["params"][0], + group["params"][1], + group["p_mult"], + group["penalty_update"], + group["delta"], + group["pbf"], + group["dual_momentum"], + group["primal_update_process_length"], + ) group_constraints = constraints[i * len(duals) : (i + 1) * len(duals)] cdivp = group_constraints.div(penalties) with torch.no_grad(): - _update_duals(duals, cdivp, penalty_barrier_funcs[pbf]['d'], momentum) + _update_duals(duals, cdivp, penalty_barrier_funcs[pbf]["d"], momentum) clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) - _update_penalties(penalties, p_mult, duals, penalty_barrier_funcs[pbf]['d'](group_constraints), delta) + _update_penalties( + penalties, + p_mult, + duals, + penalty_barrier_funcs[pbf]["d"](group_constraints), + delta, + ) clamp_(penalties, min=self.penalty_range[0], max=self.penalty_range[1]) def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: @@ -214,7 +271,7 @@ def forward(self, loss: Tensor, constraints: Tensor) -> Tensor: group_constraints = constraints[i * len(duals) : (i + 1) * len(duals)] # calculate lagrangian cdivp = group_constraints.div(penalties) - pbf_val = penalty_barrier_funcs[pbf]['f'](cdivp) + pbf_val = penalty_barrier_funcs[pbf]["f"](cdivp) lagrangian.add_(duals.mul(penalties) @ pbf_val) return lagrangian @@ -240,28 +297,64 @@ def forward_update(self, loss: Tensor, constraints: Tensor) -> Tensor: lagrangian.add_(loss) _last_c_group_index = 0 for i, group in enumerate(self.param_groups): - duals, penalties, p_mult, _update_penalties, delta, pbf, momentum, primal_update_process_length = group["params"][0], group["params"][1], group["p_mult"], group["penalty_update"], group["delta"], group["pbf"], group['dual_momentum'], group["primal_update_process_length"] - group_constraints = constraints[_last_c_group_index : _last_c_group_index + len(duals)] + ( + duals, + penalties, + p_mult, + _update_penalties, + delta, + pbf, + momentum, + primal_update_process_length, + ) = ( + group["params"][0], + group["params"][1], + group["p_mult"], + group["penalty_update"], + group["delta"], + group["pbf"], + group["dual_momentum"], + group["primal_update_process_length"], + ) + group_constraints = constraints[ + _last_c_group_index : _last_c_group_index + len(duals) + ] _last_c_group_index = _last_c_group_index + len(duals) # calculate lagrangian - if self.iter + 1 == primal_update_process_length: # this enables a second variant of the algorithm + if ( + self.iter + 1 == primal_update_process_length + ): # this enables a second variant of the algorithm # update duals and penalties cdivp = group_constraints.div(penalties) with torch.no_grad(): - _update_duals(duals, cdivp, penalty_barrier_funcs[pbf]['d'], momentum) + _update_duals( + duals, cdivp, penalty_barrier_funcs[pbf]["d"], momentum + ) clamp_(duals, min=self.dual_range[0], max=self.dual_range[1]) - _update_penalties(penalties, p_mult, duals, penalty_barrier_funcs[pbf]['d'](group_constraints), delta) - clamp_(penalties, min=self.penalty_range[0], max=self.penalty_range[1]) + _update_penalties( + penalties, + p_mult, + duals, + penalty_barrier_funcs[pbf]["d"](group_constraints), + delta, + ) # , cdivp) + clamp_( + penalties, min=self.penalty_range[0], max=self.penalty_range[1] + ) cdivp = group_constraints.div(penalties) - pbf_val = penalty_barrier_funcs[pbf]['f'](cdivp) - lagrangian.add_(duals.mul(penalties) @ pbf_val) + pbf_val = penalty_barrier_funcs[pbf]["f"](cdivp) + + # change duals to 0 for them < 1e-4, but do not overwrite the actual duals to keep the momentum working + active = duals >= 1e-5 + if active.any(): + lagrangian.add_(duals[active].mul(penalties[active]) @ pbf_val[active]) # update the iter self.iter = (self.iter + 1) % primal_update_process_length return lagrangian - + def update_penalties(self, constraints: Tensor) -> None: """ Updates penalties according to the specified penalty update strategy for each constraint group. @@ -272,12 +365,22 @@ def update_penalties(self, constraints: Tensor) -> None: :rtype: None """ for i, group in enumerate(self.param_groups): - duals, penalties, p_mult, _update_penalties, pbf = group["params"][0], group["params"][1], group["p_mult"], group["penalty_update"], group['pbf'] + duals, penalties, p_mult, _update_penalties, pbf = ( + group["params"][0], + group["params"][1], + group["p_mult"], + group["penalty_update"], + group["pbf"], + ) group_constraints = constraints[i * len(duals) : (i + 1) * len(duals)] - _update_penalties(penalties, p_mult, duals, penalty_barrier_funcs[pbf]['d'](group_constraints)) + _update_penalties( + penalties, + p_mult, + duals, + penalty_barrier_funcs[pbf]["d"](group_constraints), + ) clamp_(penalties, min=self.penalty_range[0], max=self.penalty_range[1]) - def state_dict(self) -> dict[str, Any]: """ Returns the state of the optimizer as a dictionary, including dual and penalty ranges and all constraint groups. @@ -285,13 +388,13 @@ def state_dict(self) -> dict[str, Any]: :return: Dictionary containing optimizer state with param groups and configuration. :rtype: dict[str, Any] """ - + state_dict = super().state_dict() state_dict["state"]["penalty_range"] = self.penalty_range state_dict["state"]["dual_range"] = self.dual_range # save params themselves in state_dict instead of param ID in default PyTorch - for id_pg, pg in enumerate(state_dict['param_groups']): - pg['params'] = self.param_groups[id_pg]['params'] + for id_pg, pg in enumerate(state_dict["param_groups"]): + pg["params"] = self.param_groups[id_pg]["params"] return state_dict def load_state_dict(self, state_dict: dict[str, Any]) -> None: @@ -311,35 +414,75 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.param_groups.append(param) - penalty_barrier_funcs = { - 'quadratic_logarithmic': {'f': quad_log, 'd': quad_log_der}, - 'quadratic_reciprocal': {'f': quad_recipr, 'd': quad_recipr_der} + "quadratic_logarithmic": {"f": quad_log, "d": quad_log_der}, + "quadratic_reciprocal": {"f": quad_recipr, "d": quad_recipr_der}, } -def _update_duals(duals: Tensor, cdivp: Tensor, pbf_der: Callable, gamma: float) -> None: + +def _update_duals( + duals: Tensor, cdivp: Tensor, pbf_der: Callable, gamma: float +) -> None: pbf_der_val = pbf_der(cdivp) upd = pbf_der_val.mul(duals) - duals.mul_(gamma).add_(upd, alpha=1-gamma) + duals.mul_(gamma).add_(upd, alpha=1 - gamma) + -def _update_penalties_const(penalties: Tensor, p_mult: Tensor = None, duals: Tensor = None, phi_der: Tensor = None, delta: float = None): +def _update_penalties_const( + penalties: Tensor, + p_mult: Tensor = None, + duals: Tensor = None, + phi_der: Tensor = None, + delta: float = None, +): pass -def _update_penalties_dimin(penalties: Tensor, p_mult: Tensor, duals: Tensor = None, phi_der: Tensor = None, delta: float = None): + +def _update_penalties_dimin( + penalties: Tensor, + p_mult: Tensor, + duals: Tensor = None, + phi_der: Tensor = None, + delta: float = None, +): penalties.mul_(p_mult) -def _update_penalties_adapt(penalties: Tensor, p_mult: Tensor, duals: Tensor, phi_der: Tensor, delta: float): - d_phd = torch.where(phi_der < 1., phi_der, torch.clamp(delta * phi_der, min=1.0)) - b = (1-p_mult)*penalties/(d_phd + 1e-8) + +def _update_penalties_adapt( + penalties: Tensor, p_mult: Tensor, duals: Tensor, phi_der: Tensor, delta: float +): + d_phd = torch.where(phi_der < 1.0, phi_der, delta * phi_der) + b = (1 - p_mult) * penalties / (d_phd + 1e-8) penalties.mul_(p_mult).add_(b) -def _update_penalties_dimin_dual(penalties: Tensor, p_mult: Tensor, duals: Tensor, phi_der: Tensor = None, delta: float = None): + +def _update_penalties_aimd( + penalties: Tensor, + p_mult: Tensor, + duals: Tensor, + phi_der: Tensor, + delta: float, + cdivp: Tensor, +): + p_add_rate = 0.1 + p_upd_add = torch.where(cdivp <= 0.0, p_add_rate, 0.0) + p_upd_mult = torch.where(cdivp > 0.0, p_mult, 1.0) + penalties.add_(p_upd_add).mul_(p_upd_mult) + + +def _update_penalties_dimin_dual( + penalties: Tensor, + p_mult: Tensor, + duals: Tensor, + phi_der: Tensor = None, + delta: float = None, +): penalties.mul_(p_mult).mul_(duals) penalty_update_funcs = { - 'const': _update_penalties_const, - 'dimin': _update_penalties_dimin, - 'adapt': _update_penalties_adapt, - 'dimin_dual': _update_penalties_dimin_dual -} \ No newline at end of file + "const": _update_penalties_const, + "dimin": _update_penalties_dimin, + "adapt": _update_penalties_adapt, + "dimin_dual": _update_penalties_dimin_dual, +} diff --git a/src/humancompatible/train/fairness/utils/balanced_batch_sampler.py b/src/humancompatible/train/fairness/utils/balanced_batch_sampler.py index fe6fcd8..2fb8370 100644 --- a/src/humancompatible/train/fairness/utils/balanced_batch_sampler.py +++ b/src/humancompatible/train/fairness/utils/balanced_batch_sampler.py @@ -71,12 +71,15 @@ def __iter__(self): # determine number of tiles if not self._extend_groups or group_id not in self._extend_groups: num_tiles = 1 + tile_size = len(group_indices) else: - num_tiles = ceil(max(self._group_sizes) / self._group_sizes[group_id]) - # tile with random reorderings of list of indices of the group + num_tiles = ceil(max(self._group_sizes) / self._n_samples_per_group) + tile_size = self._n_samples_per_group + # num_tiles = ceil(max(self._group_sizes) / self._group_sizes[group_id]) + # tile with random n_samples_per_group-sized reorderings of list of indices of the group for _ in range(num_tiles): - # shuffle - indices_shuffled = torch.randperm(len(group_indices), generator=self.generator).tolist() + # shuffle within tile + indices_shuffled = torch.randperm(len(group_indices), generator=self.generator).tolist()[: tile_size] # add new shuffled tile to the indices group_indices_tiled_shuffled.extend(indices_shuffled) # cutoff at the length of max group @@ -95,7 +98,6 @@ def __iter__(self): for indices in self._group_indices ): max_batches += 1 # Include partial batches if drop_last is False - # Yield balanced batches for batch_idx in range(max_batches): batch = [] diff --git a/tests/test_balanced_batch_sampler.py b/tests/test_balanced_batch_sampler.py index 4c3afe8..62ae520 100644 --- a/tests/test_balanced_batch_sampler.py +++ b/tests/test_balanced_batch_sampler.py @@ -23,6 +23,35 @@ def setUp(self): ] ).T + + def test_extends_without_replacement(self): + for _ in range(10): # Run multiple times to check randomness + sampler = BalancedBatchSampler( + group_indices=self.subset_indices, + batch_size=6, + drop_last=True, + extend_groups=[0, 1], + ) + # Check that each extended group is sampled without replacement within a batch + batches = [] + for i, batch in enumerate(sampler): + self.assertEqual(len(set(batch)), len(batch)) + batches.extend(batch) + + + def test_extend_num_batches(self): + sampler = BalancedBatchSampler( + group_indices=self.subset_indices, + batch_size=6, + drop_last=True, + extend_groups=[0, 1], + ) + # check correct number of batches in case of tiling + i = 0 + for _ in iter(sampler): + i += 1 + self.assertEqual(i, 2) + def test_batch_size_divisible(self): with self.assertRaises(AssertionError): BalancedBatchSampler( @@ -72,32 +101,7 @@ def test_balanced_extended_batches(self): self.assertEqual(len([i for i in batch if i in self.subset_indices[1]]), 2) self.assertEqual(len([i for i in batch if i in self.subset_indices[2]]), 2) - def test_extends_without_replacement(self): - sampler = BalancedBatchSampler( - group_indices=self.subset_indices, - batch_size=6, - drop_last=True, - extend_groups=[0, 1], - ) - # Check that each extended group is sampled without replacement within a batch - batches = [] - for i, batch in enumerate(sampler): - self.assertEqual(len(set(batch)), len(batch)) - batches.extend(batch) - - - def test_extend_num_batches(self): - sampler = BalancedBatchSampler( - group_indices=self.subset_indices, - batch_size=6, - drop_last=True, - extend_groups=[0, 1], - ) - # check correct number of batches in case of tiling - i = 0 - for _ in iter(sampler): - i += 1 - self.assertEqual(i, 2) + def test_extend_large_batchsize(self): # check AssertionError on batch_size / n_groups > size of one of the groups diff --git a/tests/test_pbm.py b/tests/test_pbm.py index 0f44648..a7fa2d7 100644 --- a/tests/test_pbm.py +++ b/tests/test_pbm.py @@ -114,7 +114,7 @@ class TestPenaltyUpdateStrategies(unittest.TestCase): def test_constant_penalty(self): """Test that penalty_update='const' keeps penalties constant during forward_update.""" init_penalty = 5.0 - pbm = PBM(m=3, penalty_update='const', init_penalties=init_penalty) + pbm = PBM(m=3, penalty_update='const', init_penalties=init_penalty, penalty_range=(0.1, 100.)) penalties_before = pbm.penalties.clone() loss = torch.tensor(1.0) @@ -453,8 +453,8 @@ def test_gamma_affects_dual_convergence_rate(self): def test_delta_affects_adaptive_penalty_behavior(self): """Test that delta parameter affects adaptive penalty update.""" - pbm_low_delta = PBM(m=1, penalty_update='dimin_adapt', delta=0.5, penalty_mult=0.8, init_penalties=10.0) - pbm_high_delta = PBM(m=1, penalty_update='dimin_adapt', delta=2.0, penalty_mult=0.8, init_penalties=10.0) + pbm_low_delta = PBM(m=1, penalty_update='dimin_adapt', delta=0.5, penalty_mult=0.8, init_penalties=10.0, penalty_range=(0.1, 100.)) + pbm_high_delta = PBM(m=1, penalty_update='dimin_adapt', delta=2.0, penalty_mult=0.8, init_penalties=10.0, penalty_range=(0.1, 100.)) loss = torch.tensor(1.0) # Constraint with significant violation diff --git a/tiny_image_net.py b/tiny_image_net.py new file mode 100644 index 0000000..8d44e1b --- /dev/null +++ b/tiny_image_net.py @@ -0,0 +1,385 @@ +import matplotlib.pyplot as plt +import logging +logging.basicConfig(level=logging.INFO) +from pathlib import Path +from tinyimagenet import TinyImageNet +from torchvision import transforms as T +from torch.utils.data import DataLoader +from torchvision import models +from torch import nn +from fairret.statistic import PositiveRate, TruePositiveRate, FalsePositiveRate, PositivePredictiveValue, FalseOmissionRate +from fairret.loss import NormLoss +from humancompatible.train.fairness.utils import BalancedBatchSampler +import torch +from tqdm import tqdm +import os +from typing import Callable, Any, Dict +from dataclasses import dataclass +import torch +import fairret +from humancompatible.train.dual_optim import ALM, MoreauEnvelope, PBM + + +def dataset_to_tensors(dataset, batch_size=512, num_workers=8): + """Fast parallel loading of an entire dataset into tensors.""" + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + all_X, all_targets = [], [] + for X_batch, target_batch in tqdm(loader, desc="Loading dataset into tensors", total=len(loader)): + all_X.append(X_batch) + all_targets.append(target_batch) + return torch.cat(all_X, dim=0), torch.cat(all_targets, dim=0) + + +def load_or_cache(dataset, cache_path, batch_size=512, num_workers=8): + """Load tensors from cache if available, otherwise build and save.""" + if os.path.exists(cache_path): + print(f"Loading cache from {cache_path}...") + data = torch.load(cache_path, weights_only=True) + return data["X"], data["targets"] + + print(f"Building cache → {cache_path} (one-time cost)...") + X, targets = dataset_to_tensors(dataset, batch_size=batch_size, num_workers=num_workers) + torch.save({"X": X, "targets": targets}, cache_path) + print(f"Cache saved ({X.nbytes / 1e9:.2f} GB)") + return X, targets + + + +def train_tinyimagenet(): + + # define batch size here + batch_size = 1200 + + # define the path here + dataset_path="~/.torchvision/tinyimagenet/" + + + # define transforms function + normalize_transform = T.Compose([ T.ToTensor(), + T.Normalize(mean=TinyImageNet.mean, + std=TinyImageNet.std), + # Converting cropped images to tensors + ]) + train_transform = T.Compose([ T.Resize(256), # Resize images to 256 x 256 + T.CenterCrop(224), # Center crop image + T.RandomHorizontalFlip(), + normalize_transform + + ]) + + # --- Load datasets --- + train = TinyImageNet(Path(dataset_path), split="train", transform=train_transform, imagenet_idx=False) + val_full = TinyImageNet(Path(dataset_path), split="val", transform=normalize_transform, imagenet_idx=False) + + print(f"Dataset has {len(train.classes)} classes. Sample classes: {train.classes[:5]}") + + # --- Cache and split val into val/test --- + X_val_full, targets_val_full = load_or_cache(val_full, cache_path="./data/cache_val.pt") + + n = len(X_val_full) + idx = torch.randperm(n, generator=torch.Generator().manual_seed(42)) + split = n // 2 + val_idx, test_idx = idx[:split], idx[split:] + + raw_splits = { + "train": load_or_cache(train, cache_path="./data/cache_train.pt"), + "val": (X_val_full[val_idx], targets_val_full[val_idx]), + "test": (X_val_full[test_idx], targets_val_full[test_idx]), + } + + # --- Build loaders --- + loaders = {} + for name, (X, targets) in raw_splits.items(): + print(f"\nDataset: {name} | Size: {len(X)}") + print(f" X: {X.shape}, targets: {targets.shape}") + + groups_onehot = torch.eye(200)[targets] + dataset_torch = torch.utils.data.TensorDataset(X, groups_onehot, targets) + + group_counts = torch.bincount(targets, minlength=200) + print("Samples per group:", group_counts) + + sampler = BalancedBatchSampler( + group_onehot=groups_onehot, batch_size=batch_size, drop_last=True + ) + loaders[name] = torch.utils.data.DataLoader( + dataset_torch, batch_size=batch_size, shuffle=True, num_workers=4 + ) + loaders[name + "_balanced"] = torch.utils.data.DataLoader( + dataset_torch, batch_sampler=sampler, num_workers=4 + ) + print(f" Loaders created: '{name}' and '{name}_balanced'") + + # create fair dataloaders + + # ----- Build model, criterion, optimizer ----- + device = torch.device("cuda") + epochs = 5 + loader_name = "val_balanced" + + + # ----- Unconstrained Optimization Adam ----- + constraint_type = LossPairwise(loss=nn.CrossEntropyLoss(reduction='none')) + model = build_model().to(device) + criterion = nn.CrossEntropyLoss(reduction='none') # Unaggregated loss for fairness constraints + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "max_constr": []} + + for epoch in range(1, epochs + 1): + train_loss, train_acc, max_constr = run_epoch(model, loaders[loader_name], + criterion, optimizer, device, + train=True) + val_loss, val_acc, max_constr = run_epoch(model, loaders["val_balanced"], + criterion, optimizer, device, + train=False) + # scheduler.step() + + history["train_loss"].append(train_loss) + history["train_acc"].append(train_acc) + history["val_loss"].append(val_loss) + history["val_acc"].append(val_acc) + history["max_constr"].append(max_constr) + print(f"Epoch {epoch:>3}/{epochs} | " + f"train loss {train_loss:.4f} acc {train_acc:.3f} | " + f"val loss {val_loss:.4f} acc {val_acc:.3f}" + f" | max constraint {max_constr:.4f}") + + + # ----- SPMB Optimization ----- + constraint_type = LossPairwise(loss=nn.CrossEntropyLoss(reduction='none')) + model = build_model().to(device) + criterion = nn.CrossEntropyLoss(reduction='none') # Unaggregated loss for fairness constraints + + # Define data and optimizers + optimizer = MoreauEnvelope(torch.optim.Adam(model.parameters(), lr=0.002), mu=2.0) + + dual = PBM( + m=39800, + # penalty_update='dimin', + # penalty_update='dimin_adapt', + penalty_update='const', + pbf = 'quadratic_reciprocal', + gamma=0.95, + init_duals=0.00001, + init_penalties=1., + penalty_range=(0.5, 1.), + penalty_mult=0.99, + dual_range=(0.000001, 100.), + delta=1.0, + device=device + ) + + history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "max_constr": []} + + for epoch in range(1, epochs + 1): + train_loss, train_acc, max_constr = run_epoch(model, loaders[loader_name], + criterion, optimizer, device, + train=True, dual=dual) + val_loss, val_acc, max_constr = run_epoch(model, loaders["val_balanced"], + criterion, optimizer, device, + train=False, dual=dual) + # scheduler.step() + + # print numer of duals smaller than 1e-5 and larger than 100 + print('small duals', (dual.duals <= 1e-5).sum().item()) + print('large duals', (dual.duals >= 100.).sum().item()) + history["train_loss"].append(train_loss) + history["train_acc"].append(train_acc) + history["val_loss"].append(val_loss) + history["val_acc"].append(val_acc) + history["max_constr"].append(max_constr) + print(f"Epoch {epoch:>3}/{epochs} | " + f"train loss {train_loss:.4f} acc {train_acc:.3f} | " + f"val loss {val_loss:.4f} acc {val_acc:.3f}" + f" | max constraint {max_constr:.4f}") + + + # ----- SSLALM Optimization ----- + # constraint_type = LossPairwise(loss=nn.CrossEntropyLoss(reduction='none')) + # model = build_model().to(device) + # criterion = nn.CrossEntropyLoss(reduction='none') # Unaggregated loss for fairness constraints + + # # Define data and optimizers + # optimizer = MoreauEnvelope(torch.optim.Adam(model.parameters(), lr=0.005), mu=2.0) + + # dual = ALM( + # m=39800, + # lr=0.1, + # momentum=0.5, + # # penalty_update='dimin', + # device=device + # ) + + # history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "max_constr": []} + + # for epoch in range(1, epochs + 1): + # train_loss, train_acc, max_constr = run_epoch(model, loaders[loader_name], + # criterion, optimizer, device, + # train=True, dual=dual) + # val_loss, val_acc, max_constr = run_epoch(model, loaders["val_balanced"], + # criterion, optimizer, device, + # train=False, dual=dual) + # # scheduler.step() + + # history["train_loss"].append(train_loss) + # history["train_acc"].append(train_acc) + # history["val_loss"].append(val_loss) + # history["val_acc"].append(val_acc) + # history["max_constr"].append(max_constr) + # print(f"Epoch {epoch:>3}/{epochs} | " + # f"train loss {train_loss:.4f} acc {train_acc:.3f} | " + # f"val loss {val_loss:.4f} acc {val_acc:.3f}" + # f" | max constraint {max_constr:.4f}") + + +def build_model(num_classes=200): + """EfficientNet-B0 from scratch (no pretrained weights).""" + model = models.efficientnet_b0(weights=None) + # Replace classifier head for 200-class TinyImageNet + in_features = model.classifier[1].in_features + model.classifier[1] = nn.Linear(in_features, num_classes) + return model + + +def run_epoch(model, loader, criterion, optimizer, device, train=True, dual=None): + model.train() if train else model.eval() + total_loss, correct, total, total_constr = 0.0, 0, 0, 0.0 + constraint_type = LossPairwise(loss=nn.CrossEntropyLoss(reduction='none')) + threshold = 0.1 # Example threshold for constraint violation + + ctx = torch.enable_grad() if train else torch.no_grad() + with ctx: + for x, sens, y in tqdm(loader, desc="train" if train else "eval", leave=False): + x, sens, y = x.to(device), sens.to(device), y.to(device) + + if train: + optimizer.zero_grad() + + if dual is None: + pred = model(x) + loss = criterion(pred, y) + # calculate the constraints + constraints = constraint_type.compute_constraints(None, None, sens, None, loss=loss) + constraints = constraints - threshold + max_constr = constraints.max().item() + + if train: + loss.mean().backward() # Aggregate loss for backward pass + optimizer.step() + + + elif dual is not None: + pred = model(x) + loss = criterion(pred, y) + constraints = constraint_type.compute_constraints(None, None, sens, None, loss=loss) + constraints = constraints - threshold + max_constr = constraints.max().item() + + # compute the lagrangian value + lagrangian = dual.forward_update(loss.mean(), constraints) + + if train: + lagrangian.backward() + optimizer.step() + optimizer.zero_grad() + + total_loss += loss.mean().item() * x.size(0) + correct += (pred.argmax(1) == y).sum().item() + total += x.size(0) + total_constr += max_constr + + return total_loss / total, correct / total, total_constr / len(loader) + + + + + + + +def positive_rate_per_group(out_batch, batch_sens, prob_f=torch.nn.functional.sigmoid): + """ + Calculates the positive rate vector based on the given outputs of the model for the given groups. + + """ + if prob_f is None: + preds = out_batch + else: + preds = prob_f( out_batch ) + pr = PositiveRate() + probs_per_group = pr(preds, batch_sens) + + return probs_per_group + +def posrate_per_group(model, out, batch_sens, batch_labels): + pos_rate_pergroup = positive_rate_per_group(out, batch_sens) + constraints = ((pos_rate_pergroup.unsqueeze(1) - pos_rate_pergroup.unsqueeze(0)).to(torch.float)) + mask = ~torch.eye(batch_sens.shape[-1], dtype=torch.bool) + constraints = constraints[mask] + + return constraints + +def posrate_fairret_constraint(model, out, batch_sens, batch_labels): + statistic = PositiveRate() + fair_criterion = NormLoss(statistic=statistic) + + return fair_criterion(out, batch_sens).unsqueeze(0) + +def weight_constraint(model, out, batch_sens, batch_labels): + norms = [] + for param in model.parameters(): + norm = torch.linalg.norm(param) + norms.append(norm.unsqueeze(0)) + + return torch.concat(norms) + + +@dataclass +class ConstraintMetadata: + """This class is a wrapper for fairness constraints; + it contains the function that computes the constraint + and a function that computes the number of constraints given the number of protected groups + (for example, if the constraint calculates a metric for each pair of groups,`m`would be`n_groups`* (`n_groups` - 1)).""" + fn: Callable[[Any, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + m_fn: Callable[[int], int] + +class LossPairwise(ConstraintMetadata): + """Wrapper class for a fairness constraint that enforces equal loss across groups. + The constraint is computed as the pairwise difference between the losses for each group.""" + def __init__(self, loss: Callable = None, abs_diff: bool = False): + """ + Args: + loss (Callable): A function that computes the loss for each sample in the batch; must be **unaggregated** (i.e., reduction='none') + If not provided, the constraint will expect the loss to be precomputed and passed as an argument to the compute_constraints function. + """ + super().__init__( + fn=self.compute_constraints, + m_fn=lambda n_groups: n_groups * (n_groups - 1) if not self.abs_diff else n_groups * (n_groups - 1) // 2 + ) + self.abs_diff = abs_diff + if self.abs_diff: + raise NotImplementedError("abs_diff=True is not implemented yet.") + self.loss = loss + + def compute_constraints(self, model, batch_out, batch_sens, batch_labels, loss = None): + if loss is None: + loss = self.loss(batch_out, batch_labels) + + per_group_losses = _get_normalized_per_group_losses(loss, batch_sens).squeeze() + constraints = ((per_group_losses.unsqueeze(1) - per_group_losses.unsqueeze(0))) + mask = ~torch.eye(batch_sens.shape[-1], dtype=torch.bool) + constraints = constraints[mask] + return constraints + +def _get_normalized_per_group_losses(loss, sens_onehot): + return loss.unsqueeze(0) @ sens_onehot / sens_onehot.sum(dim=0) + + +if __name__ == "__main__": + train_tinyimagenet() \ No newline at end of file diff --git a/train_4_optimizers.py b/train_4_optimizers.py deleted file mode 100644 index 1e9b684..0000000 --- a/train_4_optimizers.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Minimal constrained neural network training with 4 optimizers.""" -import torch -import torch.nn as nn -from torch.utils.data import TensorDataset, DataLoader -from folktables import ACSDataSource, generate_categories, ACSIncome -from sklearn.preprocessing import StandardScaler -from sklearn.model_selection import train_test_split - -# Import optimizers -from humancompatible.train.optim import SSG -from humancompatible.train.dual_optim import ALM, PBM, MoreauEnvelope - -# Setup -torch.manual_seed(0) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print(f"Using device: {device}") - -# Load folktables data -data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person") -acs_data = data_source.get_data(states=["VA"], download=True) -definition_df = data_source.get_definitions(download=True) -categories = generate_categories( - features=ACSIncome.features, definition_df=definition_df -) -df_feat, df_labels, _ = ACSIncome.df_to_pandas( - acs_data, categories=categories, dummies=True -) - -sens_cols = ["SEX_Female", "SEX_Male"] -features = df_feat.drop(columns=sens_cols).to_numpy(dtype="float") -groups = df_feat[sens_cols].to_numpy(dtype="float") -labels = df_labels.to_numpy(dtype="float") - -# Split and scale -X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split( - features, labels, groups, test_size=0.2, random_state=42 -) -X_train, X_val, y_train, y_val, groups_train, groups_val = train_test_split( - X_train, y_train, groups_train, test_size=0.25, random_state=42 -) - -scaler = StandardScaler() -X_train = scaler.fit_transform(X_train) -X_val = scaler.transform(X_val) - -# Convert to tensors and move to device -X_train = torch.tensor(X_train, dtype=torch.float32).to(device) -y_train = torch.tensor(y_train, dtype=torch.float32).to(device) -groups_train = torch.tensor(groups_train, dtype=torch.float32).to(device) - -# Create dataloader -dataset = TensorDataset(X_train, groups_train, y_train) -loader = DataLoader(dataset, batch_size=16, shuffle=False) - -def create_model(): - """Simple neural network.""" - return nn.Sequential( - nn.Linear(X_train.shape[1], 32), - nn.ReLU(), - nn.Linear(32, 1), - ).to(device) - -criterion = nn.BCEWithLogitsLoss() - -def get_constraint(model, groups): - """Constraint: positive rate difference between groups.""" - group_preds = [[] for _ in range(groups.shape[1])] - for i in range(groups.shape[1]): - group_preds[i] = (torch.sigmoid(model(X_train)) * groups[:, i].unsqueeze(1)).mean() - - # Return max difference in positive rates across groups - rates = torch.stack(group_preds) - return (rates.max() - rates.min()) - -# ============ 1. ADAM (unconstrained) ============ -# print("\n=== 1. ADAM ===") -# model = create_model() -# opt = torch.optim.Adam(model.parameters(), lr=0.001) - -# for epoch in range(3): -# losses = [] -# constraints = [] -# for batch_x, batch_groups, batch_y in loader: -# batch_x, batch_groups, batch_y = batch_x.to(device), batch_groups.to(device), batch_y.to(device) -# output = model(batch_x) -# loss = criterion(output, batch_y) -# loss.backward() -# opt.step() -# opt.zero_grad() -# losses.append(loss.item()) - -# constraint = get_constraint(model, groups_train).to(device) -# constraints.append(constraint.item()) -# print(f"Epoch {epoch}: loss={sum(losses)/len(losses):.4f}, constraint={constraint:.4f}") - -# # ============ 2. ALM (Augmented Lagrangian) ============ -# print("\n=== 2. ALM ===") -# model = create_model() -# opt = MoreauEnvelope(torch.optim.Adam(model.parameters(), lr=0.01)) -# dual = ALM(m=1, lr=0.01, momentum=0.5, device=device) -# if hasattr(dual, 'to'): -# dual = dual.to(device) - -# for epoch in range(3): -# losses = [] -# for batch_x, batch_groups, batch_y in loader: -# batch_x, batch_groups, batch_y = batch_x.to(device), batch_groups.to(device), batch_y.to(device) -# output = model(batch_x) -# loss = criterion(output, batch_y) -# constraint = get_constraint(model, groups_train).to(device) - -# lagrangian = dual.forward_update(loss, constraint.unsqueeze(0)) -# lagrangian.backward() -# opt.step() -# opt.zero_grad() -# losses.append(loss.item()) - -# constraint = get_constraint(model, groups_train).to(device) -# print(f"Epoch {epoch}: loss={sum(losses)/len(losses):.4f}, constraint={constraint:.4f}") - -# # ============ 3. PBM (Penalty-Barrier Method) ============ -# print("\n=== 3. PBM ===") -# model = create_model() -# opt = MoreauEnvelope(torch.optim.Adam(model.parameters(), lr=0.001)) -# dual = PBM(m=1, penalty_update='dimin_adapt', gamma=0.7, -# init_duals=0.001, init_penalties=1., penalty_range=(0.01, 1.), -# dual_range=(0.01, 100.), device=device) -# if hasattr(dual, 'to'): -# dual = dual.to(device) - -# for epoch in range(3): -# losses = [] -# for batch_x, batch_groups, batch_y in loader: -# batch_x, batch_groups, batch_y = batch_x.to(device), batch_groups.to(device), batch_y.to(device) -# output = model(batch_x) -# loss = criterion(output, batch_y) -# constraint = get_constraint(model, groups_train).to(device) - -# lagrangian = dual.forward_update(loss, constraint.unsqueeze(0)) -# lagrangian.backward() -# opt.step() -# opt.zero_grad() -# losses.append(loss.item()) - -# constraint = get_constraint(model, groups_train).to(device) -# print(f"Epoch {epoch}: loss={sum(losses)/len(losses):.4f}, constraint={constraint:.4f}") - -# ============ 4. SSW (Switching Subgradient) ============ -print("\n=== 4. SSW ===") -model = create_model() -opt = torch.optim.Adam(model.parameters(), lr=0.001) -opt2 = torch.optim.Adam(model.parameters(), lr=0.001) - -for epoch in range(3): - losses = [] - for batch_x, batch_groups, batch_y in loader: - batch_x, batch_groups, batch_y = batch_x.to(device), batch_groups.to(device), batch_y.to(device) - output = model(batch_x) - loss = criterion(output, batch_y) - constraint = get_constraint(model, groups_train).to(device) - - if constraint.item() > 0: - constraint.backward() - opt2.step() - losses.append(loss.item()) - opt2.zero_grad() - else: - loss.backward() - opt.step() # No constraint violation, so pass 0 to update - losses.append(loss.item()) - opt.zero_grad() - - losses.append(loss.item()) - - constraint = get_constraint(model, groups_train).to(device) - print(f"Epoch {epoch}: loss={sum(losses)/len(losses):.4f}, constraint={constraint:.4f}") - -print("\nāœ“ Complete")