In [1]:
import numpy as np 
import pandas as pd 
import os
!git clone https://github.com/nikheelpandey/BYOL-PyTorch.git

Cloning into 'BYOL-PyTorch'...
remote: Enumerating objects: 67, done.[K
remote: Counting objects: 100% (67/67), done.[K
remote: Compressing objects: 100% (48/48), done.[K
remote: Total 67 (delta 24), reused 49 (delta 14), pack-reused 0[K
Unpacking objects: 100% (67/67), done.


In [2]:
os.chdir("./BYOL-PyTorch/")
import torch
import torch.nn as nn
from tqdm import tqdm
from logger import Logger
import torch.optim as optim
from models.model import BYOL
from datetime import datetime
from torchvision.models import resnet18 


if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    device = torch.device("cuda")
    print('GPU')
else:
    dtype = torch.FloatTensor
    device = torch.device("cpu")


model_path = "/kaggle/input/byol-demo/BYOL-PyTorch/ckpt/0425155245/byol_0425170509.pth"

GPU


In [3]:
import torch.nn as nn


class FineTunedModel(nn.Module):

    def __init__(self, encoder,input_dim, num_classes ):
        super().__init__()
        self.input_dim =  input_dim
        self.num_classes = num_classes
#         self.encoder = encoder
        
        for param in encoder.parameters():
            param.requires_grad = False
        
        
        classification_head =   nn.Sequential(nn.Linear(input_dim, 512),
                    nn.ReLU(),
                    nn.Dropout(0.4),
                    nn.Linear(512, input_dim))
                    
        self.model = nn.Sequential(
                        encoder,
                        classification_head)
        
        
    def forward(self, x):
        z = self.model(x)
        return z

    
model = BYOL()
obj = torch.load(model_path)
model.online_encoder.load_state_dict(obj['online_network'])
encoder = torch.nn.Sequential(list(model.children())[0][0])

In [4]:

best_acc = 0.0
batch = batch_size = 16
uid = "ssc"               #second_stage_classifier
image_size = (96,96)
num_classes = 128
epochs = 50



clf = FineTunedModel(encoder, 2048, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(clf.parameters(), lr=0.0001,
                      momentum=0.99, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [5]:
rand_img = torch.rand(128, 3, 96, 96).to(device)
clf(rand_img).shape

torch.Size([128, 2048])

In [6]:
import torchvision
import torchvision.transforms as T
from torchvision import datasets, transforms



CIFAR_MEAN =  [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD  =  [0.24703233, 0.24348505, 0.26158768]


class InitalTransformation():
    def __init__(self):
        self.transform = T.Compose([
            T.ToTensor(),
            transforms.Normalize(CIFAR_MEAN,CIFAR_STD),
        ])

    def __call__(self, x):
        x = self.transform(x)
        return  x



def get_train_test_dataloaders(dataset = "stl10", data_dir="./dataset", batch_size = 16,num_workers = 4, download=True): 
    
    train_loader = torch.utils.data.DataLoader(
        dataset = torchvision.datasets.STL10(data_dir, split="train", transform=InitalTransformation(), download=download),
        shuffle=True,
        batch_size= batch_size,
        num_workers = num_workers
    )
    

    test_loader = torch.utils.data.DataLoader(
        dataset = torchvision.datasets.STL10(data_dir, split="test", transform=InitalTransformation(), download=download),
        shuffle=True,
        batch_size= batch_size,
        num_workers = num_workers
        )
    return train_loader, test_loader


train_loader, test_loader = get_train_test_dataloaders(batch_size=batch_size)


Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./dataset/stl10_binary.tar.gz


0it [00:00, ?it/s]

Extracting ./dataset/stl10_binary.tar.gz to ./dataset
Files already downloaded and verified


In [7]:
import kornia

def get_clf_train_test_transform(image_size,s=.2):
        
    train_transform = nn.Sequential(
                
                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
                # kornia.augmentation.Normalize(CIFAR_MEAN_,CIFAR_STD_),
            )

    test_transform = nn.Sequential(  
                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
                # kornia.augmentation.RandomGrayscale(p=0.05),
                # kornia.augmentation.Normalize(CIFAR_MEAN_,CIFAR_STD_)
        )

    return train_transform , test_transform


train_transform, test_transform = get_clf_train_test_transform(image_size)

In [8]:
from tqdm.notebook import tqdm

def train_clf(epoch, epochs):
    clf.train()
    train_loss = 0
    
    correct, total = 0,0
    
    local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}')
    
    for idx, (images, labels) in enumerate(local_progress):
        images, labels = images.to(device), labels.to(device)
        images = images.to(device)
        aug_image = train_transform(images)

        optimizer.zero_grad()
#         print(aug_image.shape)
        outputs = clf(aug_image)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        data_dict = {"loss": train_loss, "accuracy":100.*correct/total}
        local_progress.set_postfix(data_dict)

    return data_dict    
    
    

def test_clf(epoch, epochs):

    global best_acc
    clf.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():    
        local_progress = tqdm(test_loader, desc=f'Epoch {epoch}/{epochs}')
        for idx, (images, label) in enumerate(local_progress):
            
            images, label = images.to(device), label.to(device)
            images = test_transform(images)
            outputs = clf(images)
            loss = criterion(outputs, label)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += label.size(0)
            correct += predicted.eq(label).sum().item()
            data_dict = {"test_loss": test_loss, "test_accuracy":100.*correct/total}
            local_progress.set_postfix(data_dict)
            
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': clf.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        
        model_path = os.path.join(ckpt_dir, f"{uid}_{datetime.now().strftime('%m%d%H%M%S')}.pth")

        torch.save({
            'epoch':epoch+1,
            'state_dict': clf.state_dict()
                }, model_path)
        print(f'Model saved at: {model_path}')
        best_acc = acc

    return data_dict

In [9]:
from logger import Logger




ckpt_dir = "_ckpt/clf_"+str(datetime.now().strftime('%m%d%H%M%S'))
log_dir = "_runs/clf_"+str(datetime.now().strftime('%m%d%H%M%S'))


    
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

if not os.path.exists(log_dir):
    os.makedirs(log_dir)



logger = Logger(log_dir=log_dir, tensorboard=True, matplotlib=True)


for epoch in range(0, epochs):
    data_dict = train_clf(epoch,epochs)
    logger.update_scalers(data_dict)
    data_dict = test_clf(epoch,epochs)
    logger.update_scalers(data_dict)
    scheduler.step()

Epoch 0/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 0/50:   0%|          | 0/500 [00:00<?, ?it/s]

Saving..
Model saved at: _ckpt/clf_0427074722/ssc_0427074755.pth


Epoch 1/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 1/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 2/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 2/50:   0%|          | 0/500 [00:00<?, ?it/s]

Saving..
Model saved at: _ckpt/clf_0427074722/ssc_0427074857.pth


Epoch 3/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 3/50:   0%|          | 0/500 [00:00<?, ?it/s]

Saving..
Model saved at: _ckpt/clf_0427074722/ssc_0427074929.pth


Epoch 4/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 4/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 5/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 5/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 6/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 6/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 7/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 7/50:   0%|          | 0/500 [00:00<?, ?it/s]

Saving..
Model saved at: _ckpt/clf_0427074722/ssc_0427075132.pth


Epoch 8/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 8/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 9/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 9/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 10/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 10/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 11/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 11/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 12/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 12/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 13/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 13/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 14/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 14/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 15/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 15/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 16/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 16/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 17/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 17/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 18/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 18/50:   0%|          | 0/500 [00:00<?, ?it/s]

Saving..
Model saved at: _ckpt/clf_0427074722/ssc_0427075719.pth


Epoch 19/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 19/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 20/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 20/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 21/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 21/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 22/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 22/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 23/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 23/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 24/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 24/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 25/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 25/50:   0%|          | 0/500 [00:00<?, ?it/s]

Saving..
Model saved at: _ckpt/clf_0427074722/ssc_0427080101.pth


Epoch 26/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 26/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 27/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 27/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 28/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 28/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 29/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 29/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 30/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 30/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 31/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 31/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 32/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 32/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 33/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 33/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 34/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 34/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 35/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 35/50:   0%|          | 0/500 [00:00<?, ?it/s]

Saving..
Model saved at: _ckpt/clf_0427074722/ssc_0427080621.pth


Epoch 36/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 36/50:   0%|          | 0/500 [00:00<?, ?it/s]

Saving..
Model saved at: _ckpt/clf_0427074722/ssc_0427080653.pth


Epoch 37/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 37/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 38/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 38/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 39/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 39/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 40/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 40/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 41/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 41/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 42/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 42/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 43/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 43/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 44/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 44/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 45/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 45/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 46/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 46/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 47/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 47/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 48/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 48/50:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch 49/50:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 49/50:   0%|          | 0/500 [00:00<?, ?it/s]