In [None]:
import torch
import torchvision
import numpy as np
BATCH_SIZE = 256
LEARNING_RATE = 0.001
NUM_EPOCHS = 50
WEIGHT_DECAY = 0
transforms = torchvision.transforms.Compose([
                                            torchvision.transforms.ToTensor()
])
train_generator = torchvision.datasets.CIFAR10("./data", download=True, target_transform=None, train=True, transform=transforms)
test_generator = torchvision.datasets.CIFAR10("./data", download = True, target_transform=None, train = False, transform=transforms)
train_loader = torch.utils.data.DataLoader(train_generator, batch_size=BATCH_SIZE, shuffle = True, num_workers = 4, drop_last = True)
test_loader = torch.utils.data.DataLoader(test_generator, batch_size=BATCH_SIZE, shuffle = False, num_workers = 4,drop_last=True)

class Early_stopping:
    def __init__(self, patience = 5, verbose = True, delta = 0, path = "earlystopecheckpoint2.pth"):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        #누적카운팅
        print("val loss min : {}".format(self.val_loss_min))
        if self.val_loss_min < val_loss - self.delta:
            self.counter +=1
            print("EarlyStopping Counter : {} out of {}".format(self.counter, self.patience))
            if self.counter >= self.patience :
                self.early_stop = True
        else:
            print("저장")
            self.save_checkpoint(val_loss, model)
            self.counter = 0
    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print("validation loss decrease {} -> {}. Saving model".format(self.val_loss_min, val_loss))
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

        
class cifar10_classifier(torch.nn.Module):
    def __init__(self):
        super(cifar10_classifier, self).__init__()
        self.layer0 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 16, kernel_size = 5, stride=1, padding = 1),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d((2,2)),
            torch.nn.Conv2d(16,32,kernel_size= 5, stride = 1, padding = 1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d((2,2)),
            torch.nn.Conv2d(32, 64, kernel_size=5, stride = 1, padding = 1),
            torch.nn.ReLU()
        )
        self.layer1 = torch.nn.Sequential(
            torch.nn.Linear(64 * 4*4, 10)
        )

    def forward(self, x):
        out = self.layer0(x)
        out = torch.flatten(out, 1)
        return self.layer1(out)
def top5acc(y_pred, y_val):
    np_y_pred = y_pred.cpu().numpy()
    np_y_val = y_val.cpu().numpy()
    correct = 0
    for i in range(np_y_pred.shape[0]):
        if np_y_val[i] in np.argsort(np_y_pred[i])[-5:]:
            correct +=1 

    return correct

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = cifar10_classifier().to(device)
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE, eps = 0.0000001, weight_decay = WEIGHT_DECAY)

early_stopping = Early_stopping()

for epoch in range(NUM_EPOCHS):
    print("----------epoch{}----------".format(epoch+1))
    for i, (image, label) in enumerate(train_loader):
        image = image.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        output = model.forward(image)
        loss = loss_func(output, label)
        loss.backward()
        optimizer.step()

    model.eval()
    val_losses = []
    for image, label in test_loader:
        image = image.to(device)
        label = label.to(device)

        output = model.forward(image)
        loss = loss_func(output, label)

        val_losses.append(loss.item())
    val_loss = np.average(val_losses)
    val_losses = []
    print("val loss : {}".format(val_loss))
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("EARLY STOPPING!!")
        break

model.load_state_dict(torch.load("earlystopecheckpoint2.pth"))
# torch.save(model.state_dict(), "./save_model/CIFAR10.pth")

print("model evaluation")
total = 0
top1correct = 0
top5correct = 0
model.eval()
with torch.no_grad():
    for image, label in test_loader:
        image = image.to(device)
        label = label.to(device)

        output = model.forward(image)
        _, output_idx = torch.max(output, 1)
        total += image.shape[0]
        top1correct += (output_idx == label).sum().float()
        top5correct += top5acc(output, label)
print("top 1 accuracy : {}".format(100 * top1correct / total))
print("top 5 accuracy : {}".format(100 * top5correct / total))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
----------epoch1----------
