In [None]:
for param in model.bert.transformer.layer[-2:].parameters():  # Access layers through 'transformer.layer'
    param.requires_grad = True

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [None]:
import torch
print(torch.cuda.is_available())  # Should return True
print(torch.cuda.get_device_name(0))  # Displays the GPU name

True
Tesla T4


In [1]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

NameError: name 'torch' is not defined

In [None]:
# Import necessary libraries
import torch
from torch import nn
from transformers import DistilBertModel

# Define TextBiMMOE if missing
class TextBiMMOE(nn.Module):
    def __init__(self, hidden_dim, num_experts):
        super(TextBiMMOE, self).__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        # Additional layers and experts...

    def forward(self, input_ids, attention_mask):
        x = self.bert(input_ids, attention_mask=attention_mask).last_hidden_state
        # Forward pass logic...
        return x

# Move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TextBiMMOE(hidden_dim=256, num_experts=4).to(device)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertModel
from datasets import load_dataset
import numpy as np
from tqdm import tqdm

# QDT Constants
class QDTConstants:
    def __init__(self):
        self.LAMBDA = 0.867
        self.GAMMA = 0.4497
        self.BETA = 0.310
        self.ETA = 0.520
        self.DROPOUT_RATE = 0.2
        self.L2_WEIGHT = 1e-5
        self.ENERGY_CEILING = 0.8

print("\nInitializing QDT-BiMMOE Model...")
print("=" * 50)
print("Loading Constants:")
qdt = QDTConstants()
print(f"λ (LAMBDA): {qdt.LAMBDA}")
print(f"γ (GAMMA): {qdt.GAMMA}")
print(f"β (BETA): {qdt.BETA}")
print(f"η (ETA): {qdt.ETA}")

class IMDbDataset(Dataset):
    def __init__(self, split="train", max_length=128, subsample=None):
        print(f"\nLoading {split} dataset...")
        self.dataset = load_dataset("imdb", split=split)
        if subsample:
            print(f"Subsampling {subsample} examples...")
            self.dataset = self.dataset.select(range(subsample))

        print("Initializing tokenizer...")
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        self.max_length = max_length

        self.texts = self.dataset["text"]
        self.sentiment_labels = torch.tensor(self.dataset["label"]).float().unsqueeze(-1)
        self.quality_labels = torch.tensor([
            1.0 if len(text.split()) > 100 else 0.0
            for text in self.texts
        ]).float().unsqueeze(-1)

        print(f"Dataset size: {len(self.texts)} examples")

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            'input_ids': encoding["input_ids"].squeeze(0),
            'attention_mask': encoding["attention_mask"].squeeze(0),
            'labels': [
                self.sentiment_labels[idx],
                self.quality_labels[idx]
            ]
        }

class RegularizedExpert(nn.Module):
    def __init__(self, input_dim, hidden_dim, qdt):
        super().__init__()
        self.qdt = qdt

        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(self.qdt.DROPOUT_RATE),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim)
        )

        self.energy_gate = nn.Sequential(
            nn.Linear(input_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        energy = self.energy_gate(x)
        output = self.network(x)
        return output, energy, output

class TextBiMMOE(nn.Module):
    def __init__(self, hidden_dim=256, num_experts=4):
        super().__init__()
        self.qdt = QDTConstants()

        print("\nInitializing BiMMOE architecture...")
        print(f"Hidden dimension: {hidden_dim}")
        print(f"Number of experts: {num_experts}")

        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        bert_dim = self.bert.config.hidden_size

        for param in self.bert.parameters():
            param.requires_grad = False

        self.experts = nn.ModuleList([
            RegularizedExpert(bert_dim, hidden_dim, self.qdt)
            for _ in range(num_experts)
        ])

        self.gates = nn.ModuleList([
            nn.Sequential(
                nn.Linear(bert_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),
                nn.Dropout(self.qdt.DROPOUT_RATE),
                nn.GELU(),
                nn.Linear(hidden_dim // 2, num_experts)
            )
            for _ in range(2)
        ])

        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),
                nn.Dropout(self.qdt.DROPOUT_RATE),
                nn.GELU(),
                nn.Linear(hidden_dim // 2, 1)
            )
            for _ in range(2)
        ])

        self.register_buffer('total_energy', torch.zeros(1))
        self.register_buffer('expert_usage', torch.zeros(num_experts))

    def forward(self, input_ids, attention_mask, training=True):
        with torch.no_grad():
            bert_output = self.bert(input_ids, attention_mask=attention_mask)
            x = bert_output.last_hidden_state[:, 0, :]

        if training:
            x = x + torch.randn_like(x) * 0.01

        expert_outputs = []
        expert_energies = []
        diversity_features = []

        for idx, expert in enumerate(self.experts):
            output, energy, div_feat = expert(x)
            expert_outputs.append(output)
            expert_energies.append(energy.mean().item())
            diversity_features.append(div_feat)

            if training:
                self.expert_usage[idx] = self.qdt.GAMMA * self.expert_usage[idx] + \
                                       (1 - self.qdt.GAMMA) * energy.mean().item()

        expert_outputs = torch.stack(expert_outputs, dim=1)
        diversity_features = torch.stack(diversity_features, dim=1)

        task_outputs = []
        gate_weights = []

        for gate, head in zip(self.gates, self.task_heads):
            logits = gate(x) / np.sqrt(expert_outputs.size(-1))
            weights = torch.softmax(logits, dim=-1)

            if training:
                usage_penalty = torch.softmax(-self.expert_usage, dim=0)
                weights = weights * usage_penalty.unsqueeze(0)
                weights = weights / weights.sum(dim=1, keepdim=True)

            gate_weights.append(weights)
            combined = torch.sum(expert_outputs * weights.unsqueeze(-1), dim=1)
            task_outputs.append(head(combined))

        mean_energy = np.mean(expert_energies)
        self.total_energy = torch.clamp(
            self.qdt.GAMMA * self.total_energy + (1 - self.qdt.GAMMA) * mean_energy,
            max=self.qdt.ENERGY_CEILING
        )

        return task_outputs, {
            'expert_energies': expert_energies,
            'total_energy': self.total_energy.item(),
            'gate_weights': gate_weights,
            'expert_usage': self.expert_usage.tolist(),
            'diversity_features': diversity_features
        }

class GeneralizedQDTLoss(nn.Module):
    def __init__(self, qdt_constants):
        super().__init__()
        self.qdt = qdt_constants
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, outputs, targets, stats, model):
        task_losses = []
        for output, target in zip(outputs, targets):
            task_losses.append(self.criterion(output, target))

        main_loss = sum(task_losses)
        energy_loss = stats['total_energy']

        diversity_loss = 0
        div_features = stats['diversity_features']
        for i in range(div_features.size(1)):
            for j in range(i + 1, div_features.size(1)):
                similarity = torch.cosine_similarity(
                    div_features[:, i], div_features[:, j], dim=-1
                ).mean()
                diversity_loss += similarity

        l2_reg = sum(torch.norm(param) for param in model.parameters())

        total_loss = main_loss + \
                    self.qdt.L2_WEIGHT * l2_reg + \
                    0.1 * energy_loss + \
                    0.1 * diversity_loss

        return total_loss

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    task_correct = [0, 0]
    task_total = [0, 0]

    print("\nRunning validation...")
    with torch.no_grad():
        for batch in tqdm(val_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = [label.to(device) for label in batch['labels']]

            outputs, stats = model(input_ids, attention_mask, training=False)
            loss = criterion(outputs, targets, stats, model)
            total_loss += loss.item()

            for idx, (output, target) in enumerate(zip(outputs, targets)):
                pred = (torch.sigmoid(output) > 0.5).float()
                task_correct[idx] += (pred == target).sum().item()
                task_total[idx] += target.numel()

    val_loss = total_loss / len(val_loader)
    print("\nValidation Results:")
    print(f"Loss: {val_loss:.4f}")
    print("Task Accuracies:")
    for idx in range(2):
        acc = 100 * task_correct[idx] / task_total[idx]
        print(f"  Task {idx + 1}: {acc:.2f}%")
    print("-" * 50)

    return val_loss

def train_model(model, train_loader, val_loader, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nTraining on device: {device}")

    model = model.to(device)
    criterion = GeneralizedQDTLoss(model.qdt)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=0.001,
        weight_decay=model.qdt.L2_WEIGHT
    )

    print("\nStarting training...")
    print("=" * 50)

    best_val_loss = float('inf')
    patience = 0
    max_patience = 3

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        task_correct = [0, 0]
        task_total = [0, 0]

        print(f"\nEpoch {epoch + 1}/{epochs}")
        pbar = tqdm(train_loader)
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = [label.to(device) for label in batch['labels']]

            optimizer.zero_grad()
            outputs, stats = model(input_ids, attention_mask, training=True)
            loss = criterion(outputs, targets, stats, model)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

            with torch.no_grad():
                for idx, (output, target) in enumerate(zip(outputs, targets)):
                    pred = (torch.sigmoid(output) > 0.5).float()
                    task_correct[idx] += (pred == target).sum().item()
                    task_total[idx] += target.numel()

            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'energy': f"{stats['total_energy']:.4f}"
            })

        avg_loss = total_loss / len(train_loader)
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Total Energy: {stats['total_energy']:.4f}")
        print("Expert Usage:", [f"{u:.3f}" for u in stats['expert_usage']])
        print("Task Accuracies:")
        for idx in range(2):
            acc = 100 * task_correct[idx] / task_total[idx]
            print(f"  Task {idx + 1}: {acc:.2f}%")

        val_loss = validate(model, val_loader, criterion, device)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = 0
            print(f"New best validation loss: {val_loss:.4f}")
        else:
            patience += 1
            if patience >= max_patience:
                print(f"\nEarly stopping triggered - no improvement for {max_patience} epochs")
                break

if __name__ == "__main__":
    print("\nStarting QDT-BiMMOE Training Pipeline")
    print("=" * 50)

    # Create datasets
    train_dataset = IMDbDataset(split="train", max_length=128, subsample=20000)
    val_dataset = IMDbDataset(split="test", max_length=128, subsample=5000)

    # Create dataloaders
    print("\nCreating dataloaders...")
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=2
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        num_workers=2
    )

    # Create and train model
    model = TextBiMMOE(hidden_dim=256, num_experts=4)
    train_model(model, train_loader, val_loader)


Initializing QDT-BiMMOE Model...
Loading Constants:
λ (LAMBDA): 0.867
γ (GAMMA): 0.4497
β (BETA): 0.31
η (ETA): 0.52

Starting QDT-BiMMOE Training Pipeline

Loading train dataset...


README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Subsampling 20000 examples...
Initializing tokenizer...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Dataset size: 20000 examples

Loading test dataset...
Subsampling 5000 examples...
Initializing tokenizer...
Dataset size: 5000 examples

Creating dataloaders...

Initializing BiMMOE architecture...
Hidden dimension: 256
Number of experts: 4

Training on device: cuda

Starting training...

Epoch 1/5


100%|██████████| 625/625 [02:37<00:00,  3.97it/s, loss=0.3973, energy=0.5000]



Epoch 1 Summary:
Average Loss: 0.6518
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 77.84%
  Task 2: 90.05%

Running validation...


100%|██████████| 157/157 [00:35<00:00,  4.46it/s]



Validation Results:
Loss: 0.3805
Task Accuracies:
  Task 1: 90.48%
  Task 2: 92.68%
--------------------------------------------------
New best validation loss: 0.3805

Epoch 2/5


100%|██████████| 625/625 [02:34<00:00,  4.04it/s, loss=0.3589, energy=0.5000]



Epoch 2 Summary:
Average Loss: 0.5385
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 80.00%
  Task 2: 92.53%

Running validation...


100%|██████████| 157/157 [00:34<00:00,  4.56it/s]



Validation Results:
Loss: 0.3266
Task Accuracies:
  Task 1: 93.18%
  Task 2: 93.88%
--------------------------------------------------
New best validation loss: 0.3266

Epoch 3/5


100%|██████████| 625/625 [02:34<00:00,  4.04it/s, loss=0.5089, energy=0.5000]



Epoch 3 Summary:
Average Loss: 0.5052
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 80.81%
  Task 2: 93.05%

Running validation...


100%|██████████| 157/157 [00:34<00:00,  4.49it/s]



Validation Results:
Loss: 0.5124
Task Accuracies:
  Task 1: 78.54%
  Task 2: 94.56%
--------------------------------------------------

Epoch 4/5


100%|██████████| 625/625 [02:33<00:00,  4.08it/s, loss=0.3226, energy=0.5001]



Epoch 4 Summary:
Average Loss: 0.4834
Total Energy: 0.5001
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 81.11%
  Task 2: 93.63%

Running validation...


100%|██████████| 157/157 [00:40<00:00,  3.92it/s]



Validation Results:
Loss: 0.2383
Task Accuracies:
  Task 1: 94.32%
  Task 2: 94.52%
--------------------------------------------------
New best validation loss: 0.2383

Epoch 5/5


100%|██████████| 625/625 [02:32<00:00,  4.11it/s, loss=0.6425, energy=0.5000]



Epoch 5 Summary:
Average Loss: 0.4723
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 81.55%
  Task 2: 93.97%

Running validation...


100%|██████████| 157/157 [00:34<00:00,  4.49it/s]


Validation Results:
Loss: 0.2747
Task Accuracies:
  Task 1: 93.32%
  Task 2: 94.84%
--------------------------------------------------





In [None]:
def train_model(model, train_loader, val_loader, epochs=5, save_path="best_model.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nTraining on device: {device}")

    model = model.to(device)
    criterion = GeneralizedQDTLoss(model.qdt)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=0.001,
        weight_decay=model.qdt.L2_WEIGHT
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)  # Learning rate scheduler

    best_val_loss = float('inf')
    patience = 0
    max_patience = 3

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        task_correct = [0, 0]
        task_total = [0, 0]

        print(f"\nEpoch {epoch + 1}/{epochs}")
        pbar = tqdm(train_loader)
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = [label.to(device) for label in batch['labels']]

            optimizer.zero_grad()
            outputs, stats = model(input_ids, attention_mask, training=True)
            loss = criterion(outputs, targets, stats, model)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

            with torch.no_grad():
                for idx, (output, target) in enumerate(zip(outputs, targets)):
                    pred = (torch.sigmoid(output) > 0.5).float()
                    task_correct[idx] += (pred == target).sum().item()
                    task_total[idx] += target.numel()

            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'energy': f"{stats['total_energy']:.4f}"
            })

        avg_loss = total_loss / len(train_loader)
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Total Energy: {stats['total_energy']:.4f}")
        print("Expert Usage:", [f"{u:.3f}" for u in stats['expert_usage']])
        print("Task Accuracies:")
        for idx in range(2):
            acc = 100 * task_correct[idx] / task_total[idx]
            print(f"  Task {idx + 1}: {acc:.2f}%")

        val_loss = validate(model, val_loader, criterion, device)
        scheduler.step()  # Update learning rate

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = 0
            save_checkpoint(epoch + 1, model, optimizer, best_val_loss, save_path)
        else:
            patience += 1
            if patience >= max_patience:
                print(f"\nEarly stopping triggered - no improvement for {max_patience} epochs")
                break

    print(f"Training completed. Best Validation Loss: {best_val_loss:.4f}")

In [None]:
class GeneralizedQDTLoss(nn.Module):
    def __init__(self, qdt_constants):
        super().__init__()
        self.qdt = qdt_constants
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, outputs, targets, stats, model):
        task_losses = []
        for output, target in zip(outputs, targets):
            task_losses.append(self.criterion(output, target))

        # Calculate main_loss here
        main_loss = sum(task_losses)

        energy_loss = stats['total_energy']

        # Ensure div_features is accessible here:
        div_features = stats['diversity_features']

        diversity_loss = 0
        for i in range(div_features.size(1)):
            for j in range(i + 1, div_features.size(1)):
                similarity = torch.cosine_similarity(
                    div_features[:, i], div_features[:, j], dim=-1
                ).mean()
                diversity_loss += similarity

        l2_reg = sum(torch.norm(param) for param in model.parameters())

       # Now total_loss can be calculated correctly
        total_loss = main_loss + \
                    self.qdt.L2_WEIGHT * l2_reg + \
                    0.1 * energy_loss + \
                    0.1 * diversity_loss

        return total_loss

In [None]:
# Define the number of epochs
epochs = 5  # Or any desired number of epochs

# Define the criterion outside the loop to make it accessible to validate
criterion = GeneralizedQDTLoss(model.qdt)  # Assuming GeneralizedQDTLoss is defined

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # Move model to the correct device

# Initialize optimizer here
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5) # Example optimizer

# Initialize best_val_loss
best_val_loss = float('inf') # Initialize to infinity to ensure the first model is always saved

# **Define max_patience here**
max_patience = 3  # Set the maximum number of epochs to wait for improvement

for epoch in range(epochs):
    model.train()
    total_loss = 0

    # Training code (omitted for brevity)

    # Call validate and assign val_loss
    val_loss = validate(model, val_loader, criterion, device)

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience = 0
        save_checkpoint(epoch + 1, model, optimizer, best_val_loss, filepath="best_model.pth")
    else:
        patience += 1
        if patience >= max_patience:
            print(f"\nEarly stopping triggered - no improvement for {max_patience} epochs")
            break


Running validation...


100%|██████████| 157/157 [00:37<00:00,  4.17it/s]



Validation Results:
Loss: 0.4093
Task Accuracies:
  Task 1: 79.28%
  Task 2: 93.48%
--------------------------------------------------
Checkpoint saved at best_model.pth

Running validation...


100%|██████████| 157/157 [00:37<00:00,  4.15it/s]



Validation Results:
Loss: 0.4093
Task Accuracies:
  Task 1: 79.28%
  Task 2: 93.48%
--------------------------------------------------

Running validation...


100%|██████████| 157/157 [00:35<00:00,  4.45it/s]



Validation Results:
Loss: 0.4093
Task Accuracies:
  Task 1: 79.28%
  Task 2: 93.48%
--------------------------------------------------

Running validation...


100%|██████████| 157/157 [00:35<00:00,  4.46it/s]


Validation Results:
Loss: 0.4093
Task Accuracies:
  Task 1: 79.28%
  Task 2: 93.48%
--------------------------------------------------

Early stopping triggered - no improvement for 3 epochs





In [None]:
# Cell ipython-input-3-5ff9ca44c9f1:
criterion = GeneralizedQDTLoss(qdt_constants=model.qdt)

# Initialize the model
model = TextBiMMOE(hidden_dim=256, num_experts=4)

# Then call the train_model function:
train_model(model, train_loader, val_loader, epochs=10, save_path="best_model.pth")


Initializing BiMMOE architecture...
Hidden dimension: 256
Number of experts: 4

Training on device: cuda

Epoch 1/10


100%|██████████| 625/625 [02:47<00:00,  3.74it/s, loss=0.4210, energy=0.5000]



Epoch 1 Summary:
Average Loss: 0.6532
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 77.45%
  Task 2: 90.34%

Running validation...


100%|██████████| 157/157 [00:36<00:00,  4.31it/s]



Validation Results:
Loss: 0.3328
Task Accuracies:
  Task 1: 91.54%
  Task 2: 93.06%
--------------------------------------------------
Checkpoint saved at best_model.pth

Epoch 2/10


100%|██████████| 625/625 [02:33<00:00,  4.07it/s, loss=0.3697, energy=0.5000]



Epoch 2 Summary:
Average Loss: 0.5338
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 80.25%
  Task 2: 92.98%

Running validation...


100%|██████████| 157/157 [00:37<00:00,  4.19it/s]



Validation Results:
Loss: 0.3151
Task Accuracies:
  Task 1: 93.34%
  Task 2: 94.00%
--------------------------------------------------
Checkpoint saved at best_model.pth

Epoch 3/10


100%|██████████| 625/625 [02:30<00:00,  4.15it/s, loss=0.4204, energy=0.5000]



Epoch 3 Summary:
Average Loss: 0.4779
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 81.11%
  Task 2: 93.84%

Running validation...


100%|██████████| 157/157 [00:37<00:00,  4.21it/s]



Validation Results:
Loss: 0.2642
Task Accuracies:
  Task 1: 91.98%
  Task 2: 94.76%
--------------------------------------------------
Checkpoint saved at best_model.pth

Epoch 4/10


100%|██████████| 625/625 [02:37<00:00,  3.97it/s, loss=0.6553, energy=0.5000]



Epoch 4 Summary:
Average Loss: 0.4682
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 81.20%
  Task 2: 93.98%

Running validation...


100%|██████████| 157/157 [00:35<00:00,  4.42it/s]



Validation Results:
Loss: 0.3729
Task Accuracies:
  Task 1: 85.92%
  Task 2: 94.92%
--------------------------------------------------

Epoch 5/10


100%|██████████| 625/625 [02:36<00:00,  3.99it/s, loss=0.3661, energy=0.5000]



Epoch 5 Summary:
Average Loss: 0.4415
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 82.00%
  Task 2: 94.44%

Running validation...


100%|██████████| 157/157 [00:34<00:00,  4.50it/s]



Validation Results:
Loss: 0.2509
Task Accuracies:
  Task 1: 92.98%
  Task 2: 94.86%
--------------------------------------------------
Checkpoint saved at best_model.pth

Epoch 6/10


100%|██████████| 625/625 [02:38<00:00,  3.95it/s, loss=0.2963, energy=0.5000]



Epoch 6 Summary:
Average Loss: 0.4312
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 82.51%
  Task 2: 94.76%

Running validation...


100%|██████████| 157/157 [00:34<00:00,  4.51it/s]



Validation Results:
Loss: 0.3055
Task Accuracies:
  Task 1: 90.94%
  Task 2: 94.86%
--------------------------------------------------

Epoch 7/10


100%|██████████| 625/625 [02:34<00:00,  4.05it/s, loss=0.2127, energy=0.5000]



Epoch 7 Summary:
Average Loss: 0.4193
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 82.72%
  Task 2: 94.69%

Running validation...


100%|██████████| 157/157 [00:35<00:00,  4.48it/s]



Validation Results:
Loss: 0.2382
Task Accuracies:
  Task 1: 92.50%
  Task 2: 95.48%
--------------------------------------------------
Checkpoint saved at best_model.pth

Epoch 8/10


100%|██████████| 625/625 [02:59<00:00,  3.49it/s, loss=0.3777, energy=0.5000]



Epoch 8 Summary:
Average Loss: 0.4122
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 82.72%
  Task 2: 94.84%

Running validation...


100%|██████████| 157/157 [00:50<00:00,  3.10it/s]



Validation Results:
Loss: 0.2777
Task Accuracies:
  Task 1: 91.06%
  Task 2: 95.24%
--------------------------------------------------

Epoch 9/10


100%|██████████| 625/625 [02:41<00:00,  3.86it/s, loss=0.2932, energy=0.5000]



Epoch 9 Summary:
Average Loss: 0.4020
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 83.41%
  Task 2: 94.96%

Running validation...


100%|██████████| 157/157 [00:35<00:00,  4.48it/s]



Validation Results:
Loss: 0.3016
Task Accuracies:
  Task 1: 88.42%
  Task 2: 95.60%
--------------------------------------------------

Epoch 10/10


100%|██████████| 625/625 [02:35<00:00,  4.01it/s, loss=0.3564, energy=0.5000]



Epoch 10 Summary:
Average Loss: 0.3995
Total Energy: 0.5000
Expert Usage: ['0.500', '0.500', '0.500', '0.500']
Task Accuracies:
  Task 1: 83.45%
  Task 2: 95.00%

Running validation...


100%|██████████| 157/157 [00:35<00:00,  4.47it/s]



Validation Results:
Loss: 0.2667
Task Accuracies:
  Task 1: 91.44%
  Task 2: 95.24%
--------------------------------------------------

Early stopping triggered - no improvement for 3 epochs
Training completed. Best Validation Loss: 0.2382


In [None]:
from torch.optim.lr_scheduler import StepLR

# Define the optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.001,
    weight_decay=model.qdt.L2_WEIGHT
)

# Define the scheduler
scheduler = StepLR(optimizer, step_size=2, gamma=0.5)  # Halve the learning rate every 2 epochs

# Save the model checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'best_val_loss': best_val_loss
}, "best_model.pth")

In [None]:
# Load the checkpoint
checkpoint = torch.load("best_model.pth")

# Load only the model's weights
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load("best_model.pth")


<All keys matched successfully>

In [None]:
import torch
from torch.utils.data import DataLoader

# Create the test dataset
test_dataset = IMDbDataset(split="test", max_length=128, subsample=5000)

# Create the DataLoader for the test dataset
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

# Load the saved checkpoint
checkpoint = torch.load("best_model.pth")

# Load the state dictionary into the model
model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
model.eval()

# Perform evaluation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        # Forward pass
        outputs, stats = model(input_ids, attention_mask, training=False)

        # Evaluate predictions here
        # Example: Process `outputs` for your specific tasks (e.g., binary classification)

# Load the checkpoint
checkpoint = torch.load("best_model.pth")

# Extract the model state dictionary and load it into the model
model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
model.eval()

# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Evaluate on the test dataset
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        # Forward pass
        outputs, stats = model(input_ids, attention_mask, training=False)

        # Process predictions
        # Example for binary classification
        sentiment_predictions = (torch.sigmoid(outputs[0]) > 0.5).float()
        quality_predictions = (torch.sigmoid(outputs[1]) > 0.5).float()
        print("Sentiment Predictions:", sentiment_predictions)
        print("Quality Predictions:", quality_predictions)




Loading test dataset...
Subsampling 5000 examples...
Initializing tokenizer...
Dataset size: 5000 examples


  checkpoint = torch.load("best_model.pth")
  checkpoint = torch.load("best_model.pth")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        [1.],
        [1.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')
Quality Predictions: tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], device='cuda:0')
Sentiment Predictions: tensor([[1.],
        [1.],
        [0.],
        [0.],
        [0.],
 

In [None]:
# Iterate through test_loader and save predictions
predictions = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        # Forward pass
        outputs, _ = model(input_ids, attention_mask)

        sentiment_preds = (torch.sigmoid(outputs[0]) > 0.5).float()
        quality_preds = (torch.sigmoid(outputs[1]) > 0.5).float()

        predictions.append((sentiment_preds.cpu().numpy(), quality_preds.cpu().numpy()))

# Save or analyze `predictions` as needed


In [None]:
from torch.optim.lr_scheduler import StepLR

# Define the optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.001,
    weight_decay=model.qdt.L2_WEIGHT
)

# Define the scheduler
scheduler = StepLR(optimizer, step_size=2, gamma=0.5)  # Halve the learning rate every 2 epochs

# Save the model checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'best_val_loss': best_val_loss
}, "best_model.pth")