# 训练 CLIP 文本到图像嵌入的映射网络

1. 环境设置
首先，导入必要的库，并配置一些基本参数。

In [2]:
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
import random
import numpy as np

# 设置随机种子以确保可重复性
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# 配置参数
TEXT_EMBEDDINGS_PATH = './clip_embeddings/text_embeddings.pt'   # 替换为实际路径
IMAGE_EMBEDDINGS_PATH = './clip_embeddings/image_embeddings.pt' # 替换为实际路径
MODEL_SAVE_PATH = 'text_to_image_embedder.pth'
BATCH_SIZE = 128
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")


使用设备: cuda


2. 定义数据集类
创建一个自定义的 Dataset 类，用于加载预先计算的文本和图像嵌入。

In [3]:
# 定义自定义 Dataset 类
class CLIPEmbedMappingDatasetPrecomputed(Dataset):
    def __init__(self, text_embeddings_path, image_embeddings_path):
        """
        初始化数据集，加载预计算的嵌入
        :param text_embeddings_path: 文本嵌入文件路径
        :param image_embeddings_path: 图像嵌入文件路径
        """
        # 加载嵌入
        try:
            self.text_embeddings = torch.load(text_embeddings_path)
            self.image_embeddings = torch.load(image_embeddings_path)
        except Exception as e:
            print(f"Error loading embeddings: {e}")
            raise

        # 确认嵌入数量匹配
        assert len(self.text_embeddings) == len(self.image_embeddings), "文本和图像嵌入的数量不匹配"

        # 确保嵌入的 dtype 为 float32
        if self.text_embeddings.dtype != torch.float32:
            self.text_embeddings = self.text_embeddings.float()
            print("已将文本嵌入转换为 float32")
        if self.image_embeddings.dtype != torch.float32:
            self.image_embeddings = self.image_embeddings.float()
            print("已将图像嵌入转换为 float32")

    def __len__(self):
        return len(self.text_embeddings)

    def __getitem__(self, idx):
        return self.text_embeddings[idx], self.image_embeddings[idx]


3. 定义映射网络模型
定义一个简单的全连接神经网络，将文本嵌入映射到图像嵌入空间。

In [4]:
# 定义映射网络模型
class TextToImageEmbedder(nn.Module):
    def __init__(self, clip_dim=512, embed_dim=512):
        super(TextToImageEmbedder, self).__init__()
        self.mapping = nn.Sequential(
            nn.Linear(clip_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
    
    def forward(self, text_embeddings):
        image_embeddings = self.mapping(text_embeddings)
        return image_embeddings


4. 创建数据加载器
使用自定义的 Dataset 类创建 DataLoader，以便在训练过程中批量加载数据。

In [5]:
# 创建映射网络的数据集和数据加载器（使用预计算嵌入）
try:
    mapping_dataset = CLIPEmbedMappingDatasetPrecomputed(
        text_embeddings_path=TEXT_EMBEDDINGS_PATH,
        image_embeddings_path=IMAGE_EMBEDDINGS_PATH
    )
except Exception as e:
    print(f"Error initializing mapping dataset: {e}")
    raise

# 确认数据集大小
print(f"数据集大小: {len(mapping_dataset)}")

# 创建 DataLoader
mapping_loader = DataLoader(
    mapping_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # 根据需要调整
    pin_memory=True
)


  self.text_embeddings = torch.load(text_embeddings_path)
  self.image_embeddings = torch.load(image_embeddings_path)


已将文本嵌入转换为 float32
已将图像嵌入转换为 float32
数据集大小: 162770


5. 初始化模型、损失函数和优化器

In [6]:
# 创建映射网络模型
embedder = TextToImageEmbedder(clip_dim=512, embed_dim=512).to(device)

# 定义损失函数和优化器
criterion_mapping = nn.MSELoss()
optimizer_mapping = optim.Adam(embedder.parameters(), lr=LEARNING_RATE)


6. 训练映射网络
定义训练函数，并开始训练过程。

In [7]:
# 定义训练函数
def train_mapping_network(model, loader, optimizer, criterion, num_epochs=20):
    model.train()
    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0
        for text_emb, img_emb in tqdm(loader, desc=f"Mapping Epoch {epoch}/{num_epochs}"):
            # 将嵌入移动到GPU
            text_emb = text_emb.to(device)
            img_emb = img_emb.to(device)

            # 前向传播
            pred_img_emb = model(text_emb)
            loss = criterion(pred_img_emb, img_emb)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(loader.dataset)
        print(f"Mapping Epoch {epoch}, 平均损失: {avg_loss:.6f}")
    return model

# 开始训练
trained_embedder = train_mapping_network(embedder, mapping_loader, optimizer_mapping, criterion_mapping, num_epochs=20)


Mapping Epoch 1/20: 100%|██████████| 1272/1272 [00:03<00:00, 413.40it/s]


Mapping Epoch 1, 平均损失: 0.000563


Mapping Epoch 2/20: 100%|██████████| 1272/1272 [00:02<00:00, 463.65it/s]


Mapping Epoch 2, 平均损失: 0.000514


Mapping Epoch 3/20: 100%|██████████| 1272/1272 [00:02<00:00, 471.60it/s]


Mapping Epoch 3, 平均损失: 0.000507


Mapping Epoch 4/20: 100%|██████████| 1272/1272 [00:02<00:00, 468.57it/s]


Mapping Epoch 4, 平均损失: 0.000502


Mapping Epoch 5/20: 100%|██████████| 1272/1272 [00:02<00:00, 490.42it/s]


Mapping Epoch 5, 平均损失: 0.000500


Mapping Epoch 6/20: 100%|██████████| 1272/1272 [00:02<00:00, 438.81it/s]


Mapping Epoch 6, 平均损失: 0.000498


Mapping Epoch 7/20: 100%|██████████| 1272/1272 [00:02<00:00, 538.05it/s]


Mapping Epoch 7, 平均损失: 0.000496


Mapping Epoch 8/20: 100%|██████████| 1272/1272 [00:02<00:00, 491.11it/s]


Mapping Epoch 8, 平均损失: 0.000495


Mapping Epoch 9/20: 100%|██████████| 1272/1272 [00:02<00:00, 534.53it/s]


Mapping Epoch 9, 平均损失: 0.000494


Mapping Epoch 10/20: 100%|██████████| 1272/1272 [00:02<00:00, 543.98it/s]


Mapping Epoch 10, 平均损失: 0.000493


Mapping Epoch 11/20: 100%|██████████| 1272/1272 [00:02<00:00, 538.61it/s]


Mapping Epoch 11, 平均损失: 0.000492


Mapping Epoch 12/20: 100%|██████████| 1272/1272 [00:02<00:00, 473.40it/s]


Mapping Epoch 12, 平均损失: 0.000491


Mapping Epoch 13/20: 100%|██████████| 1272/1272 [00:02<00:00, 472.87it/s]


Mapping Epoch 13, 平均损失: 0.000491


Mapping Epoch 14/20: 100%|██████████| 1272/1272 [00:02<00:00, 481.79it/s]


Mapping Epoch 14, 平均损失: 0.000490


Mapping Epoch 15/20: 100%|██████████| 1272/1272 [00:02<00:00, 540.46it/s]


Mapping Epoch 15, 平均损失: 0.000489


Mapping Epoch 16/20: 100%|██████████| 1272/1272 [00:02<00:00, 506.05it/s]


Mapping Epoch 16, 平均损失: 0.000489


Mapping Epoch 17/20: 100%|██████████| 1272/1272 [00:02<00:00, 522.38it/s]


Mapping Epoch 17, 平均损失: 0.000488


Mapping Epoch 18/20: 100%|██████████| 1272/1272 [00:02<00:00, 556.67it/s]


Mapping Epoch 18, 平均损失: 0.000488


Mapping Epoch 19/20: 100%|██████████| 1272/1272 [00:02<00:00, 561.67it/s]


Mapping Epoch 19, 平均损失: 0.000487


Mapping Epoch 20/20: 100%|██████████| 1272/1272 [00:02<00:00, 499.27it/s]

Mapping Epoch 20, 平均损失: 0.000487





In [8]:
# 保存映射网络
torch.save(trained_embedder.state_dict(), MODEL_SAVE_PATH)
print(f"映射网络已保存为: {MODEL_SAVE_PATH}")


映射网络已保存为: text_to_image_embedder.pth


## 利用后来生成的嵌入训练了一个新的网络效果比较好
唯一用到的

In [1]:
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
import random
import numpy as np

# 设置随机种子以确保可重复性
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# 配置参数
TEXT_EMBEDDINGS_PATH = '/root/autodl-tmp/clip_embeddings/text_embeddings_partition_0.pt'   # 替换为实际路径
IMAGE_EMBEDDINGS_PATH = '/root/autodl-tmp/clip_embeddings/image_embeddings_partition_0.pt' # 替换为实际路径
MODEL_SAVE_PATH = 'text_to_image_embedder.pth'
BATCH_SIZE = 128
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 定义自定义 Dataset 类
class CLIPEmbedMappingDatasetPrecomputed(Dataset):
    def __init__(self, text_embeddings_path, image_embeddings_path):
        """
        初始化数据集，加载预计算的嵌入
        :param text_embeddings_path: 文本嵌入文件路径
        :param image_embeddings_path: 图像嵌入文件路径
        """
        # 加载嵌入
        try:
            self.text_embeddings = torch.load(text_embeddings_path)
            self.image_embeddings = torch.load(image_embeddings_path)
        except Exception as e:
            print(f"Error loading embeddings: {e}")
            raise

        # 确认嵌入数量匹配
        assert len(self.text_embeddings) == len(self.image_embeddings), "文本和图像嵌入的数量不匹配"

        # 确保嵌入的 dtype 为 float32
        if self.text_embeddings.dtype != torch.float32:
            self.text_embeddings = self.text_embeddings.float()
            print("已将文本嵌入转换为 float32")
        if self.image_embeddings.dtype != torch.float32:
            self.image_embeddings = self.image_embeddings.float()
            print("已将图像嵌入转换为 float32")

    def __len__(self):
        return len(self.text_embeddings)

    def __getitem__(self, idx):
        return self.text_embeddings[idx], self.image_embeddings[idx]

# 定义优化后的映射网络模型
class TextToImageEmbedder(nn.Module):
    def __init__(self, clip_dim=512, embed_dim=512):
        super(TextToImageEmbedder, self).__init__()
        self.mapping = nn.Sequential(
            nn.Linear(clip_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, embed_dim),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
    
    def forward(self, text_embeddings):
        image_embeddings = self.mapping(text_embeddings)
        return image_embeddings

# 定义新的余弦相似度损失函数
class CosineSimilarityLoss(nn.Module):
    def __init__(self):
        super(CosineSimilarityLoss, self).__init__()
        self.cos = nn.CosineSimilarity(dim=1)
    
    def forward(self, pred, target):
        # 余弦相似度范围在 [-1, 1]，我们希望最大化相似度，因此损失为 1 - cos_sim
        return 1 - self.cos(pred, target).mean()

# 创建映射网络的数据集和数据加载器（使用预计算嵌入）
try:
    mapping_dataset = CLIPEmbedMappingDatasetPrecomputed(
        text_embeddings_path=TEXT_EMBEDDINGS_PATH,
        image_embeddings_path=IMAGE_EMBEDDINGS_PATH
    )
except Exception as e:
    print(f"Error initializing mapping dataset: {e}")
    raise

# 确认数据集大小
print(f"数据集大小: {len(mapping_dataset)}")

# 创建 DataLoader
mapping_loader = DataLoader(
    mapping_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,  # 根据需要调整
    pin_memory=True
)

# 创建映射网络模型
embedder = TextToImageEmbedder(clip_dim=512, embed_dim=512).to(device)

# 定义损失函数和优化器
criterion_mapping = CosineSimilarityLoss()
optimizer_mapping = optim.Adam(embedder.parameters(), lr=LEARNING_RATE)

# 定义训练函数
def train_mapping_network(model, loader, optimizer, criterion, num_epochs=20):
    model.train()
    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0
        for text_emb, img_emb in tqdm(loader, desc=f"Mapping Epoch {epoch}/{num_epochs}"):
            # 将嵌入移动到GPU
            text_emb = text_emb.to(device)
            img_emb = img_emb.to(device)

            # 前向传播
            pred_img_emb = model(text_emb)
            loss = criterion(pred_img_emb, img_emb)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(loader.dataset)
        print(f"Mapping Epoch {epoch}, 平均损失: {avg_loss:.6f}")
    return model

# 开始训练
trained_embedder = train_mapping_network(embedder, mapping_loader, optimizer_mapping, criterion_mapping, num_epochs=NUM_EPOCHS)

# 保存映射网络
torch.save(trained_embedder.state_dict(), MODEL_SAVE_PATH)
print(f"映射网络已保存为: {MODEL_SAVE_PATH}")


使用设备: cuda
已将文本嵌入转换为 float32
已将图像嵌入转换为 float32
数据集大小: 162770


Mapping Epoch 1/20: 100%|██████████| 1272/1272 [00:04<00:00, 271.96it/s]


Mapping Epoch 1, 平均损失: 0.001742


Mapping Epoch 2/20: 100%|██████████| 1272/1272 [00:04<00:00, 311.34it/s]


Mapping Epoch 2, 平均损失: 0.001429


Mapping Epoch 3/20: 100%|██████████| 1272/1272 [00:04<00:00, 307.54it/s]


Mapping Epoch 3, 平均损失: 0.001401


Mapping Epoch 4/20: 100%|██████████| 1272/1272 [00:04<00:00, 313.98it/s]


Mapping Epoch 4, 平均损失: 0.001386


Mapping Epoch 5/20: 100%|██████████| 1272/1272 [00:03<00:00, 340.36it/s]


Mapping Epoch 5, 平均损失: 0.001377


Mapping Epoch 6/20: 100%|██████████| 1272/1272 [00:04<00:00, 317.79it/s]


Mapping Epoch 6, 平均损失: 0.001370


Mapping Epoch 7/20: 100%|██████████| 1272/1272 [00:04<00:00, 281.46it/s]


Mapping Epoch 7, 平均损失: 0.001365


Mapping Epoch 8/20: 100%|██████████| 1272/1272 [00:03<00:00, 326.37it/s]


Mapping Epoch 8, 平均损失: 0.001361


Mapping Epoch 9/20: 100%|██████████| 1272/1272 [00:03<00:00, 323.17it/s]


Mapping Epoch 9, 平均损失: 0.001357


Mapping Epoch 10/20: 100%|██████████| 1272/1272 [00:04<00:00, 315.46it/s]


Mapping Epoch 10, 平均损失: 0.001354


Mapping Epoch 11/20: 100%|██████████| 1272/1272 [00:04<00:00, 297.54it/s]


Mapping Epoch 11, 平均损失: 0.001351


Mapping Epoch 12/20: 100%|██████████| 1272/1272 [00:04<00:00, 309.84it/s]


Mapping Epoch 12, 平均损失: 0.001349


Mapping Epoch 13/20: 100%|██████████| 1272/1272 [00:04<00:00, 315.38it/s]


Mapping Epoch 13, 平均损失: 0.001347


Mapping Epoch 14/20: 100%|██████████| 1272/1272 [00:03<00:00, 333.99it/s]


Mapping Epoch 14, 平均损失: 0.001345


Mapping Epoch 15/20: 100%|██████████| 1272/1272 [00:03<00:00, 325.78it/s]


Mapping Epoch 15, 平均损失: 0.001342


Mapping Epoch 16/20: 100%|██████████| 1272/1272 [00:04<00:00, 313.83it/s]


Mapping Epoch 16, 平均损失: 0.001341


Mapping Epoch 17/20: 100%|██████████| 1272/1272 [00:04<00:00, 311.12it/s]


Mapping Epoch 17, 平均损失: 0.001339


Mapping Epoch 18/20: 100%|██████████| 1272/1272 [00:04<00:00, 309.20it/s]


Mapping Epoch 18, 平均损失: 0.001338


Mapping Epoch 19/20: 100%|██████████| 1272/1272 [00:04<00:00, 298.15it/s]


Mapping Epoch 19, 平均损失: 0.001336


Mapping Epoch 20/20: 100%|██████████| 1272/1272 [00:04<00:00, 308.99it/s]

Mapping Epoch 20, 平均损失: 0.001335
映射网络已保存为: text_to_image_embedder.pth





# 新的更深的网络 不行 余弦相似度反而降低了


In [3]:
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
import random
import numpy as np

# 设置随机种子以确保可重复性
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# 配置参数
TEXT_EMBEDDINGS_PATH = '/root/autodl-tmp/clip_embeddings/text_embeddings_partition_0.pt'   # 替换为实际路径
IMAGE_EMBEDDINGS_PATH = '/root/autodl-tmp/clip_embeddings/image_embeddings_partition_0.pt' # 替换为实际路径
MODEL_SAVE_PATH = 'text_to_image_embedder_v2.pth' # 修改了保存文件名，以区分不同版本的模型
BATCH_SIZE = 128
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 定义自定义 Dataset 类
class CLIPEmbedMappingDatasetPrecomputed(Dataset):
    def __init__(self, text_embeddings_path, image_embeddings_path):
        """
        初始化数据集，加载预计算的嵌入
        :param text_embeddings_path: 文本嵌入文件路径
        :param image_embeddings_path: 图像嵌入文件路径
        """
        # 加载嵌入
        try:
            self.text_embeddings = torch.load(text_embeddings_path)
            self.image_embeddings = torch.load(image_embeddings_path)
        except Exception as e:
            print(f"Error loading embeddings: {e}")
            raise

        # 确认嵌入数量匹配
        assert len(self.text_embeddings) == len(self.image_embeddings), "文本和图像嵌入的数量不匹配"

        # 确保嵌入的 dtype 为 float32
        if self.text_embeddings.dtype != torch.float32:
            self.text_embeddings = self.text_embeddings.float()
            print("已将文本嵌入转换为 float32")
        if self.image_embeddings.dtype != torch.float32:
            self.image_embeddings = self.image_embeddings.float()
            print("已将图像嵌入转换为 float32")

    def __len__(self):
        return len(self.text_embeddings)

    def __getitem__(self, idx):
        return self.text_embeddings[idx], self.image_embeddings[idx]

# 定义优化后的映射网络模型 V2 (增加深度和宽度)
class TextToImageEmbedderV2(nn.Module): # V2 表示 Version 2
    def __init__(self, clip_dim=512, embed_dim=512):
        super(TextToImageEmbedderV2, self).__init__()
        self.mapping = nn.Sequential(
            nn.Linear(clip_dim, 4096), # 加宽到 4096
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(4096, 2048), # 加宽到 2048
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512), # 增加一个 512 维度的隐藏层
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, embed_dim) # 输出层保持不变
        )

    def forward(self, text_embeddings):
        image_embeddings = self.mapping(text_embeddings)
        return image_embeddings


# 定义新的余弦相似度损失函数
class CosineSimilarityLoss(nn.Module):
    def __init__(self):
        super(CosineSimilarityLoss, self).__init__()
        self.cos = nn.CosineSimilarity(dim=1)

    def forward(self, pred, target):
        # 余弦相似度范围在 [-1, 1]，我们希望最大化相似度，因此损失为 1 - cos_sim
        return 1 - self.cos(pred, target).mean()

# 创建映射网络的数据集和数据加载器（使用预计算嵌入）
try:
    mapping_dataset = CLIPEmbedMappingDatasetPrecomputed(
        text_embeddings_path=TEXT_EMBEDDINGS_PATH,
        image_embeddings_path=IMAGE_EMBEDDINGS_PATH
    )
except Exception as e:
    print(f"Error initializing mapping dataset: {e}")
    raise

# 确认数据集大小
print(f"数据集大小: {len(mapping_dataset)}")

# 创建 DataLoader
mapping_loader = DataLoader(
    mapping_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,  # 根据需要调整
    pin_memory=True
)

# 创建映射网络模型 V2 (使用加深加宽的模型)
embedder = TextToImageEmbedderV2(clip_dim=512, embed_dim=512).to(device) # 使用 TextToImageEmbedderV2

# 定义损失函数和优化器
criterion_mapping = CosineSimilarityLoss()
optimizer_mapping = optim.Adam(embedder.parameters(), lr=LEARNING_RATE)

# 定义训练函数
def train_mapping_network(model, loader, optimizer, criterion, num_epochs=20):
    model.train()
    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0
        for text_emb, img_emb in tqdm(loader, desc=f"Mapping Epoch {epoch}/{num_epochs}"):
            # 将嵌入移动到GPU
            text_emb = text_emb.to(device)
            img_emb = img_emb.to(device)

            # 前向传播
            pred_img_emb = model(text_emb)
            loss = criterion(pred_img_emb, img_emb)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(loader.dataset)
        print(f"Mapping Epoch {epoch}, 平均损失: {avg_loss:.6f}")
    return model

# 开始训练
trained_embedder = train_mapping_network(embedder, mapping_loader, optimizer_mapping, criterion_mapping, num_epochs=NUM_EPOCHS)

# 保存映射网络
torch.save(trained_embedder.state_dict(), MODEL_SAVE_PATH)
print(f"映射网络已保存为: {MODEL_SAVE_PATH}")

使用设备: cuda
已将文本嵌入转换为 float32
已将图像嵌入转换为 float32
数据集大小: 162770


Mapping Epoch 1/20: 100%|██████████| 1272/1272 [00:06<00:00, 211.89it/s]


Mapping Epoch 1, 平均损失: 0.001736


Mapping Epoch 2/20: 100%|██████████| 1272/1272 [00:05<00:00, 217.61it/s]


Mapping Epoch 2, 平均损失: 0.001431


Mapping Epoch 3/20: 100%|██████████| 1272/1272 [00:05<00:00, 239.94it/s]


Mapping Epoch 3, 平均损失: 0.001401


Mapping Epoch 4/20: 100%|██████████| 1272/1272 [00:05<00:00, 242.28it/s]


Mapping Epoch 4, 平均损失: 0.001386


Mapping Epoch 5/20: 100%|██████████| 1272/1272 [00:05<00:00, 217.06it/s]


Mapping Epoch 5, 平均损失: 0.001376


Mapping Epoch 6/20: 100%|██████████| 1272/1272 [00:06<00:00, 207.89it/s]


Mapping Epoch 6, 平均损失: 0.001368


Mapping Epoch 7/20: 100%|██████████| 1272/1272 [00:06<00:00, 207.22it/s]


Mapping Epoch 7, 平均损失: 0.001363


Mapping Epoch 8/20: 100%|██████████| 1272/1272 [00:05<00:00, 218.32it/s]


Mapping Epoch 8, 平均损失: 0.001358


Mapping Epoch 9/20: 100%|██████████| 1272/1272 [00:05<00:00, 219.83it/s]


Mapping Epoch 9, 平均损失: 0.001354


Mapping Epoch 10/20: 100%|██████████| 1272/1272 [00:05<00:00, 231.67it/s]


Mapping Epoch 10, 平均损失: 0.001351


Mapping Epoch 11/20: 100%|██████████| 1272/1272 [00:05<00:00, 230.08it/s]


Mapping Epoch 11, 平均损失: 0.001348


Mapping Epoch 12/20: 100%|██████████| 1272/1272 [00:05<00:00, 222.27it/s]


Mapping Epoch 12, 平均损失: 0.001345


Mapping Epoch 13/20: 100%|██████████| 1272/1272 [00:05<00:00, 214.32it/s]


Mapping Epoch 13, 平均损失: 0.001343


Mapping Epoch 14/20: 100%|██████████| 1272/1272 [00:05<00:00, 223.73it/s]


Mapping Epoch 14, 平均损失: 0.001340


Mapping Epoch 15/20: 100%|██████████| 1272/1272 [00:05<00:00, 222.04it/s]


Mapping Epoch 15, 平均损失: 0.001338


Mapping Epoch 16/20: 100%|██████████| 1272/1272 [00:06<00:00, 205.60it/s]


Mapping Epoch 16, 平均损失: 0.001335


Mapping Epoch 17/20: 100%|██████████| 1272/1272 [00:06<00:00, 204.11it/s]


Mapping Epoch 17, 平均损失: 0.001333


Mapping Epoch 18/20: 100%|██████████| 1272/1272 [00:05<00:00, 217.12it/s]


Mapping Epoch 18, 平均损失: 0.001331


Mapping Epoch 19/20: 100%|██████████| 1272/1272 [00:05<00:00, 214.51it/s]


Mapping Epoch 19, 平均损失: 0.001330


Mapping Epoch 20/20: 100%|██████████| 1272/1272 [00:06<00:00, 210.13it/s]


Mapping Epoch 20, 平均损失: 0.001327
映射网络已保存为: text_to_image_embedder_v2.pth


# 新的 图片太模糊 不行

In [4]:
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
import random
import numpy as np

# 设置随机种子以确保可重复性
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# 配置参数
TEXT_EMBEDDINGS_PATH = '/root/autodl-tmp/clip_embeddings/text_embeddings_partition_0.pt'   # 替换为实际路径
IMAGE_EMBEDDINGS_PATH = '/root/autodl-tmp/clip_embeddings/image_embeddings_partition_0.pt' # 替换为实际路径
MODEL_SAVE_PATH = 'text_to_image_embedder_v3.pth' # 修改了保存文件名，以区分不同版本的模型
BATCH_SIZE = 128
NUM_EPOCHS = 20
LEARNING_RATE = 5e-5  # 降低学习率
WEIGHT_DECAY = 1e-5    # 添加 L2 正则化

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 定义自定义 Dataset 类
class CLIPEmbedMappingDatasetPrecomputed(Dataset):
    def __init__(self, text_embeddings_path, image_embeddings_path):
        """
        初始化数据集，加载预计算的嵌入
        :param text_embeddings_path: 文本嵌入文件路径
        :param image_embeddings_path: 图像嵌入文件路径
        """
        # 加载嵌入
        try:
            self.text_embeddings = torch.load(text_embeddings_path)
            self.image_embeddings = torch.load(image_embeddings_path)
        except Exception as e:
            print(f"Error loading embeddings: {e}")
            raise

        # 确认嵌入数量匹配
        assert len(self.text_embeddings) == len(self.image_embeddings), "文本和图像嵌入的数量不匹配"

        # 确保嵌入的 dtype 为 float32
        if self.text_embeddings.dtype != torch.float32:
            self.text_embeddings = self.text_embeddings.float()
            print("已将文本嵌入转换为 float32")
        if self.image_embeddings.dtype != torch.float32:
            self.image_embeddings = self.image_embeddings.float()
            print("已将图像嵌入转换为 float32")

    def __len__(self):
        return len(self.text_embeddings)

    def __getitem__(self, idx):
        return self.text_embeddings[idx], self.image_embeddings[idx]

# 定义优化后的映射网络模型 V2 (增加深度和宽度)
class TextToImageEmbedderV2(nn.Module): # V2 表示 Version 2
    def __init__(self, clip_dim=512, embed_dim=512):
        super(TextToImageEmbedderV2, self).__init__()
        self.mapping = nn.Sequential(
            nn.Linear(clip_dim, 4096), # 加宽到 4096
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(4096, 2048), # 加宽到 2048
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512), # 增加一个 512 维度的隐藏层
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, embed_dim) # 输出层保持不变
        )

    def forward(self, text_embeddings):
        image_embeddings = self.mapping(text_embeddings)
        return image_embeddings


# 定义新的余弦相似度损失函数
class CosineSimilarityLoss(nn.Module):
    def __init__(self):
        super(CosineSimilarityLoss, self).__init__()
        self.cos = nn.CosineSimilarity(dim=1)

    def forward(self, pred, target):
        # 余弦相似度范围在 [-1, 1]，我们希望最大化相似度，因此损失为 1 - cos_sim
        return 1 - self.cos(pred, target).mean()

# 创建映射网络的数据集和数据加载器（使用预计算嵌入）
try:
    mapping_dataset = CLIPEmbedMappingDatasetPrecomputed(
        text_embeddings_path=TEXT_EMBEDDINGS_PATH,
        image_embeddings_path=IMAGE_EMBEDDINGS_PATH
    )
except Exception as e:
    print(f"Error initializing mapping dataset: {e}")
    raise

# 确认数据集大小
print(f"数据集大小: {len(mapping_dataset)}")

# 创建 DataLoader
mapping_loader = DataLoader(
    mapping_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,  # 根据需要调整
    pin_memory=True
)

# 创建映射网络模型 V2 (使用加深加宽的模型)
embedder = TextToImageEmbedderV2(clip_dim=512, embed_dim=512).to(device) # 使用 TextToImageEmbedderV2

# 定义损失函数和优化器
criterion_mapping = CosineSimilarityLoss()
optimizer_mapping = optim.Adam(embedder.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) # 添加 weight_decay
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_mapping, T_max=NUM_EPOCHS) # Cosine Annealing Scheduler

# 定义训练函数
def train_mapping_network(model, loader, optimizer, criterion, num_epochs=20, scheduler=None): # 添加 scheduler 参数
    model.train()
    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0
        for text_emb, img_emb in tqdm(loader, desc=f"Mapping Epoch {epoch}/{num_epochs}"):
            # 将嵌入移动到GPU
            text_emb = text_emb.to(device)
            img_emb = img_emb.to(device)

            # 前向传播
            pred_img_emb = model(text_emb)
            loss = criterion(pred_img_emb, img_emb)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        if scheduler:
            scheduler.step() # 更新学习率

        avg_loss = epoch_loss / len(loader.dataset)
        current_lr = optimizer.param_groups[0]['lr'] # 获取当前学习率
        print(f"Mapping Epoch {epoch}, 平均损失: {avg_loss:.6f}, Learning Rate: {current_lr:.8f}") # 打印学习率
    return model

# 开始训练
trained_embedder = train_mapping_network(embedder, mapping_loader, optimizer_mapping, criterion_mapping, num_epochs=NUM_EPOCHS, scheduler=scheduler) # 传入 scheduler

# 保存映射网络
torch.save(trained_embedder.state_dict(), MODEL_SAVE_PATH)
print(f"映射网络已保存为: {MODEL_SAVE_PATH}")

使用设备: cuda
已将文本嵌入转换为 float32
已将图像嵌入转换为 float32
数据集大小: 162770


Mapping Epoch 1/20: 100%|██████████| 1272/1272 [00:05<00:00, 219.07it/s]


Mapping Epoch 1, 平均损失: 0.001952, Learning Rate: 0.00004969


Mapping Epoch 2/20: 100%|██████████| 1272/1272 [00:06<00:00, 210.94it/s]


Mapping Epoch 2, 平均损失: 0.001475, Learning Rate: 0.00004878


Mapping Epoch 3/20: 100%|██████████| 1272/1272 [00:06<00:00, 203.26it/s]


Mapping Epoch 3, 平均损失: 0.001431, Learning Rate: 0.00004728


Mapping Epoch 4/20: 100%|██████████| 1272/1272 [00:05<00:00, 213.30it/s]


Mapping Epoch 4, 平均损失: 0.001408, Learning Rate: 0.00004523


Mapping Epoch 5/20: 100%|██████████| 1272/1272 [00:05<00:00, 212.02it/s]


Mapping Epoch 5, 平均损失: 0.001393, Learning Rate: 0.00004268


Mapping Epoch 6/20: 100%|██████████| 1272/1272 [00:06<00:00, 211.03it/s]


Mapping Epoch 6, 平均损失: 0.001383, Learning Rate: 0.00003969


Mapping Epoch 7/20: 100%|██████████| 1272/1272 [00:05<00:00, 230.55it/s]


Mapping Epoch 7, 平均损失: 0.001376, Learning Rate: 0.00003635


Mapping Epoch 8/20: 100%|██████████| 1272/1272 [00:05<00:00, 249.92it/s]


Mapping Epoch 8, 平均损失: 0.001370, Learning Rate: 0.00003273


Mapping Epoch 9/20: 100%|██████████| 1272/1272 [00:05<00:00, 241.12it/s]


Mapping Epoch 9, 平均损失: 0.001365, Learning Rate: 0.00002891


Mapping Epoch 10/20: 100%|██████████| 1272/1272 [00:05<00:00, 221.52it/s]


Mapping Epoch 10, 平均损失: 0.001360, Learning Rate: 0.00002500


Mapping Epoch 11/20: 100%|██████████| 1272/1272 [00:05<00:00, 223.39it/s]


Mapping Epoch 11, 平均损失: 0.001357, Learning Rate: 0.00002109


Mapping Epoch 12/20: 100%|██████████| 1272/1272 [00:05<00:00, 230.52it/s]


Mapping Epoch 12, 平均损失: 0.001353, Learning Rate: 0.00001727


Mapping Epoch 13/20: 100%|██████████| 1272/1272 [00:05<00:00, 228.69it/s]


Mapping Epoch 13, 平均损失: 0.001350, Learning Rate: 0.00001365


Mapping Epoch 14/20: 100%|██████████| 1272/1272 [00:05<00:00, 221.09it/s]


Mapping Epoch 14, 平均损失: 0.001348, Learning Rate: 0.00001031


Mapping Epoch 15/20: 100%|██████████| 1272/1272 [00:05<00:00, 235.80it/s]


Mapping Epoch 15, 平均损失: 0.001345, Learning Rate: 0.00000732


Mapping Epoch 16/20: 100%|██████████| 1272/1272 [00:05<00:00, 225.29it/s]


Mapping Epoch 16, 平均损失: 0.001343, Learning Rate: 0.00000477


Mapping Epoch 17/20: 100%|██████████| 1272/1272 [00:05<00:00, 220.71it/s]


Mapping Epoch 17, 平均损失: 0.001342, Learning Rate: 0.00000272


Mapping Epoch 18/20: 100%|██████████| 1272/1272 [00:05<00:00, 226.77it/s]


Mapping Epoch 18, 平均损失: 0.001340, Learning Rate: 0.00000122


Mapping Epoch 19/20: 100%|██████████| 1272/1272 [00:05<00:00, 237.44it/s]


Mapping Epoch 19, 平均损失: 0.001340, Learning Rate: 0.00000031


Mapping Epoch 20/20: 100%|██████████| 1272/1272 [00:05<00:00, 219.14it/s]


Mapping Epoch 20, 平均损失: 0.001339, Learning Rate: 0.00000000
映射网络已保存为: text_to_image_embedder_v3.pth
