In [None]:
# Q1.a

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import time
from tqdm import tqdm


class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def pgd_untargeted(model, x, y, k, eps, eps_step):
    model.eval()
    x_adv = x.clone().detach() + torch.zeros_like(x).uniform_(-eps, eps).to(x.device)
    x_adv = torch.clamp(x_adv, 0, 1)

    for i in range(k):
        x_adv.requires_grad_(True)
        with torch.enable_grad():
            outputs = model(x_adv)
            loss = nn.CrossEntropyLoss()(outputs, y)
            grad = torch.autograd.grad(loss, x_adv)[0]

        with torch.no_grad():
            x_adv = x_adv + eps_step * grad.sign()
            delta = torch.clamp(x_adv - x, -eps, eps)
            x_adv = torch.clamp(x + delta, 0, 1)
            x_adv = x_adv.detach()

    return x_adv

def train_with_ibp(model, train_loader, test_loader, epochs, target_eps=0.1, device="cuda"):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    warmup_steps = 2000
    rampup_steps = 10000
    total_steps = epochs * len(train_loader)

    lr_decay_steps = [15000, 25000]
    step = 0

    start_time = time.time()

    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            batch_size = data.size(0)

            if step in lr_decay_steps:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.1

            if step < warmup_steps:
                eps = 0
                kappa = 1.0
            else:
                eps = min(target_eps * (step - warmup_steps) / rampup_steps, target_eps)
                kappa = max(0.5, 1.0 - 0.5 * (step - warmup_steps) / rampup_steps)

            optimizer.zero_grad()

            natural_output = model(data)
            natural_loss = criterion(natural_output, target)

            if eps > 0:
                lower = data.view(-1, 784) - eps
                upper = data.view(-1, 784) + eps

                w1 = model.fc1.weight
                b1 = model.fc1.bias
                mu1 = (upper + lower) / 2
                r1 = (upper - lower) / 2
                center1 = torch.matmul(mu1, w1.t()) + b1
                radius1 = torch.matmul(r1, torch.abs(w1.t()))
                lower = torch.clamp(center1 - radius1, min=0)
                upper = torch.clamp(center1 + radius1, min=0)

                w2 = model.fc2.weight
                b2 = model.fc2.bias
                mu2 = (upper + lower) / 2
                r2 = (upper - lower) / 2
                center2 = torch.matmul(mu2, w2.t()) + b2
                radius2 = torch.matmul(r2, torch.abs(w2.t()))
                lower = torch.clamp(center2 - radius2, min=0)
                upper = torch.clamp(center2 + radius2, min=0)

                w3 = model.fc3.weight
                b3 = model.fc3.bias
                mu3 = (upper + lower) / 2
                r3 = (upper - lower) / 2
                center3 = torch.matmul(mu3, w3.t()) + b3
                radius3 = torch.matmul(r3, torch.abs(w3.t()))
                lower = center3 - radius3
                upper = center3 + radius3

                worst_case_logits = torch.zeros_like(natural_output).to(device)

                for i in range(batch_size):
                    for class_idx in range(10):
                        if class_idx == target[i]:
                            worst_case_logits[i, class_idx] = lower[i, class_idx]
                        else:
                            worst_case_logits[i, class_idx] = upper[i, class_idx]

                robust_loss = criterion(worst_case_logits, target)
            else:
                robust_loss = natural_loss

            loss = kappa * natural_loss + (1 - kappa) * robust_loss
            loss.backward()
            optimizer.step()

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, '
                      f'Loss: {loss.item():.4f}, eps: {eps:.4f}, kappa: {kappa:.4f}, '
                      f'LR: {optimizer.param_groups[0]["lr"]}')

            step += 1

    training_time = time.time() - start_time
    accuracy, robust_accuracy = evaluate_model(model, test_loader, target_eps, device)

    return accuracy, robust_accuracy, training_time

def evaluate_model(model, test_loader, target_eps, device):
    model.eval()
    correct = 0
    robust_correct = 0
    total = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        total += target.size(0)

        with torch.no_grad():
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

        x_adv = pgd_untargeted(model, data, target, k=200, eps=target_eps, eps_step=target_eps/4)
        with torch.no_grad():
            output = model(x_adv)
            robust_pred = output.argmax(dim=1)
            robust_correct += robust_pred.eq(target).sum().item()

    accuracy = 100. * correct / total
    robust_accuracy = 100. * robust_correct / total

    return accuracy, robust_accuracy

def main():
    torch.backends.cudnn.benchmark = True

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, pin_memory=True)

    print("Starting IBP training...")
    model_ibp = SimpleNN().to(device)
    accuracy_ibp, robust_accuracy_ibp, ibp_time = train_with_ibp(
        model_ibp, train_loader, test_loader, epochs=10, target_eps=0.1, device=device
    )

    print(f"\nIBP Training Results:")
    print(f"Standard Accuracy: {accuracy_ibp:.2f}%")
    print(f"Robust Accuracy (PGD): {robust_accuracy_ibp:.2f}%")
    print(f"Training Time: {ibp_time:.2f} seconds")

    print("\nStarting standard training...")
    model_std = SimpleNN().to(device)
    start_time = time.time()

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model_std.parameters(), lr=0.001)

    for epoch in range(5):
        model_std.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model_std(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    standard_time = time.time() - start_time

    standard_accuracy, standard_robust_accuracy = evaluate_model(
        model_std, test_loader, 0.1, device
    )

    print(f"\nStandard Training Results:")
    print(f"Standard Accuracy: {standard_accuracy:.2f}%")
    print(f"Robust Accuracy (PGD): {standard_robust_accuracy:.2f}%")
    print(f"Training Time: {standard_time:.2f} seconds")

    torch.save(model_ibp.state_dict(), 'ibp_model.pth')

if __name__ == "__main__":
    main()

In [None]:
# Q1.b

import matplotlib.pyplot as plt


def interval_analysis(model, x, epsilon, device):
    model.eval()
    x = x.view(-1, 784)

    lower_bound = x - epsilon
    upper_bound = x + epsilon

    weights = model.fc1.weight
    biases = model.fc1.bias
    lower_output = torch.matmul(lower_bound, weights.t()) + biases
    upper_output = torch.matmul(upper_bound, weights.t()) + biases
    lower_bound = torch.clamp(lower_output, min=0)
    upper_bound = torch.clamp(upper_output, min=0)

    weights = model.fc2.weight
    biases = model.fc2.bias
    lower_output = torch.matmul(lower_bound, weights.t()) + biases
    upper_output = torch.matmul(upper_bound, weights.t()) + biases
    lower_bound = torch.clamp(lower_output, min=0)
    upper_bound = torch.clamp(upper_output, min=0)

    weights = model.fc3.weight
    biases = model.fc3.bias
    lower_output = torch.matmul(lower_bound, weights.t()) + biases
    upper_output = torch.matmul(upper_bound, weights.t()) + biases

    return lower_output, upper_output

def verify_robustness(model, test_loader, epsilon_values, device):
    model.eval()
    verified_correct = [0] * len(epsilon_values)
    total = 0
    non_robust_examples = {eps: [] for eps in epsilon_values}

    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)
        batch_size = data.size(0)
        total += batch_size

        for eps_idx, epsilon in enumerate(epsilon_values):
            lower_bounds, upper_bounds = interval_analysis(model, data, epsilon, device)

            for i in range(batch_size):
                true_class = target[i]
                is_robust = True

                for other_class in range(10):
                    if other_class != true_class:
                        if lower_bounds[i, true_class] <= upper_bounds[i, other_class]:
                            is_robust = False
                            if len(non_robust_examples[epsilon]) < 2:
                                non_robust_examples[epsilon].append({
                                    'image': data[i].cpu(),
                                    'true_label': true_class.item(),
                                    'bounds': (lower_bounds[i].cpu(), upper_bounds[i].cpu())
                                })
                            break

                if is_robust:
                    verified_correct[eps_idx] += 1

        if batch_idx % 10 == 0:
            print(f"Processed {batch_idx * batch_size}/{total} images")

    verified_accuracies = [100 * correct / total for correct in verified_correct]
    return verified_accuracies, non_robust_examples

def plot_verification_results(epsilon_values, verified_accuracies):
    plt.figure(figsize=(10, 6))
    plt.plot(epsilon_values, verified_accuracies, marker='o')
    plt.xlabel('ε')
    plt.ylabel('Verified Accuracy (%)')
    plt.title('Verified Accuracy vs Perturbation Size')
    plt.grid(True)
    plt.savefig('verification_results.png')
    plt.close()

def analyze_non_robust_examples(non_robust_examples, epsilon_values, model, device):
    fig, axes = plt.subplots(len(epsilon_values), 2, figsize=(10, 4*len(epsilon_values)))

    for i, eps in enumerate(epsilon_values):
        if not non_robust_examples[eps]:
            continue

        example = non_robust_examples[eps][0]
        image = example['image']
        true_label = example['true_label']
        lower_bounds, upper_bounds = example['bounds']

        axes[i, 0].imshow(image.squeeze(), cmap='gray')
        axes[i, 0].set_title(f'ε={eps:.3f}, True Label: {true_label}')

        x_adv = pgd_untargeted(model,
                              image.unsqueeze(0).to(device),
                              torch.tensor([true_label]).to(device),
                              k=200,
                              eps=eps,
                              eps_step=eps/4)

        axes[i, 1].imshow(x_adv.squeeze().cpu(), cmap='gray')
        with torch.no_grad():
            pred = model(x_adv).argmax(dim=1).item()
        axes[i, 1].set_title(f'Adversarial Example\nPredicted: {pred}')

    plt.tight_layout()
    plt.savefig('non_robust_examples.png')
    plt.close()

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = SimpleNN().to(device)
    model.load_state_dict(torch.load('ibp_model.pth'))

    transform = transforms.Compose([transforms.ToTensor()])
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, pin_memory=True)

    epsilon_values = np.linspace(0.01, 0.1, 10)

    print("Starting verification...")
    verified_accuracies, non_robust_examples = verify_robustness(
        model, test_loader, epsilon_values, device
    )

    print("\nVerification Results:")
    for eps, acc in zip(epsilon_values, verified_accuracies):
        print(f"ε = {eps:.3f}: {acc:.2f}% verified accuracy")

    plot_verification_results(epsilon_values, verified_accuracies)

    analyze_non_robust_examples(non_robust_examples, epsilon_values, model, device)

if __name__ == "__main__":
    main()

In [None]:
# Q2


class BOWModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=300, hidden_dim=100, num_classes=2):
        super(BOWModel, self).__init__()

        self.g_word = nn.Linear(embed_dim, embed_dim)
        self.relu = nn.ReLU()

        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.relu(self.g_word(x))
        x = torch.mean(x, dim=1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x

def train_with_ibp(model, train_loader, test_loader, epochs, target_eps=0.1, device="cuda"):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

    warmup_steps = 2000
    rampup_steps = 10000
    total_steps = epochs * len(train_loader)
    step = 0

    start_time = time.time()

    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            batch_size = data.size(0)

            if step < warmup_steps:
                eps = 0
                kappa = 1.0
            else:
                eps = min(target_eps * (step - warmup_steps) / rampup_steps, target_eps)
                kappa = max(0.5, 1.0 - 0.5 * (step - warmup_steps) / rampup_steps)

            optimizer.zero_grad()

            natural_output = model(data)
            natural_loss = criterion(natural_output, target)

            if eps > 0:
                lower = data - eps
                upper = data + eps

                w1 = model.g_word.weight
                b1 = model.g_word.bias
                mu1 = (upper + lower) / 2
                r1 = (upper - lower) / 2
                center1 = torch.matmul(mu1.view(-1, mu1.size(-1)), w1.t()) + b1
                radius1 = torch.matmul(r1.view(-1, r1.size(-1)), torch.abs(w1.t()))
                lower = torch.clamp(center1 - radius1, min=0)
                upper = torch.clamp(center1 + radius1, min=0)
                lower = lower.view(batch_size, -1, lower.size(-1))
                upper = upper.view(batch_size, -1, upper.size(-1))
                lower = torch.mean(lower, dim=1)
                upper = torch.mean(upper, dim=1)

                w2 = model.fc1.weight
                b2 = model.fc1.bias
                mu2 = (upper + lower) / 2
                r2 = (upper - lower) / 2
                center2 = torch.matmul(mu2, w2.t()) + b2
                radius2 = torch.matmul(r2, torch.abs(w2.t()))
                lower = torch.clamp(center2 - radius2, min=0)
                upper = torch.clamp(center2 + radius2, min=0)

                w3 = model.fc2.weight
                b3 = model.fc2.bias
                mu3 = (upper + lower) / 2
                r3 = (upper - lower) / 2
                center3 = torch.matmul(mu3, w3.t()) + b3
                radius3 = torch.matmul(r3, torch.abs(w3.t()))
                lower = torch.clamp(center3 - radius3, min=0)
                upper = torch.clamp(center3 + radius3, min=0)

                w4 = model.fc3.weight
                b4 = model.fc3.bias
                mu4 = (upper + lower) / 2
                r4 = (upper - lower) / 2
                center4 = torch.matmul(mu4, w4.t()) + b4
                radius4 = torch.matmul(r4, torch.abs(w4.t()))
                lower = center4 - radius4
                upper = center4 + radius4

                worst_case_logits = torch.zeros_like(natural_output).to(device)

                for i in range(batch_size):
                    for class_idx in range(2):
                        if class_idx == target[i]:
                            worst_case_logits[i, class_idx] = lower[i, class_idx]
                        else:
                            worst_case_logits[i, class_idx] = upper[i, class_idx]

                robust_loss = criterion(worst_case_logits, target)
            else:
                robust_loss = natural_loss

            loss = kappa * natural_loss + (1 - kappa) * robust_loss
            loss.backward()
            optimizer.step()

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, '
                      f'Loss: {loss.item():.4f}, eps: {eps:.4f}, kappa: {kappa:.4f}')

            step += 1

        model.eval()
        test_loss = 0
        correct = 0
        robust_correct = 0
        total = 0

        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                batch_size = data.size(0)
                total += batch_size

                output = model(data)
                test_loss += criterion(output, target).item() * batch_size
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()

                if eps > 0:
                    lower = data - eps
                    upper = data + eps

                    w1 = model.g_word.weight
                    b1 = model.g_word.bias
                    mu1 = (upper + lower) / 2
                    r1 = (upper - lower) / 2
                    center1 = torch.matmul(mu1.view(-1, mu1.size(-1)), w1.t()) + b1
                    radius1 = torch.matmul(r1.view(-1, r1.size(-1)), torch.abs(w1.t()))
                    lower = torch.clamp(center1 - radius1, min=0)
                    upper = torch.clamp(center1 + radius1, min=0)
                    lower = lower.view(batch_size, -1, lower.size(-1))
                    upper = upper.view(batch_size, -1, upper.size(-1))
                    lower = torch.mean(lower, dim=1)
                    upper = torch.mean(upper, dim=1)

                    w2 = model.fc1.weight
                    b2 = model.fc1.bias
                    mu2 = (upper + lower) / 2
                    r2 = (upper - lower) / 2
                    center2 = torch.matmul(mu2, w2.t()) + b2
                    radius2 = torch.matmul(r2, torch.abs(w2.t()))
                    lower = torch.clamp(center2 - radius2, min=0)
                    upper = torch.clamp(center2 + radius2, min=0)

                    w3 = model.fc2.weight
                    b3 = model.fc2.bias
                    mu3 = (upper + lower) / 2
                    r3 = (upper - lower) / 2
                    center3 = torch.matmul(mu3, w3.t()) + b3
                    radius3 = torch.matmul(r3, torch.abs(w3.t()))
                    lower = torch.clamp(center3 - radius3, min=0)
                    upper = torch.clamp(center3 + radius3, min=0)

                    w4 = model.fc3.weight
                    b4 = model.fc3.bias
                    mu4 = (upper + lower) / 2
                    r4 = (upper - lower) / 2
                    center4 = torch.matmul(mu4, w4.t()) + b4
                    radius4 = torch.matmul(r4, torch.abs(w4.t()))
                    lower = center4 - radius4
                    upper = center4 + radius4

                    worst_case_logits = torch.zeros_like(output).to(device)
                    for i in range(batch_size):
                        for class_idx in range(2):
                            if class_idx == target[i]:
                                worst_case_logits[i, class_idx] = lower[i, class_idx]
                            else:
                                worst_case_logits[i, class_idx] = upper[i, class_idx]

                    robust_correct += (worst_case_logits.argmax(dim=1) == target).sum().item()

        test_loss /= total
        accuracy = 100. * correct / total
        robust_accuracy = 100. * robust_correct / total if eps > 0 else accuracy

        print(f'\nTest set: Average loss: {test_loss:.4f}, '
              f'Accuracy: {accuracy:.2f}%, Robust Accuracy: {robust_accuracy:.2f}%')

    training_time = time.time() - start_time
    return accuracy, robust_accuracy, training_time

def evaluate_model(model, test_loader, target_eps, device):
    model.eval()
    correct = 0
    robust_correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            total += target.size(0)

            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()

            if target_eps > 0:
                lower = data - target_eps
                upper = data + target_eps
                robust_correct += (worst_case_logits.argmax(dim=1) == target).sum().item()

    accuracy = 100. * correct / total
    robust_accuracy = 100. * robust_correct / total

    return accuracy, robust_accuracy

In [None]:
# Q2

# Generated using GPT to import GloVe cause torchtext was not working on my device.
def download_glove_embeddings():
    """Download GloVe embeddings if not already present."""
    url = "https://nlp.stanford.edu/data/glove.6B.zip"
    if not os.path.exists('glove.6B'):
        os.makedirs('glove.6B')

    zip_path = 'glove.6B/glove.6B.zip'
    if not os.path.exists(zip_path):
        print("Downloading GloVe embeddings...")
        response = requests.get(url)
        with open(zip_path, 'wb') as f:
            f.write(response.content)

        import zipfile
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall('glove.6B')

# Generated using GPT to import GloVe cause torchtext was not working on my device.
def load_glove_embeddings(vocab_size=50000, embedding_dim=300):
    """Load GloVe embeddings and create embedding matrix."""
    print("Loading GloVe embeddings...")

    # Download embeddings if needed
    download_glove_embeddings()

    embeddings_dict = {}
    word_to_idx = {'<pad>': 0}  # Add padding token
    vectors = [[0.] * embedding_dim]  # Add padding vector

    # Load embeddings from file
    with open(f'glove.6B/glove.6B.{embedding_dim}d.txt', 'r', encoding='utf-8') as f:
        for i, line in enumerate(tqdm(f, desc="Loading embeddings")):
            if i >= vocab_size - 1:  # -1 because we added padding
                break
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype='float32')
            embeddings_dict[word] = vector
            word_to_idx[word] = len(vectors)
            vectors.append(vector)

    print(f"Loaded {len(vectors)} word vectors")
    return np.array(vectors), word_to_idx



def preprocess_data(texts, word_to_idx, max_len=500):
    sequences = []
    for text in tqdm(texts, desc="Processing texts"):
        sequence = []
        for word in text.lower().split():
            if word in word_to_idx:
                sequence.append(word_to_idx[word])
        if len(sequence) > max_len:
            sequence = sequence[:max_len]
        else:
            sequence = sequence + [0] * (max_len - len(sequence))
        sequences.append(sequence)
    return np.array(sequences)

def create_datasets(train_texts, train_labels, test_texts, test_labels, embeddings, batch_size=32):
    train_embedded = torch.FloatTensor(np.array([
        [embeddings[idx] for idx in sequence]
        for sequence in tqdm(train_texts)
    ]))

    test_embedded = torch.FloatTensor(np.array([
        [embeddings[idx] for idx in sequence]
        for sequence in tqdm(test_texts)
    ]))

    train_dataset = TensorDataset(
        train_embedded,
        torch.LongTensor(train_labels)
    )
    test_dataset = TensorDataset(
        test_embedded,
        torch.LongTensor(test_labels)
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True
    )

    return train_loader, test_loader

def load_imdb_dataset():
    dataset = load_dataset("imdb")
    train_texts = dataset['train']['text']
    train_labels = dataset['train']['label']
    test_texts = dataset['test']['text']
    test_labels = dataset['test']['label']

    return train_texts, train_labels, test_texts, test_labels

In [None]:
def train_standard(model, train_loader, test_loader, epochs, device="cuda"):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

    start_time = time.time()

    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                pred = output.argmax(dim=1)
                total += target.size(0)
                correct += pred.eq(target).sum().item()

        accuracy = 100. * correct / total
        print(f'\nTest set: Accuracy: {accuracy:.2f}%')

    training_time = time.time() - start_time
    return accuracy, training_time

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)

    print(f"Using device: {device}")

    dataset = load_dataset("imdb")

    train_texts = dataset['train']['text']
    train_labels = dataset['train']['label']
    test_texts = dataset['test']['text']
    test_labels = dataset['test']['label']

    embeddings, word_to_idx = load_glove_embeddings()
    vocab_size = len(embeddings)

    train_sequences = preprocess_data(train_texts, word_to_idx)
    test_sequences = preprocess_data(test_texts, word_to_idx)

    train_loader, test_loader = create_datasets(
        train_sequences, train_labels,
        test_sequences, test_labels,
        embeddings,
        batch_size=32
    )

    print("\n=== Standard Training ===")
    model_std = BOWModel(vocab_size=vocab_size, embed_dim=300, hidden_dim=100, num_classes=2)
    accuracy_std, time_std = train_standard(
        model=model_std,
        train_loader=train_loader,
        test_loader=test_loader,
        epochs=10,
        device=device
    )

    print("\n=== IBP Training ===")
    model_ibp = BOWModel(vocab_size=vocab_size, embed_dim=300, hidden_dim=100, num_classes=2)
    accuracy_ibp, robust_accuracy_ibp, time_ibp = train_with_ibp(
        model=model_ibp,
        train_loader=train_loader,
        test_loader=test_loader,
        epochs=10,
        target_eps=0.1,
        device=device
    )

    print("\n=== Final Results ===")
    print(f"Standard Training:")
    print(f"- Standard Accuracy: {accuracy_std:.2f}%")
    print(f"- Training Time: {time_std:.2f} seconds")
    print(f"\nIBP Training:")
    print(f"- Standard Accuracy: {accuracy_ibp:.2f}%")
    print(f"- Verified Accuracy: {robust_accuracy_ibp:.2f}%")
    print(f"- Training Time: {time_ibp:.2f} seconds")
    print(f"\nComparison:")
    print(f"- Training Time Ratio (IBP/Standard): {time_ibp/time_std:.2f}x")
    print(f"- Accuracy Drop from IBP: {accuracy_std - accuracy_ibp:.2f}%")

    torch.save(model_std.state_dict(), 'bow_standard_model.pth')
    torch.save(model_ibp.state_dict(), 'bow_ibp_model.pth')
    print("\nModels saved as bow_standard_model.pth and bow_ibp_model.pth")

if __name__ == "__main__":
    main()