In [26]:
import torchvision
from torchvision.models import VisionTransformer
import torch
import torchvision.transforms as transforms
from finalnlp.replacer import replace_linears_in_pytorch_model
from finalnlp import bitnet158
from finalnlp import utils
import wandb

In [27]:
image_size=32
patch_size=4
num_layers=4
num_heads=2
hidden_dim=20
mlp_dim=20
num_classes=10
batch_size=8
lr=0.001
EPOCHS = 3
model = VisionTransformer(image_size=image_size, patch_size=patch_size, num_layers=num_layers, 
                          num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim, num_classes=num_classes)

In [28]:
wandb.login()

wandb.init(
    # Set the project where this run will be logged
    project="CIFAR-10",
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=f"plain",
    # Track hyperparameters and run metadata
    config={
    "image_size": image_size,
    "patch_size": patch_size,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "hidden_dim": hidden_dim,
    "mlp_dim": mlp_dim,
    "num_classes": num_classes,
    "batch_size": batch_size,
    "lr": lr,
    "EPOCHS": EPOCHS,
    "byte_count": utils.count_bytes(model, bitnet158.BitLinear158B),
    "dataset": "CIFAR",
    })



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,▁▄▅▇▇█
val_loss,█▅▄▂▂▁

0,1
acc,0.3457
val_loss,1.78396


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

In [29]:
transform = transforms.Compose(
    [torchvision.transforms.Grayscale(num_output_channels=3),
     transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.CIFAR10('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.CIFAR10('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False)

print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

Files already downloaded and verified
Files already downloaded and verified
Training set has 50000 instances
Validation set has 10000 instances


In [30]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = torch.nn.CrossEntropyLoss()

In [31]:
def get_accuracy_and_val_loss(model):
    running_vloss = 0.0
    model.eval()
    accuracy = 0.0
    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        total = 0
        num_correct = 0
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            pred = torch.argmax(voutputs, dim=1)
            num_equal = torch.sum(vlabels == pred)
            num_correct += num_equal.item()
            total += batch_size
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss
            
        accuracy = num_correct / total
    avg_vloss = running_vloss / (i + 1)
    wandb.log({"acc": accuracy, "val_loss": avg_vloss})
    
    return avg_vloss, accuracy

In [32]:
def train_one_epoch(epoch_index):
    running_loss = 0.
    last_loss = 0.

    for i, data in enumerate(training_loader):
        model.train(True)
        
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            #print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.
            avg_vloss, accuracy = get_accuracy_and_val_loss(model)
            print('  Batch: {} Validation Loss: {} Accuracy: {}'.format(i + 1, avg_vloss, accuracy))
            

    return last_loss

In [33]:
# Initializing in a separate cell so we can easily add more epochs to the same run
epoch_number = 0

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number)

    epoch_number += 1

EPOCH 1:
  Batch: 1000 Validation Loss: 2.018859624862671 Accuracy: 0.2496
  Batch: 2000 Validation Loss: 1.9930331707000732 Accuracy: 0.2396
  Batch: 3000 Validation Loss: 1.801468014717102 Accuracy: 0.3372
  Batch: 4000 Validation Loss: 1.8096880912780762 Accuracy: 0.3227
  Batch: 5000 Validation Loss: 1.7822550535202026 Accuracy: 0.35
  Batch: 6000 Validation Loss: 1.7286046743392944 Accuracy: 0.3705
EPOCH 2:
  Batch: 1000 Validation Loss: 1.7034220695495605 Accuracy: 0.3769
  Batch: 2000 Validation Loss: 1.6682103872299194 Accuracy: 0.3879
  Batch: 3000 Validation Loss: 1.6526589393615723 Accuracy: 0.4054
  Batch: 4000 Validation Loss: 1.6166620254516602 Accuracy: 0.4087
  Batch: 5000 Validation Loss: 1.6342134475708008 Accuracy: 0.4101
  Batch: 6000 Validation Loss: 1.6033977270126343 Accuracy: 0.4197
EPOCH 3:
  Batch: 1000 Validation Loss: 1.5589041709899902 Accuracy: 0.4334
  Batch: 2000 Validation Loss: 1.566047191619873 Accuracy: 0.4341
  Batch: 3000 Validation Loss: 1.5415236