In [1]:
import numpy as np 
import torch 
import torch.nn as nn 
from torchvision import datasets 
from torchvision import transforms
from torch.utils.data import DataLoader 
import timm


class cfg:
    datadir = './data'
    img_size = 256 
    batch_size = 128
    model_name = 'resnet18'
    num_classes = 10 
    lr = 1e-3 
    epochs = 50

device = 'cuda:1'

transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Resize((cfg.img_size,cfg.img_size))]
)    
    

trainset = datasets.CIFAR10(root      = cfg.datadir,
                            train     = True,
                            transform = transform,
                            download  = True
                            )
testset  = datasets.CIFAR10(root      = cfg.datadir,
                            train     = False,
                            transform = transform,
                            download  = True
                            )

trainloader, testloader = DataLoader(trainset, batch_size = cfg.batch_size, shuffle=True), DataLoader(testset, batch_size = cfg.batch_size, shuffle=False)

net = timm.create_model(model_name = cfg.model_name, pretrained = True, num_classes = cfg.num_classes)
net.to(device)

optimizer = torch.optim.Adam(net.parameters(), lr = 1e-3)
criterion = nn.CrossEntropyLoss()

  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified


# 기본 

In [3]:
from sklearn.metrics import accuracy_score
def train_step(net, optimizer, criterion, trainloader):
        net.train()
        optimizer.zero_grad() 
    
        predict = [] 
        target = [] 
        for imgs, labels in trainloader:
            imgs, labels = imgs.to(device), labels.to(device)
            
            
            output = net(imgs)
            loss = criterion(output,labels)
            loss.backward()
            optimizer.step()    
            
            predict.append(output.argmax(1).detach().cpu().numpy())
            target.append(labels.detach().cpu().numpy())

        
        target = np.concatenate(target)
        predict = np.concatenate(predict)        
        train_accuracy = accuracy_score(target,predict)
        return train_accuracy
    
def valid_step(net, testloader, criterion):
        net.eval()
        predict = [] 
        target = [] 
        with torch.no_grad():
            for imgs, labels in testloader:
                imgs, labels = imgs.to(device), labels.to(device)
                
                output = net(imgs)
                loss = criterion(output, labels)
                predict.append(output.argmax(1).detach().cpu().numpy())
                target.append(labels.detach().cpu().numpy())
                
        target = np.concatenate(target)
        predict = np.concatenate(predict)        
        test_accuracy = accuracy_score(target,predict)
        return test_accuracy
    
for epoch in range(cfg.epochs):
    
    # Train
    train_accuracy = train_step(net, optimizer, criterion, trainloader)
    
    # Test 
    test_accuracy = valid_step(net, testloader, criterion)
    
    print(f'Epoch : {epoch} | train_accuracy : {train_accuracy} | test_accuracy : {test_accuracy}')
            
    break 

    

Epoch : 0 | train_accuracy : 0.29096 | test_accuracy : 0.2726


# Ignite Metrics 

In [4]:
from ignite.metrics import Accuracy

def train_step(net, optimizer, criterion, trainloader):
        net.train()
        optimizer.zero_grad() 
    
        acc_metric = Accuracy()
        for imgs, labels in trainloader:
            imgs, labels = imgs.to(device), labels.to(device)
            
            
            output = net(imgs)
            loss = criterion(output,labels)
            loss.backward()
            optimizer.step()    
            
            acc_metric.update(
                                (output,labels)
                                )

        
        train_accuracy = acc_metric.compute()
        return train_accuracy
    
def valid_step(net, testloader, criterion):
        net.eval()
        
        val_acc_metric = Accuracy()
        with torch.no_grad():
            for imgs, labels in testloader:
                imgs, labels = imgs.to(device), labels.to(device)
                
                output = net(imgs)
                loss = criterion(output, labels)
                
                val_acc_metric.update(
                                        (output, labels)
                                        )
                
        test_accuracy = val_acc_metric.compute()
        return test_accuracy
    
for epoch in range(cfg.epochs):
    
    # Train
    train_accuracy = train_step(net, optimizer, criterion, trainloader)
    
    # Test 
    test_accuracy = valid_step(net, testloader, criterion)
    
    print(f'Epoch : {epoch} | train_accuracy : {train_accuracy} | test_accuracy : {test_accuracy}')
            
    break 

    

Epoch : 0 | train_accuracy : 0.31758 | test_accuracy : 0.3342


# Ignite Engine 

In [5]:
from ignite.engine.engine import Engine
from ignite.engine import Events
from ignite.metrics import Accuracy
from ignite.contrib.handlers import ProgressBar

def train(engine, batch):
    net.train()
    optimizer.zero_grad() 
    
    imgs, labels = batch[0].to(device), batch[1].to(device)        
    
    output = net(imgs)
    loss = criterion(output,labels)
    loss.backward()
    optimizer.step()   
    
    acc_metric.update(
            (output,labels)
            )

    return loss.item()  

def valid(engine, batch):
    net.eval()
    imgs, labels = batch[0].to(device), batch[1].to(device)
    
    with torch.no_grad():
        output = net(imgs)
        loss = criterion(output, labels)
        
    val_acc_metric.update(
        (output, labels)
    )
    

trainer = Engine(train)
evaluator = Engine(valid)

for epoch in range(cfg.epochs):
    
    # Train    
    acc_metric = Accuracy()
    trainer.run(trainloader)        
    train_accuracy = acc_metric.compute()
    
    # Test 
    val_acc_metric = Accuracy()
    evaluator.run(testloader)
    valid_accuracy = val_acc_metric.compute()
            
    test_accuracy = val_acc_metric.compute()
    print(f'Epoch : {epoch} | train_accuracy : {train_accuracy} | test_accuracy : {test_accuracy}')
            
    break 
    

Epoch : 0 | train_accuracy : 0.34938 | test_accuracy : 0.3685


# Ignite API 

In [6]:
from ignite.engine.engine import Engine
from ignite.engine import Events
from ignite.metrics import Accuracy
from ignite.contrib.handlers import ProgressBar

def train(engine, batch):
    net.train()
    optimizer.zero_grad() 
    
    imgs, labels = batch[0].to(device), batch[1].to(device)        
    
    output = net(imgs)
    loss = criterion(output,labels)
    loss.backward()
    optimizer.step()    
    
    return loss.item()  

def valid(engine, batch):
    net.eval()
    with torch.no_grad():
        x, y = batch
        x,y = x.to(device), y.to(device)
        y_pred = net(x)
    return y_pred, y

trainer = Engine(train)
evaluator = Engine(valid)
Accuracy().attach(evaluator, "accuracy")
ProgressBar().attach(trainer)

@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training(engine):
    batch_loss = engine.state.output
    lr = optimizer.param_groups[0]['lr']
    e = engine.state.epoch
    n = engine.state.max_epochs
    i = engine.state.iteration
    print(f"Epoch {e}/{n} : {i} - batch loss: {batch_loss:.3f}, lr: {lr}")

@trainer.on(Events.EPOCH_COMPLETED(every=1))
def run_validation():
    evaluator.run(testloader)
    
@trainer.on(Events.EPOCH_COMPLETED(every=1))
def log_validation():
    metrics = evaluator.state.metrics
    print(metrics['accuracy'])
    
trainer.run(trainloader, max_epochs=1)

Iteration: [100/391]  26%|██▌        [00:24<01:11]

Epoch 1/1 : 100 - batch loss: 1.653, lr: 0.001


Iteration: [200/391]  51%|█████      [00:49<00:47]

Epoch 1/1 : 200 - batch loss: 1.723, lr: 0.001


Iteration: [300/391]  77%|███████▋   [01:14<00:22]

Epoch 1/1 : 300 - batch loss: 1.639, lr: 0.001


                                                  

0.3764


State:
	iteration: 391
	epoch: 1
	epoch_length: 391
	max_epochs: 1
	output: 1.6260793209075928
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>