In [2]:
## Test new PCGrad Implementation for TF2.0

import tensorflow as tf
import numpy as np

# Dummy data: 100 samples, 10 features
X = np.random.randn(100, 10).astype(np.float32)

# Task 1: predict y1 = sum of features
# Task 2: predict y2 = product of first two features
y1 = np.sum(X, axis=1, keepdims=True).astype(np.float32)
y2 = (X[:, 0] * X[:, 1]).reshape(-1, 1).astype(np.float32)

# Simple MLP model with one shared layer and two heads
class MultiTaskModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.shared = tf.keras.layers.Dense(32, activation='relu')
        self.head1 = tf.keras.layers.Dense(1)  # for y1
        self.head2 = tf.keras.layers.Dense(1)  # for y2

    def call(self, inputs):
        x = self.shared(inputs)
        return self.head1(x), self.head2(x)

model = MultiTaskModel()

# Your working PCGrad class (assumed to be fixed as discussed)
class PCGrad(tf.keras.optimizers.Optimizer):
    def __init__(self, optimizer, name="PCGrad", **kwargs):
        super().__init__(name=name, learning_rate=0.0, **kwargs)
        self._optimizer = optimizer

    def apply_gradients(self, grads_and_vars, name=None, **kwargs):
        return self._optimizer.apply_gradients(grads_and_vars, **kwargs)

    def compute_gradients(self, losses, tape, var_list):
        assert isinstance(losses, list)
        grads_task = []
        for loss in losses:
            grads = tape.gradient(loss, var_list)
            grads = [tf.zeros_like(v) if g is None else g for g, v in zip(grads, var_list)]
            grads_task.append(grads)

        def flatten(grads):
            return tf.concat([tf.reshape(g, [-1]) for g in grads], axis=0)

        flat_grads_task = tf.stack([flatten(g) for g in grads_task])
        flat_grads_task = tf.random.shuffle(flat_grads_task)

        def project(g, others):
            for o in others:
                dot = tf.reduce_sum(g * o)
                g -= tf.cond(dot < 0, lambda: dot / (tf.reduce_sum(o * o) + 1e-12) * o, lambda: tf.zeros_like(g))
            return g

        projected = [project(g, tf.concat([flat_grads_task[:i], flat_grads_task[i+1:]], axis=0))
                     for i, g in enumerate(flat_grads_task)]
        mean_grad = tf.reduce_mean(tf.stack(projected), axis=0)

        reshaped_grads = []
        idx = 0
        for v in var_list:
            shape = tf.shape(v)
            size = tf.reduce_prod(shape)
            reshaped_grads.append(tf.reshape(mean_grad[idx:idx + size], shape))
            idx += size

        return list(zip(reshaped_grads, var_list))

    @property
    def learning_rate(self):
        return self._optimizer.learning_rate

# Optimizer
opt = PCGrad(tf.keras.optimizers.Adam(1e-2))

# Loss functions
mse = tf.keras.losses.MeanSquaredError()

# Training loop
epochs = 30
batch_size = 16

for epoch in range(epochs):
    for i in range(0, len(X), batch_size):
        X_batch = X[i:i+batch_size]
        y1_batch = y1[i:i+batch_size]
        y2_batch = y2[i:i+batch_size]

        with tf.GradientTape(persistent=True) as tape:
            out1, out2 = model(X_batch)
            loss1 = mse(y1_batch, out1)
            loss2 = mse(y2_batch, out2)

        grads_and_vars = opt.compute_gradients([loss1, loss2], tape, model.trainable_variables)
        opt.apply_gradients(grads_and_vars)

    print(f"Epoch {epoch+1}: loss1 = {loss1.numpy():.4f}, loss2 = {loss2.numpy():.4f}")

Epoch 1: loss1 = 3.9446, loss2 = 4.6393
Epoch 2: loss1 = 1.6433, loss2 = 3.6033
Epoch 3: loss1 = 0.3916, loss2 = 2.7050
Epoch 4: loss1 = 0.1011, loss2 = 1.8749
Epoch 5: loss1 = 0.1288, loss2 = 1.2099
Epoch 6: loss1 = 0.0677, loss2 = 0.8024
Epoch 7: loss1 = 0.0291, loss2 = 0.4881
Epoch 8: loss1 = 0.0506, loss2 = 0.3068
Epoch 9: loss1 = 0.0306, loss2 = 0.2097
Epoch 10: loss1 = 0.0186, loss2 = 0.1591
Epoch 11: loss1 = 0.0175, loss2 = 0.1316
Epoch 12: loss1 = 0.0207, loss2 = 0.1133
Epoch 13: loss1 = 0.0160, loss2 = 0.0883
Epoch 14: loss1 = 0.0109, loss2 = 0.0627
Epoch 15: loss1 = 0.0081, loss2 = 0.0409
Epoch 16: loss1 = 0.0063, loss2 = 0.0288
Epoch 17: loss1 = 0.0055, loss2 = 0.0187
Epoch 18: loss1 = 0.0049, loss2 = 0.0148
Epoch 19: loss1 = 0.0043, loss2 = 0.0125
Epoch 20: loss1 = 0.0039, loss2 = 0.0089
Epoch 21: loss1 = 0.0036, loss2 = 0.0075
Epoch 22: loss1 = 0.0030, loss2 = 0.0071
Epoch 23: loss1 = 0.0025, loss2 = 0.0057
Epoch 24: loss1 = 0.0025, loss2 = 0.0050
Epoch 25: loss1 = 0.0027,