1. 训练对比学习的特征提取器
2. 训练分类问题的监督

In [None]:
import setpath
from torch.utils.data import DataLoader
from prepare.eegdataset import N_Mix_GeneralEEGImageDataset, MySubset
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from run.resnet import TesNet
import torch.nn.functional as F
from run.start import get_device
from prepare.show import plot_tsne
from prepare.show import get_material_dir


In [None]:
data_path = '/data0/tianjunchao/dataset/CVPR2021-02785/data/img_pkl/32x32'
train_transforms = transforms.Compose([
    transforms.Normalize([0.512, 0.512, 0.512], [0.228, 0.228, 0.228]),
])
dataset = N_Mix_GeneralEEGImageDataset(path=data_path, n_channels=1, grid_size=8, n_samples=1)
dataset = MySubset(dataset, range(len(dataset)), train_transforms)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=3, prefetch_factor=2)
device = get_device()

In [None]:
model = TesNet(model_name='resnet50').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
def nt_xent_loss(emb1, emb2, temperature=0.5):
    batch_size = emb1.size(0)
    emb = torch.cat([emb1, emb2], dim=0)
    emb_norm = F.normalize(emb, dim=1)

    sim_matrix = torch.mm(emb_norm, emb_norm.t().contiguous())
    sim_matrix.fill_diagonal_(-float('inf'))  # 将对角线填充为负无穷，以避免计算相似度时取到自身

    # 计算相似度矩阵中的最大值，用于数值稳定性
    max_sim = torch.max(sim_matrix, dim=1).values.view(-1, 1)

    # 计算分母中的e^(s_ij / t)，并减去最大相似度，以防止数值溢出
    sim_matrix = torch.exp((sim_matrix - max_sim) / temperature)


    pos_sim = torch.cat([sim_matrix[i, i + batch_size].unsqueeze(0) for i in range(batch_size)])
    neg_sim = torch.cat([sim_matrix[i, i - batch_size].unsqueeze(0) for i in range(batch_size, 2 * batch_size)])

    loss_matrix = -torch.log(pos_sim / (pos_sim + neg_sim))

    # 计算损失
    # loss_matrix = -torch.log(numerators / sums)
    loss = loss_matrix.sum() / (2 * batch_size)

    return loss

In [None]:
epochs = 100
for epoch in range(epochs):
    model.train()
    for i, (x, y) in enumerate(dataloader):
        # 清空优化器的梯度
        optimizer.zero_grad()

        dx = torch.stack(x, dim=1)
        _,emb1 = model(dx[:, 0, :, :, :].to(device))
        _,emb2 = model(dx[:, 1, :, :, :].to(device))
        # emb2 = torch.ones_like(emb1).to(device)

        # emb2 = torch.rand_like(emb1).to(device)
        # emb1 = torch.zeros_like(emb2).to(device)
        loss = nt_xent_loss(emb1, emb2)
        
        loss.backward()
        optimizer.step()
        
        if i % (len(dataloader)//5) == 0:
            print('epoch: {}, loss:%.3f'.format(epoch) % loss.item())
    plot_tsne(model.feature_extractor, dataloader, device, 'resnet18', target='train epoch'+str(epoch),n_samples=1000,material_dir=get_material_dir())

