In [1]:
import sys
sys.path.append("..")

In [2]:
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn as nn

from vision_transformer.dataset import iterator
from torch.utils.data import DataLoader
from vision_transformer.layers import classifier

In [3]:
train_iterator = iterator.ImageNetIterator(is_train=True)
valid_iterator = iterator.ImageNetIterator(is_train=False)

train_loader = DataLoader(train_iterator, batch_size=64*2, shuffle=True, num_workers=10)
valid_loader = DataLoader(valid_iterator, batch_size=64*2, shuffle=False, num_workers=10)

In [4]:
height = 256
width = 256
channel = 3
patch = 16
d_model = 256
d_ff = d_model * 4
n_head = 8
dropout_p = 0.1
n_enc_layer = 3
output_dim = len(train_iterator.label_dict)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
model = nn.DataParallel(classifier.ViT(
    height,
    width,
    channel,
    patch,
    d_model,
    d_ff,
    n_head,
    dropout_p,
    n_enc_layer,
    output_dim,)).to(device)

In [6]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [7]:
def train() : 
    model.train()
    losses = []
    accuracies = []

    for x,y in tqdm(train_loader, desc='train') : 
        pred = model(x.to(device))

        optimizer.zero_grad()
        loss = criterion(pred, y.to(device))
        loss.backward()
        optimizer.step()

        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

def evalulate() : 
    model.eval()
    losses = []
    accuracies = []

    for x,y in tqdm(valid_loader, desc='valid') : 
        pred = model(x.to(device))

        loss = criterion(pred, y.to(device))
        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

In [None]:
epoches = 20

for proc in range(epoches) : 
    t_acc, t_loss = train()
    v_acc, v_loss = evalulate()
    print(f"""
                === {proc+1}th Epoch ===
    
        Train Loss : {round(t_loss, 3)} | Train Acc : {round(t_acc, 3)}
        Valid Loss : {round(v_loss, 3)} | Valid Acc : {round(v_acc, 3)}
        
        ============================================
        ============================================
    """)

train: 100%|██████████| 600/600 [01:58<00:00,  5.07it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.44it/s]



                === 1th Epoch ===
    
        Train Loss : 3.984 | Train Acc : 0.114
        Valid Loss : 3.615 | Valid Acc : 0.167
        
    


train: 100%|██████████| 600/600 [01:57<00:00,  5.12it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.42it/s]



                === 2th Epoch ===
    
        Train Loss : 3.443 | Train Acc : 0.194
        Valid Loss : 3.36 | Valid Acc : 0.212
        
    


train: 100%|██████████| 600/600 [01:56<00:00,  5.14it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.42it/s]



                === 3th Epoch ===
    
        Train Loss : 3.269 | Train Acc : 0.224
        Valid Loss : 3.256 | Valid Acc : 0.23
        
    


train: 100%|██████████| 600/600 [01:57<00:00,  5.12it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.49it/s]



                === 4th Epoch ===
    
        Train Loss : 3.151 | Train Acc : 0.246
        Valid Loss : 3.214 | Valid Acc : 0.24
        
    


train: 100%|██████████| 600/600 [01:57<00:00,  5.11it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.48it/s]



                === 5th Epoch ===
    
        Train Loss : 3.056 | Train Acc : 0.263
        Valid Loss : 3.147 | Valid Acc : 0.257
        
    


train: 100%|██████████| 600/600 [01:57<00:00,  5.11it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.44it/s]



                === 6th Epoch ===
    
        Train Loss : 2.97 | Train Acc : 0.279
        Valid Loss : 3.103 | Valid Acc : 0.269
        
    


train: 100%|██████████| 600/600 [01:57<00:00,  5.12it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.48it/s]



                === 7th Epoch ===
    
        Train Loss : 2.886 | Train Acc : 0.294
        Valid Loss : 3.12 | Valid Acc : 0.261
        
    


train: 100%|██████████| 600/600 [01:56<00:00,  5.14it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.43it/s]



                === 8th Epoch ===
    
        Train Loss : 2.822 | Train Acc : 0.305
        Valid Loss : 3.111 | Valid Acc : 0.271
        
    


train: 100%|██████████| 600/600 [01:57<00:00,  5.12it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.42it/s]



                === 9th Epoch ===
    
        Train Loss : 2.742 | Train Acc : 0.319
        Valid Loss : 3.11 | Valid Acc : 0.276
        
    


train: 100%|██████████| 600/600 [01:56<00:00,  5.13it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.48it/s]



                === 10th Epoch ===
    
        Train Loss : 2.67 | Train Acc : 0.331
        Valid Loss : 3.105 | Valid Acc : 0.278
        
    


train: 100%|██████████| 600/600 [01:56<00:00,  5.13it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.52it/s]



                === 11th Epoch ===
    
        Train Loss : 2.598 | Train Acc : 0.349
        Valid Loss : 3.103 | Valid Acc : 0.277
        
    


train: 100%|██████████| 600/600 [01:57<00:00,  5.13it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.51it/s]



                === 12th Epoch ===
    
        Train Loss : 2.526 | Train Acc : 0.36
        Valid Loss : 3.181 | Valid Acc : 0.277
        
    


train: 100%|██████████| 600/600 [01:56<00:00,  5.14it/s]
valid: 100%|██████████| 258/258 [00:30<00:00,  8.49it/s]



                === 13th Epoch ===
    
        Train Loss : 2.445 | Train Acc : 0.377
        Valid Loss : 3.188 | Valid Acc : 0.278
        
    


train: 100%|██████████| 600/600 [01:57<00:00,  5.11it/s]
valid: 100%|██████████| 258/258 [00:29<00:00,  8.62it/s]



                === 14th Epoch ===
    
        Train Loss : 2.368 | Train Acc : 0.39
        Valid Loss : 3.205 | Valid Acc : 0.276
        
    


train: 100%|██████████| 600/600 [01:56<00:00,  5.15it/s]
valid:  16%|█▌        | 40/258 [00:05<00:23,  9.08it/s]