In [32]:
# Successfully installed
# huggingface-hub-0.29.3
# regex-2024.11.6 
# safetensors-0.5.3 
# tokenizers-0.21.1 
# transformers-4.50.1

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss

from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForMaskedLM,
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification,
    PretrainedConfig,
    PreTrainedModel,
)

from transformers.activations import ACT2FN

from transformers.modeling_outputs import (
    BaseModelOutput,
    MaskedLMOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from typing import Optional, Tuple, Union

from transformers import RoFormerConfig

from transformers.models.roformer.modeling_roformer import (
    RoFormerEncoder,
    RoFormerSinusoidalPositionalEmbedding,
)

from bsqdna.utils import *
# #from .modules import ByteNetEncoder, ConvNetEncoder, MLP, CNN

In [6]:
inputs = torch.rand(8, 4, 4096)

In [24]:
pat = PatchifyLinear(8, 128)

In [25]:
unpat = UnPatchifyLinear(8, 128)

In [26]:
a = pat(inputs)

In [27]:
b = unpat(a)

In [28]:
b.shape

torch.Size([8, 4, 4096])

In [29]:
a.shape

torch.Size([8, 512, 128])

In [13]:
pat = PatchifyLinear(8, 128)

In [14]:
pat(inputs).shape

torch.Size([8, 512, 128])

In [None]:

class BSQDNAmodel(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.embeddings = EMBEDDING_CLASS[config.embedding](config)
        self.encoder = RoFormerEncoder(config)
        self.ln_f = nn.LayerNorm(config.hidden_size, bias=config.bias)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, input_ids=None, input_probs=None, aux_features=None, **kwargs):
        x = self.embeddings(
            input_ids=input_ids, input_probs=input_probs, aux_features=aux_features
        )
        x = self.encoder(x)


        # should be optional
        x = self.ln_f(x.last_hidden_state)
        x = BaseModelOutput(last_hidden_state=x)


        return x

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
