In [1]:
import sys
sys.path.append("..")

In [2]:
from tqdm.notebook import tqdm

from mlp_mixer.dataset.iterator import ImageNetIterator
from gmlp.layer.model import Gmlp
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
train_iter = ImageNetIterator(root='../../imagenet/',
                 height=256,
                 width=256,
                 is_train=True,
                 in_memory=False,
                 verbose=False)

test_iter = ImageNetIterator(root='../../imagenet/',
                 height=256,
                 width=256,
                 is_train=True,
                 in_memory=False,
                 verbose=False)

train_loader = DataLoader(train_iter, batch_size=64*2, shuffle=True, num_workers=10)
valid_loader = DataLoader(test_iter, batch_size=64*2, shuffle=False, num_workers=10)

In [4]:
h, w, c = 256, 256, 3
p = 32
s = (h*w)//(p**2)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model = nn.DataParallel(Gmlp(h, w, c, p, s, 128, 128*4, 6, len(train_iter.label_dict)).to(device))

In [6]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [7]:
# torch.autograd.set_detect_anomaly(True)

def train() : 
    model.train()
    losses = []
    accuracies = []

    for x,y in tqdm(train_loader, desc='train') : 
        pred = model(x.to(device))

        optimizer.zero_grad()
        loss = criterion(pred, y.to(device))
        loss.backward()
        optimizer.step()

        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

def evalulate() : 
    model.eval()
    losses = []
    accuracies = []

    for x,y in tqdm(valid_loader, desc='valid') : 
        pred = model(x.to(device))

        loss = criterion(pred, y.to(device))
        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

In [8]:
epoches = 50

for proc in range(epoches) : 
    t_acc, t_loss = train()
    v_acc, v_loss = evalulate()
    print(f"""
                            === {proc+1}th Epoch ===
    
        Train Loss : {round(t_loss, 3)} | Train Acc : {round(t_acc, 3)}
        Valid Loss : {round(v_loss, 3)} | Valid Acc : {round(v_acc, 3)}
        
        ============================================
        ============================================
    """)

train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 1th Epoch ===
    
        Train Loss : 50.954 | Train Acc : 0.009
        Valid Loss : 5.238 | Valid Acc : 0.01
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 2th Epoch ===
    
        Train Loss : 5.223 | Train Acc : 0.009
        Valid Loss : 5.289 | Valid Acc : 0.009
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 3th Epoch ===
    
        Train Loss : 4.931 | Train Acc : 0.009
        Valid Loss : 4.837 | Valid Acc : 0.012
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 4th Epoch ===
    
        Train Loss : 4.824 | Train Acc : 0.01
        Valid Loss : 4.787 | Valid Acc : 0.011
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 5th Epoch ===
    
        Train Loss : 4.764 | Train Acc : 0.013
        Valid Loss : 4.751 | Valid Acc : 0.014
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 6th Epoch ===
    
        Train Loss : 4.752 | Train Acc : 0.014
        Valid Loss : 4.751 | Valid Acc : 0.013
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 7th Epoch ===
    
        Train Loss : 4.74 | Train Acc : 0.014
        Valid Loss : 4.729 | Valid Acc : 0.015
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 8th Epoch ===
    
        Train Loss : 4.728 | Train Acc : 0.017
        Valid Loss : 4.736 | Valid Acc : 0.017
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 9th Epoch ===
    
        Train Loss : 4.747 | Train Acc : 0.015
        Valid Loss : 4.733 | Valid Acc : 0.018
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 10th Epoch ===
    
        Train Loss : 4.793 | Train Acc : 0.012
        Valid Loss : 4.734 | Valid Acc : 0.014
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 11th Epoch ===
    
        Train Loss : 4.718 | Train Acc : 0.018
        Valid Loss : 4.706 | Valid Acc : 0.018
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 12th Epoch ===
    
        Train Loss : 4.725 | Train Acc : 0.018
        Valid Loss : 4.706 | Valid Acc : 0.02
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 13th Epoch ===
    
        Train Loss : 4.7 | Train Acc : 0.02
        Valid Loss : 4.674 | Valid Acc : 0.023
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 14th Epoch ===
    
        Train Loss : 4.667 | Train Acc : 0.025
        Valid Loss : 4.646 | Valid Acc : 0.027
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 15th Epoch ===
    
        Train Loss : 4.647 | Train Acc : 0.026
        Valid Loss : 4.645 | Valid Acc : 0.026
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 16th Epoch ===
    
        Train Loss : 4.628 | Train Acc : 0.027
        Valid Loss : 4.615 | Valid Acc : 0.029
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 17th Epoch ===
    
        Train Loss : 4.608 | Train Acc : 0.031
        Valid Loss : 4.596 | Valid Acc : 0.031
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 18th Epoch ===
    
        Train Loss : 4.502 | Train Acc : 0.041
        Valid Loss : 4.342 | Valid Acc : 0.061
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 19th Epoch ===
    
        Train Loss : 4.223 | Train Acc : 0.075
        Valid Loss : 4.08 | Valid Acc : 0.098
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 20th Epoch ===
    
        Train Loss : 4.045 | Train Acc : 0.104
        Valid Loss : 3.968 | Valid Acc : 0.114
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 21th Epoch ===
    
        Train Loss : 3.915 | Train Acc : 0.123
        Valid Loss : 3.834 | Valid Acc : 0.137
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 22th Epoch ===
    
        Train Loss : 3.812 | Train Acc : 0.137
        Valid Loss : 3.699 | Valid Acc : 0.155
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 23th Epoch ===
    
        Train Loss : 3.726 | Train Acc : 0.154
        Valid Loss : 3.682 | Valid Acc : 0.158
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 24th Epoch ===
    
        Train Loss : 3.647 | Train Acc : 0.165
        Valid Loss : 3.56 | Valid Acc : 0.18
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 25th Epoch ===
    
        Train Loss : 3.568 | Train Acc : 0.178
        Valid Loss : 3.489 | Valid Acc : 0.191
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 26th Epoch ===
    
        Train Loss : 3.499 | Train Acc : 0.19
        Valid Loss : 3.415 | Valid Acc : 0.205
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 27th Epoch ===
    
        Train Loss : 3.427 | Train Acc : 0.202
        Valid Loss : 3.362 | Valid Acc : 0.215
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 28th Epoch ===
    
        Train Loss : 3.365 | Train Acc : 0.212
        Valid Loss : 3.289 | Valid Acc : 0.227
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 29th Epoch ===
    
        Train Loss : 3.309 | Train Acc : 0.223
        Valid Loss : 3.24 | Valid Acc : 0.234
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 30th Epoch ===
    
        Train Loss : 3.269 | Train Acc : 0.232
        Valid Loss : 3.185 | Valid Acc : 0.245
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 31th Epoch ===
    
        Train Loss : 3.206 | Train Acc : 0.24
        Valid Loss : 3.15 | Valid Acc : 0.25
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 32th Epoch ===
    
        Train Loss : 3.158 | Train Acc : 0.248
        Valid Loss : 3.107 | Valid Acc : 0.258
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 33th Epoch ===
    
        Train Loss : 3.111 | Train Acc : 0.258
        Valid Loss : 3.023 | Valid Acc : 0.275
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 34th Epoch ===
    
        Train Loss : 3.064 | Train Acc : 0.268
        Valid Loss : 2.974 | Valid Acc : 0.283
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 35th Epoch ===
    
        Train Loss : 3.026 | Train Acc : 0.274
        Valid Loss : 2.95 | Valid Acc : 0.287
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 36th Epoch ===
    
        Train Loss : 3.022 | Train Acc : 0.275
        Valid Loss : 3.737 | Valid Acc : 0.148
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 37th Epoch ===
    
        Train Loss : 3.0 | Train Acc : 0.279
        Valid Loss : 2.878 | Valid Acc : 0.299
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 38th Epoch ===
    
        Train Loss : 2.907 | Train Acc : 0.296
        Valid Loss : 2.862 | Valid Acc : 0.306
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 39th Epoch ===
    
        Train Loss : 2.87 | Train Acc : 0.303
        Valid Loss : 2.811 | Valid Acc : 0.315
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 40th Epoch ===
    
        Train Loss : 2.847 | Train Acc : 0.307
        Valid Loss : 2.783 | Valid Acc : 0.319
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 41th Epoch ===
    
        Train Loss : 2.801 | Train Acc : 0.316
        Valid Loss : 2.719 | Valid Acc : 0.331
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 42th Epoch ===
    
        Train Loss : 2.783 | Train Acc : 0.32
        Valid Loss : 2.716 | Valid Acc : 0.334
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 43th Epoch ===
    
        Train Loss : 2.779 | Train Acc : 0.322
        Valid Loss : 4.024 | Valid Acc : 0.113
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 44th Epoch ===
    
        Train Loss : 2.882 | Train Acc : 0.302
        Valid Loss : 2.659 | Valid Acc : 0.344
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 45th Epoch ===
    
        Train Loss : 2.695 | Train Acc : 0.337
        Valid Loss : 2.624 | Valid Acc : 0.351
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 46th Epoch ===
    
        Train Loss : 2.664 | Train Acc : 0.341
        Valid Loss : 2.592 | Valid Acc : 0.358
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 47th Epoch ===
    
        Train Loss : 2.641 | Train Acc : 0.348
        Valid Loss : 2.566 | Valid Acc : 0.363
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 48th Epoch ===
    
        Train Loss : 2.619 | Train Acc : 0.353
        Valid Loss : 2.531 | Valid Acc : 0.369
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 49th Epoch ===
    
        Train Loss : 2.589 | Train Acc : 0.356
        Valid Loss : 2.498 | Valid Acc : 0.376
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                            === 50th Epoch ===
    
        Train Loss : 2.57 | Train Acc : 0.361
        Valid Loss : 2.499 | Valid Acc : 0.375
        
    
