In [1]:
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet18 as ResNet18
import os
import torchvision.models as models
from tqdm.auto import tqdm

import wandb
from configs.config import Config

In [2]:
# wandb
wandb.init(project="Resnet18_cifar10", entity="jskim0406", name=f'resnet18_cifar10')


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjskim0406[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
# Simple Learning Rate Scheduler
def lr_scheduler(optimizer, epoch):
    lr = learning_rate
    if epoch >= 50:
        lr /= 10
    if epoch >= 100:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Xavier         
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
device = 'cuda'
model = ResNet18(pretrained=False)

In [6]:
model.apply(init_weights)
model = model.to(device)

  torch.nn.init.xavier_uniform(m.weight)


In [7]:
c = Config()

learning_rate = c.learning_rate
num_epoch = c.num_epoch

model_name = 'resnet18_full.pth'

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_loss = 0
valid_loss = 0
correct = 0
total_cnt = 0
best_acc = 0

In [8]:
# Train
for epoch in tqdm(range(num_epoch)):
#     print(f"====== { epoch+1} epoch of { num_epoch } ======")
    model.train()
    lr_scheduler(optimizer, epoch)
    train_loss = 0
    valid_loss = 0
    correct = 0
    total_cnt = 0
    # Train Phase
    for step, batch in enumerate(train_loader):
        #  input and target
        batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        
        logits = model(batch[0])
        loss = loss_fn(logits, batch[1])
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        _, predict = logits.max(1)
        
        total_cnt += batch[1].size(0)
        correct +=  predict.eq(batch[1]).sum().item()
        
        wandb.log({'train_Acc': correct / total_cnt})
        wandb.log({'train_loss': loss.item() / batch[1].size(0)})
        
#         if step % 100 == 0 and step != 0:
#             print(f"\n====== { step } Step of { len(train_loader) } ======")
#             print(f"Train Acc : { correct / total_cnt }")
#             print(f"Train Loss : { loss.item() / batch[1].size(0) }")
            
    correct = 0
    total_cnt = 0
    
    # Test Phase
    with torch.no_grad():
        model.eval()
        for step, batch in enumerate(test_loader):
            # input and target
            batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
            total_cnt += batch[1].size(0)
            logits = model(batch[0])
            valid_loss += loss_fn(logits, batch[1])
            _, predict = logits.max(1)
            correct += predict.eq(batch[1]).sum().item()
        valid_acc = correct / total_cnt
#         print(f"\nValid Acc : { valid_acc }")    
#         print(f"Valid Loss : { valid_loss / total_cnt }")
        
        wandb.log({'Valid_Acc': valid_acc})
        wandb.log({'Valid_Loss': valid_loss / total_cnt})

        if(valid_acc > best_acc):
            best_acc = valid_acc
            torch.save(model, model_name)
            print("Model Saved!")

  0%|          | 0/100 [00:00<?, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!
Model Saved!


**result check**

In [9]:
res = pd.read_csv("valid_acc.csv")

In [15]:
res = res.sort_values(by='resnet18_cifar10 - Valid_Acc', ascending=False)

In [17]:
res.head(10)

Unnamed: 0,Step,resnet18_cifar10 - Valid_Acc,resnet18_cifar10 - Valid_Acc__MIN,resnet18_cifar10 - Valid_Acc__MAX
97,38610,0.8231,0.8231,0.8231
95,37822,0.8209,0.8209,0.8209
98,39004,0.8207,0.8207,0.8207
99,39398,0.8206,0.8206,0.8206
86,34276,0.8203,0.8203,0.8203
96,38216,0.82,0.82,0.82
91,36246,0.8195,0.8195,0.8195
92,36640,0.8194,0.8194,0.8194
89,35458,0.8193,0.8193,0.8193
82,32700,0.8188,0.8188,0.8188
