In [1]:
# Train the Vit using Food101 dataset 

In [1]:
%%capture 
%run config.ipynb
%run ViT.ipynb

if DATASET=="mnist":
    %run data-MNIST.ipynb
    print(f'Using MNIST as dataset')
elif DATASET=="cifar10": 
    %run data-CIFAR10.ipynb
    print(f'Using CIFAR10 as dataset')
else:
    raise Exception("Invalid Configuration for DATASET")

In [2]:
from torch.optim import Adam 
from datetime import datetime 
import torch 
from tqdm import tqdm
from torchinfo import summary
import matplotlib.pyplot as plt
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR

In [3]:
model = ViT(
    image_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=IN_CHANNELS,
    n_head=N_HEAD, 
    d_model=D_MODEL, 
    ffn_hidden=FFN_HIDDEN, 
    mlp_hidden=MLP_HIDDEN, 
    n_layers=N_LAYERS, 
    class_num=CLASS_NUM, 
    device=device, 
    drop_prob=DROP_PROB,
)

# load model 
if LOAD_MODEL:
    loading_model_path = model_dir / LOADING_MODEL_NAME
    model.load_state_dict(torch.load(loading_model_path, weights_only=True))

model.train()

print(device)
model.to(device)

def count_parameters(model): 
    return sum(p.numel() for p in model.parameters() if p.requires_grad) 

logger.info(summary(model, input_size=(BATCH_SIZE, IN_CHANNELS, IMG_SIZE, IMG_SIZE)))
print(summary(model, input_size=(BATCH_SIZE, IN_CHANNELS, IMG_SIZE, IMG_SIZE)))

logger.info(f'model parameter #: {count_parameters(model)}')

# for name, param in model.named_parameters():
#     print(f"Parameter {name} is on {param.device}")
# for name, buffer in model.named_buffers():
#     print(f"Buffer {name} is on {buffer.device}")

cpu
Layer (type:depth-idx)                             Output Shape              Param #
ViT                                                [64, 10]                  --
├─Encoder: 1-1                                     [64, 50, 200]             --
│    └─TransformerEmbedding: 2-1                   [64, 50, 200]             200
│    │    └─PatchEmbedding: 3-1                    [64, 49, 200]             3,400
│    │    └─PositionalEmbedding: 3-2               [1, 50, 200]              10,000
│    │    └─Dropout: 3-3                           [64, 50, 200]             --
│    └─ModuleList: 2-2                             --                        --
│    │    └─EncoderBlock: 3-4                      [64, 50, 200]             366,712
│    │    └─EncoderBlock: 3-5                      [64, 50, 200]             366,712
│    │    └─EncoderBlock: 3-6                      [64, 50, 200]             366,712
│    │    └─EncoderBlock: 3-7                      [64, 50, 200]             366,712
│  

In [4]:
# Setup optimizer 
optimizer = Adam(params = model.parameters(), lr=INIT_LR, weight_decay=WEIGHT_DECAY)

scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.99 ** epoch)
# Setup loss function for training 
loss_func = nn.CrossEntropyLoss()

In [5]:
train_loss_per_epoch = []
test_loss_per_epoch = []
test_accuracy_per_epoch = []

In [6]:
# Plot the graph of train_epoch_loss, test_epoch_loss, test_accuracy
def plot_losses(loss_values, label):
    x0 = list(range(1, len(loss_values)+1))
    plt.figure(figsize=(5, 2))
    plt.plot(x0, loss_values)
    plt.title(label)
    plt.show()

In [7]:
def train_epoch(epoch_num): 
    # Prepare recording CUDA memory snapshot
    # torch.cuda.memory._record_memory_history(
    #     max_entries=100000
    # )
    model.train()
    train_epoch_loss = 0 
    lr_rate_per_step = []
    loss_per_step = []
    

    for step, (img, food) in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()

        img = img.to(device)
        food = food.to(device)
        out = model(img)

        loss = loss_func(out, food)

        clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
        
        loss.backward()

        # track lr rate per steps
        lr_rate_per_step.append(optimizer.param_groups[0]['lr'])

        # track loss per steps 
        loss_per_step.append(loss.item())

        optimizer.step()

        train_epoch_loss += loss.item()

    # after training is done, then print out the lr_rate_per_step
    # plot_losses(lr_rate_per_step, f'LR rate in EPOCH #{epoch_num}')
    # plot_losses(loss_per_step, f'Loss per step in EPOCH #{epoch_num}')
    
    train_step_loss = train_epoch_loss / (step + 1) 
    return train_epoch_loss, train_step_loss

In [8]:
def evaluate():
    model.eval()
    test_epoch_loss = 0
    correct_cnt = 0
    total_cnt = 0

    with torch.no_grad():
        for step, (img, food) in tqdm(enumerate(test_dataloader)):
            img = img.to(device)
            food = food.to(device)
            out = model(img)
            
            pred, idx_ = out.max(-1)

            loss = loss_func(out, food)
            
            correct_cnt += torch.eq(food, idx_).sum().item()
            total_cnt += food.size(0)

            test_epoch_loss += loss.item()

    test_step_loss = test_epoch_loss / (step + 1)
    accuracy = correct_cnt / total_cnt * 100

    return test_epoch_loss, test_step_loss, accuracy

In [9]:
# Actual training is done here

min_test_loss = 100_000_000

for epoch in range(EPOCHS):
    train_epoch_loss, train_step_loss = train_epoch(epoch)
    test_epoch_loss, test_step_loss, test_accuracy = evaluate()

    train_loss_per_epoch.append(train_step_loss)
    test_loss_per_epoch.append(test_step_loss)
    test_accuracy_per_epoch.append(test_accuracy)

    logger.info(f'Epoch #{epoch} End | Train Loss: {train_step_loss} | Test Loss: {test_step_loss} | Test Accuracy: {test_accuracy:.2f}%')
    scheduler.step()
    # save the model parameter if it reaches the minimum test loss
    if min_test_loss > test_step_loss:
        min_test_loss = test_step_loss
        model_path = model_dir / f'model_{timestamp}_{epoch}'
        logger.info(f'Reached new min test loss. Saving the model at {model_path}')
        torch.save(model.state_dict(), model_path)

logger.info(f'Training Completely Ended!!')











29it [03:31,  3.92it/s]

KeyboardInterrupt: 

In [None]:
plot_losses(train_loss_per_epoch, 'Train Loss')

In [None]:
plot_losses(test_loss_per_epoch, 'Test Loss')

In [None]:
plot_losses(test_accuracy_per_epoch, 'Test Accuracy')