# Multitask learning
   
Shared layers followed by per target layers

In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import logging
import numpy as np
import optax

logging.basicConfig(level=logging.INFO)

Can grad compute multiple graidents at once? No.

In [2]:
def f(x):
    return x * 3, x ** 2

grad = jax.grad(f)
# Error.
# print(grad(3.0))

In [3]:
class Model(nn.Module):
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        x = nn.Dense(5)(x)
        x1 = nn.Dense(3)(x)
        x1 = nn.Dense(1)(x1)
        x2 = nn.Dense(3)(x)
        x2 = nn.Dense(1)(x2)
        return x1, x2

In [4]:
@jax.jit
def apply_model(state, X, y):
    """Computes gradients, loss and accuracy for a single batch."""
    
    def mean_squared_error(y, yhat):
        return jnp.mean((y - yhat)**2)

    def compute_loss_fn(params):
        yhats = state.apply_fn({"params": params}, X)
        # Get the average of losses. You could consdier giving weights to each task.
        loss = jnp.stack([mean_squared_error(y, yh) for yh in yhats]).sum()
        return loss

    grad_fn = jax.value_and_grad(compute_loss_fn)
    loss, grads = grad_fn(state.params)
    return loss, grads

In [5]:
@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [6]:
def train_epoch(state, dataset_fn):
    epoch_loss = []
    for X, y in dataset_fn():
        loss, grads = apply_model(state, X, y)
        state = update_model(state, grads)
        epoch_loss.append(loss)
    train_loss = np.mean(epoch_loss)
    return state, train_loss

In [7]:
batch_size = 10

def f(x):
    return x * 3

def g(x):
    return x * 7 + 8

def datagen():
    x = jnp.arange(300)
    y = jnp.array([[f(x), g(x)] for x in x])
    y = jax.random.normal(jax.random.PRNGKey(0), shape=y.shape)
    
    for i in range(0, min(len(x), len(x) - batch_size), batch_size):
        yield x[i:i+batch_size][:, np.newaxis], y[i:i+batch_size]


In [8]:

model = Model()
init_rng = jax.random.PRNGKey(42)
params = model.init(init_rng, x=jnp.ones((batch_size, 1)))['params']

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adam(learning_rate=1e-3))

for i in range(10):
    state, train_loss = train_epoch(state, datagen)
    logging.info(f"Epoch: {i:4d}, train_loss: {train_loss:.4f}")
    

INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:root:Epoch:    0, train_loss: 4501.4688
INFO:root:Epoch:    1, train_loss: 76.3762
INFO:root:Epoch:    2, train_loss: 10.2562
INFO:root:Epoch:    3, train_loss: 2.2091
INFO:root:Epoch:    4, train_loss: 2.0571
INFO:root:Epoch:    5, train_loss: 2.0513
INFO:root:Epoch:    6, train_loss: 2.0519
INFO:root:Epoch:    7, train_loss: 2.0553
INFO:root:Epoch:    8, train_loss: 2.0568
INFO:root:Epoch:    9, train_loss: 2.0554
