In [3]:
import os
from os.path import join
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision
import copy
import pandas as pd
from PIL import Image
import pytorch_lightning as pl

from torch.utils.data import Dataset, DataLoader

import lightly
from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.loss import NegativeCosineSimilarity
from lightly.utils import BenchmarkModule
from lightly.models.modules import BYOLProjectionHead, BYOLPredictionHead
from lightly.models.utils import deactivate_requires_grad
from lightly.models.utils import update_momentum
from sklearn.preprocessing import normalize
from sklearn.neighbors import NearestNeighbors
import gc

gc.collect()

src = "/lables"
dataset_csv = "dataset_csv"


torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Settings for Caltech 256 dataset
img_size = 150
batch_size = 256
num_workers = 8
max_epochs = 400
num_classes = 31

lr_factor = batch_size / 128 # scales the learning rate linearly with batch size

In [3]:
# Setting up the BYOL class


class BYOL(BenchmarkModule):
    def __init__(self, backbone, dataloader_kNN, num_classes):
        super().__init__(dataloader_kNN, num_classes)
        # create a ResNet backbone and remove the classification head        
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])
        
        # # resnet50, w_resnet
        # The one the that worked for CIFAR-10 : 2048,4096, 1024, followed by 1024, 4096, 1024
        self.projection_head = BYOLProjectionHead(2048, 4096, 256)
        self.prediction_head = BYOLPredictionHead(256, 4096, 256)

        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)

        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)

        self.criterion = lightly.loss.NegativeCosineSimilarity()
        
        
    def forward(self, x):
        y = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(y)
        p = self.prediction_head(z)
        return p

    def forward_momentum(self, x):
        y = self.backbone_momentum(x).flatten(start_dim=1)
        z = self.projection_head_momentum(y)
        z = z.detach()
        return z

    def training_step(self, batch, batch_idx):
        update_momentum(self.backbone, self.backbone_momentum, m=0.99)
        update_momentum(self.projection_head, self.projection_head_momentum, m=0.99)
        (x0, x1), _, _ = batch
        p0 = self.forward(x0)
        z0 = self.forward_momentum(x0)
        p1 = self.forward(x1)
        z1 = self.forward_momentum(x1)
        loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0))
        self.log('train_loss_ssl', loss)
        return loss
    
    def configure_optimizers(self):
        params = list(self.backbone.parameters()) \
            + list(self.projection_head.parameters()) \
            + list(self.prediction_head.parameters())
        optim = torch.optim.SGD(
            params, 
            lr=6e-2 * lr_factor,
            momentum=0.9, 
            weight_decay=5e-4,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]
    
    def training_epoch_end(self, outputs):
        pass

In [7]:
collate_fn = SimCLRCollateFunction(
                 input_size=img_size, 
                 gaussian_blur=0.,
                 hf_prob=0.5,
                 vf_prob=0.5,
                 rr_prob=0.5,
                 cj_prob=0.0,
                 random_gray_scale=0.0
             )

class OfficeDataset(torch.utils.data.Dataset):
    def __init__(self, img_csv_filepath, label_csv_filepath, transform=None):
        self.img_csv_file = pd.read_csv(img_csv_filepath)
        self.label_csv_file = pd.read_csv(label_csv_filepath)
        self.transform = transform
        
    def __len__(self):
        return len(self.img_csv_file)

    def __getitem__(self, idx):
        file_det = self.img_csv_file.iloc[idx]
        image = Image.open(file_det["path"])
        label = file_det["label"]
        label = self.label_csv_file[self.label_csv_file["label"] == label].values[0][1]
        # label = torch.tensor(label)
        if self.transform:
            image = self.transform(image)
            
        return (image, label, file_det["filename"])


# label_csv = join(src, dataset_csv, "office_31_labels.csv")
label_csv = join(src, dataset_csv, "caltech_256_labels.csv")

train_csv = join(src, dataset_csv, "paperspace_caltech_256_train.csv")
train_dataset = OfficeDataset(img_csv_filepath = train_csv,
                              label_csv_filepath = label_csv)


train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size,
                              shuffle=True,
                              collate_fn = collate_fn,
                              drop_last=True,
                              num_workers=num_workers)

# ------------------------------------------------------



test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(img_size),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

classif_train_dataset = OfficeDataset(img_csv_filepath = train_csv,
                              label_csv_filepath = label_csv,
                              transform = test_transforms)

# classif_train_dataset = LightlyDataset("office_31/dslr",
#                                        transform = test_transforms)

# classif_train_dataset =  torchvision.datasets.ImageFolder("modern_office_31/amazon",
#                                                            transform = classif_train_transforms)

classif_train_dataloader = DataLoader(classif_train_dataset, 
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)

# ------------------------------------------------------


# Initialize val dataset and dataloader
# test_csv = join(src, dataset_csv, "paperspace_office_31_webcam.csv")
# test_csv = join(src, dataset_csv, "paperspace_modern_office_31_new_amazon.csv")
test_csv = join(src, dataset_csv, "paperspace_caltech_256_test.csv")
test_dataset = OfficeDataset(img_csv_filepath = test_csv,
                              label_csv_filepath = label_csv,
                              transform = test_transforms)

# test_dataset = LightlyDataset("office_31/amazon",
#                               transform = test_transforms)
test_dataloader = DataLoader(test_dataset, 
                             batch_size=batch_size, 
                             shuffle=False,
                             num_workers=num_workers)

### Load the encoder weights from pre-training before trainin in BYOL

In [5]:
base_model = torchvision.models.resnet50()
model_path = "models/resnet50-19c8e357.pth"


base_model.load_state_dict(torch.load(join(src,model_path)), strict=False) 
model = BYOL(base_model, train_dataloader, num_classes)

### Train the model

In [4]:
trainer = pl.Trainer(
    max_epochs=max_epochs, 
    gpus=1)

trainer.fit(model, train_dataloaders=train_dataloader)

### Save the encoder in .pth file

In [8]:
pretrained_model_backbone = model.backbone

# you can also store the backbone and use it in another code
state_dict = {
    'model_params': pretrained_model_backbone.state_dict(),
    'epochs' : 800,
    'output_dim' : 64,
    'batch_size' : 256,
    'img_size' : 150,
    'color_augs' : False
}
torch.save(state_dict, 'res50_byol_modern_office31_amazon_v1.pth')

### Load the model from .pth file

In [11]:
# load the model in a new file for inference
# model_new = torchvision.models.resnet34()
model_new = torchvision.models.resnet50()

backbone_new = nn.Sequential(*list(model_new.children())[:-1])

ckpt = torch.load('res50_dino_caltech.pth')
backbone_new.load_state_dict(ckpt['model_params'])
print(ckpt['epochs'])
print(ckpt['output_dim'])
print(ckpt['batch_size'])
print(ckpt['img_size'])
print(ckpt['color_augs'])
print(ckpt['lr'])
print(ckpt['random_init'])

200
2048
256
150
True
0.3
False


## Use a KNN classifer to test the performance of the encoder

In [12]:
def generate_embeddings(model, dataloader):
    """Generates representations for all images in the dataloader with
    the given model
    """

    embeddings = []
    labels = []
    filenames = []
    with torch.no_grad():
        for img, label, fnames in dataloader:
            img = img.to("cpu")
            emb = model(img).flatten(start_dim=1)
            embeddings.append(emb)
            labels.extend(label)
            filenames.extend(fnames)

    embeddings = torch.cat(embeddings, 0)
    embeddings = normalize(embeddings)
    return torch.tensor(embeddings), torch.tensor(labels), filenames


In [13]:
backbone_new.eval()
train_embeddings, train_labels, train_filenames = generate_embeddings(backbone_new, classif_train_dataloader)
print("Traing embeddings loaded")

Traing embeddings loaded


In [14]:
test_embeddings, test_labels, test_filenames = generate_embeddings(backbone_new, test_dataloader)
print("Test embeddings done")

Test embeddings done


In [9]:
len(test_embeddings)

795

In [10]:
len(train_embeddings)

2817

In [15]:
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier

n_neighbors = 10
knn = KNeighborsClassifier(n_neighbors=n_neighbors)
knn.fit(train_embeddings.numpy(), train_labels.numpy())
y_pred = knn.predict(test_embeddings.numpy())

In [16]:
accuracy_score(y_pred, test_labels)

0.677734375

In [17]:
len(y_pred)

5120

In [18]:
len(test_labels)

5120

In [20]:
from lightly.utils import knn_predict
import torch.nn.functional as F
pred = knn_predict(test_embeddings, train_embeddings.T, train_labels.unsqueeze(dim=0), 256,20)
accuracy_score(pred[:,0], test_labels)

0.6859375

In [21]:
max_acc, clusters = 0, 0

for r in range(1, 200):
    pred = knn_predict(test_embeddings, train_embeddings.T, train_labels.unsqueeze(dim=0), 256,r)
    acc = accuracy_score(pred[:,0], test_labels)
    if acc > max_acc:
        max_acc = acc
        clusters = r

In [22]:
max_acc

0.6890625

In [23]:
clusters

9