Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple training script #74

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,7 @@ dmypy.json

# Pyre type checker
.pyre/

.idea/
runs/
/model/
102 changes: 102 additions & 0 deletions alphafold2_pytorch/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
######################################################################
# Transformer!
# ------------
#
# Transformer is a Seq2Seq model introduced in `“Attention is all you
# need” <https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`__
# paper for solving machine translation task. Transformer model consists
# of an encoder and decoder block each containing fixed number of layers.
#
# Encoder processes the input sequence by propogating it, through a series
# of Multi-head Attention and Feed forward network layers. The output from
# the Encoder referred to as ``memory``, is fed to the decoder along with
# target tensors. Encoder and decoder are trained in an end-to-end fashion
# using teacher forcing technique.
#

import math
import torch
from torch import nn
from torch import Tensor
from torch.nn import (TransformerEncoder, TransformerDecoder,
TransformerEncoderLayer, TransformerDecoderLayer)


class Seq2SeqTransformer(nn.Module):
def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
dim_feedforward: int = 512, num_head: int = 8, dropout: float = 0.0, activation: str = "relu",
max_len: int = 5000):
super(Seq2SeqTransformer, self).__init__()
encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=num_head,
dim_feedforward=dim_feedforward, dropout=dropout, activation=activation)
self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=num_head,
dim_feedforward=dim_feedforward, dropout=dropout, activation=activation)
self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

self.generator = nn.Linear(emb_size, tgt_vocab_size)
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
self.tgt_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout, maxlen=max_len)

# todo make mask work
def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
tgt_mask: Tensor, src_padding_mask: Tensor,
tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor,
use_padding_mask: bool = True):
src_emb = self.positional_encoding(self.src_tok_emb(src))
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
if use_padding_mask:
memory = self.transformer_encoder(src_emb, src_key_padding_mask=src_padding_mask)
outs = self.transformer_decoder(tgt_emb, memory, tgt_key_padding_mask=tgt_padding_mask)
else:
memory = self.transformer_encoder(src_emb)
outs = self.transformer_decoder(tgt_emb, memory)
return self.generator(outs)

def encode(self, src: Tensor, src_mask: Tensor):
return self.transformer_encoder(
self.src_tok_emb(src), src_mask)

def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
return self.transformer_decoder(self.positional_encoding(
self.tgt_tok_emb(tgt)), memory,
tgt_mask)


######################################################################
# Text tokens are represented by using token embeddings. Positional
# encoding is added to the token embedding to introduce a notion of word
# order.
#

class PositionalEncoding(nn.Module):
def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
super(PositionalEncoding, self).__init__()
den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
pos_embedding = torch.zeros((maxlen, emb_size))
pos_embedding[:, 0::2] = torch.sin(pos * den)
pos_embedding[:, 1::2] = torch.cos(pos * den)
pos_embedding = pos_embedding.unsqueeze(-2)

self.dropout = nn.Dropout(dropout)
self.register_buffer('pos_embedding', pos_embedding)

def forward(self, token_embedding: Tensor):
return self.dropout(token_embedding +
self.pos_embedding[:token_embedding.size(0), :])


class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size

def forward(self, tokens: Tensor):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)



4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
'transformers',
'tqdm',
'biopython',
'mp-nerf>=0.1.5'
'mp-nerf>=0.1.5',
'tensorboard',
'matplotlib'
],
setup_requires=[
'pytest-runner',
Expand Down
7 changes: 4 additions & 3 deletions train_end2end.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
from torch.optim import Adam
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from einops import rearrange

# data

import sidechainnet as scn
from sidechainnet.sequence.utils import VOCAB
# from sidechainnet.sequence.utils import VOCAB
from sidechainnet.structure.build_info import NUM_COORDS_PER_RES

# models
Expand Down Expand Up @@ -108,11 +109,11 @@ def cycle(loader, cond = lambda x: True):
# mask the atoms and backbone positions for each residue

# sequence embedding (msa / esm / attn / or nothing)
msa, embedds = None
msa, embedds = None, None

# get embedds
if FEATURES == "esm":
embedds = get_esm_embedd(seq, embedd_model, batch_converter)
embedds = get_esm_embedd(seq, embedd_model, batch_converter, device=DEVICE)
# get msa here
elif FEATURES == "msa":
pass
Expand Down