In [1]:
import sys
sys.path.append('../')

import jax
# set jax precision 64 bit
jax.config.update("jax_enable_x64", True)

from jax import vmap
import jax.numpy as jnp
import matplotlib.pyplot as plt
from src.training import train_model
from flax import linen as nn

In [2]:
class MultiHeadSelfAttention(nn.Module):
    embed_dim: int
    num_heads: int

    def setup(self):
        assert self.embed_dim % self.num_heads == 0, "Embedding dimension must be divisible by the number of heads."
        self.head_dim = self.embed_dim // self.num_heads
        self.qkv_proj = nn.Dense(self.embed_dim * 3, use_bias=False)  # For Query, Key, Value
        self.out_proj = nn.Dense(self.embed_dim, use_bias=False)
        #self.v_test = nn.Dense(self.embed_dim, use_bias=False)

    def __call__(self, x, mask=None):
        seq_len, input_dim = x.shape
        
        if mask is None:
            mask = jnp.ones((seq_len,))
        
        qkv = self.qkv_proj(x)  # Shape: (batch_size, seq_len, embed_dim * 3)
        qkv = qkv.reshape(seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = jnp.split(qkv, 3, axis=1)  # Each: (batch_size, seq_len, num_heads, head_dim)
        q, k, v = q.squeeze(1), k.squeeze(1), v.squeeze(1)

        # Compute attention weights
        attn_weights = jnp.einsum("qhd,khd->hqk", q, k) / jnp.sqrt(self.head_dim)
        attn_weights = jnp.where(mask[None, None, :], attn_weights, -1e9)
        attn_weights = jax.nn.softmax(attn_weights, axis=-1)
        attn_weights = jnp.where(mask[None, :, None], attn_weights, 0.)
        
        # Compute attention output
        attn_output = jnp.einsum("hqk,khd->qhd", attn_weights, v)
        attn_output = attn_output.reshape(seq_len, self.embed_dim)
        return attn_output #self.out_proj(attn_output)



class Model(nn.Module):
    embed_dim: int
    num_heads: int
    
    
    def setup(self):
        self.multihead = MultiHeadSelfAttention(self.embed_dim, self.num_heads)
        self.lin = nn.Dense(1)
        
    def __call__(self, x, mask=None):
        y = self.multihead(x, mask)
        y = y.flatten()
        #x = self.lin(x) #return self.lin(x)# + jnp.sum(x)
        return self.lin(y) #jnp.sum(y) / x.shape[1] #+ jnp.sum(x) #self.lin(x) #jnp.sum(x)

In [3]:
key = jax.random.PRNGKey(0)

data_dir = "/home/emastr/moment-constrained-cryo-em/project_2/data/"
data = jnp.load(f"{data_dir}/train_data.npy", allow_pickle=True).item()

#x_train_pad = data["x"][:,:,:-2]
#x_train_pad = x_train_pad[:,:3,30:34]
#mask_train_pad = data["mask"][:,:3]

x_train_pad = data["x"][:,:,:-2]
mask_train_pad = data["mask"]
y_train_pad = data["y"]


# MODEL
num_data, seq_len, token_dim = x_train_pad.shape

num_epochs= 2000
num_heads = 2


key = jax.random.PRNGKey(0)
model = Model(token_dim, num_heads)#, dropout_rate)
params = model.init(key, x_train_pad[0], mask_train_pad[0])

x_test = x_train_pad[0]
mask_test = mask_train_pad[0]

# Test 1: Invariance to shuffling of sequences
idx_seq = jnp.arange(seq_len)
idx_seq = jax.random.permutation(key, idx_seq, axis=0)
x_test_shuffled = x_test[idx_seq]
mask_test_shuffled = mask_test[idx_seq]
print(model.apply(params, x_test, mask_test) - model.apply(params, x_test_shuffled, mask_test_shuffled))

# Test 2: Invariance to values of masked tokens
mask_test_pert = mask_test.at[-1].set(0)
x_test_pert = x_test.at[-1].set(x_test[0])
print(model.apply(params, x_test, mask_test_pert) - model.apply(params, x_test_pert, mask_test_pert))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[-0.0031083]
[0.]


In [4]:

#params['params']['lin']['kernel'] = jnp.ones_like(params['params']['lin']['kernel'])
#params['params']['lin']['bias'] = jnp.zeros_like(params['params']['lin']['bias'])


loss_pointwise = lambda param, x, y: (model.apply(param, x) - y)**2
loss_fcn = lambda param: jnp.mean(vmap(loss_pointwise, (None, 0, 0))(param, x_train_pad, y_train_pad)) / jnp.std(y_train_pad)
#print(vmap(func, (None, 0))(params, x_train_pad) - vmap(jnp.sum)(x_train_pad))
#idx = 

print(f"loss:{loss_fcn(params)}")
for i in range(10000):
    loss, grads = jax.value_and_grad(loss_fcn)(params)
    params = jax.tree.map(lambda p, g: p - 1e-1*g, params, grads)
    print(loss, end='\r')
#params, avg_losses, max_losses, min_losses  = train_model(key, model, x_train_pad, y_train_pad, mask_train_pad, batch_size, num_epochs, 1e-3)


loss:47.05075335251342
0.00028312467039293316

In [5]:
# Training observations:
# - Loss is choppy (shoots back up and then down)
# - Stays at 0.1
# - Model should be able to fit the data

# Next steps:
# - Try to overfit a simple function (1. everywhere worked)
# - Try to overfit on a small data set ()
# - 

# * Super important for learning: Do not use bias in the QKV formulation??? 