In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# Define QMSModel
class QMSModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(QMSModel, self).__init__()
        self.fc_weights = nn.Parameter(torch.randn(num_classes, input_dim, input_dim))
        self.fc_bias = nn.Parameter(torch.randn(num_classes, input_dim))

    def forward(self, x):
        x_flat = x.view(x.size(0), -1)  # Flatten input
        quadratic_scores = torch.einsum('bi,nij,bj->bn', x_flat, self.fc_weights, x_flat)
        linear_scores = torch.einsum('bi,ni->bn', x_flat, self.fc_bias)
        return quadratic_scores + linear_scores

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_loader = DataLoader(
    datasets.FashionMNIST(root='./fashionMnist', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)
test_loader = DataLoader(
    datasets.FashionMNIST(root='./fashionMnist', train=False, download=True, transform=transform),
    batch_size=128, shuffle=False
)

# Instantiate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QMSModel(input_dim=28 * 28, num_classes=10).to(device)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(5):  # Train for 5 epochs
    model.train()
    correct, total = 0, 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, labels)

        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Epoch {epoch + 1}, Training Accuracy: {100 * correct / total:.2f}%')

# Testing loop
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f'Initial Test Accuracy: {test_accuracy:.2f}%')

# LEM Function
def classify_loyalty(model, data_loader, strong_threshold=0.8, weak_threshold=0.5):
    strong_samples = {'inputs': [], 'labels': []}
    normal_samples = {'inputs': [], 'labels': []}

    model.eval()
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            probabilities = nn.Softmax(dim=1)(outputs)
            max_probs, _ = torch.max(probabilities, dim=1)

            for i in range(inputs.size(0)):
                if max_probs[i] >= strong_threshold:
                    strong_samples['inputs'].append(inputs[i].cpu())
                    strong_samples['labels'].append(labels[i].cpu())
                elif weak_threshold <= max_probs[i] < strong_threshold:
                    normal_samples['inputs'].append(inputs[i].cpu())
                    normal_samples['labels'].append(labels[i].cpu())

    return strong_samples, normal_samples

# Run LEM and filter data
strong_samples, normal_samples = classify_loyalty(model, test_loader)

# Combine strong and normal samples for re-training
filtered_inputs = torch.stack(strong_samples['inputs'] + normal_samples['inputs'])
filtered_labels = torch.tensor(strong_samples['labels'] + normal_samples['labels'])

filtered_loader = DataLoader(TensorDataset(filtered_inputs, filtered_labels), batch_size=128, shuffle=True)

print(f"Filtered Data: {len(filtered_labels)} samples (Strong + Normal Loyalty)")

# Re-train model with filtered data
print("\nRe-training model with filtered data...")
for epoch in range(5):  # Re-train for 5 epochs
    model.train()
    correct, total = 0, 0

    for inputs, labels in filtered_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, labels)

        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Epoch {epoch + 1}, Filtered Training Accuracy: {100 * correct / total:.2f}%')

# Re-evaluate model on test data
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

filtered_test_accuracy = 100 * correct / total
print(f'\nTest Accuracy After Re-training with Filtered Data: {filtered_test_accuracy:.2f}%')
