In [1]:
import sys
sys.path.append("..")

In [2]:
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn as nn

from mae.dataset import iterator
from torch.utils.data import DataLoader

from mae.layers.mae import MAE

In [3]:
train_iterator = iterator.MaskedImageNetIterator(is_train=True)
valid_iterator = iterator.MaskedImageNetIterator(is_train=False)

train_loader = DataLoader(train_iterator, batch_size=64*2, shuffle=True, num_workers=10)
valid_loader = DataLoader(valid_iterator, batch_size=64*2, shuffle=False, num_workers=10)

In [4]:
height = 224
width = 224
channel = 3
patch = 16
d_model = 128
d_ff = d_model * 4
ffn_typ = 'glu'
act_typ = 'GELU'
n_head = 8
dropout_p = 0.1
n_enc_layer = 3
n_dec_layer = 3
output_dim = len(train_iterator.label_dict)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
model = nn.DataParallel(MAE(
            height,
             width,
             channel,
             patch,
             d_model,
             d_ff,
             ffn_typ,
             act_typ,
             n_head,
             dropout_p,
             n_enc_layer,
             n_dec_layer)).to(device)

In [6]:
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()

In [7]:
# torch.autograd.set_detect_anomaly(True) # for debugging

def train() : 
    model.train()
    losses = []

    for dict_ in tqdm(train_loader, desc='train') : 
        recostructed = model(dict_['input'].to(device), dict_['unmask_bool'].to(device))

        optimizer.zero_grad()
        loss = criterion(recostructed, dict_['label'].to(device))
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    agg_loss = sum(losses) / len(losses)
    return agg_loss

def evalulate() : 
    model.eval()
    losses = []

    for dict_ in tqdm(valid_loader, desc='valid') : 
        recostructed = model(dict_['input'].to(device), dict_['unmask_bool'].to(device))

        loss = criterion(recostructed, dict_['label'].to(device))
        losses.append(loss.item())

    agg_loss = sum(losses) / len(losses)
    return agg_loss

In [None]:
epoches = 20

for proc in range(epoches) : 
    t_loss = train()
    v_loss = evalulate()
    print(f"""
                === {proc+1}th Epoch ===
    
        Train Loss : {round(t_loss, 3)} | Valid Loss : {round(v_loss, 3)}
        
                    ============================================
                    ============================================
    """)

train: 100%|██████████| 600/600 [01:44<00:00,  5.76it/s]
valid: 100%|██████████| 258/258 [00:40<00:00,  6.43it/s]



                === 1th Epoch ===
    
        Train Loss : 0.065 | Valid Loss : 0.056
        
    


train: 100%|██████████| 600/600 [01:39<00:00,  6.03it/s]
valid: 100%|██████████| 258/258 [00:40<00:00,  6.43it/s]



                === 2th Epoch ===
    
        Train Loss : 0.047 | Valid Loss : 0.044
        
    


train: 100%|██████████| 600/600 [01:38<00:00,  6.08it/s]
valid: 100%|██████████| 258/258 [00:40<00:00,  6.39it/s]



                === 3th Epoch ===
    
        Train Loss : 0.042 | Valid Loss : 0.039
        
    


train: 100%|██████████| 600/600 [01:39<00:00,  6.03it/s]
valid: 100%|██████████| 258/258 [00:39<00:00,  6.45it/s]



                === 4th Epoch ===
    
        Train Loss : 0.038 | Valid Loss : 0.035
        
    


train: 100%|██████████| 600/600 [01:39<00:00,  6.03it/s]
valid: 100%|██████████| 258/258 [00:40<00:00,  6.39it/s]



                === 5th Epoch ===
    
        Train Loss : 0.034 | Valid Loss : 0.031
        
    


train: 100%|██████████| 600/600 [01:39<00:00,  6.00it/s]
valid: 100%|██████████| 258/258 [00:40<00:00,  6.43it/s]



                === 6th Epoch ===
    
        Train Loss : 0.03 | Valid Loss : 0.027
        
    


train: 100%|██████████| 600/600 [01:38<00:00,  6.12it/s]
valid: 100%|██████████| 258/258 [00:40<00:00,  6.37it/s]



                === 7th Epoch ===
    
        Train Loss : 0.027 | Valid Loss : 0.024
        
    


train: 100%|██████████| 600/600 [01:39<00:00,  6.02it/s]
valid: 100%|██████████| 258/258 [00:39<00:00,  6.47it/s]



                === 8th Epoch ===
    
        Train Loss : 0.024 | Valid Loss : 0.022
        
    


train:  94%|█████████▎| 562/600 [01:32<00:06,  6.17it/s]