In [None]:
import torch
import torch.backends.cudnn as cudnn
from torchvision import models
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset
from models.resnet_simclr import ResNetSimCLR
from simclr import SimCLR

In [None]:
import argparse
args = argparse.Namespace()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.batch_size = 4
args.n_views = 2
args.temperature = 0.07
args.fp16_precision = False
args.arch = 'resnet18'
args.log_every_n_steps = 100
args.epochs = 1
args.disable_cuda = not torch.cuda.is_available()
args

In [None]:
dataset = ContrastiveLearningDataset("./datasets")
train_dataset = dataset.get_dataset("stl10", args.n_views)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True)

In [None]:
# 同一张图像经过两次随机的数据增强得到的正样本对
print(train_dataset[0][0][0].shape)
print(train_dataset[0][0][1].shape)
# unlabeled数据，label无意义
print(train_dataset[0][1])

from torchvision.transforms.functional import to_pil_image
from IPython.display import display
# 调整张量维度以符合 matplotlib 的期望格式：[高度, 宽度, 通道数]
image_pil = to_pil_image(train_dataset[0][0][0])
display(image_pil)
image_pil = to_pil_image(train_dataset[0][0][1])
display(image_pil)

In [None]:
model = ResNetSimCLR(base_model="resnet18", out_dim=128)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)

gpu_index = 0
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    gpu_index = -1
#  It’s a no-op if the 'gpu_index' argument is a negative integer or None.
with torch.cuda.device(gpu_index):
    simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
    simclr.train(train_loader)

In [None]:
import torch.nn.functional as F
def info_nce_loss(features):

    labels = torch.cat([torch.arange(args.batch_size) for i in range(args.n_views)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels.to(args.device)

    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)
    # assert similarity_matrix.shape == (
    #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
    # assert similarity_matrix.shape == labels.shape

    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(args.device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    # assert similarity_matrix.shape == labels.shape

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    # select only the negatives the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(args.device)

    logits = logits / args.temperature
    return logits, labels

In [None]:
features = torch.randn(args.batch_size * 2, 128)
logits, labels = info_nce_loss(features)
print(logits)
print(labels)

In [None]:
features = torch.randn(args.batch_size * 2, 128)

labels = torch.cat([torch.arange(args.batch_size) for i in range(args.n_views)], dim=0)
print(labels.shape)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to(args.device)
print('labels:', labels.shape)
print(labels)
features = F.normalize(features, dim=1)

similarity_matrix = torch.matmul(features, features.T)
# assert similarity_matrix.shape == (
#     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
# assert similarity_matrix.shape == labels.shape
print('similarity_matrix:', similarity_matrix.shape)
print(similarity_matrix)

# discard the main diagonal from both: labels and similarities matrix
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(args.device)
print(mask)
labels = labels[~mask].view(labels.shape[0], -1)
print(labels)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
print(similarity_matrix)
# assert similarity_matrix.shape == labels.shape

# select and combine multiple positives
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
print(positives)

# select only the negatives the negatives
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
print(negatives)

logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(args.device)
print(logits)
print(labels)

logits = logits / args.temperature