In [19]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [20]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cpu


In [21]:
data_dir_stl10 = r'C:\Custom\DataSet\STL10'
data_dir_cifar10 = r'C:\Custom\DataSet\CIFAR10'

def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.STL10(data_dir_stl10, split='train', download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.STL10(data_dir_stl10, split='test', download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10(data_dir_cifar10, train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.CIFAR10(data_dir_cifar10, train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

In [22]:
log_dir = './runs/Mar14_15-39-41_RedMiPro15R7'
with open(os.path.join(log_dir, 'config.yml')) as file:
  config = yaml.load(file, Loader=yaml.SafeLoader)
print(config)

{'arch': 'resnet18', 'batch_size': 256, 'dataset_dir': 'C:\\Custom\\DataSet\\STL10', 'dataset_name': 'stl10', 'device': 'cpu', 'disable_cuda': True, 'epochs': 1, 'fp16_precision': False, 'learning_rate': 0.0003, 'log_every_n_steps': 100, 'n_views': 2, 'out_dim': 128, 'temperature': 0.07, 'weight_decay': 0.0001}


In [23]:
if config['arch'] == 'resnet18':
  model = torchvision.models.resnet18(num_classes=10).to(device)
elif config['arch'] == 'resnet50':
  model = torchvision.models.resnet50(num_classes=10).to(device)

In [24]:
checkpoint = torch.load(os.path.join(log_dir, 'checkpoint_0001.pth.tar'), map_location=device)
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):

  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]