In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import PIL
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms
from tqdm import tqdm
from torchvision.datasets import DatasetFolder, ImageFolder
from shuffle_batchnorm import ShuffleBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from model import MOCO, plot



In [2]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

device: cuda


In [3]:
train_root = "./data/imagenette2/train"
val_root = "./data/imagenette2/val"

In [4]:
!nvidia-smi

Thu Jun 17 22:49:28 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.80       Driver Version: 460.80       CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:04:00.0 Off |                  N/A |
| 42%   49C    P8    29W / 250W |      3MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:

feature_dim = 128
t = 0.2
m = 0.999
checkpoint = torch.load("MOCOv2.pth")
model = MOCO(t, m, feature_dim).to(device)
model.load_state_dict(checkpoint['model_state_dict'])


<All keys matched successfully>

In [6]:

train_eval_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
	])
test_eval_transform = transforms.Compose([
            transforms.Resize((224, 224)),
#             transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
	])



In [7]:

for param in model.parameters():
    param.requires_grad = False

In [8]:

model.f_q.fc = nn.Linear(2048, 10).to(device)
model.f_q.fc.requires_grad = True


In [9]:
BATCH_SIZE = 100
train_data = ImageFolder(root=train_root, transform=train_eval_transform)
val_data = ImageFolder(root=val_root, transform=test_eval_transform)
train_loader = DataLoader(train_data, shuffle=True, batch_size= BATCH_SIZE,num_workers = 4, pin_memory=True)
val_loader = DataLoader(val_data, shuffle=False, batch_size= BATCH_SIZE, num_workers = 4, pin_memory=True)
lr = 0.001
optimizer = torch.optim.Adam(model.f_q.fc.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

EPOCHS = 20
res_dict = {"train_loss_list":[],"test_loss_list":[],"train_acc_list":[],"test_acc_list":[]}


In [10]:

print("START LINEAR EVALUATION \n")
for epoch in tqdm(range(EPOCHS)):
    correct = total = total_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        labels = labels.to(device)
        scores = model.f_q(images.to(device))
        logits = F.normalize(scores, dim=1)
        predictions = torch.argmax(logits, dim = 1)
        correct += torch.sum(predictions == labels).item()
        total += labels.shape[0]
        loss = criterion(logits, labels)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    res_dict["train_loss_list"].append(total_loss/total)
    res_dict["train_acc_list"].append(correct/total)
    print(f"\n train loss: {total_loss/total:.4f} train accuracy:{correct/total:.4f} ")
    with torch.no_grad():
        correct = total = total_loss = 0
        for i, (images, labels) in enumerate(val_loader):
            labels = labels.to(device)
            scores = model.f_q(images.to(device))
            logits = F.normalize(scores, dim=1)
            predictions = torch.argmax(logits, dim = 1)
            correct += torch.sum(predictions == labels).item()
            total += labels.shape[0]
            loss = criterion(logits, labels)
            total_loss += loss.item() 
        res_dict["test_loss_list"].append(total_loss/total)
        res_dict["test_acc_list"].append(correct/total)


    plot(res_dict, eval=True)
    print(f"test loss: {total_loss/total:.4f} test accuracy:{correct/total:.4f} ")

  0%|          | 0/20 [00:00<?, ?it/s]

START LINEAR EVALUATION 


 train loss: 0.0170 train accuracy:0.8201 


  5%|▌         | 1/20 [01:36<30:28, 96.24s/it]

test loss: 0.0205 test accuracy:0.4683 

 train loss: 0.0159 train accuracy:0.8946 


 10%|█         | 2/20 [03:10<28:31, 95.07s/it]

test loss: 0.0203 test accuracy:0.4820 

 train loss: 0.0157 train accuracy:0.9065 


 15%|█▌        | 3/20 [04:43<26:41, 94.19s/it]

test loss: 0.0203 test accuracy:0.4856 

 train loss: 0.0156 train accuracy:0.9117 


 20%|██        | 4/20 [06:17<25:05, 94.12s/it]

test loss: 0.0202 test accuracy:0.4907 

 train loss: 0.0156 train accuracy:0.9176 


 25%|██▌       | 5/20 [07:50<23:26, 93.78s/it]

test loss: 0.0202 test accuracy:0.4917 

 train loss: 0.0155 train accuracy:0.9197 


 30%|███       | 6/20 [09:22<21:44, 93.15s/it]

test loss: 0.0202 test accuracy:0.4948 

 train loss: 0.0155 train accuracy:0.9177 


 35%|███▌      | 7/20 [10:55<20:08, 92.94s/it]

test loss: 0.0202 test accuracy:0.4966 

 train loss: 0.0155 train accuracy:0.9226 


 40%|████      | 8/20 [12:28<18:35, 92.95s/it]

test loss: 0.0202 test accuracy:0.4922 

 train loss: 0.0155 train accuracy:0.9199 


 45%|████▌     | 9/20 [14:00<17:00, 92.76s/it]

test loss: 0.0202 test accuracy:0.4927 

 train loss: 0.0155 train accuracy:0.9206 


 50%|█████     | 10/20 [15:33<15:27, 92.73s/it]

test loss: 0.0201 test accuracy:0.4945 

 train loss: 0.0155 train accuracy:0.9260 


 55%|█████▌    | 11/20 [17:06<13:56, 92.96s/it]

test loss: 0.0201 test accuracy:0.4930 

 train loss: 0.0154 train accuracy:0.9251 


 60%|██████    | 12/20 [18:39<12:24, 93.05s/it]

test loss: 0.0201 test accuracy:0.4973 

 train loss: 0.0154 train accuracy:0.9233 


 65%|██████▌   | 13/20 [20:12<10:50, 92.87s/it]

test loss: 0.0201 test accuracy:0.5014 

 train loss: 0.0154 train accuracy:0.9286 


 70%|███████   | 14/20 [21:44<09:16, 92.76s/it]

test loss: 0.0201 test accuracy:0.4996 

 train loss: 0.0154 train accuracy:0.9280 


 75%|███████▌  | 15/20 [23:17<07:43, 92.72s/it]

test loss: 0.0201 test accuracy:0.4989 

 train loss: 0.0154 train accuracy:0.9281 


 80%|████████  | 16/20 [24:50<06:11, 92.80s/it]

test loss: 0.0201 test accuracy:0.5029 

 train loss: 0.0154 train accuracy:0.9278 


 85%|████████▌ | 17/20 [26:23<04:38, 92.96s/it]

test loss: 0.0201 test accuracy:0.5019 

 train loss: 0.0154 train accuracy:0.9276 


 90%|█████████ | 18/20 [27:57<03:06, 93.29s/it]

test loss: 0.0201 test accuracy:0.4963 

 train loss: 0.0154 train accuracy:0.9322 


 95%|█████████▌| 19/20 [29:31<01:33, 93.22s/it]

test loss: 0.0201 test accuracy:0.4981 

 train loss: 0.0154 train accuracy:0.9282 


100%|██████████| 20/20 [31:04<00:00, 93.20s/it]

test loss: 0.0201 test accuracy:0.4953 





<Figure size 432x288 with 0 Axes>

In [None]:
# from torchvision.datasets.utils import download_url
# import os
# import tarfile
# import hashlib

# dataset_url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'
# dataset_filename = dataset_url.split('/')[-1]
# dataset_foldername = dataset_filename.split('.')[0]
# data_path = './data'
# dataset_filepath = os.path.join(data_path,dataset_filename)
# dataset_folderpath = os.path.join(data_path,dataset_foldername)

# os.makedirs(data_path, exist_ok=True)

# download = False
# if not os.path.exists(dataset_filepath):
#     download = True
# else:
#     md5_hash = hashlib.md5()


#     file = open(dataset_filepath, "rb")

#     content = file.read()

#     md5_hash.update(content)


#     digest = md5_hash.hexdigest()
#     if digest != 'fe2fc210e6bb7c5664d602c3cd71e612':
#         download = True
# if download:
#     download_url(dataset_url, data_path)

# with tarfile.open(dataset_filepath, 'r:gz') as tar:
#     tar.extractall(path=data_path)