In [1]:
import sys
sys.path.append("..")

In [2]:
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn as nn

from vision_transformer.dataset import iterator
from torch.utils.data import DataLoader
from vision_transformer.layers import classifier

In [3]:
train_iterator = iterator.ImageNetIterator(is_train=True)
valid_iterator = iterator.ImageNetIterator(is_train=False)

train_loader = DataLoader(train_iterator, batch_size=32*2, shuffle=True, num_workers=10)
valid_loader = DataLoader(valid_iterator, batch_size=32*2, shuffle=False, num_workers=10)

___

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
import torchvision.models as models
model = models.resnet18(True)
model.fc = nn.Linear(512, len(train_iterator.label_dict)) # add new lasy fc layer
model = nn.DataParallel(model).to(device) # paralleize

In [6]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [7]:
def train() : 
    model.train()
    losses = []
    accuracies = []

    for x,y in tqdm(train_loader, desc='train') : 
        pred = model(x.to(device))

        optimizer.zero_grad()
        loss = criterion(pred, y.to(device))
        loss.backward()
        optimizer.step()

        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

def evalulate() : 
    model.eval()
    losses = []
    accuracies = []

    for x,y in tqdm(valid_loader, desc='valid') : 
        pred = model(x.to(device))

        loss = criterion(pred, y.to(device))
        correct = (torch.argmax(pred, dim=1) == y.to(device)).sum()
        acc = correct.item() / y.shape[0]

        accuracies.append(acc)
        losses.append(loss.item())

    agg_acc = sum(accuracies) / len(accuracies)
    agg_loss = sum(losses) / len(losses)
    return agg_acc, agg_loss

In [8]:
epoches = 20

for proc in range(epoches) : 
    t_acc, t_loss = train()
    v_acc, v_loss = evalulate()
    print(f"""
                === {proc+1}th Epoch ===
    
        Train Loss : {round(t_loss, 3)} | Train Acc : {round(t_acc, 3)}
        Valid Loss : {round(v_loss, 3)} | Valid Acc : {round(v_acc, 3)}
        
        ============================================
        ============================================
    """)

train: 100%|██████████| 1200/1200 [02:49<00:00,  7.08it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.12it/s]



                === 1th Epoch ===
    
        Train Loss : 1.83 | Train Acc : 0.515
        Valid Loss : 1.6 | Valid Acc : 0.574
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.18it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.02it/s]



                === 2th Epoch ===
    
        Train Loss : 1.266 | Train Acc : 0.649
        Valid Loss : 1.413 | Valid Acc : 0.617
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.18it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.04it/s]



                === 3th Epoch ===
    
        Train Loss : 1.063 | Train Acc : 0.697
        Valid Loss : 1.241 | Valid Acc : 0.664
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.16it/s]
valid: 100%|██████████| 516/516 [00:37<00:00, 13.94it/s]



                === 4th Epoch ===
    
        Train Loss : 0.919 | Train Acc : 0.733
        Valid Loss : 1.184 | Valid Acc : 0.687
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.18it/s]
valid: 100%|██████████| 516/516 [00:35<00:00, 14.69it/s]



                === 5th Epoch ===
    
        Train Loss : 0.799 | Train Acc : 0.766
        Valid Loss : 1.173 | Valid Acc : 0.687
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.18it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.05it/s]



                === 6th Epoch ===
    
        Train Loss : 0.709 | Train Acc : 0.788
        Valid Loss : 1.134 | Valid Acc : 0.7
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.16it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.00it/s]



                === 7th Epoch ===
    
        Train Loss : 0.619 | Train Acc : 0.811
        Valid Loss : 1.238 | Valid Acc : 0.69
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.16it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.01it/s]



                === 8th Epoch ===
    
        Train Loss : 0.549 | Train Acc : 0.83
        Valid Loss : 1.199 | Valid Acc : 0.706
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.16it/s]
valid: 100%|██████████| 516/516 [00:35<00:00, 14.60it/s]



                === 9th Epoch ===
    
        Train Loss : 0.489 | Train Acc : 0.848
        Valid Loss : 1.176 | Valid Acc : 0.71
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.18it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 13.97it/s]



                === 10th Epoch ===
    
        Train Loss : 0.44 | Train Acc : 0.864
        Valid Loss : 1.155 | Valid Acc : 0.719
        
    


train: 100%|██████████| 1200/1200 [02:46<00:00,  7.19it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.05it/s]



                === 11th Epoch ===
    
        Train Loss : 0.403 | Train Acc : 0.873
        Valid Loss : 1.265 | Valid Acc : 0.714
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.16it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.21it/s]



                === 12th Epoch ===
    
        Train Loss : 0.369 | Train Acc : 0.883
        Valid Loss : 1.312 | Valid Acc : 0.694
        
    


train: 100%|██████████| 1200/1200 [02:46<00:00,  7.20it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 13.97it/s]



                === 13th Epoch ===
    
        Train Loss : 0.337 | Train Acc : 0.894
        Valid Loss : 1.329 | Valid Acc : 0.695
        
    


train: 100%|██████████| 1200/1200 [02:46<00:00,  7.21it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.01it/s]



                === 14th Epoch ===
    
        Train Loss : 0.318 | Train Acc : 0.9
        Valid Loss : 1.353 | Valid Acc : 0.7
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.16it/s]
valid: 100%|██████████| 516/516 [00:37<00:00, 13.93it/s]



                === 15th Epoch ===
    
        Train Loss : 0.288 | Train Acc : 0.909
        Valid Loss : 1.568 | Valid Acc : 0.661
        
    


train: 100%|██████████| 1200/1200 [02:45<00:00,  7.25it/s]
valid: 100%|██████████| 516/516 [00:37<00:00, 13.92it/s]



                === 16th Epoch ===
    
        Train Loss : 0.277 | Train Acc : 0.912
        Valid Loss : 1.33 | Valid Acc : 0.715
        
    


train: 100%|██████████| 1200/1200 [02:41<00:00,  7.42it/s]
valid: 100%|██████████| 516/516 [00:34<00:00, 15.06it/s]



                === 17th Epoch ===
    
        Train Loss : 0.258 | Train Acc : 0.919
        Valid Loss : 1.401 | Valid Acc : 0.711
        
    


train: 100%|██████████| 1200/1200 [02:41<00:00,  7.45it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.11it/s]



                === 18th Epoch ===
    
        Train Loss : 0.244 | Train Acc : 0.924
        Valid Loss : 1.471 | Valid Acc : 0.694
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.17it/s]
valid: 100%|██████████| 516/516 [00:37<00:00, 13.92it/s]



                === 19th Epoch ===
    
        Train Loss : 0.236 | Train Acc : 0.925
        Valid Loss : 1.398 | Valid Acc : 0.712
        
    


train: 100%|██████████| 1200/1200 [02:47<00:00,  7.16it/s]
valid: 100%|██████████| 516/516 [00:36<00:00, 14.13it/s]


                === 20th Epoch ===
    
        Train Loss : 0.214 | Train Acc : 0.932
        Valid Loss : 1.465 | Valid Acc : 0.71
        
    





In [10]:
torch.save(model.state_dict(), 'resnet18_imagenet_pretrained.pt')