In [2]:
!pip install faiss-gpu==1.7.2

Collecting faiss-gpu==1.7.2
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [3]:
import os

from PIL import Image
import numpy as np
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
import pandas as pd
import random
import faiss
import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
import numpy as np
import torch
from torchvision.transforms.transforms import Compose

from tqdm import tqdm
tqdm.pandas()

from google.colab import drive
drive.mount('/content/drive')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

Mounted at /content/drive


'cuda'

In [4]:
class SimCLRLoss(nn.Module):
    def __init__(self, temperature, device):
        super(SimCLRLoss, self).__init__()
        self.temperature = temperature
        self._device = device


    def forward(self, features):
        """
        input:
            - features: hidden feature representation of shape [b, 2, dim]

        output:
            - loss: loss computed according to SimCLR
        """

        b, n, dim = features.size()
        assert(n == 2)
        mask = torch.eye(b, dtype=torch.float32).to(self._device)

        contrast_features = torch.cat(torch.unbind(features, dim=1), dim=0)
        anchor = features[:, 0]

        # Dot product
        dot_product = torch.matmul(anchor, contrast_features.T) / self.temperature

        # Log-sum trick for numerical stability
        logits_max, _ = torch.max(dot_product, dim=1, keepdim=True)
        logits = dot_product - logits_max.detach()

        mask = mask.repeat(1, 2).to(self._device)
        logits_mask = torch.scatter(torch.ones_like(mask).to(self._device), 1, torch.arange(b).view(-1, 1).to(self._device), 0)
        mask = mask * logits_mask

        # Log-softmax
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # Mean log-likelihood for positive
        loss = - ((mask * log_prob).sum(1) / mask.sum(1)).mean()

        return loss

In [5]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return out

features_dim = 128

class ContrastiveModel(nn.Module):
    def __init__(self, backbone_model, backbone_dim = 512, head='mlp', features_dim=features_dim):
        super(ContrastiveModel, self).__init__()
        self.backbone = backbone_model
        self.backbone_dim = backbone_dim
        self.head = head

        self.contrastive_head = nn.Sequential(
                    nn.Linear(self.backbone_dim, self.backbone_dim),
                    nn.ReLU(), nn.Linear(self.backbone_dim, features_dim))

    def forward(self, x):
        features = self.contrastive_head(self.backbone(x))
        features = F.normalize(features, dim = 1)
        return features




In [6]:
!unzip -q drive/MyDrive/DS_dataset.zip

In [7]:
train_dir = 'train'
test_dir = 'test'

In [8]:
test_paths = []
for label in os.listdir(test_dir):
  label_path = os.path.join(test_dir, label)

  test_images_path = [(label, os.path.join(label_path, image)) for image in os.listdir(label_path)]
  test_paths.extend(test_images_path)

test_df = pd.DataFrame(test_paths, columns = ['label', 'path']).sample(frac = 1.)
len(test_df)

1000

In [9]:
train_images_paths = [(os.path.join(train_dir, file),) for file in os.listdir(train_dir)]

train_df = pd.DataFrame(train_images_paths, columns = ['path']).sample(frac = 1.)
len(train_df)

2669

In [10]:
class FlowersDataset(Dataset):

    def __init__(self, df, transforms, base_transforms):
        self._df = df
        self._transforms = transforms
        self._base_transforms = base_transforms


    def __len__(self):
        return len(self._df)

    def __getitem__(self, idx):

      row = self._df.iloc[idx]

      base_image_path = row['path']

      base_image1 = Image.open(base_image_path).convert('RGB')
      base_image2 = Image.open(base_image_path).convert('RGB')

      i1 = self._transforms(base_image1) if self._base_transforms is None else self._base_transforms(base_image1)
      i2 = self._transforms(base_image2)



      return base_image_path, i1,i2

In [11]:
image_size = 32

In [12]:
def get_mean_and_std(dataloader):

    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for (_,data, _) in dataloader:
        channels_sum += torch.mean(data, dim=[0,2,3])
        channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
        num_batches += 1

    mean = channels_sum / num_batches

    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std


basic_transformations1 = transforms.Compose([
    transforms.Resize(size=(image_size,image_size)),
        transforms.ToTensor()
    ])
temp_dataset = FlowersDataset(train_df, basic_transformations1, basic_transformations1)
temp_dataloader = DataLoader(dataset=temp_dataset, batch_size=64)

mean, std = get_mean_and_std(temp_dataloader)
mean, std

(tensor([0.4608, 0.4237, 0.3002]), tensor([0.2704, 0.2407, 0.2657]))

In [13]:
simclr_train_augmentations = transforms.Compose([
            transforms.RandomResizedCrop(size = image_size, scale = [0.2, 1.0]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(brightness = 0.4, contrast =0.4, saturation = 0.4, hue = 0.1)
            ], p= 0.8),
            transforms.RandomGrayscale(p = 0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean = mean, std = std)
        ])

simclr_val_augmentations = transforms.Compose([
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean = mean, std = std)
            ])

simclr_basic_augmentations = transforms.Compose([
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean = mean, std = std)
            ])

In [14]:
def initialize_weights(m):
  if isinstance(m, nn.Conv2d):
      nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
      if m.bias is not None:
          nn.init.constant_(m.bias.data, 0)
  elif isinstance(m, nn.BatchNorm2d):
      nn.init.constant_(m.weight.data, 1)
      nn.init.constant_(m.bias.data, 0)
  elif isinstance(m, nn.Linear):
      nn.init.kaiming_uniform_(m.weight.data)
      nn.init.constant_(m.bias.data, 0)

In [15]:
train_dataset = FlowersDataset(train_df, simclr_train_augmentations, simclr_train_augmentations)

batch_size = 512
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle = True, drop_last = True)

In [16]:
learning_rate = 0.1

criterion = SimCLRLoss(temperature = 0.1, device = device).to(device)
backbone_model = ResNet(BasicBlock, [2, 2, 2, 2])
model = ContrastiveModel(backbone_model).to(device)
model.apply(initialize_weights)
optimizer = torch.optim.SGD(model.parameters(), weight_decay = 0.0001, lr = learning_rate, nesterov = False, momentum = 0.9)


In [19]:
device

'cuda'

In [20]:
number_of_epocs = 400

In [27]:
for epoc in tqdm(range(number_of_epocs), desc = 'epocs'):

  losses = []

  for (_, images1, images2) in tqdm(train_dataloader, desc="batch", leave=False):
    b, c, h, w = images1.size()
    input_ = torch.cat([images1.unsqueeze(1), images2.unsqueeze(1)], dim=1)
    input_ = input_.view(-1, c, h, w).to(device)
    output = model(input_).view(b, 2, -1)
    loss = criterion(output)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


    losses.append(loss.detach().cpu().item())

  mean_loss = round(np.mean(losses), 3)
  print('\n', epoc, mean_loss)
  if epoc > 0 and epoc % 100 == 0:
    torch.save(model.state_dict(), f'/content/drive/My Drive/model_{epoc}_{mean_loss}.pth')



epocs:   0%|          | 0/400 [00:00<?, ?it/s]
batch:   0%|          | 0/5 [00:00<?, ?it/s][A
batch:  20%|██        | 1/5 [00:12<00:51, 12.77s/it][A
batch:  40%|████      | 2/5 [00:16<00:23,  7.73s/it][A
batch:  60%|██████    | 3/5 [00:21<00:12,  6.17s/it][A
batch:  80%|████████  | 4/5 [00:25<00:05,  5.37s/it][A
batch: 100%|██████████| 5/5 [00:29<00:00,  5.01s/it][A
epocs:   0%|          | 1/400 [00:29<3:18:14, 29.81s/it]


 0 0.542



batch:   0%|          | 0/5 [00:00<?, ?it/s][A
batch:  20%|██        | 1/5 [00:04<00:17,  4.25s/it][A
epocs:   0%|          | 1/400 [00:34<3:50:14, 34.62s/it]


KeyboardInterrupt: ignored

In [29]:
#load from checkpoint

checkpoint = torch.load('/content/drive/My Drive/model_400_0.526.pth')

model.load_state_dict(checkpoint)
model.to(device)
backbone_model = model.backbone

In [30]:
base_dataset = FlowersDataset(train_df, simclr_basic_augmentations, simclr_basic_augmentations)

base_dataloader = DataLoader(dataset=base_dataset, batch_size=batch_size, shuffle = False, drop_last = False)

In [31]:
model.eval()

features = torch.FloatTensor(len(base_dataset), features_dim)

ptr = 0
images = []
with torch.no_grad():

  for (batch_images, images_paths, images1) in tqdm(base_dataloader, desc="batch", leave=False):

    images1 = images1.to(device)

    batch_features = model(images1)

    batch_size = images1.size(0)

    features[ptr: ptr+batch_size].copy_(batch_features.detach())
    ptr += batch_size
    images.extend(batch_images)



In [32]:
topk = 5
features = features.cpu().numpy()
n, dim = features.shape[0], features.shape[1]
index = faiss.IndexFlatIP(dim)
index = faiss.index_cpu_to_all_gpus(index)
index.add(features)
distances, indices = index.search(features, topk+1)

In [33]:
image_to_negibors = {}
for (i, image) in  enumerate(images):

  neighbors_indices = indices[i]

  image_to_negibors[image] = [images[neighbor_index] for neighbor_index in neighbors_indices if neighbor_index != i]



# Cluster

In [34]:
random_mirror = True

def ShearX(img, v):
    if random_mirror and random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))

def ShearY(img, v):
    if random_mirror and random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))

def Identity(img, v):
    return img

def TranslateX(img, v):
    if random_mirror and random.random() > 0.5:
        v = -v
    v = v * img.size[0]
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))

def TranslateY(img, v):
    if random_mirror and random.random() > 0.5:
        v = -v
    v = v * img.size[1]
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))

def TranslateXAbs(img, v):
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))

def TranslateYAbs(img, v):
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))

def Rotate(img, v):
    if random_mirror and random.random() > 0.5:
        v = -v
    return img.rotate(v)

def AutoContrast(img, _):
    return PIL.ImageOps.autocontrast(img)

def Invert(img, _):
    return PIL.ImageOps.invert(img)

def Equalize(img, _):
    return PIL.ImageOps.equalize(img)

def Solarize(img, v):
    return PIL.ImageOps.solarize(img, v)

def Posterize(img, v):
    v = int(v)
    return PIL.ImageOps.posterize(img, v)

def Contrast(img, v):
    return PIL.ImageEnhance.Contrast(img).enhance(v)

def Color(img, v):
    return PIL.ImageEnhance.Color(img).enhance(v)

def Brightness(img, v):
    return PIL.ImageEnhance.Brightness(img).enhance(v)

def Sharpness(img, v):
    return PIL.ImageEnhance.Sharpness(img).enhance(v)

def augment_list():
    l = [
        (Identity, 0, 1),
        (AutoContrast, 0, 1),
        (Equalize, 0, 1),
        (Rotate, -30, 30),
        (Solarize, 0, 256),
        (Color, 0.05, 0.95),
        (Contrast, 0.05, 0.95),
        (Brightness, 0.05, 0.95),
        (Sharpness, 0.05, 0.95),
        (ShearX, -0.1, 0.1),
        (TranslateX, -0.1, 0.1),
        (TranslateY, -0.1, 0.1),
        (Posterize, 4, 8),
        (ShearY, -0.1, 0.1),
    ]
    return l

class Augment:
    def __init__(self, n):
        self.n = n
        self.augment_list = augment_list()

    def __call__(self, img):
        ops = random.choices(self.augment_list, k=self.n)
        for op, minval, maxval in ops:
            val = (random.random()) * float(maxval - minval) + minval
            img = op(img, val)

        return img

class Cutout(object):
    def __init__(self, n_holes, length, random=False):
        self.n_holes = n_holes
        self.length = length
        self.random = random

    def __call__(self, img):
        h = img.size(1)
        w = img.size(2)
        length = random.randint(1, self.length)
        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - length // 2, 0, h)
            y2 = np.clip(y + length // 2, 0, h)
            x1 = np.clip(x - length // 2, 0, w)
            x2 = np.clip(x + length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img

In [35]:
cluster_train_transformations = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(image_size),
            Augment(4),
            transforms.ToTensor(),
            transforms.Normalize(mean = mean, std = std),
            Cutout(n_holes = 1, length = 16, random = True)])

cluster_val_transformations = transforms.Compose([
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean = mean, std = std)
            ])

In [36]:
class FlowersNeighborsDataset(Dataset):

    def __init__(self, image_to_neighbors, anchor_transform, neighbor_transform):
        self._image_and_neighbors = [(image, neighbors) for (image, neighbors) in image_to_neighbors.items()]
        self._anchor_transform = anchor_transform
        self._neighbor_transform = neighbor_transform


    def __len__(self):
        return len(self._image_and_neighbors)

    def __getitem__(self, idx):

      image_path, neighbors = self._image_and_neighbors[idx]
      neighbor_path = np.random.choice(neighbors, 1)[0]


      anchor_image = Image.open(image_path).convert('RGB')
      neighbor_image = Image.open(neighbor_path).convert('RGB')

      anchor_image = self._anchor_transform(anchor_image)
      neighbor_image = self._neighbor_transform(neighbor_image)



      return anchor_image, neighbor_image

In [37]:
class ClusteringModel(nn.Module):

    def __init__(self, backbone_model, features_dim, nclusters, nheads=1):
        super(ClusteringModel, self).__init__()
        self.backbone = backbone_model
        self.backbone_dim = features_dim
        self.nheads = nheads
        assert(isinstance(self.nheads, int))
        assert(self.nheads > 0)
        self.cluster_head = nn.ModuleList([nn.Linear(self.backbone_dim, nclusters) for _ in range(self.nheads)])

    def forward(self, x, forward_pass='default'):
        if forward_pass == 'default':
            features = self.backbone(x)
            out = [cluster_head(features) for cluster_head in self.cluster_head]

        elif forward_pass == 'backbone':
            out = self.backbone(x)

        elif forward_pass == 'head':
            out = [cluster_head(x) for cluster_head in self.cluster_head]

        elif forward_pass == 'return_all':
            features = self.backbone(x)
            out = {'features': features, 'output': [cluster_head(features) for cluster_head in self.cluster_head]}

        else:
            raise ValueError('Invalid forward pass {}'.format(forward_pass))

        return out

In [38]:
def entropy(x, input_as_probabilities):
    """
    Helper function to compute the entropy over the batch

    input: batch w/ shape [b, num_classes]
    output: entropy value [is ideally -log(num_classes)]
    """

    if input_as_probabilities:
        x_ =  torch.clamp(x, min = EPS)
        b =  x_ * torch.log(x_)
    else:
        b = F.softmax(x, dim = 1) * F.log_softmax(x, dim = 1)

    if len(b.size()) == 2: # Sample-wise entropy
        return -b.sum(dim = 1).mean()
    elif len(b.size()) == 1: # Distribution-wise entropy
        return - b.sum()
    else:
        raise ValueError('Input tensor is %d-Dimensional' %(len(b.size())))

class SCANLoss(nn.Module):
    def __init__(self, entropy_weight = 2.0):
        super(SCANLoss, self).__init__()
        self.softmax = nn.Softmax(dim = 1)
        self.bce = nn.BCELoss()
        self.entropy_weight = entropy_weight # Default = 2.0

    def forward(self, anchors, neighbors):
        """
        input:
            - anchors: logits for anchor images w/ shape [b, num_classes]
            - neighbors: logits for neighbor images w/ shape [b, num_classes]

        output:
            - Loss
        """
        # Softmax
        b, n = anchors.size()
        anchors_prob = self.softmax(anchors)
        positives_prob = self.softmax(neighbors)

        # Similarity in output space
        similarity = torch.bmm(anchors_prob.view(b, 1, n), positives_prob.view(b, n, 1)).squeeze()
        ones = torch.ones_like(similarity)
        consistency_loss = self.bce(similarity, ones)

        # Entropy loss
        entropy_loss = entropy(torch.mean(anchors_prob, 0), input_as_probabilities = True)

        # Total loss
        total_loss = consistency_loss - self.entropy_weight * entropy_loss

        return total_loss, consistency_loss, entropy_loss

In [39]:
EPS=1e-8
number_of_heads = 1

clustering_model = ClusteringModel(backbone_model, 512, 5, number_of_heads).to(device)

#update head only
for name, param in clustering_model.named_parameters():
  if 'cluster_head' in name:
    param.requires_grad = True
  else:
    param.requires_grad = False

params = list(filter(lambda p: p.requires_grad, clustering_model.parameters()))
assert(len(params) == 2 * number_of_heads)


In [40]:
clustering_optimizer = torch.optim.SGD(params, lr = 0.001, weight_decay = 0.0001)
clustering_criterion = SCANLoss(entropy_weight = 5.0).to(device)

In [41]:
cluster_batch_size = 128

cluster_train_dataset = FlowersNeighborsDataset(image_to_negibors, cluster_train_transformations, cluster_train_transformations)
cluster_train_dataloader = DataLoader(dataset=cluster_train_dataset, batch_size=cluster_batch_size, shuffle = True, drop_last = True)


In [71]:
# for (anchors, neighbors) in cluster_train_dataloader:

#   anchors = anchors.to(device)
#   neighbors = neighbors.to(device)

#   with torch.no_grad():
#     anchors_features = clustering_model(anchors, forward_pass='backbone')
#     neighbors_features = clustering_model(neighbors, forward_pass='backbone')

#   anchors_output = clustering_model(anchors_features, forward_pass='head')
#   neighbors_output = clustering_model(neighbors_features, forward_pass='head')

#   break


In [42]:
for epoc in tqdm(range(101), desc="Epoc", leave=False):

  for (anchors, neighbors) in tqdm(cluster_train_dataloader, desc="Batch", leave=False):

    anchors = anchors.to(device)
    neighbors = neighbors.to(device)

    with torch.no_grad():
      anchors_features = clustering_model(anchors, forward_pass='backbone')
      neighbors_features = clustering_model(neighbors, forward_pass='backbone')

    anchors_output = clustering_model(anchors_features, forward_pass='head')
    neighbors_output = clustering_model(neighbors_features, forward_pass='head')
    total_loss, consistency_loss, entropy_loss = [], [], []

    for (i, (anchors_output_subhead, neighbors_output_subhead)) in enumerate(zip(anchors_output, neighbors_output)):

      total_loss_, consistency_loss_, entropy_loss_ = clustering_criterion(anchors_output_subhead,
                                                                          neighbors_output_subhead)
      total_loss.append(total_loss_)
      consistency_loss.append(consistency_loss_)
      entropy_loss.append(entropy_loss_)

    total_loss = torch.sum(torch.stack(total_loss, dim=0))

    clustering_optimizer.zero_grad()
    total_loss.backward()
    clustering_optimizer.step()

    if epoc > 0 and epoc % 20 == 0:
      torch.save(clustering_model.state_dict(), f'/content/drive/My Drive/model_scan_{epoc}.pth')


Epoc:   0%|          | 0/101 [00:00<?, ?it/s]
Batch:   0%|          | 0/20 [00:00<?, ?it/s][A
Batch:   5%|▌         | 1/20 [00:00<00:18,  1.04it/s][A
Batch:  10%|█         | 2/20 [00:01<00:16,  1.08it/s][A
Batch:  15%|█▌        | 3/20 [00:02<00:15,  1.10it/s][A
Batch:  20%|██        | 4/20 [00:03<00:14,  1.09it/s][A
Batch:  25%|██▌       | 5/20 [00:04<00:13,  1.09it/s][A
Batch:  30%|███       | 6/20 [00:05<00:12,  1.10it/s][A
Batch:  35%|███▌      | 7/20 [00:06<00:11,  1.09it/s][A
Batch:  40%|████      | 8/20 [00:07<00:11,  1.09it/s][A
Batch:  45%|████▌     | 9/20 [00:08<00:09,  1.11it/s][A
Batch:  50%|█████     | 10/20 [00:09<00:08,  1.11it/s][A
Batch:  55%|█████▌    | 11/20 [00:10<00:08,  1.10it/s][A
Batch:  60%|██████    | 12/20 [00:10<00:07,  1.09it/s][A
Batch:  65%|██████▌   | 13/20 [00:11<00:06,  1.08it/s][A
Batch:  70%|███████   | 14/20 [00:12<00:05,  1.08it/s][A
Batch:  75%|███████▌  | 15/20 [00:13<00:04,  1.09it/s][A
Batch:  80%|████████  | 16/20 [00:14<00:03, 

In [43]:
cluster_model_checkpoint_path = '/content/drive/My Drive/model_scan_80.pth'
checkpoint = torch.load(cluster_model_checkpoint_path)
clustering_model.load_state_dict(checkpoint)
clustering_model.to(device)

# Self Label

In [44]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(MaskedCrossEntropyLoss, self).__init__()

    def forward(self, input, target, mask, weight, reduction='mean'):
        if not (mask != 0).any():
            raise ValueError('Mask in MaskedCrossEntropyLoss is all zeros.')
        target = torch.masked_select(target, mask)
        b, c = input.size()
        n = target.size(0)
        input = torch.masked_select(input, mask.view(b, 1)).view(n, c)
        return F.cross_entropy(input, target, weight = weight, reduction = reduction)

class ConfidenceBasedCE(nn.Module):
    def __init__(self, threshold, apply_class_balancing):
        super(ConfidenceBasedCE, self).__init__()
        self.loss = MaskedCrossEntropyLoss()
        self.softmax = nn.Softmax(dim = 1)
        self.threshold = threshold
        self.apply_class_balancing = apply_class_balancing

    def forward(self, anchors_weak, anchors_strong):
        """
        Loss function during self-labeling

        input: logits for original samples and for its strong augmentations
        output: cross entropy
        """
        # Retrieve target and mask based on weakly augmentated anchors
        weak_anchors_prob = self.softmax(anchors_weak)
        max_prob, target = torch.max(weak_anchors_prob, dim = 1)
        mask = max_prob > self.threshold
        b, c = weak_anchors_prob.size()
        target_masked = torch.masked_select(target, mask.squeeze())
        n = target_masked.size(0)

        # Inputs are strongly augmented anchors
        input_ = anchors_strong

        # Class balancing weights
        if self.apply_class_balancing:
            idx, counts = torch.unique(target_masked, return_counts = True)
            freq = 1/(counts.float()/n)
            weight = torch.ones(c).cuda()
            weight[idx] = freq

        else:
            weight = None

        # Loss
        loss = self.loss(input_, target, mask, weight = weight, reduction='mean')

        return loss


In [45]:
self_label_criterion = ConfidenceBasedCE(0.99, apply_class_balancing = False).to(device)
self_label_optimizer = torch.optim.Adam(clustering_model.parameters(), lr = 0.0001, weight_decay = 0.0001)


self_label_train_transformations = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(image_size),
            Augment(4),
            transforms.ToTensor(),
            transforms.Normalize(mean = mean, std = std),
            Cutout(n_holes = 1, length =16, random = True)])

self_label_val_transformations = transforms.Compose([
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean = mean, std = std)
            ])

self_label_train_dataset = FlowersDataset(train_df, self_label_train_transformations, self_label_val_transformations)
self_label_train_dataloader = DataLoader(dataset=self_label_train_dataset, batch_size=256)


self_label_val_dataset = FlowersDataset(test_df, self_label_val_transformations, self_label_val_transformations)
self_label_val_dataloader = DataLoader(dataset=self_label_val_dataset, batch_size=256)

In [48]:
# for epoc in tqdm(range(201), desc="Epoc", leave=False):

#   losses = []

#   for (_, images_augmented, images) in tqdm(self_label_train_dataloader, desc="Batch", leave=False):

#         images_augmented = images_augmented.to(device)
#         images = images.to(device)


#         with torch.no_grad():
#           output = clustering_model(images)[0]

#         output_augmented = clustering_model(images_augmented)[0]

#         loss = self_label_criterion(output, output_augmented)


#         self_label_optimizer.zero_grad()
#         loss.backward()
#         self_label_optimizer.step()
#         losses.update(loss.detach().cpu().item())

#   mean_loss = round(np.mean(losses), 3)

#   print(epoc, mean_loss)
#   if epoc > 0 and epoc % 50 == 0:
#     torch.save(clustering_model.state_dict(), f'/content/drive/My Drive/model_self_label_{epoc}.pth')



torch.Size([256, 3, 32, 32])

In [53]:
eval_dataset = FlowersDataset(test_df, simclr_basic_augmentations, simclr_basic_augmentations)

eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=batch_size, shuffle = False, drop_last = False)

In [66]:
clustering_model.eval()
image_embeddings_and_path = []
with torch.no_grad():

  for (images_paths, images, _) in tqdm(eval_dataloader, desc="batch", leave=False):

    images = images.to(device)

    model_results = clustering_model(images)[0]

    images_embeddings = model_results.detach().cpu().numpy()

    for (index, image_path) in enumerate(images_paths):
      image_embeddings_and_path.append((images_embeddings[index], image_path))

len(image_embeddings_and_path)



1000

In [69]:
results_df = pd.DataFrame(image_embeddings_and_path, columns = ['embeddings', 'path'])
results_df = pd.merge(results_df, test_df, on = ['path']).drop(columns = ['path'])
results_df.shape[0]

1000

In [71]:
from sklearn.metrics.pairwise import cosine_similarity
cosine_df = pd.DataFrame(cosine_similarity(results_df['embeddings'].tolist(),dense_output=True))
cosine_df['label'] = results_df['label']

In [74]:
from sklearn.cluster import KMeans

features_df = cosine_df.drop(columns = ['label'])

kmeans = KMeans(
    init="k-means++",
    n_clusters=5
)

kmeans.fit(features_df)

clusters = kmeans.predict(features_df)
results_df['cluster'] = clusters

results_df.groupby(['label', 'cluster']).size().to_frame('size').reset_index().sort_values(by = ['label', 'cluster'])



Unnamed: 0,label,cluster,size
0,daisy,0,48
1,daisy,1,37
2,daisy,2,24
3,daisy,3,45
4,daisy,4,46
5,dandelion,0,58
6,dandelion,1,21
7,dandelion,2,16
8,dandelion,3,54
9,dandelion,4,51
