In [1]:
import sys
sys.path.append('/users/nfoster3/data/nfoster3/two_bit_bananas')
sys.path.append('/users/nfoster3/data/nfoster3/two_bit_bananas/simple_test')

from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
import torch
# import mnist from torchvision.datasets
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from SimpleModel import AnnealedModel

preprocessor = lambda x: ToTensor()(x).flatten()

mnist_dataset = MNIST(download=True, root='/users/nfoster3/data/nfoster3/two_bit_bananas/mnist_test/data', train=True, transform=preprocessor)

# Create a DataLoader with batch size 32 and shuffle the data with 10% set asside for testing
dataloader = DataLoader(mnist_dataset, batch_size=64, shuffle=True, drop_last=True)

In [2]:
def train(model, epochs, dataloader, lr=1e-4, final_temperature=200.0, lmbda = None, device = 'cuda:0'):
    # Set the model to training mode
    model.train()
    model = model.to(device)
    # Create an optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    temp_increase = (final_temperature /1.0) ** (1.0 / (epochs -1))
    
    # Loop over the epochs
    for epoch in range(epochs):
        # Loop over the data
        batch_progress_bar = tqdm(dataloader)
        for data, labels in batch_progress_bar:
            # Move the data to the device
            data = data.to(device)
            labels = labels.to(device)
            # Zero out the gradients
            optimizer.zero_grad()
            
            # 1-Hot encode the labels
            ce_loss = torch.nn.functional.cross_entropy(model(data), labels)
            # ce_loss = torch.tensor(0.0, requires_grad=False)
            if lmbda is not None:
                l0 = lmbda * model.compute_l0()
            else:
                l0 = torch.tensor(0.0, requires_grad=False)
            loss = ce_loss +  l0
            
            # Backpropagate the loss
            loss.backward()
            
            # Take a step
            optimizer.step()

            # Compute Training Accuracy
            predictions = torch.argmax(model(data), dim=-1)
            accuracy = torch.sum(predictions == labels) / labels.shape[0]

            # Update the progress bar
            batch_progress_bar.set_description(f'Epoch {epoch} CE Loss {ce_loss.item():.4f} L0 Loss {l0.item():.4f} Temp {model.temperature:.0f} Accuracy {accuracy:.4f}')

            # Increase the temperature
        model.set_temperature(model.temperature * temp_increase)


In [3]:
model = AnnealedModel(
    input_size=784,
    output_size=10,
    hidden_size=256,
    hidden_layers=10,
    bias=True,
    softmax=False,
)

In [4]:
train(model, 5, dataloader, lr=1e-2, final_temperature=200.0, lmbda=1e-8, device='cuda:0')

Epoch 0 CE Loss 10.4082 L0 Loss 0.0086 Temp 1 Accuracy 0.3281: 100%|█████████████████████████████████████████████████| 937/937 [00:20<00:00, 45.00it/s]
Epoch 1 CE Loss 3.3203 L0 Loss 0.0086 Temp 4 Accuracy 0.7031: 100%|██████████████████████████████████████████████████| 937/937 [00:20<00:00, 45.65it/s]
Epoch 2 CE Loss 0.9039 L0 Loss 0.0086 Temp 14 Accuracy 0.7656: 100%|█████████████████████████████████████████████████| 937/937 [00:20<00:00, 45.22it/s]
Epoch 3 CE Loss 0.5123 L0 Loss 0.0086 Temp 53 Accuracy 0.9062: 100%|█████████████████████████████████████████████████| 937/937 [00:20<00:00, 45.38it/s]
Epoch 4 CE Loss 0.3845 L0 Loss 0.0086 Temp 200 Accuracy 0.9062: 100%|████████████████████████████████████████████████| 937/937 [00:20<00:00, 45.68it/s]


In [5]:
model.eval()

In [6]:
test_dataset = MNIST(download=True, root='/users/nfoster3/data/nfoster3/two_bit_bananas/mnist_test/data', transform=preprocessor, target_transform=None, train=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, drop_last=True)

In [7]:
def test(model, test_dataloader):
    model.eval()
    model = model.to('cuda:0')
    total = 0
    correct = 0
    for data, labels in test_dataloader:
        data = data.to('cuda:0')
        labels = labels.to('cuda:0')
        predictions = torch.argmax(model(data), dim=1)
        correct += torch.sum(predictions == labels)
        total += labels.numel()
    return correct / total

In [8]:
result = test(model, test_dataloader)
print(f"Test Accuracy: {result:.4f}")

Test Accuracy: 0.8850


In [10]:
# Save the model
# torch.save(model.state_dict(), '/users/nfoster3/data/nfoster3/two_bit_bananas/mnist_test/model.pt')