In [1]:
import time
import torch

import loader
import utility
import transformers

torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

Device: cuda


In [2]:
# Hyperparams
loader_params = {
    "batch_size": 4,
    "pad_images": True,
    "percent_mask": 0.1,
    "shuffle": True
}

train_params = {
    "epochs": 10,
    "lr": 1e-4,
    "print_frequency": 125,
    "max_eval_iters": None,
    "ema_alpha": 0.01,
    "nce_weight": 0.4
}

model_params = {
    "max_img_size": 32,
    "unique_patches": 13,
    "embed_dim": 32,
    "depth": 6,
    "num_heads": 4,
    "mlp_ratio": 4,
    "qkv_bias": False,
    "drop_rate": 0.0,
    "attn_drop_rate": 0.0,
    "head": False,
    "output_classes": 13
}

In [3]:
def train_iBOT(student_ViT, teacher_ViT, loader_params, train_params, device):
    train_dataloader = loader.get_dataloader('dict_traindata.txt', loader_params)
    eval_dataloader = loader.get_dataloader('dict_evaldata.txt', loader_params)
    optim = torch.optim.AdamW(student_ViT.parameters(), lr=train_params['lr'])
    
    for epoch in range(train_params['epochs']):
        print(f"\nEpoch {epoch + 1}/{train_params['epochs']}")
        print(f"{'Iter':>9} || {'Train Patch':>10} | {'Train Class':>10} | {'Train Total':>10} || {'Eval Patch':>10} | {'Eval Class':>10} | {'Eval Total':>10} || {'Samples/s':>10}")
        train_patch, train_class, last_iter = 0, 0, 0
        start_time = time.time()
        
        for i, (ids, u, u_masks, v, v_masks) in enumerate(train_dataloader):
            patch_loss, class_loss = utility.compute_loss(u, u_masks, v, v_masks, student_ViT, teacher_ViT, train_params['nce_weight'], device)
            loss = patch_loss + class_loss
            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()
            utility.update_teacher_weights(student_ViT, teacher_ViT, train_params['ema_alpha'])
    
            train_patch += patch_loss.cpu().item()
            train_class += class_loss.cpu().item()
            if (i+1) % train_params['print_frequency'] == 0 or (i+1) == len(train_dataloader):
                iter_count = (i+1) - last_iter
                elapsed_time = time.time() - start_time
                
                eval_patch, eval_class = utility.get_eval_loss(student_ViT, teacher_ViT, eval_dataloader, train_params['nce_weight'], device, max_eval_iters=train_params['max_eval_iters'])
                print(f"{(i + 1):>4}/{len(train_dataloader):>4} || {(train_patch / iter_count):>11.3f} | {(train_class / iter_count):>11.3f} | {((train_class + train_patch) / iter_count):>11.3f} || {eval_patch:>10.3f} | {eval_class:>10.3f} | {(eval_class + eval_patch):>10.3f} || {((iter_count * loader_params['batch_size']) / elapsed_time):>9.2f}")

                last_iter = (i+1)
                train_patch, train_class = 0, 0
                start_time = time.time()
        
        student_ViT.save_model()


# Instantiate and initialize teacher network to match the student
student_ViT = transformers.VisionTransformer(**model_params).to(device)
teacher_ViT = transformers.VisionTransformer(**model_params).to(device)
teacher_ViT.load_state_dict(student_ViT.state_dict())

# Train model
train_iBOT(student_ViT, teacher_ViT, loader_params, train_params, device)

Vision Transformer instantiated with 75,232 parameters.
Vision Transformer instantiated with 75,232 parameters.

Epoch 1/10
     Iter || Train Patch | Train Class | Train Total || Eval Patch | Eval Class | Eval Total ||  Samples/s
 125/1542 ||       0.856 |       1.065 |       1.920 ||      0.766 |      0.970 |      1.736 ||     11.97
 250/1542 ||       0.703 |       0.900 |       1.603 ||      0.689 |      0.733 |      1.422 ||     11.97
 375/1542 ||       0.656 |       0.710 |       1.366 ||      0.656 |      0.618 |      1.275 ||     11.98
 500/1542 ||       0.628 |       0.571 |       1.199 ||      0.624 |      0.532 |      1.156 ||     11.99
 625/1542 ||       0.560 |       0.556 |       1.116 ||      0.607 |      0.522 |      1.130 ||     11.97
 750/1542 ||       0.561 |       0.538 |       1.099 ||      0.573 |      0.520 |      1.093 ||     11.99
 875/1542 ||       0.547 |       0.508 |       1.056 ||      0.557 |      0.514 |      1.070 ||     11.98
1000/1542 ||       0.521 | 

KeyboardInterrupt: 