In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import PreTrainedTokenizerFast

tensor = torch.tensor([[1,1,1,0,1,1,1,1,0,0],[1,1,1,0,1,1,0,0,0,0]], dtype=torch.int64)
print(tensor)

tensor = -tensor.to(torch.float32)
tensor = F.max_pool1d(tensor, kernel_size=2, stride=2, padding=0)
tensor = -tensor.to(torch.int64)

print(tensor)

tensor([[1, 1, 1, 0, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 0, 1, 1, 0, 0, 0, 0]])
tensor([[1, 0, 1, 1, 0],
        [1, 0, 1, 0, 0]])


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = PreTrainedTokenizerFast.from_pretrained('/home/kkj/ProtDiffusion/ProtDiffusion/tokenizer/tokenizer_v4.1')

In [3]:
tokenized = tokenizer('-[ACDFGDIGDE]---',
                        padding=True,
                        truncation=False, # We truncate the sequences beforehand
                        return_token_type_ids=False,
                        return_attention_mask=True, # We need to attend to padding tokens, so we set this to False
                        return_tensors="pt",
)
print(tokenized)

{'input_ids': tensor([[ 2, 23,  3,  4,  5,  7,  8,  5, 10,  8,  5,  6, 24,  2,  2,  2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [4]:
print(tokenized['attention_mask'].dtype)

torch.int64


In [5]:
import random

random.randint(0,0)

0

In [6]:
import numpy as np

def round_length(length: int, pad: int = 2, rounding: int = 16) -> int:
    '''
    Round the length to the nearest multiple of 16.
    '''
    return int(np.ceil((length + pad) / rounding) * rounding)

def process_sequence(sequence: str,
                     bos_token: str = "[",
                     eos_token: str = "]",
                     pad_token: str = "-",
) -> str:
    '''
    Process the sequence by adding the bos and eos tokens, and padding it to a multiple of 16 (or what the variable is set to in the round_kength).
    Return the sequence and the length of the sequence.
    '''
    seq_len = round_length(len(sequence))
    sequence = bos_token + sequence + eos_token
    len_diff = seq_len - len(sequence)
    rand_int = random.randint(0, len_diff)
    sequence = pad_token * rand_int + sequence + pad_token * (len_diff - rand_int)

    return sequence

In [7]:
process_sequence('ACDFGDIGDEIJGH')

'[ACDFGDIGDEIJGH]'

In [8]:
from models.dit_transformer_1d import DiTTransformer1DModel
from training_utils import count_parameters

model = DiTTransformer1DModel(
    num_attention_heads = 16,
    attention_head_dim = 72,
    in_channels = 64,
    num_layers = 8,
    attention_bias = True,
    activation_fn = "gelu-approximate",
    num_classes = 2,
    upcast_attention = False,
    norm_type = "ada_norm_zero",
    norm_elementwise_affine = False,
    norm_eps = 1e-5,
    pos_embed_type = "sinusoidal", # sinusoidal
    num_positional_embeddings = 1024,
    use_rope_embed = True, # RoPE https://github.com/lucidrains/rotary-embedding-torch
).to('cuda')
count_parameters(model)
model.train()

Using Sinusoidal Positional Embeddings
num_positional_embeddings:  1024
Using RoPE
RoPE dim:  72
Model has 207216064 trainable parameters


DiTTransformer1DModel(
  (conv_in): Conv1d(64, 1152, kernel_size=(3,), stride=(1,), padding=(1,))
  (rotary_emb): RotaryEmbedding()
  (transformer_blocks): ModuleList(
    (0-7): 8 x BasicTransformerBlock1D(
      (pos_embed): SinusoidalPositionalEmbedding()
      (norm1): AdaLayerNormZero(
        (emb): CombinedTimestepLabelEmbeddings(
          (time_proj): Timesteps()
          (timestep_embedder): TimestepEmbedding(
            (linear_1): Linear(in_features=256, out_features=1152, bias=True)
            (act): SiLU()
            (linear_2): Linear(in_features=1152, out_features=1152, bias=True)
          )
          (class_embedder): LabelEmbedding(
            (embedding_table): Embedding(3, 1152)
          )
        )
        (silu): SiLU()
        (linear): Linear(in_features=1152, out_features=6912, bias=True)
        (norm): LayerNorm((1152,), eps=1e-06, elementwise_affine=False)
      )
      (attn1): Attention(
        (to_q): Linear(in_features=1152, out_features=1152, bi

In [9]:
x = torch.randn(16, 64, 1024).to('cuda')
m = torch.randint(0, 2, (16, 1024), dtype=torch.bool).to('cuda')
t = torch.randint(0, 1000, (16,), dtype=torch.int64).to('cuda') # Timesteps, any int is valid?
cl = torch.randint(0, 3, (16,), dtype=torch.int64).to('cuda') # Classifier labels, 0 and 1 are the only valid labels, 2 is a dropped label

out = model(x, m, t, cl)
print(out.sample.shape)

torch.Size([16, 64, 1024])
