# reference

https://blog.csdn.net/qq_43027065/article/details/118657728

In [1]:
# net.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50


# stage one ,unsupervised learning
class SimCLRStage1(nn.Module):
    def __init__(self, feature_dim=128):
        super(SimCLRStage1, self).__init__()

        self.f = []
        for name, module in resnet50().named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        # projection head
        self.g = nn.Sequential(nn.Linear(2048, 512, bias=False),
                               nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True),
                               nn.Linear(512, feature_dim, bias=True))

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)


# stage two ,supervised learning
class SimCLRStage2(torch.nn.Module):
    def __init__(self, num_class):
        super(SimCLRStage2, self).__init__()
        # encoder
        self.f = SimCLRStage1().f
        # classifier
        self.fc = nn.Linear(2048, num_class, bias=True)

        for param in self.f.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.fc(feature)
        return out


class Loss(torch.nn.Module):
    def __init__(self):
        super(Loss,self).__init__()

    def forward(self,out_1,out_2,batch_size,temperature=0.5):
        # 分母 ：X.X.T，再去掉对角线值，分析结果一行，可以看成它与除了这行外的其他行都进行了点积运算（包括out_1和out_2）,
        # 而每一行为一个batch的一个取值，即一个输入图像的特征表示，
        # 因此，X.X.T，再去掉对角线值表示，每个输入图像的特征与其所有输出特征（包括out_1和out_2）的点积，用点积来衡量相似性
        # 加上exp操作，该操作实际计算了分母
        # [2*B, D]
        out = torch.cat([out_1, out_2], dim=0)
        # [2*B, 2*B]
        sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
        mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()
        # [2*B, 2*B-1]
        sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)

        # 分子： *为对应位置相乘，也是点积
        # compute loss
        pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
        # [2*B]
        pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
        return (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()


if __name__=="__main__":
    for name, module in resnet50().named_children():
        print(name,module)



conv1 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
bn1 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
relu ReLU(inplace=True)
maxpool MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
layer1 Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05

In [2]:
# config.py
import os
from torchvision import transforms

use_gpu=True
gpu_name=1

pre_model=os.path.join('pth','model.pth')

save_path="pth"

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])


In [3]:
# loaddataset.py
from torchvision.datasets import CIFAR10
from PIL import Image

from utils.data_utils import ContrastivePairDataset

class PreDataset(CIFAR10):
    def __getitem__(self, item):
        img,target=self.data[item],self.targets[item]
        img = Image.fromarray(img)

        if self.transform is not None:
            imgL = self.transform(img)
            imgR = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return imgL, imgR, target

In [4]:
train_dataset = ContrastivePairDataset('cifar10',contrastive_transform = train_transform)
print(train_dataset[0])

Files already downloaded and verified
Files already downloaded and verified
(tensor([[[-0.0835, -0.1998, -0.3355,  ..., -0.3549, -0.2967, -0.3161],
         [-0.1804, -0.2580, -0.3355,  ..., -0.6457, -0.5487, -0.5293],
         [-0.3743, -0.3936, -0.4130,  ..., -0.6844, -0.6263, -0.6069],
         ...,
         [-0.4906, -0.4518, -0.3355,  ...,  0.1879,  0.0716, -0.0447],
         [-0.3549, -0.3355, -0.2386,  ...,  0.0910, -0.0641, -0.1998],
         [-0.3161, -0.2967, -0.1998,  ..., -0.0641, -0.1610, -0.3161]],

        [[-0.0386, -0.1566, -0.2942,  ..., -0.3139, -0.2549, -0.2746],
         [-0.1369, -0.2156, -0.2942,  ..., -0.6089, -0.5106, -0.4909],
         [-0.3336, -0.3532, -0.3729,  ..., -0.6482, -0.5892, -0.5696],
         ...,
         [-0.4516, -0.4122, -0.2942,  ...,  0.2368,  0.1188,  0.0008],
         [-0.3139, -0.2942, -0.1959,  ...,  0.1384, -0.0189, -0.1566],
         [-0.2746, -0.2549, -0.1566,  ..., -0.0189, -0.1172, -0.2746]],

        [[ 0.1394,  0.0223, -0.1143,  .

In [5]:
save_path = "/remote-home/songtianwei/research/unlearn_multimodal/output/unlearn_self_supervised"
batch_size = 400

In [8]:
# trainstage1.py
import torch,argparse,os


# train stage one
def train():
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda:" + str("1"))
        # 每次训练计算图改动较小使用，在开始前选取较优的基础算法（比如选择一种当前高效的卷积算法）
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    train_data=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True, num_workers=16 , drop_last=True)

    model = SimCLRStage1().to(DEVICE)
    lossLR= Loss().to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

    os.makedirs(save_path, exist_ok=True)
    for epoch in range(1,1000+1):
        model.train()
        total_loss = 0
        for batch,(imgL,imgR,labels) in enumerate(train_data):
            imgL,imgR,labels=imgL.to(DEVICE),imgR.to(DEVICE),labels.to(DEVICE)

            _, pre_L=model(imgL)
            _, pre_R=model(imgR)

            loss=lossLR(pre_L,pre_R,batch_size)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print("epoch", epoch, "batch", batch, "loss:", loss.detach().item())
            total_loss += loss.detach().item()

        print("epoch loss:",total_loss/len(train_dataset)*batch_size)

        with open(os.path.join(save_path, "stage1_loss.txt"), "a") as f:
            f.write(str(total_loss/len(train_dataset)*batch_size) + " ")

        if epoch % 5==0:
            torch.save(model.state_dict(), os.path.join(save_path, 'model_stage1_epoch' + str(epoch) + '.pth'))



In [9]:
train()

current deveice: cuda:1
epoch 1 batch 0 loss: 6.660226345062256
epoch 1 batch 1 loss: 6.676202297210693
epoch 1 batch 2 loss: 6.687729358673096
epoch 1 batch 3 loss: 6.653592586517334
epoch 1 batch 4 loss: 6.638753414154053
epoch 1 batch 5 loss: 6.596418380737305
epoch 1 batch 6 loss: 6.563289642333984
epoch 1 batch 7 loss: 6.5818610191345215
epoch 1 batch 8 loss: 6.543348789215088
epoch 1 batch 9 loss: 6.51478910446167
epoch 1 batch 10 loss: 6.489308834075928
epoch 1 batch 11 loss: 6.487301349639893
epoch 1 batch 12 loss: 6.49058198928833
epoch 1 batch 13 loss: 6.386373043060303
epoch 1 batch 14 loss: 6.417940616607666
epoch 1 batch 15 loss: 6.422379016876221
epoch 1 batch 16 loss: 6.345763206481934
epoch 1 batch 17 loss: 6.446457386016846
epoch 1 batch 18 loss: 6.408658504486084
epoch 1 batch 19 loss: 6.421524524688721
epoch 1 batch 20 loss: 6.430500507354736
epoch 1 batch 21 loss: 6.350765228271484
epoch 1 batch 22 loss: 6.31705904006958
epoch 1 batch 23 loss: 6.346568584442139
epoc

KeyboardInterrupt: 

In [None]:
# trainstage2.py
import torch,argparse,os
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader


pre_model_path = "/remote-home/songtianwei/research/unlearn_multimodal/output/unlearn_test_self_supervised/model_stage1_epoch50.pth"


# train stage two
def train_stage2():
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda:" + str(2))   #config.gpu_name
        # 每次训练计算图改动较小使用，在开始前选取较优的基础算法（比如选择一种当前高效的卷积算法）
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    # load dataset for train and eval
    train_dataset = CIFAR10(root='dataset', train=True, transform=train_transform, download=True)
    train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
    eval_dataset = CIFAR10(root='dataset', train=False, transform=test_transform, download=True)
    eval_data = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)

    model =net.SimCLRStage2(num_class=len(train_dataset.classes)).to(DEVICE)
    model.load_state_dict(torch.load(args.pre_model, map_location='cpu'),strict=False)
    loss_criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6)

    os.makedirs(save_path, exist_ok=True)
    for epoch in range(1,max_epoch+1):
        model.train()
        total_loss=0
        for batch, (data, target) in enumerate(train_data):
            data, target = data.to(DEVICE), target.to(DEVICE)
            pred = model(data)

            loss = loss_criterion(pred, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print("epoch",epoch,"loss:", total_loss / len(train_dataset)*args.batch_size)
        with open(os.path.join(save_path, "stage2_loss.txt"), "a") as f:
            f.write(str(total_loss / len(train_dataset)*args.batch_size) + " ")

        if epoch % 5==0:
            torch.save(model.state_dict(), os.path.join(save_path, 'model_stage2_epoch' + str(epoch) + '.pth'))

            model.eval()
            with torch.no_grad():
                print("batch", " " * 1, "top1 acc", " " * 1, "top5 acc")
                total_loss, total_correct_1, total_correct_5, total_num = 0.0, 0.0, 0.0, 0
                for batch, (data, target) in enumerate(train_data):
                    data, target = data.to(DEVICE), target.to(DEVICE)
                    pred = model(data)

                    total_num += data.size(0)
                    prediction = torch.argsort(pred, dim=-1, descending=True)
                    top1_acc = torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                    top5_acc = torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                    total_correct_1 += top1_acc
                    total_correct_5 += top5_acc

                    print("  {:02}  ".format(batch + 1), " {:02.3f}%  ".format(top1_acc / data.size(0) * 100),
                          "{:02.3f}%  ".format(top5_acc / data.size(0) * 100))

                print("all eval dataset:", "top1 acc: {:02.3f}%".format(total_correct_1 / total_num * 100),
                          "top5 acc:{:02.3f}%".format(total_correct_5 / total_num * 100))
                with open(os.path.join(save_path, "stage2_top1_acc.txt"), "a") as f:
                    f.write(str(total_correct_1 / total_num * 100) + " ")
                with open(os.path.join(save_path, "stage2_top5_acc.txt"), "a") as f:
                    f.write(str(total_correct_5 / total_num * 100) + " ")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train SimCLR')
    parser.add_argument('--batch_size', default=200, type=int, help='')
    parser.add_argument('--max_epoch', default=200, type=int, help='')
    parser.add_argument('--pre_model', default=config.pre_model, type=str, help='')

    args = parser.parse_args()
    train(args)
