ref simsiam implementation

https://github.com/facebookresearch/simsiam/blob/main/main_simsiam.py

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
from RandomCompositeTransformation import RandomCompositeTransformation as CompositeAugmenter
import time
import random
from textattack.augmentation import Augmenter
from textattack.transformations import WordSwapWordNet,WordSwapEmbedding, WordInnerSwapRandom, WordDeletion, BackTranslation, WordSwapExtend, WordSwapRandomCharacterSubstitution, WordSwapHomoglyphSwap, WordSwapRandomCharacterInsertion
from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
from textattack.constraints.semantics import WordEmbeddingDistance
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

In [2]:
class TextDataset(Dataset):
    """Dataset for loading text data with optional augmentation."""
    def __init__(self, texts, labels=None, tokenizer=None, max_length=128, augment=False, augmenter=None):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augment = augment
        self.augmenter = augmenter

        if self.augment:
            self.text_views1 = []
            self.text_views2 = []
            index = 0
            for text in self.texts:
                index += 1
                start_time = time.time()
                # Generate two augmented views for each text
                augmented_text1 = self.text_augment(text)

                while True:
                    augmented_text2 = self.text_augment(text)
                    if augmented_text1 != augmented_text2:
                        break
                    print('augmented_text1 equals to augmented_text2', augmented_text1)

                self.text_views1.append(augmented_text1)
                self.text_views2.append(augmented_text2)

                spent = time.time() - start_time
                print('process text', index, 'cost', spent, 'seconds')

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]

        if self.augment:
            text_view1 = self.text_views1[idx]
            text_view2 = self.text_views2[idx]
        else:
            text_view1 = text

        inputs1 = self.tokenizer(
            text_view1,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        inputs1 = {key: val.squeeze(0) for key, val in inputs1.items()}

        if self.augment:
            inputs2 = self.tokenizer(
                text_view2,
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt",
            )
            inputs2 = {key: val.squeeze(0) for key, val in inputs2.items()}
            return inputs1, inputs2

        if self.labels is not None:
            label = self.labels[idx]
            return inputs1, label
        return inputs1

    def text_augment(self, text):
        """text augmentation."""
        augmented_texts = self.augmenter.augment(text)
        return random.choice(augmented_texts)

class SimSiamText(nn.Module):
    def __init__(self, base_encoder, dim=2048, pred_dim=512):
        super(SimSiamText, self).__init__()
        self.encoder = base_encoder
        self.projector = nn.Sequential(
            nn.Linear(768, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=True),
        )
        self.predictor = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True),
            nn.Linear(pred_dim, dim, bias=True),
        )
        self.regressor = nn.Linear(768, 1)  # 线性回归头

    def forward(self, x1, x2=None, regression=False):
        """
          定义模型的前向传播过程
          x1 和 x2 是同一批图像的两个不同增强视图。
          z1 和 z2 是编码器对 x1 和 x2 的编码结果。
          p1 和 p2 是预测器对 z1 和 z2 的预测结果。
        """
        if regression:
            return self.regressor(self.encoder(**x1)["pooler_output"])  # 仅用于回归
        z1 = self.projector(self.encoder(**x1)["pooler_output"])
        z2 = self.projector(self.encoder(**x2)["pooler_output"])
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        # z1.detach() 和 z2.detach() 表示在反向传播时不计算 z1 和 z2 的梯度，因为它们只作为目标使用。
        return p1, p2, z1.detach(), z2.detach()
    
    def remove_projection_head(self):
        """移除 projector 和 predictor 用于微调"""
        self.projector = None
        self.predictor = None


def train_simsiam(model, dataloader, criterion, optimizer, device, print_freq=10):
    """训练 SimSiam 并打印训练进度和损失"""
    model.train()
    for epoch in range(1):  # 预训练 10 轮
        epoch_loss = 0.0
        start_time = time.time()

        for batch_idx, batch in enumerate(dataloader):
            inputs1 = {key: val.to(device) for key, val in batch[0].items()}
            inputs2 = {key: val.to(device) for key, val in batch[1].items()}

            p1, p2, z1, z2 = model(inputs1, inputs2)
            loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % print_freq == 0:
                print(f"Epoch [{epoch+1}/10], Step [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}")

        avg_loss = epoch_loss / len(dataloader)
        elapsed_time = time.time() - start_time
        print(f"Epoch [{epoch+1}/10] Completed | Avg Loss: {avg_loss:.4f} | Time: {elapsed_time:.2f}s\n")

def fine_tune(model, dataloader, criterion, optimizer, device, print_freq=10):
    """微调 SimSiam 进行回归任务，并打印训练进度和损失"""
    model.train()
    for epoch in range(2):  # 微调 5 轮
        epoch_loss = 0.0
        start_time = time.time()

        for batch_idx, batch in enumerate(dataloader):
            inputs = {key: val.to(device) for key, val in batch[0].items()}
            labels = batch[1].to(device).float()

            preds = model(inputs, regression=True).squeeze()  # 取回归输出
            loss = criterion(preds, labels)
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % print_freq == 0:
                print(f"Fine-tune Epoch [{epoch+1}/5], Step [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}")

        avg_loss = epoch_loss / len(dataloader)
        elapsed_time = time.time() - start_time
        print(f"Fine-tune Epoch [{epoch+1}/5] Completed | Avg Loss: {avg_loss:.4f} | Time: {elapsed_time:.2f}s\n")

def evaluate(model, dataloader, device):
    """评估模型在测试集上的 MAE 误差"""
    model.eval()
    total_mae = 0.0
    num_samples = 0

    with torch.no_grad():
        for batch in dataloader:
            inputs = {key: val.to(device) for key, val in batch[0].items()}
            labels = batch[1].to(device)

            preds = model(inputs, regression=True).squeeze()
            mae = torch.abs(preds - labels).sum().item()  # 计算 MAE
            total_mae += mae
            num_samples += labels.size(0)

    avg_mae = total_mae / num_samples
    print(f"Evaluation - MAE: {avg_mae:.4f}")

# 主运行逻辑
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    bert_model = BertModel.from_pretrained("bert-base-uncased")

    # 定义不同的增强方法
    wordnet_transformation = WordSwapWordNet()
    #embedding_transformation = WordSwapEmbedding(max_candidates=5)
    backtranslate_transformation = BackTranslation(chained_back_translation=2)
    #extendword_transformation = WordSwapExtend()
    randomwordsubs_transformation = WordSwapRandomCharacterSubstitution()
    homoglyphswap_transformation = WordSwapHomoglyphSwap()
    randomcharinsert_transformation = WordSwapRandomCharacterInsertion()

    # 组合多个增强方法，并指定执行概率
    random_composite_transformation = CompositeAugmenter(
        transformations=[
            backtranslate_transformation, 
            homoglyphswap_transformation, 
            wordnet_transformation, 
            randomwordsubs_transformation,
            randomcharinsert_transformation
        ],
        probabilities=[1, 0.5, 0.5, 0.1, 0.1]  # 执行概率
    )

    # 定义约束，避免对停用词进行修改，防止重复修改 , StopwordModification()
    constraints = [RepeatModification()]

    # 语义相似性约束
    #semantic_constraint = WordEmbeddingDistance(min_cos_sim=0.8)
    #constraints.append(semantic_constraint)

    # 创建增强器
    text_augmenter = Augmenter(
        transformation=random_composite_transformation,
        constraints=constraints,
        pct_words_to_swap=0.1,
        transformations_per_example=3  # 生成 3 个不同版本的增强文本
    )

    # 加载数据
    data = pd.read_csv('data/mes_all.csv')
    data['text'] = data['title'] + ' ' + data['description']
    texts = data['text'].value.to_list()

    # 预训练数据集
    dataset = TextDataset(texts, tokenizer=tokenizer, augment=True, max_length=512, augmenter=text_augmenter)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # 初始化模型
    simsiam = SimSiamText(base_encoder=bert_model).to(device)

    # 预训练 SimSiam
    criterion = nn.CosineSimilarity(dim=1).to(device)
    optimizer = torch.optim.Adam(simsiam.parameters(), lr=3e-4)
    train_simsiam(simsiam, dataloader, criterion, optimizer, device)

    # **移除 projector 和 predictor**
    simsiam.remove_projection_head()

    # 划分数据集
    labeled_data = data[data['storypoint'] != -1]
    split_idx = int(len(labeled_data) * 0.8)

    fine_tune_data = labeled_data.iloc[:split_idx]

    fine_tune_texts = fine_tune_data['text'].values.to_list()
    fine_tune_labels = fine_tune_data['storypoint'].values.to_list()

    # 微调（回归任务）
    labeled_dataset = TextDataset(fine_tune_texts, fine_tune_labels, tokenizer=tokenizer, augment=False)
    labeled_dataloader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)

    regression_criterion = nn.MSELoss().to(device)
    optimizer = torch.optim.Adam(simsiam.parameters(), lr=3e-4)
    fine_tune(simsiam, labeled_dataloader, regression_criterion, optimizer, device)

    test_data = labeled_data.iloc[split_idx:]
    test_texts = test_data['text'].values.to_list()
    test_labels = test_data['storypoint'].values.to_list()

    test_dataset = TextDataset(test_texts, test_labels, tokenizer=tokenizer, augment=False)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # 评估模型
    evaluate(simsiam, test_dataloader, device)

In [3]:
# main()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
# 定义不同的增强方法
#wordnet_transformation = WordSwapWordNet()
embedding_transformation = WordSwapEmbedding(max_candidates=3)
#backtranslate_transformation = BackTranslation() # chained_back_translation=2
worddelete_transformation = WordDeletion()
extendword_transformation = WordSwapExtend()
randomwordsubs_transformation = WordSwapRandomCharacterSubstitution()
homoglyphswap_transformation = WordSwapHomoglyphSwap()
randomcharinsert_transformation = WordSwapRandomCharacterInsertion()

# 组合多个增强方法，并指定执行概率
random_composite_transformation = CompositeAugmenter(
    transformations=[
        extendword_transformation,
        homoglyphswap_transformation, 
        embedding_transformation, 
        randomwordsubs_transformation,
        randomcharinsert_transformation,
    ],
    probabilities=[0.8, 0.5, 0.5, 0.1, 0.1]  # 执行概率
)

# 定义约束，避免对停用词进行修改，防止重复修改
constraints = [RepeatModification(), StopwordModification()]

# 语义相似性约束
#semantic_constraint = WordEmbeddingDistance(min_cos_sim=0.8)
#constraints.append(semantic_constraint)

# 创建增强器
text_augmenter = Augmenter(
    transformation=random_composite_transformation,
    constraints=constraints,
    pct_words_to_swap=0.1,
    transformations_per_example=3  # 生成 n 个不同版本的增强文本
)

In [4]:
# 加载数据
data = pd.read_csv('./data/mes_all.csv')
data['description'] = data['description'].fillna('')
data.dropna(inplace=True)
data['text'] = data['title']
texts = data['text'].values.tolist()

In [5]:
# 预训练数据集
dataset = TextDataset(texts, tokenizer=tokenizer, augment=True, max_length=128, augmenter=text_augmenter)

process text 1 cost 0.008417606353759766 seconds
process text 2 cost 0.012782573699951172 seconds
process text 3 cost 0.0042765140533447266 seconds
process text 4 cost 0.006949186325073242 seconds
process text 5 cost 0.004137754440307617 seconds
augmented_text1 equals to augmented_text2 View Application Print list
process text 6 cost 0.009031295776367188 seconds
process text 7 cost 0.0028998851776123047 seconds
augmented_text1 equals to augmented_text2 Filter usеrs
process text 8 cost 0.003038167953491211 seconds
process text 9 cost 0.00994420051574707 seconds
process text 10 cost 0.005281209945678711 seconds
process text 11 cost 0.004038810729980469 seconds
process text 12 cost 0.013895034790039062 seconds
process text 13 cost 0.006384849548339844 seconds
process text 14 cost 0.0030281543731689453 seconds
process text 15 cost 0.0033349990844726562 seconds
process text 16 cost 0.0 seconds
process text 17 cost 0.01142263412475586 seconds
process text 18 cost 0.0021080970764160156 second

In [6]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [7]:
# 初始化模型
simsiam = SimSiamText(base_encoder=bert_model).to(device)

# 预训练 SimSiam
criterion = nn.CosineSimilarity(dim=1).to(device)
optimizer = torch.optim.Adam(simsiam.parameters(), lr=3e-5)
torch.cuda.empty_cache()
train_simsiam(simsiam, dataloader, criterion, optimizer, device)

Epoch [1/10], Step [0/146], Loss: -0.0018
Epoch [1/10], Step [10/146], Loss: -0.1428
Epoch [1/10], Step [20/146], Loss: -0.2051
Epoch [1/10], Step [30/146], Loss: -0.3784
Epoch [1/10], Step [40/146], Loss: -0.4995
Epoch [1/10], Step [50/146], Loss: -0.6043
Epoch [1/10], Step [60/146], Loss: -0.6651
Epoch [1/10], Step [70/146], Loss: -0.6484
Epoch [1/10], Step [80/146], Loss: -0.7270
Epoch [1/10], Step [90/146], Loss: -0.7281
Epoch [1/10], Step [100/146], Loss: -0.7923
Epoch [1/10], Step [110/146], Loss: -0.8425
Epoch [1/10], Step [120/146], Loss: -0.8417
Epoch [1/10], Step [130/146], Loss: -0.8392
Epoch [1/10], Step [140/146], Loss: -0.8538
Epoch [1/10] Completed | Avg Loss: -0.6105 | Time: 347.76s



In [8]:
# 移除 projector 和 predictor
simsiam.remove_projection_head()

# 划分数据集
labeled_data = data[data['storypoint'] != -1]
split_idx = int(len(labeled_data) * 0.8)

fine_tune_data = labeled_data.iloc[:split_idx]

fine_tune_texts = fine_tune_data['text'].values.tolist()
fine_tune_labels = fine_tune_data['storypoint'].values.tolist()

# 微调（回归任务）
labeled_dataset = TextDataset(fine_tune_texts, fine_tune_labels, tokenizer=tokenizer, augment=False)
labeled_dataloader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)

regression_criterion = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(simsiam.parameters(), lr=3e-4)
fine_tune(simsiam, labeled_dataloader, regression_criterion, optimizer, device)

Fine-tune Epoch [1/5], Step [0/30], Loss: 14.0251
Fine-tune Epoch [1/5], Step [10/30], Loss: 3.0063
Fine-tune Epoch [1/5], Step [20/30], Loss: 3.6303
Fine-tune Epoch [1/5] Completed | Avg Loss: 5.4145 | Time: 18.83s

Fine-tune Epoch [2/5], Step [0/30], Loss: 2.6315
Fine-tune Epoch [2/5], Step [10/30], Loss: 7.3142
Fine-tune Epoch [2/5], Step [20/30], Loss: 3.5442
Fine-tune Epoch [2/5] Completed | Avg Loss: 4.8970 | Time: 18.96s



In [9]:
test_data = labeled_data.iloc[split_idx:]
test_texts = test_data['text'].values.tolist()
test_labels = test_data['storypoint'].values.tolist()

test_dataset = TextDataset(test_texts, test_labels, tokenizer=tokenizer, augment=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 评估模型
evaluate(simsiam, test_dataloader, device)

Evaluation - MAE: 1.5518
