Skip to content

Commit

Permalink
adding flash v2 torch for mask, custom for bias
Browse files Browse the repository at this point in the history
  • Loading branch information
hypnopump committed Oct 21, 2023
1 parent b1c89a2 commit 27d1b4e
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 38 deletions.
13 changes: 7 additions & 6 deletions train_monomer_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
[ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l)
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
mkdir -p $1
mkdir -p temp
tmp_dir=`mktemp -d`
python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port $MASTER_PORT $(which unicore-train) ./example_data/ --user-dir unifold \
# --tensorboard-logdir $1/tsb/
torchrun --nproc_per_node=$n_gpu --master_port $MASTER_PORT $(which unicore-train) ./example_data/ --user-dir unifold \
--num-workers 8 --ddp-backend=no_c10d \
--task af2 --loss af2 --arch af2 \
--optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 --clip-norm 0.0 --per-sample-clip-norm 0.1 --allreduce-fp32-grad \
--lr-scheduler exponential_decay --lr 1e-3 --warmup-updates 1000 --decay-ratio 0.95 --decay-steps 50000 --batch-size 1 \
--update-freq 1 --seed 42 --tensorboard-logdir $1/tsb/ \
--max-update 1000 --max-epoch 1 --log-interval 10 --log-format simple \
--lr-scheduler exponential_decay --lr 1e-3 --warmup-updates 10 --decay-ratio 0.95 --decay-steps 50000 --batch-size 1 \
--update-freq 1 --seed 42 \
--max-update 1000 --max-epoch 1 --log-interval 1 --log-format tqdm \
--save-interval-updates 100 --validate-interval-updates 100 --keep-interval-updates 5 --no-epoch-checkpoints \
--save-dir $1 --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --ema-decay 0.999 --bf16 --bf16-sr # for V100 or older GPUs, you can disable --bf16 for faster speed.
--save-dir temp --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --ema-decay 0.999 --fp16 # --bf16 --bf16-sr # for V100 or older GPUs, you can disable --bf16 for faster speed.
rm -rf $tmp_dir
148 changes: 116 additions & 32 deletions unifold/modules/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,25 @@
from unicore.modules import (
softmax_dropout,
LayerNorm,
attention
)

try:
import trident as td
FLASHV2_AVAILABLE = True
except ImportError:
print(f"Could not import trident. Disabling default flash attention")
FLASHV2_AVAILABLE = False

def gen_attn_mask(mask, neg_inf):

TORCH_FLASH = True


def gen_attn_mask(mask: torch.Tensor, neg_inf: float) -> torch.Tensor:
""" Masked fill with neg value to mask attn scores. """
assert neg_inf < -1e4
attn_mask = torch.zeros_like(mask)
attn_mask[mask == 0] = neg_inf
attn_mask = torch.full_like(mask, fill_value=neg_inf)
attn_mask.masked_fill_(mask.bool(), 0.)
return attn_mask


Expand All @@ -32,6 +44,7 @@ def __init__(
super(Attention, self).__init__()

self.num_heads = num_heads
self.head_dim = head_dim
total_dim = head_dim * self.num_heads
self.gating = gating
self.linear_q = Linear(q_dim, total_dim, bias=False, init="glorot")
Expand All @@ -52,34 +65,81 @@ def forward(
mask: torch.Tensor = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
""" Notation:
- b: batch
- n: num_seqs (or auxiliary batch dims, sometimes=n l)
- h: head
- i: length (of q)
- j: length (of k and/or v)
- d: hidden dim
- d_h: head_dim
Inputs:
* q, k, v: (b, n, i, d) torch.Tensor
* mask: (b, n, (h=1), (i=1), j) torch.Tensor float mode (w/ -inf)
* bias: (b, (n=1), h, i, j) torch.Tensor
"""
g = None
if self.linear_g is not None:
# gating, use raw query input
g = self.linear_g(q)

# (b, n, i, d) -> (b, n, i, (h dim_head))
q = self.linear_q(q)
q *= self.norm
k = self.linear_k(k)
v = self.linear_v(v)

q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous()
k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous()
v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3)

attn = torch.matmul(q, k.transpose(-1, -2))
del q, k

attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias)
o = torch.matmul(attn, v)
del attn, v
batch_dims = q.shape[:-2]

if bias is None and TORCH_FLASH:
# (b, n, i, (h dim_head)) -> (b, n, h, i, dim_head)
q = q.view(*q.shape[:-1], self.num_heads, self.head_dim).transpose(-2, -3).contiguous()
k = k.view(*k.shape[:-1], self.num_heads, self.head_dim).transpose(-2, -3).contiguous()
v = v.view(*v.shape[:-1], self.num_heads, self.head_dim).transpose(-2, -3)
# (b, n, h, i, dim_head), (b, n, h, j, dim_head) -> (b, n, h, i, dim_head)
o = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
del q, k, v

# custom kernel for pair bias
elif FLASHV2_AVAILABLE and self.head_dim >= 16:
# (mask, bias) -> (b, n, h, i, j) -> ((b n), h, i, j)
new_bias = mask + bias
new_bias.clamp_(min=torch.finfo(new_bias.dtype).min)
new_bias = new_bias.view(-1, *new_bias.shape[-3:]).contiguous()

# (b, n, i, (h dim_head)) -> ((b n), h, i, dim_head)
q = q.view(-1, q.shape[-2], self.num_heads, self.head_dim).transpose(-2, -3).contiguous()
k = k.view(-1, k.shape[-2], self.num_heads, self.head_dim).transpose(-2, -3).contiguous()
v = v.view(-1, v.shape[-2], self.num_heads, self.head_dim).transpose(-2, -3)

# ((b n), h, i, dim_head) -> ((b n), h, i, dim_head) -> (b, n, h, i dim_head) ; torch for bool mask only
o = td.function.scaled_dot_product_attention(q, k, v, attn_mask=new_bias)
o = o.view(*batch_dims, *o.shape[-3:])
del q, k, v

else:
# (b, n, i, (h dim_head)) -> (b, n, h, i, dim_head)
q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous()
k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous()
v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3)

q *= self.norm
# (b, n, h, i, dim_head), (b, n, h, dim_head, j) -> (b, n, h, i, j)
attn = torch.matmul(q, k.transpose(-1, -2))
del q, k

# (b, n, h, i, j) -> (b, n, h, i, j), (b, n, h, j, dim_head) -> (b, n, h, i, dim_head)
attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias)
o = torch.matmul(attn, v)
del attn, v

# (b, n, h, i, dim_head) -> (b, n, i, h, dim_head) -> (b, n, i, (h dim_head))
o = o.transpose(-2, -3).contiguous()
o = o.view(*o.shape[:-2], -1)

if g is not None:
o = torch.sigmoid(g) * o

# merge heads
# (b, n, i, (h dim_head)) -> (b, n, i, d)
o = nn.functional.linear(o, self.linear_o.weight)
return o

Expand All @@ -92,6 +152,7 @@ def __init__(self, input_dim, head_dim, num_heads, inf, eps):
super(GlobalAttention, self).__init__()

self.num_heads = num_heads
self.head_dim = head_dim
self.inf = inf
self.eps = eps
self.linear_q = Linear(
Expand All @@ -106,37 +167,62 @@ def __init__(self, input_dim, head_dim, num_heads, inf, eps):
self.norm = head_dim**-0.5

def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:

# gating
""" Notation:
- b: batch
- l: length
- d: input_dim / hidden dim depending on context
- h: num_heads
- s: num_sequences
Inputs:
* x: (b, l, s, d) torch.Tensor
* mask: (b, l, s) torch.Tensor
Outputs: (b, l, s, (h dim_head)) torch.Tensor
"""
# (b, l, s, d) -> (b, l, s, (h dim_head)) # gating
g = self.sigmoid(self.linear_g(x))

# (b, l, s, d) -> (b, l, s, dim_head)
k = self.linear_k(x)
v = self.linear_v(x)

# (b, l, s, d), (b, l, s, 1) -> (b, l, d), (b, l, 1) -> (b, l, d)
q = torch.sum(x * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1, keepdims=True) + self.eps
torch.sum(mask, dim=-1, keepdim=True).add_(self.eps)
)
# (b, l, d) -> (b, l, (h dim_head))
q = self.linear_q(q)
q *= self.norm
# (b, l, (h dim_head)) -> (b, l, h, dim_head)
q = q.view(q.shape[:-1] + (self.num_heads, -1))

attn = torch.matmul(q, k.transpose(-1, -2))
del q, k
if TORCH_FLASH:
# (b, l, h, i=1, dim_head), (b, l, h=1, s, dim_head) -> (b, l, h, dim_head)
o = torch.nn.functional.scaled_dot_product_attention(
q.unsqueeze_(-2),
k.unsqueeze_(-3),
v.unsqueeze_(-3),
# (b, l, s) -> (b, l, h=1, i=1, s)
attn_mask=mask[..., None, None, :]
).squeeze_(-2)
del q, k, v
else:
q *= self.norm
# (b, l, h, dim_head), (b, l, s, dim_head) -> (b, l, h, s)
attn = torch.matmul(q, k.transpose(-1, -2))
del q, k

attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :]
attn = softmax_dropout(attn, 0, self.training, mask=attn_mask)
attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :]
attn = softmax_dropout(attn, 0, self.training, mask=attn_mask)

o = torch.matmul(
attn,
v,
)
del attn, v
# (b, l, h, s), (b, l, s, dim_head) -> (b, l, h, dim_head)
o = torch.matmul(attn, v)
del attn, v

# (b, l, s, (h dim_head)) -> (b, l, s, h, dim_head)
g = g.view(g.shape[:-1] + (self.num_heads, -1))
# (b, l, h, dim_head), (b, l, s, h, dim_head) -> (b, l, s, h, dim_head)
o = o.unsqueeze(-3) * g
del g

# merge heads
# (b, l, s, h, dim_head) -> (b, l, s, (h dim_head)) # merge heads
o = o.reshape(o.shape[:-2] + (-1,))
return self.linear_o(o)

Expand Down Expand Up @@ -221,7 +307,6 @@ def forward(
attn_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:

bias = None
if self.pair_bias:
z = self.layer_norm_z(z)
Expand Down Expand Up @@ -327,7 +412,6 @@ def forward(
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:

m = m.transpose(-2, -3)
mask = mask.transpose(-1, -2)

Expand Down

0 comments on commit 27d1b4e

Please sign in to comment.