In [None]:
%run debugLogicGates.py

In [None]:
print("--- Training with AdamW Optimizer (float32) ---")
torch.manual_seed(0)
torch.set_default_dtype(dtype_float32)

X, y = get_gate_data('xor', dtype=dtype_float32)
model_adamw, loss_fn_adamw = initialize_model_and_loss(dtype=dtype_float32)
optimizer_adamw = optim.AdamW(model_adamw.parameters(), lr=0.05)

adamw_loss_history = train_model(model_adamw, optimizer_adamw, loss_fn_adamw, X, y,
                                 epochs=8000, log_interval=500, optimizer_type='adamw')
evaluate_model(model_adamw, loss_fn_adamw, X, y, gate_name='NAND (AdamW, float32)')

In [None]:
print("--- Training with LBFGS Optimizer (float64) ---")
torch.manual_seed(0)
torch.set_default_dtype(dtype_float64)

X, y = get_gate_data('xor', dtype=dtype_float64)
model_lbfgs, loss_fn_lbfgs = initialize_model_and_loss(dtype=dtype_float64)
optimizer_lbfgs = torch.optim.LBFGS(model_lbfgs.parameters(), lr=0.05, max_iter=20, history_size=100)

lbfgs_loss_history = train_model(model_lbfgs, optimizer_lbfgs, loss_fn_lbfgs, X, y,
                                 epochs=8000, log_interval=500, optimizer_type='lbfgs')
evaluate_model(model_lbfgs, loss_fn_lbfgs, X, y, gate_name='NAND (LBFGS, float64)')

In [None]:
print("--- Training with Custom Optimizer (float32) ---")
torch.manual_seed(0)
torch.set_default_dtype(dtype_float32)

X, y = get_gate_data('xor', dtype=dtype_float32)
model_custom, loss_fn_custom = initialize_model_and_loss(dtype=dtype_float32)

custom_loss_history = train_model(model_custom, None, loss_fn_custom, X, y,
                                  epochs=800, log_interval=100, optimizer_type='custom',
                                  custom_update_fn=custom_gradient_update)
evaluate_model(model_custom, loss_fn_custom, X, y, gate_name='NAND (Custom, float32)')

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(adamw_loss_history, label='AdamW Loss (float32)', alpha=0.7)
plt.plot(lbfgs_loss_history, label='LBFGS Loss (float64)', alpha=0.7)
plt.plot(custom_loss_history, label='Custom Update Loss (float32)', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.yscale('log')
plt.show()