# Lab1 - Self-Supervised Learning (SSL)

In this course, we will focus on the main steps of implementing [SimCLR](https://proceedings.mlr.press/v119/chen20j) in PyTorch.

1. Image Preprocessing and Augmentation
2. NT-Xent Loss
3. Leave-one-out KNN

## 1. Image Preprocessing and Augmentation

In [1]:
from torchvision import transforms, datasets
from IPython.display import display
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import os
from copy import deepcopy
from torchsummary import summary
import torch.utils.data as data
import numpy as np

In [2]:
print(torch.__version__)
print(torch.version.cuda)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
#print("Device:", device)

1.13.0
11.6


## Get Training and Testing Data

In [3]:
root_train = '.\\data\\unlabeled\\'
root_test = '.\\data\\test\\'
BATCH_SIZE = 256
TEMPERATURE = 0.07
EPOCH = 500
CHANNEL = 3

# Load Training Data
train_data = []
train_dir_path = os.path.dirname(root_train)
all_file_name = os.listdir(train_dir_path)
for name in all_file_name:
    train_data.append( Image.open(os.path.join( root_train,name)) )

#Load Test Data
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_dir_path = os.path.dirname(root_test)
test_img_data = datasets.ImageFolder( test_dir_path, transform = transform_test )
test_loader = data.DataLoader( test_img_data, batch_size=BATCH_SIZE, shuffle=True )
print( test_img_data[0][0].shape )


FileNotFoundError: [WinError 3] 系統找不到指定的路徑。: '.\\data\\unlabeled'

## Testing augmentation

In [8]:
trans = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(60),
    transforms.RandomResizedCrop(size=96, scale=(0.2, 0.6)),
    transforms.GaussianBlur(kernel_size=9,sigma=(0.1, 0.8)),
])
#display( trans(train_data[0]) )
print("Jimmy is handsome monkey")

Jimmy is handsome monkey


In [None]:
NUM_IMAGES = len( all_file_name )
TEST_NUM_IMAGE =  len( test_img_data )

---

## 2. The (N)ormalized (T)emperature-scaled (C)ross (Ent)ropy Loss (NT-Xent)

### Notation Definition

- Let $u$ and $v$ be the encoded features of an image in different views (different augmentation).

- The similarity of $u$ and $v$ is defined as $\text{sim}(u,v)=\frac{u^Tv}{\vert u\vert\vert v\vert}$.

- For a batch of $N$ images, there are $2N$ encoded features:

    $$
    \{z_i\}_{i=1}^{2N}=\{u_1,u_2,\cdots,u_N,v_1,v_2,\cdots,v_N\}
    $$

### Designs a loss function to learn that the feature $u_i$ can figure out $v_i$ from $(2N-1)$ features.

Let $z_i$ be the reference feature. We can use cross entropy loss (negative log softmax) to make $z_i$ and $z_j$ closer and make $z_i$ father away from $z_k,\forall k\neq j$ at the same time.

$$
\mathcal{L}_{i,j}=-\log\Bigg(\frac{\exp\big(\frac{\text{sim}(z_i,z_j)}{\tau}\big)}{\sum_{k=1}^{2N}\mathbb 1[k\neq i]\exp\big(\frac{\text{sim}(z_i,z_k)}{\tau}\big)}\Bigg)
$$

where $\tau \le 1$ is a constant to scale up the output range of $\text{sim}(\cdot,\cdot)$ from $[-1, 1]$ to $[\frac{-1}{\tau},\frac{1}{\tau}]$.

Consider all ordered pairs $(u_i, v_i)$ and $(v_i, u_i)$, $\forall i \in \{1,\cdots,N\}$

$$
\mathcal{L}_{\text{NT-Xent}}=\frac{1}{2N}\sum_{i=1}^N \mathcal{L}_{i,i+N} + \mathcal{L}_{i+N,i}
$$

### Implementation
In following, we provide an implementation of batchify NT-Xent loss which does not contain any **for loop** at python level.

In [None]:

def xt_xent(
    u: torch.Tensor,                               # [N, C]
    v: torch.Tensor,                               # [N, C]
    temperature: float = TEMPERATURE,
):
    """
    N: batch size
    C: feature dimension
    """
    N, C = u.shape
    #print( 'N, C:',N,C )
    # torch.cat是将两个张量（tensor）拼接在一起，cat是concatnate的意思 -> dim = 0接在下面
    z = torch.cat([u, v], dim=0)                   # [2N, C]
    z = F.normalize(z, p=2, dim=1)                 # [2N, C]
    s = torch.matmul(z, z.t()) / temperature       # [2N, 2N] similarity matrix
    mask = torch.eye(2 * N).bool().to(z.device)    # [2N, 2N] identity matrix
    s = torch.masked_fill(s, mask, -float('inf'))  # fill the diagonal with negative infinity -> k!=i 的部分
    label = torch.cat([                            # [2N]
        torch.arange(N, 2 * N),                    # {N, ..., 2N - 1}
        torch.arange(N),                           # {0, ..., N - 1}
    ]).to(z.device)
    #print('label:', label)
    #print('s:i', s)
    loss = F.cross_entropy(s, label)               # NT-Xent loss
    return loss

### Transform order

In [None]:
#transforms.RandomOrder
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(90),
    transforms.RandomResizedCrop(size=96, scale=(0.2, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])



# SimCLR CNN
Define a single layer CNN as the image encoder.

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

class SimCLR(nn.Module):
    def __init__(self, projhead = 256, emd_dim = 512):
        super().__init__()
        resnet18 = torchvision.models.resnet18(weights=None)
        resnet18.fc = Identity()
        self.encoder = resnet18
        self.projection = nn.Sequential(
            nn.Linear(emd_dim, projhead),
            nn.ReLU(),
            nn.Linear(projhead, projhead)
        )
    def forward(self, x):
        encoding = self.encoder(x)
        projection = self.projection(encoding)  
        return projection

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

class LinearEvaluation(nn.Module):
    def __init__(self, model):
        super().__init__()
        simclr = deepcopy(model)
        simclr.projection = Identity()
        self.simclr = simclr
        for param in self.simclr.parameters():
            param.requires_grad = False
    def forward(self, x):
        encoding = self.simclr(x)
        #pred = self.linear(encoding) 
        return encoding 

In [None]:
from torchsummary import summary

simclr_model = SimCLR()
simclr_model.to(device)

SimCLR(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

---

## 3. Leave-one-out Cross Validation with K-Nearest Neighbors (KNN)
- Leave-one-out Cross Validation
    
    For each data, the other data are training data.

- KNN
    
    An object is classified by a plurality vote of its $K$ neighbors in training data.

### Implementation
For KNN, the space complexity is $O(N^2\times C)$ where $N$ is the number of data and $C$ is the dimension of feature size. We provide a batchify implementation to reduce memory footprint.

In [None]:
def KNN(emb, cls, batch_size, Ks=[1, 10, 50, 100]):
    """Apply KNN for different K and return the maximum acc"""
    preds = []
    mask = torch.eye(batch_size).bool().to(emb.device)
    mask = F.pad(mask, (0, len(emb) - batch_size))
    for batch_x in torch.split(emb, batch_size):
        dist = torch.norm(
            batch_x.unsqueeze(1) - emb.unsqueeze(0), dim=2, p="fro")
        now_batch_size = len(batch_x)
        mask = mask[:now_batch_size]
        dist = torch.masked_fill(dist, mask, float('inf'))
        # update mask
        mask = F.pad(mask[:, :-now_batch_size], (now_batch_size, 0))
        pred = []
        for K in Ks:
            knn = dist.topk(K, dim=1, largest=False).indices
            knn = cls[knn].cpu()
            pred.append(torch.mode(knn).values)
        pred = torch.stack(pred, dim=0)
        preds.append(pred)
    preds = torch.cat(preds, dim=1)
    accs = [(pred == cls.cpu()).float().mean().item() for pred in preds]
    return max(accs)

## [ First ] Training

In [None]:
#Extract feature
max_acc = 0.97
lr = 0.001 
optimizer = torch.optim.Adam(simclr_model.parameters(), lr=lr)
simclr_model.to(device)
loss_list = []
for epoch in range(1, EPOCH + 1):
    total_loss = 0
    x1 = torch.stack([transform(train_data[idx]) for idx in range(NUM_IMAGES)])
    x2 = torch.stack([transform(train_data[idx]) for idx in range(NUM_IMAGES)])
    train_loader_1 = data.DataLoader(x1, batch_size=BATCH_SIZE, shuffle=False)
    train_loader_2 = data.DataLoader(x2, batch_size=BATCH_SIZE, shuffle=False)
    for batch_idx, (data_1,data_2) in enumerate(zip(train_loader_1,train_loader_2)):
        data_1,data_2 = data_1.to(device), data_2.to(device)
        #print( data_1.shape )
        optimizer.zero_grad()
        u = simclr_model(data_1)
        v = simclr_model(data_2)
        loss = xt_xent(u, v)
        loss.backward()
        optimizer.step()
        # 查看 loss 情況 ###
        total_loss +=loss.item()
    print(f'Epoch {epoch}: ' +
            f'  Loss: {total_loss/29:.6f}')
    #### Validation 狀況 ###
    val_model = deepcopy(simclr_model)
    eval_model = LinearEvaluation(val_model).to(device)
    eval_model.to(device)
    for batch_idx, (val_data, label) in enumerate(test_loader):
        val_data,label = val_data.to(device), label.to(device)
        emd_data = eval_model(val_data)
        if batch_idx == 0:
            embed = emd_data
            labels = label
        else:
            embed = torch.cat((embed,emd_data),0)
            labels = torch.cat((labels,label),0)
            #print('EBD SHAPE',embed.shape)
            #print('Label',labels.shape)
    acc = KNN(embed, labels, batch_size=BATCH_SIZE)
    print("Val Accuracy: %.5f" % acc)
    if acc >= max_acc:
        max_acc = acc
        torch.save(simclr_model,'./best.pth')
        
    ### 儲存 checkpoint

## [ Second ] Training

In [None]:
#Extract feature
max_acc = 0.974
simclr_model = torch.load('./best.pth')
lr = 1e-5
optimizer = torch.optim.Adam(simclr_model.parameters(), lr=lr)
simclr_model.to(device)
loss_list = []
for epoch in range(1, EPOCH + 1):
    total_loss = 0
    x1 = torch.stack([transform(train_data[idx]) for idx in range(NUM_IMAGES)])
    x2 = torch.stack([transform(train_data[idx]) for idx in range(NUM_IMAGES)])
    train_loader_1 = data.DataLoader(x1, batch_size=BATCH_SIZE, shuffle=False)
    train_loader_2 = data.DataLoader(x2, batch_size=BATCH_SIZE, shuffle=False)
    for batch_idx, (data_1,data_2) in enumerate(zip(train_loader_1,train_loader_2)):
        data_1,data_2 = data_1.to(device), data_2.to(device)
        #print( data_1.shape )
        optimizer.zero_grad()
        u = simclr_model(data_1)
        v = simclr_model(data_2)
        loss = xt_xent(u, v)
        loss.backward()
        optimizer.step()
        # 查看 loss 情況 ###
        total_loss +=loss.item()
    print(f'Epoch {epoch}: ' +
            f'  Loss: {total_loss/29:.6f}')
    #### Validation 狀況 ###
    val_model = deepcopy(simclr_model)
    eval_model = LinearEvaluation(val_model).to(device)
    eval_model.to(device)
    for batch_idx, (val_data, label) in enumerate(test_loader):
        val_data,label = val_data.to(device), label.to(device)
        emd_data = eval_model(val_data)
        if batch_idx == 0:
            embed = emd_data
            labels = label
        else:
            embed = torch.cat((embed,emd_data),0)
            labels = torch.cat((labels,label),0)
            #print('EBD SHAPE',embed.shape)
            #print('Label',labels.shape)
    acc = KNN(embed, labels, batch_size=BATCH_SIZE)
    print("Val Accuracy: %.5f" % acc)
    if acc >= max_acc:
        max_acc = acc
        torch.save(simclr_model,'./best.pth')

KeyboardInterrupt: 

# Testing good model 


In [None]:
###TESTING GOOD MODEL###
t = torch.load('./best.pth')
t.to(device)
a = deepcopy(t)
b = LinearEvaluation(a).to(device)
#b.load_state_dict(torch.load('./model_checkpoint/simclr_model_new.pt'))
b.to(device)
test_loader = data.DataLoader(test_img_data, batch_size=BATCH_SIZE, shuffle=True)
for batch_idx, (val_data, label) in enumerate(test_loader):
    val_data,label = val_data.to(device), label.to(device)
    emd_data = b(val_data)
    if batch_idx == 0:
        embed = emd_data
        labels = label
    else:
        embed = torch.cat((embed,emd_data),0)
        labels = torch.cat((labels,label),0)
        print('EBD SHAPE',embed.shape)
                #print('Label',labels.shape)
acc = KNN(embed, labels, batch_size=BATCH_SIZE)
print("Val Accuracy: %.5f" % acc)


EBD SHAPE torch.Size([500, 512])
Val Accuracy: 0.98400


## Saving numpy file

In [None]:
##SAVE NUMPY file##
train_ebd = torch.stack([transform_test(train_data[idx]) for idx in range(NUM_IMAGES)])
train_ebd_loader = data.DataLoader(train_ebd, batch_size=BATCH_SIZE, shuffle=False)
for batch_idx, data_ebd in enumerate(train_ebd_loader):
    data_ebd = data_ebd.to(device)
    train_ebd_data =  b(data_ebd)
    if batch_idx == 0:
        t_embed = train_ebd_data.cpu()
    else:
        t_embed = torch.cat((t_embed.cpu(),train_ebd_data.cpu()),0)
print('EBD SHAPE',t_embed.shape)
np.save('310581040.npy',t_embed)
embedding = np.load('310581040.npy')
print(embedding.dtype)
print(embedding.shape)

EBD SHAPE torch.Size([7294, 512])
float32
(7294, 512)
