In [3]:
pip install datasets

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.1.1 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [7]:
pip install transformers

Collecting transformers
  Downloading transformers-4.46.3-py3-none-any.whl.metadata (44 kB)
     ---------------------------------------- 0.0/44.1 kB ? eta -:--:--
     ---------------------------------------- 44.1/44.1 kB 2.1 MB/s eta 0:00:00
Collecting tokenizers<0.21,>=0.20 (from transformers)
  Downloading tokenizers-0.20.3-cp312-none-win_amd64.whl.metadata (6.9 kB)
Collecting safetensors>=0.4.1 (from transformers)
  Downloading safetensors-0.4.5-cp312-none-win_amd64.whl.metadata (3.9 kB)
Downloading transformers-4.46.3-py3-none-any.whl (10.0 MB)
   ---------------------------------------- 0.0/10.0 MB ? eta -:--:--
   - -------------------------------------- 0.3/10.0 MB 7.9 MB/s eta 0:00:02
   -- ------------------------------------- 0.6/10.0 MB 7.4 MB/s eta 0:00:02
   -------- ------------------------------- 2.1/10.0 MB 18.9 MB/s eta 0:00:01
   ------------------- -------------------- 4.8/10.0 MB 27.8 MB/s eta 0:00:01
   ---------------------------------- ----- 8.6/10.0 MB 42.3 MB


[notice] A new release of pip is available: 24.1.1 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [5]:
import torch
from torch import nn
from transformers import BertConfig
from torchdiffeq import odeint


class AttentionODEFunc(nn.Module):
    def __init__(self, head_dim):
        super(AttentionODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(head_dim, head_dim),
            nn.ReLU(),
            nn.Linear(head_dim, head_dim)
        )

    def forward(self, t, attention_scores):
        return self.net(attention_scores)


class ODEAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads=8, time_steps=10):
        super(ODEAttention, self).__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // num_heads
        self.time_steps = time_steps
        self.scaling_factor = self.head_dim ** 0.5

        self.ode_func = AttentionODEFunc(self.head_dim)

    def forward(self, Q, K, V):
        batch_size, seq_length, _ = Q.size()
        assert Q.size(-1) == self.hidden_dim, f"Expected hidden size {self.hidden_dim}, got {Q.size(-1)}"

        Q = Q.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        Q = Q.reshape(-1, seq_length, self.head_dim)
        K = K.reshape(-1, seq_length, self.head_dim)
        V = V.reshape(-1, seq_length, self.head_dim)

        initial_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scaling_factor

        initial_scores = initial_scores.unsqueeze(-1)

        initial_scores = initial_scores.view(-1, self.head_dim)

        t_span = torch.linspace(0, 1, steps=self.time_steps).to(Q.device)

        evolved_scores = odeint(self.ode_func, initial_scores, t_span)

        evolved_scores = evolved_scores[-1].view(-1, seq_length, seq_length)

        batch_heads = evolved_scores.size(0)
        expected_size = batch_heads * seq_length * seq_length
        assert evolved_scores.numel() == expected_size, (
            f"Shape mismatch: expected {expected_size}, got {evolved_scores.numel()}"
        )

        attention_weights = torch.softmax(evolved_scores, dim=-1)

        output = torch.matmul(attention_weights, V)

        output = output.view(batch_size, self.num_heads, seq_length, self.head_dim).permute(0, 2, 1, 3)
        output = output.reshape(batch_size, seq_length, self.hidden_dim)
        return output, attention_weights


class ODEBertSelfAttention(nn.Module):
    def __init__(self, config):
        super(ODEBertSelfAttention, self).__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.ode_attention = ODEAttention(config.hidden_size, config.num_attention_heads)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def forward(self, hidden_states):
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)
        context_layer, attention_weights = self.ode_attention(Q, K, V)
        context_layer = self.dense(context_layer)
        context_layer = self.dropout(context_layer)
        return context_layer, attention_weights


class ODEBertModel(nn.Module):
    def __init__(self, config):
        super(ODEBertModel, self).__init__()
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.encoder = nn.ModuleList(
            [ODEBertSelfAttention(config) for _ in range(config.num_hidden_layers)]
        )

    def forward(self, input_ids, attention_mask):
        embedded_inputs = self.embeddings(input_ids)
        hidden_states = embedded_inputs
        for layer in self.encoder:
            hidden_states, _ = layer(hidden_states)
        return {"last_hidden_state": hidden_states}


def test_self_attention():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = BertConfig(hidden_size=768, num_attention_heads=12, attention_probs_dropout_prob=0.1)
    ode_self_attention = ODEBertSelfAttention(config).to(device)

    # Synthetic input
    batch_size = 4
    seq_length = 128
    hidden_states = torch.rand(batch_size, seq_length, config.hidden_size).to(device)

    # Forward pass
    context_layer, attention_weights = ode_self_attention(hidden_states)
    print(f"Context Layer Shape: {context_layer.shape}")
    print(f"Attention Weights Shape: {attention_weights.shape}")


def test_bert_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = BertConfig(
        vocab_size=30522,
        hidden_size=768,
        num_attention_heads=12,
        num_hidden_layers=4,
        attention_probs_dropout_prob=0.1,
    )
    ode_bert_model = ODEBertModel(config).to(device)

    batch_size = 4
    seq_length = 128
    input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)).to(device)
    attention_mask = torch.ones((batch_size, seq_length)).to(device)

    outputs = ode_bert_model(input_ids=input_ids, attention_mask=attention_mask)
    print(f"Last Hidden State Shape: {outputs['last_hidden_state'].shape}")


# testinggg
if __name__ == "__main__":
    print("Testing ODEBertSelfAttention...")
    test_self_attention()
    print("\nTesting ODEBertModel...")
    test_bert_model()


Testing ODEBertSelfAttention...
Context Layer Shape: torch.Size([4, 128, 768])
Attention Weights Shape: torch.Size([48, 128, 128])

Testing ODEBertModel...
Last Hidden State Shape: torch.Size([4, 128, 768])


In [6]:
class ODEBertForTextClassification(ODEBertModel):
    def __init__(self, config, num_labels):
        super().__init__(config)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs["last_hidden_state"][:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        loss = None
        if labels is not None:
            criterion = nn.CrossEntropyLoss()
            loss = criterion(logits, labels)
        return {"loss": loss, "logits": logits}

In [7]:
def test_text_classification():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    config = BertConfig(
        vocab_size=30522,
        hidden_size=768,
        num_attention_heads=12,
        num_hidden_layers=4,
        attention_probs_dropout_prob=0.1,
        hidden_dropout_prob=0.1,
    )
    num_labels = 2  

    model = ODEBertForTextClassification(config, num_labels).to(device)

    batch_size = 4
    seq_length = 128
    input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)).to(device)
    attention_mask = torch.ones((batch_size, seq_length)).to(device)
    labels = torch.randint(0, num_labels, (batch_size,)).to(device)  

    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    print(f"Loss: {outputs['loss']}")
    print(f"Logits Shape: {outputs['logits'].shape}") 


#  testing
if __name__ == "__main__":
    print("Testing ODEBertForTextClassification on GPU...")
    test_text_classification()

Testing ODEBertForTextClassification on GPU...
Loss: 0.6983002424240112
Logits Shape: torch.Size([4, 2])


In [8]:
from transformers import BertTokenizer, BertConfig


In [9]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
config = BertConfig.from_pretrained("bert-base-uncased")

In [12]:
num_labels = 2  
model = ODEBertForTextClassification(config, num_labels)

In [23]:
from datasets import load_dataset
dataset = load_dataset("imdb")


README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [27]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [29]:
from torch.utils.data import DataLoader


In [31]:
train_data = tokenized_datasets["train"].shuffle(seed=42).select(range(200))
test_data = tokenized_datasets["test"].shuffle(seed=42).select(range(50))
train_dataloader = DataLoader(train_data, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=8)


In [33]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
epochs = 3

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_dataloader:
        optimizer.zero_grad()
        outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
        loss = outputs["loss"]
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_dataloader)}")



KeyboardInterrupt



In [37]:
from tqdm import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
epochs = 100

for epoch in range(epochs):
    model.train()
    total_loss = 0

    # Wrap the dataloader with tqdm for progress tracking
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}", unit="batch")

    for batch in progress_bar:
        optimizer.zero_grad()
        outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
        loss = outputs["loss"]
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Update the progress bar with the current loss
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")


Epoch 1/100: 100%|██████████████████████████████████████████████████████| 25/25 [20:58<00:00, 50.35s/batch, loss=0.697]


Epoch 1, Average Loss: 0.6947


Epoch 2/100: 100%|██████████████████████████████████████████████████████| 25/25 [20:35<00:00, 49.43s/batch, loss=0.676]


Epoch 2, Average Loss: 0.6922


Epoch 3/100: 100%|██████████████████████████████████████████████████████| 25/25 [20:21<00:00, 48.85s/batch, loss=0.685]


Epoch 3, Average Loss: 0.7855


Epoch 4/100: 100%|███████████████████████████████████████████████████████| 25/25 [20:51<00:00, 50.06s/batch, loss=0.66]


Epoch 4, Average Loss: 0.6897


Epoch 5/100: 100%|██████████████████████████████████████████████████████| 25/25 [20:36<00:00, 49.45s/batch, loss=0.671]


Epoch 5, Average Loss: 0.6814


Epoch 6/100: 100%|██████████████████████████████████████████████████████| 25/25 [20:47<00:00, 49.91s/batch, loss=0.771]


Epoch 6, Average Loss: 0.6756


Epoch 7/100: 100%|██████████████████████████████████████████████████████| 25/25 [22:15<00:00, 53.42s/batch, loss=0.752]


Epoch 7, Average Loss: 0.6682


Epoch 8/100: 100%|██████████████████████████████████████████████████████| 25/25 [21:24<00:00, 51.37s/batch, loss=0.541]


Epoch 8, Average Loss: 0.6411


Epoch 9/100: 100%|██████████████████████████████████████████████████████| 25/25 [21:40<00:00, 52.02s/batch, loss=0.675]


Epoch 9, Average Loss: 0.6040


Epoch 10/100: 100%|██████████████████████████████████████████████████████| 25/25 [21:27<00:00, 51.51s/batch, loss=1.13]


Epoch 10, Average Loss: 0.5383


Epoch 11/100: 100%|█████████████████████████████████████████████████████| 25/25 [21:06<00:00, 50.66s/batch, loss=0.203]


Epoch 11, Average Loss: 0.4951


Epoch 12/100: 100%|█████████████████████████████████████████████████████| 25/25 [19:37<00:00, 47.09s/batch, loss=0.566]


Epoch 12, Average Loss: 0.2861


Epoch 13/100: 100%|█████████████████████████████████████████████████████| 25/25 [19:31<00:00, 46.86s/batch, loss=0.559]


Epoch 13, Average Loss: 0.2413


Epoch 14/100: 100%|█████████████████████████████████████████████████████| 25/25 [19:35<00:00, 47.03s/batch, loss=0.932]


Epoch 14, Average Loss: 0.3258


Epoch 15/100: 100%|███████████████████████████████████████████████████| 25/25 [18:55<00:00, 45.41s/batch, loss=0.00507]


Epoch 15, Average Loss: 0.1321


Epoch 16/100: 100%|████████████████████████████████████████████████████| 25/25 [19:23<00:00, 46.56s/batch, loss=0.0096]


Epoch 16, Average Loss: 0.0543


Epoch 17/100:   0%|                                                                          | 0/25 [00:29<?, ?batch/s]


KeyboardInterrupt: 