## 練習自監督學習，內含四種自監督學習演算法SimCLR, MoCo, BYOL, BarlowTwins（也包含最基本的監督式學習訓練方法）

In [None]:
import time
import math
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import torchvision
import torchvision.datasets as datasets
from torchvision.models.resnet import resnet50, resnet18
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm




In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

! nvidia-smi

cuda
Thu Dec  2 11:52:51 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8    27W / 149W |      3MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| P

In [None]:
# 定義圖片變換(SSL要比較高強度的圖片變換)

SSL_train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

supervised_train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])


In [None]:
# # 印出圖片看看transform的作用

# import matplotlib.pyplot as plt

# def imshow(img):
#     img[0] = img[0] * 0.2023 + 0.4914    
#     img[1] = img[1] * 0.1994 + 0.4822  
#     img[2] = img[2] * 0.2010 + 0.4465    
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()

# batch_size = 5
# stl10_classes = ('airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck')
# train_data_temp = datasets.STL10(root='./dataset', split='train', transform=train_transform, download=True)
# train_loader = DataLoader(train_data_temp, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# # get some random training images
# dataiter = iter(train_loader)
# images, labels = dataiter.next()

# # show images
# # imshow(torchvision.utils.make_grid(images)) # 32*32 pixel
# for im in images:
#   imshow(im)
# # print labels
# print(' '.join('%5s' % stl10_classes[labels[j]] for j in range(batch_size)))

In [None]:
# 定義模型
class SSL_ResNet50(nn.Module):
    def __init__(self, feature_dim=128):
        super(SSL_ResNet50, self).__init__()

        self.f = []
        for name, module in resnet50().named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        # projection head
        self.g = nn.Sequential(nn.Linear(2048, 512, bias=False), nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)

class SSL_ResNet18(nn.Module): # SimCLR, MoCo
    def __init__(self, feature_dim=128):
        super(SSL_ResNet18, self).__init__()

        self.f = []
        for name, module in resnet18().named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        # projection head
        self.g = nn.Sequential(nn.Linear(512, 512, bias=False), nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)

class BYOL_ResNet18(nn.Module):
    def __init__(self, feature_dim=128):
        super(BYOL_ResNet18, self).__init__()

        self.f = []
        for name, module in resnet18().named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        # projection head
        self.g = nn.Sequential(nn.Linear(512, 512, bias=False), nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return F.normalize(feature, dim=-1), out # the different is output, no normalization

class Prediction_Layer(nn.Module): # for BYOL, just a simple MLP layer
    def __init__(self, in_channels=128, hidden_size=512, output_size=128):
        super(Prediction_Layer, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(in_channels, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.net(x)


class BarlowTwins_ResNet18(nn.Module):
    def __init__(self, feature_dim=1024):
        super(BarlowTwins_ResNet18, self).__init__()

        self.f = []
        for name, module in resnet18().named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        # projection head
        self.g = nn.Sequential(nn.Linear(512, 1024, bias=False), nn.BatchNorm1d(1024),
                               nn.ReLU(inplace=True), nn.Linear(1024, feature_dim, bias=True)) 

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return F.normalize(feature, dim=-1), out


class Supervised_ResNet18(nn.Module):
    def __init__(self, num_class=10):
        super(Supervised_ResNet18, self).__init__()

        self.f = []
        for name, module in resnet18().named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        # fully connected layer
        self.fc = nn.Sequential(nn.Linear(512, num_class, bias=True))

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.fc(feature)
        return out



In [None]:
class CIFAR10Pair(datasets.CIFAR10):
    """CIFAR10 Dataset.
    """

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            pos_1 = self.transform(img)
            pos_2 = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return pos_1, pos_2, target

# for own dataset
class CustomDataPair(datasets.ImageFolder):

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            pos_1 = self.transform(sample)
            pos_2 = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return pos_1, pos_2, target

In [None]:
# SimCLR train
def SimCLR_train(encoder_q, data_loader, train_optimizer):
    encoder_q.train()
    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for x_q, x_k, _ in train_bar:
        x_q, x_k = x_q.cuda(non_blocking=True), x_k.cuda(non_blocking=True)
        _, view_q = encoder_q(x_q)
        _, view_k = encoder_q(x_k)

        batch_size = view_q.size(0)
        feature_dim = view_q.size(1)

        mask = torch.eye(batch_size, dtype=torch.bool).to(x_q.device)
        # [Batch, Batch]
        score1 = torch.mm(view_q, view_q.t().contiguous())
        score2 = torch.mm(view_k, view_k.t().contiguous())
        # [Batch, Batch-1]
        score1 = score1[~mask].view(batch_size, -1)
        score2 = score2[~mask].view(batch_size, -1)
        # [Batch, Batch]
        score3 = torch.mm(view_q, view_k.t().contiguous())        
        score4 = torch.mm(view_k, view_q.t().contiguous())
        # [Batch, Batch+Batch-1]
        score3 = torch.cat([score3, score1], dim=-1)
        score4 = torch.cat([score4, score2], dim=-1)


        # compute loss
        labels = torch.arange(batch_size, dtype=torch.long, device=x_q.device)
        loss1 = F.cross_entropy(score3 / temperature, labels)
        loss2 = F.cross_entropy(score4 / temperature, labels)
        loss = (loss1+loss2)/2.0


        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += batch_size
        total_loss += loss.item() * batch_size
        train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num))

    return total_loss / total_num

In [None]:
# MoCo train
def MoCo_train(encoder_q, encoder_k, data_loader, train_optimizer):
    global memory_queue
    encoder_q.train()
    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for x_q, x_k, _ in train_bar:
        x_q, x_k = x_q.cuda(non_blocking=True), x_k.cuda(non_blocking=True)
        _, query = encoder_q(x_q)
        _, key = encoder_k(x_k)
        
        batch_size = query.size(0)
        feature_dim = query.size(1)

        score_pos = torch.bmm(query.view(batch_size, 1, feature_dim), key.view(batch_size, feature_dim, 1))
        score_pos = torch.squeeze(score_pos, dim=1)
        score_neg = torch.mm(query, memory_queue.t().contiguous())
        # [B, 1+M]
        out = torch.cat([score_pos, score_neg], dim=1)
        # compute loss
        loss = F.cross_entropy(out / temperature, torch.zeros(batch_size, dtype=torch.long, device=x_q.device))

        #---------------- symmetry loss can improve the performance
        _, query2 = encoder_q(x_k)
        _, key2 = encoder_k(x_q)

        score_pos = torch.bmm(query2.view(batch_size, 1, feature_dim), key2.view(batch_size, feature_dim, 1))
        score_pos = torch.squeeze(score_pos, dim=1)
        score_neg = torch.mm(query2, memory_queue.t().contiguous())
        # [B, 1+M]
        out = torch.cat([score_pos, score_neg], dim=1)
        # compute loss
        loss2 = F.cross_entropy(out / temperature, torch.zeros(batch_size, dtype=torch.long, device=x_q.device))
        #----------------
        loss = (loss + loss2)/2.0


        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        # momentum update
        for parameter_q, parameter_k in zip(encoder_q.parameters(), encoder_k.parameters()):
            parameter_k.data.copy_(parameter_k.data * momentum + parameter_q.data * (1.0 - momentum))
        # update queue
        memory_queue = torch.cat((memory_queue, key, key2), dim=0)[2*batch_size:]

        total_num += batch_size
        total_loss += loss.item() * batch_size
        train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num))

    return total_loss / total_num

In [None]:
# BYOL train
def BYOL_loss(x, y):
        x = F.normalize(x, dim=1)
        y = F.normalize(y, dim=1)
        return 2 - 2 * (x * y).sum(dim=-1)
    
def BYOL_train(encoder_q, predict_layer, encoder_k, data_loader, train_optimizer):
    encoder_q.train()
    predict_layer.train()
    
    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for x_q, x_k, _ in train_bar:
        x_q, x_k = x_q.cuda(non_blocking=True), x_k.cuda(non_blocking=True)
        
        
        _, query1 = encoder_q(x_q)
        _, query2 = encoder_q(x_k)
        pred1 = predict_layer(query1)
        pred2 = predict_layer(query2)

        with torch.no_grad():
            _, key2 = encoder_k(x_k)
            _, key1 = encoder_k(x_q)
            
        
        loss1 = BYOL_loss(pred1,key2)
        loss2 = BYOL_loss(pred2,key1)
        loss = (loss1+loss2).mean()

        train_optimizer.zero_grad() # 這三步驟是算完loss要更新model參數的步驟，對每個演算法都是通用的
        loss.backward()
        train_optimizer.step()

        # momentum update
        for parameter_q, parameter_k in zip(encoder_q.parameters(), encoder_k.parameters()):
            parameter_k.data.copy_(parameter_k.data * momentum + parameter_q.data * (1.0 - momentum))

        batch_size = query1.size(0)
        feature_dim = query1.size(1)
        total_num += batch_size
        total_loss += loss.item() * batch_size
        train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num))

    return total_loss / total_num

In [None]:
# BarlowTwins train
def BarlowTwins_train(encoder_q, data_loader, train_optimizer, lambda_param):
    encoder_q.train()
    
    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for x_q, x_k, _ in train_bar:
        x_q, x_k = x_q.cuda(non_blocking=True), x_k.cuda(non_blocking=True)
        _, view_q = encoder_q(x_q)
        _, view_k = encoder_q(x_k)

        N = view_q.size(0) # batch size
        D = view_q.size(1) # feature dimension(output dimension of projection layer

        view_q_norm = (view_q - view_q.mean(0)) / view_q.std(0) # NxD
        view_k_norm = (view_k - view_k.mean(0)) / view_k.std(0) # NxD

        # cross-correlation matrix
        c = torch.mm(view_q_norm.t().contiguous(), view_k_norm) / N # DxD 
        # loss
        c_diff = (c - torch.eye(D, device=view_q.device)).pow(2) # DxD
        # multiply off-diagonal elems of c_diff by lambda
        c_diff[~torch.eye(D, dtype=bool)] *= lambda_param
        loss = c_diff.sum()

        # loss反向傳播+優化器更新模型參數
        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += batch_size
        total_loss += loss.item() * batch_size
        train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num))

    return total_loss / total_num

In [None]:
#supervised learning, train with labels
def Supervised_train(net, data_loader, train_optimizer):
    net.train()
    
    total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
    with torch.enable_grad():
        for data, target in data_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            
            out = net(data)
            loss = F.cross_entropy(out, target)

            train_optimizer.zero_grad()
            loss.backward()
            train_optimizer.step()

            total_num += data.size(0)
            total_loss += loss.item() * data.size(0)
            prediction = torch.argsort(out, dim=-1, descending=True)
            total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()

            data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%'
                                     .format('Train', epoch, epochs, total_loss / total_num,
                                             total_correct_1 / total_num * 100, total_correct_5 / total_num * 100))

    return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100

In [None]:
# knn測試準確率，非Linear Protocol，會比較不準
def SSL_test(net, memory_data_loader, test_data_loader):
    net.eval()
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, _, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature, out = net(data.cuda(non_blocking=True))
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, _, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature, out = net(data)

            total_num += data.size(0)
            # compute cos similarity between each feature vector and feature bank ---> [B, N]
            sim_matrix = torch.mm(feature, feature_bank)
            # [B, K]
            sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1)
            # [B, K]
            sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)
            sim_weight = (sim_weight / temperature).exp()

            # counts for each class
            one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device)
            # [B*K, C]
            one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
            # weighted score ---> [B, C]
            pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)

            pred_labels = pred_scores.argsort(dim=-1, descending=True)
            total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%'
                                     .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100))

    return total_top1 / total_num * 100, total_top5 / total_num * 100


In [None]:
def Supervised_test(net, test_data_loader):
    net.eval()

    total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(test_data_loader)
    with torch.no_grad():
        for data, target in data_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            
            out = net(data)

            total_num += data.size(0)
            prediction = torch.argsort(out, dim=-1, descending=True)
            total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()

            data_bar.set_description('{} Epoch: [{}/{}] ACC@1: {:.2f}% ACC@5: {:.2f}%'
                                     .format('Test', epoch, epochs,
                                             total_correct_1 / total_num * 100, total_correct_5 / total_num * 100))

    return total_correct_1 / total_num * 100, total_correct_5 / total_num * 100

In [None]:
# 如果optimizer是SGD的話，我們根據epoch來調整learning rate，避免learning rate都不變化
def adjust_learning_rate(learning_rate, optimizer, epoch, total_epochs, cosine):
    lr = learning_rate
    lr_decay_rate = 0.2
    if cosine:
        eta_min = lr * (lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / total_epochs)) / 2
    else:
        lr_decay_epochs = np.array([math.floor(total_epochs*0.5), math.floor(total_epochs*0.75), math.floor(total_epochs*0.875)])
        steps = np.sum(epoch > np.asarray(lr_decay_epochs))
        if steps > 0:
            lr = lr * (lr_decay_rate ** steps)

    print("LR:{}".format(lr))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
# 訓練主程式
if __name__ == '__main__':
    
    feature_dim = 1024  # 映射後向量維度, 若是使用BarlowTwins需要改成1024以上，其餘128就可以
    m = 4096 # moco queue大小
    temperature, momentum = 0.1, 0.99 # 超參數設置
    k = 200 # knn中的k(此為非正式測試準確度的方式，正式方式還是以Linear Protocol為主)
    batch_size = 512
    epochs = 3
    model_name = 'resnet18' # 特徵抽取器的架構
    Learning_method = 'BarlowTwins' # SimCLR, MoCo, BYOL, BarlowTwins, Supervised->this is a basic supervised learning method
    LR = 0.5

    # 準備資料集，監督式學習用普通的cifar10，自監督學習要用回傳兩張經過變換的圖片
    if Learning_method == 'Supervised':
        train_data = datasets.CIFAR10(root='./dataset', train=True, transform=supervised_train_transform, download=True)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
        test_data = datasets.CIFAR10(root='./dataset', train=False, transform=test_transform, download=True)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    else:
        train_data = CIFAR10Pair(root='./dataset', train=True, transform=SSL_train_transform, download=True)
        # train_data = CustomDataPair(root='~/cutone_data/cifar10_pic_10per/train/', transform=train_transform)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
        memory_data = CIFAR10Pair(root='./dataset', train=True, transform=test_transform, download=True)
        # memory_data = CustomDataPair(root='~/cutone_data/cifar10_pic_10per/train/', transform=test_transform)
        memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
        test_data = CIFAR10Pair(root='./dataset', train=False, transform=test_transform, download=True)
        # test_data = CustomDataPair(root='~/cutone_data/cifar10_pic_10per/test/', transform=test_transform)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    # 設定各個演算法的訓練模型、優化器
    if Learning_method == 'SimCLR':
        print('running SimCLR')
        model_q = SSL_ResNet18(feature_dim).cuda()
        optimizer = optim.SGD(model_q.parameters(), lr=LR, momentum=0.9, weight_decay=1e-6)
    elif Learning_method == 'MoCo':
        print('running MoCo')
        model_q = SSL_ResNet18(feature_dim).cuda()
        model_k = SSL_ResNet18(feature_dim).cuda()
        # init memory queue as unit random vector ---> [M, D]
        memory_queue = F.normalize(torch.randn(m, feature_dim).cuda(), dim=-1)
        optimizer = optim.SGD(model_q.parameters(), lr=LR, momentum=0.9, weight_decay=1e-6)
    elif Learning_method == 'BYOL':
        print('running BYOL')
        model_q = BYOL_ResNet18(feature_dim).cuda()
        model_k = BYOL_ResNet18(feature_dim).cuda()
        pred_model = Prediction_Layer().cuda()
        optimizer = optim.SGD(list(model_q.parameters())+list(pred_model.parameters()), 
                                  lr=LR, momentum=0.9, weight_decay=1e-6)
    elif Learning_method == 'BarlowTwins':
        print('running BarlowTwins')
        model_q = BarlowTwins_ResNet18(feature_dim).cuda()
        optimizer = optim.SGD(model_q.parameters(), lr=LR, momentum=0.9, weight_decay=1e-6)
    elif Learning_method == 'Supervised':
        print('running Supervised Learning')
        model_q = Supervised_ResNet18(num_class=len(train_data.classes)).cuda()
        optimizer = optim.SGD(model_q.parameters(), lr=LR, momentum=0.9, weight_decay=1e-6)
    else:
        raise ValueError('Learning method is either SimCLR, MoCo, BYOL, BarlowTwins, Supervised.')
    
    
    
    # for MoCo, BYOL演算法，初始化model_k(不使用gradient更新的模型)
    if Learning_method == 'MoCo' or Learning_method == 'BYOL':
        for param_q, param_k in zip(model_q.parameters(), model_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False
    

    # train 資料類別數
    c = len(test_data.classes)
    
    

    # training loop
    results = {'train_loss': [], 'total_time': [], 'test_acc@1': [], 'test_acc@5': []}
    save_name_pre = '{}_{}_f{}_q{}_t{}_m{}_k{}_B{}_e{}'.format(Learning_method, model_name, feature_dim, m, temperature, momentum, k, batch_size, epochs)
    best_acc = 0.0
    begin_time = time.time()
    for epoch in range(1, epochs + 1):
        if type(optimizer) is optim.SGD: 
            adjust_learning_rate(LR, optimizer, epoch, epochs, cosine=True)
        
        if Learning_method == 'MoCo':
            train_loss = MoCo_train(model_q, model_k, train_loader, optimizer)
        elif Learning_method == 'SimCLR':
            train_loss = SimCLR_train(model_q, train_loader, optimizer)
        elif Learning_method == 'BYOL':
            train_loss = BYOL_train(model_q, pred_model, model_k, train_loader, optimizer)
        elif Learning_method == 'BarlowTwins':
            train_loss = BarlowTwins_train(model_q, train_loader, optimizer, lambda_param=0.005) # 可以動手改一下lambda_param，這參數對於模型的影響蠻大的
        elif Learning_method == 'Supervised':
            train_loss, train_acc1, train_acc5 = Supervised_train(model_q, train_loader, optimizer)
        else:
            raise ValueError('Learning method is either SimCLR, MoCo, BYOL, BarlowTwins, Supervised')
        
        
        results['train_loss'].append(train_loss)
        results['total_time'].append(time.time()-begin_time)
        if Learning_method != 'Supervised':
            test_acc_1, test_acc_5 = SSL_test(model_q, memory_loader, test_loader)
        else:
            test_acc_1, test_acc_5 = Supervised_test(model_q, test_loader)

        results['test_acc@1'].append(test_acc_1)
        results['test_acc@5'].append(test_acc_5)
        
        # save statistics
        if not os.path.exists('./results'):
            os.mkdir('./results')
        data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
        data_frame.to_csv('./results/{}_results.csv'.format(save_name_pre), index_label='epoch')
        # if test_acc_1 > best_acc:
        #     best_acc = test_acc_1
        #     torch.save(model_q.state_dict(), 'results/{}_model.pth'.format(save_name_pre))
        # save model every 50 epoch
        if epoch % 50 == 0:
            torch.save(model_q.state_dict(), './results/{}_model.pth'.format(save_name_pre))
    
    torch.save(model_q.state_dict(), './results/{}_model.pth'.format(save_name_pre)) 

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./dataset/cifar-10-python.tar.gz to ./dataset
Files already downloaded and verified
Files already downloaded and verified
running BarlowTwins
LR:0.376


Train Epoch: [1/3] Loss: 676.3565: 100%|██████████| 97/97 [05:54<00:00,  3.65s/it]
Feature extracting: 100%|██████████| 98/98 [00:37<00:00,  2.59it/s]
Test Epoch: [1/3] Acc@1:25.09% Acc@5:75.32%: 100%|██████████| 20/20 [00:13<00:00,  1.51it/s]


LR:0.12800000000000006


Train Epoch: [2/3] Loss: 627.9661: 100%|██████████| 97/97 [05:52<00:00,  3.63s/it]
Feature extracting: 100%|██████████| 98/98 [00:37<00:00,  2.60it/s]
Test Epoch: [2/3] Acc@1:25.49% Acc@5:78.22%: 100%|██████████| 20/20 [00:13<00:00,  1.51it/s]


LR:0.004000000000000001


Train Epoch: [3/3] Loss: 596.8053: 100%|██████████| 97/97 [05:53<00:00,  3.65s/it]
Feature extracting: 100%|██████████| 98/98 [00:37<00:00,  2.60it/s]
Test Epoch: [3/3] Acc@1:25.45% Acc@5:78.09%: 100%|██████████| 20/20 [00:13<00:00,  1.51it/s]


## 以下是讀取特徵抽取模型，並加上fully connected layer做Linear Evaluation的測試（給SSL做測試的方法！監督試學習不建議！）

In [None]:
# 讀取訓練完的模型，並且用Linear protocol方式評估模型好壞
model_path = ''
print(model_path)
epochs = 50
model_name = 'SSL_resnet18' # SSL_resnet50, SSL_resnet18, Supervised
save_folder = './results/'
batch_size = 256

# 準備資料集，可能要注意一下train_data要用怎樣的transform
train_data = datasets.CIFAR10(root='./dataset', train=True, transform=SSL_train_transform, download=True)
# train_data = datasets.ImageFolder(root='./cifar10_10percent/train', transform=train_transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
test_data = datasets.CIFAR10(root='./dataset', train=False, transform=test_transform, download=True)
# test_data = datasets.ImageFolder(root='./cifar10_10percent/test', transform=test_transform)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

In [None]:
class Net_18(nn.Module):
    def __init__(self, num_class, pretrained_path):
        super(Net_18, self).__init__()

        # encoder
        self.f = SSL_ResNet18().f
        # classifier
        self.fc = nn.Linear(512, num_class, bias=True)
        self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False)

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.fc(feature)
        return out

In [None]:
def train_evaluation(net, data_loader, train_optimizer):
    net.train()

    total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
    with torch.enable_grad():
        for data, target in data_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            
            out = net(data)
            loss = F.cross_entropy(out, target)
            
            train_optimizer.zero_grad()
            loss.backward()
            train_optimizer.step()

            total_num += data.size(0)
            total_loss += loss.item() * data.size(0)
            prediction = torch.argsort(out, dim=-1, descending=True)
            total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()

            data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%'
                                     .format('Train', epoch, epochs, total_loss / total_num,
                                             total_correct_1 / total_num * 100, total_correct_5 / total_num * 100))

    return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100



In [None]:
def test_evaluation(net, data_loader):
    net.eval()

    total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
    with torch.no_grad():
        for data, target in data_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            
            out = net(data)
            loss = F.cross_entropy(out, target)

            total_num += data.size(0)
            total_loss += loss.item() * data.size(0)
            prediction = torch.argsort(out, dim=-1, descending=True)
            total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()

            data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%'
                                     .format('Test', epoch, epochs, total_loss / total_num,
                                             total_correct_1 / total_num * 100, total_correct_5 / total_num * 100))

    return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100


In [None]:
if model_name == 'SSL_resnet50':
    model = Net_50(num_class=len(train_data.classes), pretrained_path=model_path).cuda()
    for param in model.f.parameters():
        param.requires_grad = False
elif model_name == 'SSL_resnet18':
    model = Net_18(num_class=len(train_data.classes), pretrained_path=model_path).cuda()
    for param in model.f.parameters():
        param.requires_grad = False
elif model_name == 'Supervised':
    model = Supervised_ResNet18().cuda()
    model.load_state_dict(torch.load(model_path), strict=False)
    for param in model.f.parameters():
        param.requires_grad = False
else:
    raise ValueError('model is either resnet50 or resnet18')


results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [], 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []}


optimizer = optim.Adam(model.fc.parameters(), lr=1e-2, weight_decay=1e-6)

loss_criterion = nn.CrossEntropyLoss()

best_acc = 0.0
for epoch in range(1, epochs + 1):
    train_loss, train_acc_1, train_acc_5 = train_evaluation(model, train_loader, optimizer)
    results['train_loss'].append(train_loss)
    results['train_acc@1'].append(train_acc_1)
    results['train_acc@5'].append(train_acc_5)
    test_loss, test_acc_1, test_acc_5 = test_evaluation(model, test_loader)
    results['test_loss'].append(test_loss)
    results['test_acc@1'].append(test_acc_1)
    results['test_acc@5'].append(test_acc_5)
    # save statistics
    data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
    data_frame.to_csv('./results/CIFAR10_linear_evaluation.csv', index_label='epoch')
    # if test_acc_1 > best_acc:
    #     best_acc = test_acc_1
    #     torch.save(model.state_dict(), 'results/CIFAR10_linear_model.pth')