In [6]:
from tqdm.notebook import tqdm

from mlp_mixer.dataset.iterator import ImageNetIterator
from mlp_mixer.layer.mixer import MM
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
train_iter = ImageNetIterator(root='../imagenet/',
                 height=256,
                 width=256,
                 is_train=True,
                 in_memory=False,
                 verbose=False)

test_iter = ImageNetIterator(root='../imagenet/',
                 height=256,
                 width=256,
                 is_train=True,
                 in_memory=False,
                 verbose=False)

train_loader = DataLoader(train_iter, batch_size=64*2, shuffle=True, num_workers=10)
valid_loader = DataLoader(test_iter, batch_size=64*2, shuffle=False, num_workers=10)

In [3]:
h, w, c = 256, 256, 3
p = 32
C=  256
S = (h*w)//(p**2)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model = nn.DataParallel(MM(h, w, c, 6, p, C, S, 2048, 256, len(train_iter.label_dict)).to(device))

In [7]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [8]:
def train() : 
    model.train()
    losses = []
    accuracies = []

    for x,y in tqdm(train_loader, desc='train') : 
        pred = model(x.to(device))

        optimizer.zero_grad()
        loss = criterion(pred, y.to(device))
        loss.backward()
        optimizer.step()

        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

def evalulate() : 
    model.eval()
    losses = []
    accuracies = []

    for x,y in tqdm(valid_loader, desc='valid') : 
        pred = model(x.to(device))

        loss = criterion(pred, y.to(device))
        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

In [59]:
epoches = 20

for proc in range(epoches) : 
    t_acc, t_loss = train()
    v_acc, v_loss = evalulate()
    print(f"""
                === {proc+1}th Epoch ===
    
        Train Loss : {round(t_loss, 3)} | Train Acc : {round(t_acc, 3)}
        Valid Loss : {round(v_loss, 3)} | Valid Acc : {round(v_acc, 3)}
        
        ============================================
        ============================================
    """)

train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 1th Epoch ===
    
        Train Loss : 3.89 | Train Acc : 0.143
        Valid Loss : 3.412 | Valid Acc : 0.211
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 2th Epoch ===
    
        Train Loss : 3.244 | Train Acc : 0.239
        Valid Loss : 2.925 | Valid Acc : 0.292
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 3th Epoch ===
    
        Train Loss : 2.966 | Train Acc : 0.289
        Valid Loss : 2.652 | Valid Acc : 0.346
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 4th Epoch ===
    
        Train Loss : 2.725 | Train Acc : 0.333
        Valid Loss : 2.35 | Valid Acc : 0.405
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 5th Epoch ===
    
        Train Loss : 2.492 | Train Acc : 0.379
        Valid Loss : 2.074 | Valid Acc : 0.461
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 6th Epoch ===
    
        Train Loss : 2.187 | Train Acc : 0.44
        Valid Loss : 1.639 | Valid Acc : 0.561
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 7th Epoch ===
    
        Train Loss : 1.818 | Train Acc : 0.519
        Valid Loss : 1.154 | Valid Acc : 0.683
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 8th Epoch ===
    
        Train Loss : 1.334 | Train Acc : 0.631
        Valid Loss : 0.719 | Valid Acc : 0.792
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 9th Epoch ===
    
        Train Loss : 0.857 | Train Acc : 0.751
        Valid Loss : 0.434 | Valid Acc : 0.869
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 10th Epoch ===
    
        Train Loss : 0.557 | Train Acc : 0.832
        Valid Loss : 0.342 | Valid Acc : 0.892
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 11th Epoch ===
    
        Train Loss : 0.464 | Train Acc : 0.859
        Valid Loss : 0.356 | Valid Acc : 0.89
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 12th Epoch ===
    
        Train Loss : 0.414 | Train Acc : 0.876
        Valid Loss : 0.307 | Valid Acc : 0.906
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 13th Epoch ===
    
        Train Loss : 0.398 | Train Acc : 0.885
        Valid Loss : 0.279 | Valid Acc : 0.916
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 14th Epoch ===
    
        Train Loss : 0.37 | Train Acc : 0.894
        Valid Loss : 0.318 | Valid Acc : 0.908
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 15th Epoch ===
    
        Train Loss : 0.346 | Train Acc : 0.901
        Valid Loss : 0.261 | Valid Acc : 0.924
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 16th Epoch ===
    
        Train Loss : 0.339 | Train Acc : 0.908
        Valid Loss : 0.249 | Valid Acc : 0.927
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 17th Epoch ===
    
        Train Loss : 0.351 | Train Acc : 0.907
        Valid Loss : 0.273 | Valid Acc : 0.926
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 18th Epoch ===
    
        Train Loss : 0.336 | Train Acc : 0.912
        Valid Loss : 0.236 | Valid Acc : 0.936
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 19th Epoch ===
    
        Train Loss : 0.3 | Train Acc : 0.922
        Valid Loss : 0.226 | Valid Acc : 0.94
        
    


train:   0%|          | 0/600 [00:00<?, ?it/s]

valid:   0%|          | 0/600 [00:00<?, ?it/s]


                === 20th Epoch ===
    
        Train Loss : 0.322 | Train Acc : 0.92
        Valid Loss : 0.26 | Valid Acc : 0.933
        
    
