ref simsiam implementation

https://github.com/facebookresearch/simsiam/blob/main/main_simsiam.py

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
from RandomCompositeTransformation import RandomCompositeTransformation as CompositeAugmenter
import time
import random
from textattack.augmentation import Augmenter
from textattack.transformations import WordSwapWordNet,WordSwapEmbedding, WordInnerSwapRandom, WordDeletion, BackTranslation, WordSwapExtend, WordSwapRandomCharacterSubstitution, WordSwapHomoglyphSwap, WordSwapRandomCharacterInsertion
from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
from textattack.constraints.semantics import WordEmbeddingDistance
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

In [None]:
# global parameters
TRAIN_EPOCH = 10
FINE_TUNE_EPOCH = 5
TRAIN_BATCH = 16
FINE_TUNE_BATCH = 16
LEARNING_RATE = 2e-5
PROJECT = 'mes_all'
SAVE_PATH = f'/kaggle/working/models/best_model_{PROJECT}.pth'

PRINT_TEXT_PROCESS_TIME = True
AUGMENT_TEXT_LEN = 200
MAX_LEN = 512

In [None]:
class TextDataset(Dataset):
    """Dataset for loading text data with optional augmentation."""
    def __init__(self, texts, labels=None, tokenizer=None, max_length=128, augment=False, augmenter=None):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augment = augment
        self.augmenter = augmenter               

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]

        if self.augment:
            augmented_texts = self.text_augment(text)
            text_view1 = augmented_texts[0]
            text_view2 = augmented_texts[1]
        else:
            text_view1 = text

        inputs1 = self.tokenizer(
            text_view1,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        inputs1 = {key: val.squeeze(0) for key, val in inputs1.items()}

        if self.augment:
            inputs2 = self.tokenizer(
                text_view2,
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt",
            )
            inputs2 = {key: val.squeeze(0) for key, val in inputs2.items()}
            return inputs1, inputs2

        if self.labels is not None:
            label = self.labels[idx]
            return inputs1, label
        return inputs1

    def text_augment(self, text):
        """text augmentation."""
        start_time = time.time()

        if len(text) > AUGMENT_TEXT_LEN:
            text_to_augment = text[:AUGMENT_TEXT_LEN]
            remaining_text = text[AUGMENT_TEXT_LEN:]
        else:
            text_to_augment = text
            remaining_text = ''
        
        augmented_texts = self.augmenter.augment(text_to_augment)
        while True:
            if augmented_texts[0] != augmented_texts[1]:
                break
            print('augmented_text1 equals to augmented_text2', augmented_texts[0])
            augmented_texts = self.augmenter.augment(text)
        spent = time.time() - start_time
        if PRINT_TEXT_PROCESS_TIME:
            print('process text cost', spent, 'seconds')
        return (augmented_texts[0] + remaining_text, augmented_texts[1] + remaining_text)

class SimSiamText(nn.Module):
    def __init__(self, base_encoder, dim=2048, pred_dim=512):
        super(SimSiamText, self).__init__()
        self.encoder = base_encoder
        self.projector = nn.Sequential(
            nn.Linear(768, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=False),
            nn.BatchNorm1d(dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, dim, bias=True),
        )
        self.predictor = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True),
            nn.Linear(pred_dim, dim, bias=True),
        )
        self.regressor = nn.Linear(768, 1)  # 线性回归头

    def forward(self, x1, x2=None, regression=False):
        """
          定义模型的前向传播过程
          x1 和 x2 是同一批图像的两个不同增强视图。
          z1 和 z2 是编码器对 x1 和 x2 的编码结果。
          p1 和 p2 是预测器对 z1 和 z2 的预测结果。
        """
        if regression:
            return self.regressor(self.encoder(**x1)["pooler_output"])  # 仅用于回归
        z1 = self.projector(self.encoder(**x1)["pooler_output"])
        z2 = self.projector(self.encoder(**x2)["pooler_output"])
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        # z1.detach() 和 z2.detach() 表示在反向传播时不计算 z1 和 z2 的梯度，因为它们只作为目标使用。
        return p1, p2, z1.detach(), z2.detach()
    
    def remove_projection_head(self):
        """移除 projector 和 predictor 用于微调"""
        self.projector = None
        self.predictor = None


def train_simsiam(model, dataloader, criterion, optimizer, scheduler, device, print_freq=10):
    """训练 SimSiam 并打印训练进度和损失"""
    model.train()
    best_loss = float('inf')
    for epoch in range(TRAIN_EPOCH):  # 预训练
        epoch_loss = 0.0
        start_time = time.time()

        for batch_idx, batch in enumerate(dataloader):
            inputs1 = {key: val.to(device) for key, val in batch[0].items()}
            inputs2 = {key: val.to(device) for key, val in batch[1].items()}

            p1, p2, z1, z2 = model(inputs1, inputs2)
            loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % print_freq == 0:
                print(f"Epoch [{epoch+1}/{TRAIN_EPOCH}], Step [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}")

        avg_loss = epoch_loss / len(dataloader)
        elapsed_time = time.time() - start_time
        print(f"Epoch [{epoch+1}/{TRAIN_EPOCH}] Completed | Avg Loss: {avg_loss:.4f} | Time: {elapsed_time:.2f}s\n")

        # 保存最好的模型
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), SAVE_PATH)  # 保存当前最好的模型

        # 更新学习率
        scheduler.step()

def fine_tune_and_eval(model, labeled_data, tokenizer, criterion, optimizer, scheduler, device, print_freq=10):
    """微调 SimSiam 进行回归任务，并打印训练进度和损失"""
    split_idx = int(len(labeled_data) * 0.8)

    fine_tune_data = labeled_data.iloc[:split_idx]

    fine_tune_texts = fine_tune_data['text'].values.tolist()
    fine_tune_labels = fine_tune_data['storypoint'].values.tolist()

    # 微调（回归任务）
    labeled_dataset = TextDataset(fine_tune_texts, fine_tune_labels, tokenizer=tokenizer, max_length=MAX_LEN, augment=False)
    labeled_dataloader = DataLoader(labeled_dataset, batch_size=FINE_TUNE_BATCH, shuffle=True)

    test_data = labeled_data.iloc[split_idx:]
    test_texts = test_data['text'].values.tolist()
    test_labels = test_data['storypoint'].values.tolist()

    test_dataset = TextDataset(test_texts, test_labels, tokenizer=tokenizer, max_length=MAX_LEN, augment=False)
    test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

    model.train()
    for epoch in range(FINE_TUNE_EPOCH):  # 微调
        epoch_loss = 0.0
        start_time = time.time()

        for batch_idx, batch in enumerate(labeled_dataloader):
            inputs = {key: val.to(device) for key, val in batch[0].items()}
            labels = batch[1].to(device).float()

            preds = model(inputs, regression=True).squeeze()  # 取回归输出
            loss = criterion(preds, labels)
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if batch_idx % print_freq == 0:
                print(f"Fine-tune Epoch [{epoch+1}/{FINE_TUNE_EPOCH}], Step [{batch_idx}/{len(labeled_dataloader)}], Loss: {loss.item():.4f}")

        avg_loss = epoch_loss / len(labeled_dataloader)
        elapsed_time = time.time() - start_time
        print(f"Fine-tune Epoch [{epoch+1}/{FINE_TUNE_EPOCH}] Completed | Avg Loss: {avg_loss:.4f} | Time: {elapsed_time:.2f}s\n")

        # 评估模型
        evaluate(model, test_dataloader, device)

def evaluate(model, dataloader, device):
    """评估模型在测试集上的 MAE 误差"""
    model.eval()
    total_mae = 0.0
    num_samples = 0

    with torch.no_grad():
        for batch in dataloader:
            inputs = {key: val.to(device) for key, val in batch[0].items()}
            labels = batch[1].to(device)

            preds = model(inputs, regression=True).squeeze()
            mae = torch.abs(preds - labels).sum().item()
            total_mae += mae
            num_samples += labels.size(0)

    avg_mae = total_mae / num_samples
    print(f"Evaluation - MAE: {avg_mae:.4f}")

In [None]:
# main()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
# 定义不同的增强方法
#wordnet_transformation = WordSwapWordNet()
embedding_transformation = WordSwapEmbedding(max_candidates=5)
#backtranslate_transformation = BackTranslation(chained_back_translation=2)
extendword_transformation = WordSwapExtend()
randomwordsubs_transformation = WordSwapRandomCharacterSubstitution()
homoglyphswap_transformation = WordSwapHomoglyphSwap()
randomcharinsert_transformation = WordSwapRandomCharacterInsertion()

# 组合多个增强方法，并指定执行概率
random_composite_transformation = CompositeAugmenter(
    transformations=[
        extendword_transformation, 
        homoglyphswap_transformation, 
        embedding_transformation, 
        randomwordsubs_transformation,
        randomcharinsert_transformation
    ],
    probabilities=[1, 0.5, 0.5, 0.2, 0.2]  # 执行概率
)

# 定义约束，避免对停用词进行修改，防止重复修改
constraints = [RepeatModification(), StopwordModification()]

# 语义相似性约束
# semantic_constraint = WordEmbeddingDistance(min_cos_sim=0.8)
# constraints.append(semantic_constraint)

# 创建增强器
text_augmenter = Augmenter(
    transformation=random_composite_transformation,
    constraints=constraints,
    pct_words_to_swap=0.1,
    transformations_per_example=2  # 生成2个不同版本的增强文本
)

In [None]:
# 加载数据
data = pd.read_csv(f'/kaggle/input/storypoint/{PROJECT}.csv')
data = data.drop_duplicates(subset='issuekey', keep='first')
data['description'] = data['description'].fillna('')
data.dropna(inplace=True)
data['text'] = data['title'] + ' ' + data['description']
texts = data['text'].values.tolist()

# 预训练数据集
dataset = TextDataset(texts, tokenizer=tokenizer, augment=True, max_length=MAX_LEN, augmenter=text_augmenter)
dataloader = DataLoader(dataset, batch_size=TRAIN_BATCH, shuffle=True)

In [None]:
# 初始化模型
simsiam = SimSiamText(base_encoder=bert_model).to(device)

# 预训练 SimSiam
criterion = nn.CosineSimilarity(dim=1).to(device)
#optimizer = torch.optim.Adam(simsiam.parameters(), lr=LEARNING_RATE)
optimizer = torch.optim.SGD(simsiam.parameters(), LEARNING_RATE, momentum=0.9, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=3, gamma=0.7)  # 每3个epoch，学习率衰减为原来的0.7
train_simsiam(simsiam, dataloader, criterion, optimizer, scheduler, device)

Epoch [1/10], Step [0/146], Loss: -0.0018
Epoch [1/10], Step [10/146], Loss: -0.1428
Epoch [1/10], Step [20/146], Loss: -0.2051
Epoch [1/10], Step [30/146], Loss: -0.3784
Epoch [1/10], Step [40/146], Loss: -0.4995
Epoch [1/10], Step [50/146], Loss: -0.6043
Epoch [1/10], Step [60/146], Loss: -0.6651
Epoch [1/10], Step [70/146], Loss: -0.6484
Epoch [1/10], Step [80/146], Loss: -0.7270
Epoch [1/10], Step [90/146], Loss: -0.7281
Epoch [1/10], Step [100/146], Loss: -0.7923
Epoch [1/10], Step [110/146], Loss: -0.8425
Epoch [1/10], Step [120/146], Loss: -0.8417
Epoch [1/10], Step [130/146], Loss: -0.8392
Epoch [1/10], Step [140/146], Loss: -0.8538
Epoch [1/10] Completed | Avg Loss: -0.6105 | Time: 347.76s



In [None]:
torch.cuda.empty_cache()
# 加载最优模型并进行微调
simsiam.load_state_dict(torch.load(SAVE_PATH))  # 加载最佳模型
simsiam = simsiam.to(device)

# 移除 projector 和 predictor
simsiam.remove_projection_head()

In [None]:
# 划分数据集
labeled_data = data[data['storypoint'] != -1]
regression_criterion = nn.MSELoss().to(device)

# 微调并测试
fine_tune_and_eval(simsiam, labeled_data, tokenizer, regression_criterion, optimizer, scheduler, device)