In [1]:
import timm
import torch
import sys

In [6]:
class Config:
    device='cuda:0'
    batch_size=256
    num_epochs=100
    # num_workers=8,
    lr=3e-4
    weight_decay=0.0008

    dataset_name='cifar10'
    # checkpoint_fn = '/home/hyunseoki/ssd1/02_src/SimpleCLR/runs/Jun26_21-50-22_hyunseoki-ubuntu/checkpoint.pth.tar'
    checkpoint_fn = None

args = Config()

In [7]:
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def get_stl10_data_loaders(download, shuffle=False, batch_size=args.batch_size):
    train_dataset = datasets.STL10('../data', split='train', download=download,
                                  transform=transforms.ToTensor())

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                num_workers=0, drop_last=False, shuffle=shuffle, pin_memory=True)
    
    test_dataset = datasets.STL10('../data', split='test', download=download,
                                    transform=transforms.ToTensor())

    test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                                num_workers=8, drop_last=False, shuffle=shuffle, pin_memory=True)
    return train_loader, test_loader


def get_cifar10_data_loaders(download, shuffle=False, batch_size=args.batch_size):
    train_dataset = datasets.CIFAR10('../data', train=True, download=download,
                                    transform=transforms.ToTensor())

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                num_workers=0, drop_last=False, shuffle=shuffle, pin_memory=True)
    
    test_dataset = datasets.CIFAR10('../data', train=False, download=download,
                                    transform=transforms.ToTensor())

    test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                                num_workers=8, drop_last=False, shuffle=shuffle, pin_memory=True)
    return train_loader, test_loader


def load_checkpoint(model, fn):
    checkpoint = torch.load(fn)
    state_dict = checkpoint['state_dict']

    for k in list(state_dict.keys()):
        if k.startswith('encoder.'):
            if k.startswith('encoder') and not k.startswith('encoder.fc'):
                state_dict[k[len("encoder."):]] = state_dict[k]
        del state_dict[k]

    log = model.load_state_dict(state_dict, strict=False)
    assert log.missing_keys == ['fc.weight', 'fc.bias']

    print(log)
    return model

In [8]:
model = timm.create_model(
    'resnet50',
    pretrained=False,
    in_chans=3,
    num_classes=10,
)

if args.checkpoint_fn:
    model = load_checkpoint(model=model, fn=args.checkpoint_fn)

if args.dataset_name == 'cifar10':
    train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif args.dataset_name == 'stl10':
    train_loader, test_loader = get_stl10_data_loaders(download=True)

# freeze all layers but the last fc
# for name, param in model.named_parameters():
#     if name not in ['fc.weight', 'fc.bias']:
#         param.requires_grad = False

# parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
# assert len(parameters) == 2  # fc.weight, fc.bias

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
criterion = torch.nn.CrossEntropyLoss().to(args.device)

print("Dataset:", args.dataset_name)

Files already downloaded and verified
Files already downloaded and verified
Dataset: cifar10


In [9]:
epochs = 100
best_top1_acc = 0

for epoch in range(epochs):
    top1_train_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(train_loader):
      model.train()
      model.to(args.device)
      x_batch = x_batch.to(args.device)
      y_batch = y_batch.to(args.device)

      logits = model(x_batch)
      loss = criterion(logits, y_batch)
      top1 = accuracy(logits, y_batch, topk=(1,))
      top1_train_accuracy += top1[0]

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    top1_train_accuracy /= (counter + 1)
    top1_accuracy = 0
    top5_accuracy = 0

    for counter, (x_batch, y_batch) in enumerate(test_loader):
      model.eval()
      model.to(args.device)
      x_batch = x_batch.to(args.device)
      y_batch = y_batch.to(args.device)

      logits = model(x_batch)

      top1, top5 = accuracy(logits, y_batch, topk=(1,5))
      top1_accuracy += top1[0]
      top5_accuracy += top5[0]
    
    top1_accuracy /= (counter + 1)
    top5_accuracy /= (counter + 1)

    if top1_accuracy > best_top1_acc:
       best_top1_acc = top1_accuracy
    print(f"Epoch {epoch} Top1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")

print(f'best top1_acc {best_top1_acc}')

Epoch 0 Top1 Train accuracy 36.07262420654297	Top1 Test accuracy: 45.37913513183594	Top5 test acc: 89.90694427490234
Epoch 1 Top1 Train accuracy 57.05755615234375	Top1 Test accuracy: 51.038028717041016	Top5 test acc: 93.56100463867188
Epoch 2 Top1 Train accuracy 66.85746002197266	Top1 Test accuracy: 56.01447677612305	Top5 test acc: 95.43370819091797
Epoch 3 Top1 Train accuracy 74.41645050048828	Top1 Test accuracy: 56.3884391784668	Top5 test acc: 94.384765625
Epoch 4 Top1 Train accuracy 79.15856170654297	Top1 Test accuracy: 46.540672302246094	Top5 test acc: 90.4566879272461
Epoch 5 Top1 Train accuracy 81.86104583740234	Top1 Test accuracy: 56.78022003173828	Top5 test acc: 93.2887191772461
Epoch 6 Top1 Train accuracy 85.34518432617188	Top1 Test accuracy: 61.865234375	Top5 test acc: 95.72438049316406
Epoch 7 Top1 Train accuracy 88.16246795654297	Top1 Test accuracy: 60.66176986694336	Top5 test acc: 95.23092651367188
Epoch 8 Top1 Train accuracy 90.31807708740234	Top1 Test accuracy: 60.491729