diff --git a/train_monomer_demo.sh b/train_monomer_demo.sh index 421006b..e907009 100755 --- a/train_monomer_demo.sh +++ b/train_monomer_demo.sh @@ -2,15 +2,16 @@ [ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l) export NCCL_ASYNC_ERROR_HANDLING=1 export OMP_NUM_THREADS=1 -mkdir -p $1 +mkdir -p temp tmp_dir=`mktemp -d` -python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port $MASTER_PORT $(which unicore-train) ./example_data/ --user-dir unifold \ +# --tensorboard-logdir $1/tsb/ +torchrun --nproc_per_node=$n_gpu --master_port $MASTER_PORT $(which unicore-train) ./example_data/ --user-dir unifold \ --num-workers 8 --ddp-backend=no_c10d \ --task af2 --loss af2 --arch af2 \ --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-6 --clip-norm 0.0 --per-sample-clip-norm 0.1 --allreduce-fp32-grad \ - --lr-scheduler exponential_decay --lr 1e-3 --warmup-updates 1000 --decay-ratio 0.95 --decay-steps 50000 --batch-size 1 \ - --update-freq 1 --seed 42 --tensorboard-logdir $1/tsb/ \ - --max-update 1000 --max-epoch 1 --log-interval 10 --log-format simple \ + --lr-scheduler exponential_decay --lr 1e-3 --warmup-updates 10 --decay-ratio 0.95 --decay-steps 50000 --batch-size 1 \ + --update-freq 1 --seed 42 \ + --max-update 1000 --max-epoch 1 --log-interval 1 --log-format tqdm \ --save-interval-updates 100 --validate-interval-updates 100 --keep-interval-updates 5 --no-epoch-checkpoints \ - --save-dir $1 --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --ema-decay 0.999 --bf16 --bf16-sr # for V100 or older GPUs, you can disable --bf16 for faster speed. + --save-dir temp --tmp-save-dir $tmp_dir --required-batch-size-multiple 1 --ema-decay 0.999 --fp16 # --bf16 --bf16-sr # for V100 or older GPUs, you can disable --bf16 for faster speed. rm -rf $tmp_dir \ No newline at end of file diff --git a/unifold/modules/attentions.py b/unifold/modules/attentions.py index e44ac38..1647dca 100644 --- a/unifold/modules/attentions.py +++ b/unifold/modules/attentions.py @@ -9,13 +9,25 @@ from unicore.modules import ( softmax_dropout, LayerNorm, + attention ) +try: + import trident as td + FLASHV2_AVAILABLE = True +except ImportError: + print(f"Could not import trident. Disabling default flash attention") + FLASHV2_AVAILABLE = False -def gen_attn_mask(mask, neg_inf): + +TORCH_FLASH = True + + +def gen_attn_mask(mask: torch.Tensor, neg_inf: float) -> torch.Tensor: + """ Masked fill with neg value to mask attn scores. """ assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf + attn_mask = torch.full_like(mask, fill_value=neg_inf) + attn_mask.masked_fill_(mask.bool(), 0.) return attn_mask @@ -32,6 +44,7 @@ def __init__( super(Attention, self).__init__() self.num_heads = num_heads + self.head_dim = head_dim total_dim = head_dim * self.num_heads self.gating = gating self.linear_q = Linear(q_dim, total_dim, bias=False, init="glorot") @@ -52,34 +65,81 @@ def forward( mask: torch.Tensor = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """ Notation: + - b: batch + - n: num_seqs (or auxiliary batch dims, sometimes=n l) + - h: head + - i: length (of q) + - j: length (of k and/or v) + - d: hidden dim + - d_h: head_dim + Inputs: + * q, k, v: (b, n, i, d) torch.Tensor + * mask: (b, n, (h=1), (i=1), j) torch.Tensor float mode (w/ -inf) + * bias: (b, (n=1), h, i, j) torch.Tensor + """ g = None if self.linear_g is not None: # gating, use raw query input g = self.linear_g(q) + # (b, n, i, d) -> (b, n, i, (h dim_head)) q = self.linear_q(q) - q *= self.norm k = self.linear_k(k) v = self.linear_v(v) - q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() - k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() - v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) - - attn = torch.matmul(q, k.transpose(-1, -2)) - del q, k - - attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias) - o = torch.matmul(attn, v) - del attn, v + batch_dims = q.shape[:-2] + + if bias is None and TORCH_FLASH: + # (b, n, i, (h dim_head)) -> (b, n, h, i, dim_head) + q = q.view(*q.shape[:-1], self.num_heads, self.head_dim).transpose(-2, -3).contiguous() + k = k.view(*k.shape[:-1], self.num_heads, self.head_dim).transpose(-2, -3).contiguous() + v = v.view(*v.shape[:-1], self.num_heads, self.head_dim).transpose(-2, -3) + # (b, n, h, i, dim_head), (b, n, h, j, dim_head) -> (b, n, h, i, dim_head) + o = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + del q, k, v + + # custom kernel for pair bias + elif FLASHV2_AVAILABLE and self.head_dim >= 16: + # (mask, bias) -> (b, n, h, i, j) -> ((b n), h, i, j) + new_bias = mask + bias + new_bias.clamp_(min=torch.finfo(new_bias.dtype).min) + new_bias = new_bias.view(-1, *new_bias.shape[-3:]).contiguous() + + # (b, n, i, (h dim_head)) -> ((b n), h, i, dim_head) + q = q.view(-1, q.shape[-2], self.num_heads, self.head_dim).transpose(-2, -3).contiguous() + k = k.view(-1, k.shape[-2], self.num_heads, self.head_dim).transpose(-2, -3).contiguous() + v = v.view(-1, v.shape[-2], self.num_heads, self.head_dim).transpose(-2, -3) + + # ((b n), h, i, dim_head) -> ((b n), h, i, dim_head) -> (b, n, h, i dim_head) ; torch for bool mask only + o = td.function.scaled_dot_product_attention(q, k, v, attn_mask=new_bias) + o = o.view(*batch_dims, *o.shape[-3:]) + del q, k, v + else: + # (b, n, i, (h dim_head)) -> (b, n, h, i, dim_head) + q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() + k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3).contiguous() + v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) + + q *= self.norm + # (b, n, h, i, dim_head), (b, n, h, dim_head, j) -> (b, n, h, i, j) + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k + + # (b, n, h, i, j) -> (b, n, h, i, j), (b, n, h, j, dim_head) -> (b, n, h, i, dim_head) + attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias) + o = torch.matmul(attn, v) + del attn, v + + # (b, n, h, i, dim_head) -> (b, n, i, h, dim_head) -> (b, n, i, (h dim_head)) o = o.transpose(-2, -3).contiguous() o = o.view(*o.shape[:-2], -1) if g is not None: o = torch.sigmoid(g) * o - # merge heads + # (b, n, i, (h dim_head)) -> (b, n, i, d) o = nn.functional.linear(o, self.linear_o.weight) return o @@ -92,6 +152,7 @@ def __init__(self, input_dim, head_dim, num_heads, inf, eps): super(GlobalAttention, self).__init__() self.num_heads = num_heads + self.head_dim = head_dim self.inf = inf self.eps = eps self.linear_q = Linear( @@ -106,37 +167,62 @@ def __init__(self, input_dim, head_dim, num_heads, inf, eps): self.norm = head_dim**-0.5 def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - - # gating + """ Notation: + - b: batch + - l: length + - d: input_dim / hidden dim depending on context + - h: num_heads + - s: num_sequences + Inputs: + * x: (b, l, s, d) torch.Tensor + * mask: (b, l, s) torch.Tensor + Outputs: (b, l, s, (h dim_head)) torch.Tensor + """ + # (b, l, s, d) -> (b, l, s, (h dim_head)) # gating g = self.sigmoid(self.linear_g(x)) + # (b, l, s, d) -> (b, l, s, dim_head) k = self.linear_k(x) v = self.linear_v(x) - + # (b, l, s, d), (b, l, s, 1) -> (b, l, d), (b, l, 1) -> (b, l, d) q = torch.sum(x * mask.unsqueeze(-1), dim=-2) / ( - torch.sum(mask, dim=-1, keepdims=True) + self.eps + torch.sum(mask, dim=-1, keepdim=True).add_(self.eps) ) + # (b, l, d) -> (b, l, (h dim_head)) q = self.linear_q(q) - q *= self.norm + # (b, l, (h dim_head)) -> (b, l, h, dim_head) q = q.view(q.shape[:-1] + (self.num_heads, -1)) - attn = torch.matmul(q, k.transpose(-1, -2)) - del q, k + if TORCH_FLASH: + # (b, l, h, i=1, dim_head), (b, l, h=1, s, dim_head) -> (b, l, h, dim_head) + o = torch.nn.functional.scaled_dot_product_attention( + q.unsqueeze_(-2), + k.unsqueeze_(-3), + v.unsqueeze_(-3), + # (b, l, s) -> (b, l, h=1, i=1, s) + attn_mask=mask[..., None, None, :] + ).squeeze_(-2) + del q, k, v + else: + q *= self.norm + # (b, l, h, dim_head), (b, l, s, dim_head) -> (b, l, h, s) + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k - attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :] - attn = softmax_dropout(attn, 0, self.training, mask=attn_mask) + attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :] + attn = softmax_dropout(attn, 0, self.training, mask=attn_mask) - o = torch.matmul( - attn, - v, - ) - del attn, v + # (b, l, h, s), (b, l, s, dim_head) -> (b, l, h, dim_head) + o = torch.matmul(attn, v) + del attn, v + # (b, l, s, (h dim_head)) -> (b, l, s, h, dim_head) g = g.view(g.shape[:-1] + (self.num_heads, -1)) + # (b, l, h, dim_head), (b, l, s, h, dim_head) -> (b, l, s, h, dim_head) o = o.unsqueeze(-3) * g del g - # merge heads + # (b, l, s, h, dim_head) -> (b, l, s, (h dim_head)) # merge heads o = o.reshape(o.shape[:-2] + (-1,)) return self.linear_o(o) @@ -221,7 +307,6 @@ def forward( attn_mask: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, ) -> torch.Tensor: - bias = None if self.pair_bias: z = self.layer_norm_z(z) @@ -327,7 +412,6 @@ def forward( mask: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, ) -> torch.Tensor: - m = m.transpose(-2, -3) mask = mask.transpose(-1, -2)