In [4]:
import sys
sys.path.append('../')

import jax
# set jax precision 64 bit
jax.config.update("jax_enable_x64", True)

from jax import vmap
import jax.numpy as jnp
import matplotlib.pyplot as plt
from src.training import train_model, make_batches
from flax import linen as nn

In [5]:
class MultiHeadSelfAttention(nn.Module):
    embed_dim: int
    num_heads: int

    def setup(self):
        assert self.embed_dim % self.num_heads == 0, "Embedding dimension must be divisible by the number of heads."
        self.head_dim = self.embed_dim // self.num_heads
        self.qkv_proj = nn.Dense(self.embed_dim * 3, use_bias=False)  # For Query, Key, Value
        self.out_proj = nn.Dense(self.embed_dim)

    def __call__(self, x, mask=None):
        seq_len, input_dim = x.shape
        
        if mask is None:
            mask = jnp.ones((seq_len,))
        
        qkv = self.qkv_proj(x)  # Shape: (batch_size, seq_len, embed_dim * 3)
        qkv = qkv.reshape(seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = jnp.split(qkv, 3, axis=1)  # Each: (batch_size, seq_len, num_heads, head_dim)
        q, k, v = q.squeeze(1), k.squeeze(1), v.squeeze(1)

        # Compute attention weights
        attn_weights = jnp.einsum("qhd,khd->hqk", q, k) / jnp.sqrt(self.head_dim)
        attn_weights = jnp.where(mask[None, None, :], attn_weights, -1e9)
        attn_weights = jax.nn.softmax(attn_weights, axis=-1)
        attn_weights = jnp.where(mask[None, :, None], attn_weights, 0.)
        
        # Compute attention output
        attn_output = jnp.einsum("hqk,khd->qhd", attn_weights, v)
        attn_output = attn_output.reshape(seq_len, self.embed_dim)
        return self.out_proj(attn_output)



class Model(nn.Module):
    embed_dim: int
    num_heads: int
    
    
    def setup(self):
        self.multihead = MultiHeadSelfAttention(self.embed_dim, self.num_heads)
        self.multihead2 = MultiHeadSelfAttention(self.embed_dim, self.num_heads)
        self.multihead3 = MultiHeadSelfAttention(self.embed_dim, self.num_heads)
        self.lin = nn.Dense(1)
        
    def __call__(self, x, mask=None):
        y = self.multihead(x, mask)
        y = jax.nn.relu(y)
        y = self.multihead2(y, mask)
        y = jax.nn.relu(y)
        y = self.multihead3(y, mask)
        y = y.flatten()
        #x = self.lin(x) #return self.lin(x)# + jnp.sum(x)
        return jnp.sum(y) / x.shape[1] #+ jnp.sum(x) #self.lin(x) #jnp.sum(x)

In [6]:
key = jax.random.PRNGKey(0)

data_dir = "/home/emastr/moment-constrained-cryo-em/project_2/data/"
#data = jnp.load(f"{data_dir}/train_data_dens_dom.npy", allow_pickle=True).item()
data = jnp.load(f"{data_dir}/train_data_dens_dom.npy", allow_pickle=True).item()


x_train_pad = data["x"]
mask_train_pad = data["mask"]
y_train_pad = data["y"]

x_train_pad_std = jnp.std(x_train_pad, axis=(0,1))
x_train_pad_mean = jnp.mean(x_train_pad, axis=(0,1))
x_train_pad = (x_train_pad - x_train_pad_mean) / x_train_pad_std

y_train_pad_std = jnp.std(y_train_pad)
y_train_pad_mean = jnp.mean(y_train_pad)
y_train_pad = (y_train_pad - y_train_pad_mean) / y_train_pad_std


# MODEL
num_data, seq_len, token_dim = x_train_pad.shape

num_epochs= 2000
num_heads = 16

key = jax.random.PRNGKey(0)
model = Model(token_dim, num_heads)#, dropout_rate)
params = model.init(key, x_train_pad[0], mask_train_pad[0])

x_test = x_train_pad[0]
mask_test = mask_train_pad[0]

# Test 1: Invariance to shuffling of sequences
idx_seq = jnp.arange(seq_len)
idx_seq = jax.random.permutation(key, idx_seq, axis=0)
x_test_shuffled = x_test[idx_seq]
mask_test_shuffled = mask_test[idx_seq]
print(model.apply(params, x_test, mask_test) - model.apply(params, x_test_shuffled, mask_test_shuffled))

# Test 2: Invariance to values of masked tokens
mask_test_pert = mask_test.at[-1].set(0)
x_test_pert = x_test.at[-1].set(x_test[0])
print(model.apply(params, x_test, mask_test_pert) - model.apply(params, x_test_pert, mask_test_pert))

4.440892098500626e-16
0.0


In [None]:
mx_err = lambda x,y: jnp.max(jnp.abs(x-y))
l1_err = lambda x,y: jnp.mean(jnp.abs(x-y))
l2_err = lambda x,y: jnp.mean((x-y)**2)
norm_err = lambda err_fcn, y_apx, y_tru, y_ref: err_fcn(y_apx, y_tru) / err_fcn(y_ref, y_tru)

model_vmap = vmap(model.apply, (None, 0, 0))
loss_fcn = lambda param, x, mask, y: l2_err(model_vmap(param, x, mask).squeeze(), y)

def metrics(par):
    y_apx = vmap(model.apply, (None, 0, 0))(par, x_train_pad, mask_train_pad).squeeze()
    y_ref = jnp.mean(y_train_pad)
    y_tru = y_train_pad
    
    mxerr = norm_err(mx_err, y_apx, y_tru, y_ref)
    l1err = norm_err(l1_err, y_apx, y_tru, y_ref)
    l2err = norm_err(l2_err, y_apx, y_tru, y_ref) ** 0.5
    return {"max": mxerr, "l1": l1err, "l2": l2err}

def dict2str(d):
    return ", ".join([f"{k}: {v:.2e}" for k,v in d.items()])

for i in range(10000):
    x_data, y_data, mask_data = x_train_pad, y_train_pad, mask_train_pad
    loss, grads = jax.value_and_grad(loss_fcn)(params, x_data, mask_data, y_data)
    params = jax.tree.map(lambda p, g: p - 1e-3*g, params, grads)
    print(f"loss: {loss:.2e}. Metrics: " + dict2str(metrics(params)), end='\r')
#params, avg_losses, max_losses, min_losses  = train_model(key, model, x_train_pad, y_train_pad, mask_train_pad, batch_size, num_epochs, 1e-3)


loss: 3.53e-01. Metrics: max: 7.00e-01, l1: 5.79e-01, l2: 5.75e-01

In [8]:
# Training observations:
# - Loss is choppy (shoots back up and then down)
# - Stays at 0.1
# - Model should be able to fit the data

# Next steps:
# - Try to overfit a simple function (1. everywhere worked)
# - Try to overfit on a small data set ()
# - 

# * Super important for learning: Do not use bias in the QKV formulation??? 
# * Super important for learning: Normalize the data before feeding to model <- incredible how following ML research makes such a big difference
# * Squeeze relevant dimensions to avoid broadcasting error
# * 