In [1]:
import torchvision
from torchvision.models import VisionTransformer
import torch
import torchvision.transforms as transforms
from finalnlp.replacer import replace_linears_in_pytorch_model
from finalnlp import bitnet158
from finalnlp import utils
import wandb

In [2]:
image_size=28
patch_size=4
num_layers=4
num_heads=2
hidden_dim=20
mlp_dim=20
num_classes=10
batch_size=4
lr=0.001
EPOCHS = 1
model = VisionTransformer(image_size=image_size, patch_size=patch_size, num_layers=num_layers, 
                          num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim, num_classes=num_classes)

In [3]:
wandb.login()

wandb.init(
    # Set the project where this run will be logged
    project="vision-transformer",
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=f"plain",
    # Track hyperparameters and run metadata
    config={
    "image_size": image_size,
    "patch_size": patch_size,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "hidden_dim": hidden_dim,
    "mlp_dim": mlp_dim,
    "num_classes": num_classes,
    "batch_size": batch_size,
    "lr": lr,
    "EPOCHS": EPOCHS,
    "byte_count": utils.count_bytes(model, bitnet158.BitLinear158B),
    "dataset": "MNIST",
    })

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33misaackletzli[0m ([33mnlp-quantization[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
transform = transforms.Compose(
    [torchvision.transforms.Grayscale(num_output_channels=3),
     transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.MNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.MNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False)

print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

Training set has 60000 instances
Validation set has 10000 instances


In [5]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = torch.nn.CrossEntropyLoss()

In [6]:
def get_accuracy_and_val_loss(model):
    running_vloss = 0.0
    model.eval()
    accuracy = 0.0
    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        total = 0
        num_correct = 0
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            pred = torch.argmax(voutputs, dim=1)
            num_equal = torch.sum(vlabels == pred)
            num_correct += num_equal.item()
            total += batch_size
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss
            
        accuracy = num_correct / total
    avg_vloss = running_vloss / (i + 1)
    wandb.log({"acc": accuracy, "val_loss": avg_vloss})
    
    return avg_vloss, accuracy

In [7]:
def train_one_epoch(epoch_index):
    running_loss = 0.
    last_loss = 0.

    for i, data in enumerate(training_loader):
        model.train(True)
        
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            #print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.
            avg_vloss, accuracy = get_accuracy_and_val_loss(model)
            print('  Batch: {} Validation Loss: {} Accuracy: {}'.format(i + 1, avg_vloss, accuracy))
            

    return last_loss

In [8]:
# Initializing in a separate cell so we can easily add more epochs to the same run
epoch_number = 0

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number)

    epoch_number += 1

EPOCH 1:
  Batch: 1000 Validation Loss: 1.642441749572754 Accuracy: 0.3699
  Batch: 2000 Validation Loss: 1.1535154581069946 Accuracy: 0.5781
  Batch: 3000 Validation Loss: 1.0149601697921753 Accuracy: 0.6344
  Batch: 4000 Validation Loss: 0.7467290163040161 Accuracy: 0.7192
  Batch: 5000 Validation Loss: 0.6492432951927185 Accuracy: 0.7873
  Batch: 6000 Validation Loss: 0.5614609122276306 Accuracy: 0.8121
  Batch: 7000 Validation Loss: 0.40694159269332886 Accuracy: 0.8748
  Batch: 8000 Validation Loss: 0.37758365273475647 Accuracy: 0.8872
  Batch: 9000 Validation Loss: 0.33137333393096924 Accuracy: 0.8992
  Batch: 10000 Validation Loss: 0.39209315180778503 Accuracy: 0.8743
  Batch: 11000 Validation Loss: 0.3198293149471283 Accuracy: 0.8964
  Batch: 12000 Validation Loss: 0.3246166408061981 Accuracy: 0.9019
  Batch: 13000 Validation Loss: 0.26091131567955017 Accuracy: 0.9185
  Batch: 14000 Validation Loss: 0.25055012106895447 Accuracy: 0.9227
  Batch: 15000 Validation Loss: 0.250259041