# Generate new molecules for glioblastoma
# 生成治疗 glioblastoma 的新分子

## 代码过程的解释
**检查可用设备**：使用 GPU（如果可用）或 CPU。

**导入必要的库**：包括 PyTorch、NumPy、Pandas 等。

**设置随机种子**：使用 `seed_all()` 函数设置随机种子，以确保结果的可重复性

**加载数据集**：使用 Pandas 加载 CSV 文件，并移除缺失值。

**提取 SMILES 和标签**：从数据集中提取需要的列。

**划分数据集**：使用 `train_test_split` 将数据集划分为训练、验证和测试集。

**构建词汇表**：从训练数据中提取所有可能的字符，构建词汇表和映射关系。

**定义必要的函数**：包括字符和索引之间的转换函数，以及字符串和张量之间的转换函数。

**定义 VAE 模型**：包含编码器和解码器，以及采样函数。

**定义训练参数**：设置训练所需的超参数，如学习率、批次大小等。

**定义辅助类和函数**：包括 KL 权重退火器、学习率调度器、循环缓冲区、日志记录器等。

**定义训练函数**：包括 `_train_epoch` 和 `_train` 函数，以及 `fit` 函数来开始训练。

**定义数据加载器相关函数**：包括 `get_dataloader`、`get_collate_fn` 等。

**定义采样函数**：用于从训练好的模型中生成新的 SMILES 字符串。

**初始化模型**：创建 VAE 模型的实例并将其移动到指定设备上。

**训练模型**：使用 `fit` 函数训练模型。

**保存模型**：将训练好的模型参数保存到指定路径。

**生成样本**：从训练好的模型中生成新的分子，并将结果保存到 CSV 文件中。

**保存损失值**：将训练过程中记录的损失值保存到 CSV 文件中，以便后续分析。

### **注意事项**

- **路径设置**：请确保您的数据集文件 `egfr_data_ubstructures_matches.csv` 位于代码运行的目录中，或者提供正确的文件路径。
- **目录创建**：在保存模型和生成的样本时，代码中使用了 `os.makedirs()` 创建目录。如果目录已存在，不会报错。
- **模型保存和加载**：在保存和加载模型时，请确保文件路径正确。如果您想在训练后加载模型，请取消相应的注释。
- **GPU 使用**：如果您的计算机上有 GPU，并且已正确安装了 CUDA，代码将自动使用 GPU 进行训练。
- **运行时间**：由于训练迭代次数较多（`n_epoch = 100`），训练过程可能需要较长时间。您可以根据需要调整训练轮数。

### 

In [19]:
# Importing necessary libraries
# 检测是否是GPU可用，⚠️，本次教程，必须必须必须❤️需要采用4060 以上的显卡，否则会报错，你电脑报废了！！！
import torch
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


In [20]:
from pathlib import Path
import os

# 获取当前工作目录
HERE = Path(os.getcwd())
DATA = HERE / 'data'
if not DATA.exists():
    DATA.mkdir(parents=True, exist_ok=True)
print(DATA)

/Users/wangyang/Desktop/AI-drug-design/list/05_workshop/02_Generate_new_molecules_glioblastoma/data


In [21]:
import deepchem as dc
import numpy as np
import random
import math
import pandas as pd
import math
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import _LRScheduler
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

from collections import UserList, defaultdict
from tqdm import tqdm

from rdkit import RDLogger                                                                                                                                                               


In [22]:
# 禁用 RDKit 日志
RDLogger.DisableLog('rdApp.*')

In [23]:
# 设置随机种子，以确保可重复性
def seed_all(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
seed_all()

In [24]:
# 加载数据集
df = pd.read_csv(DATA / 'egfr_data_ubstructures_matches.csv')
df = df[df['smiles'].notnull()]  # 移除缺失值

In [25]:
df.head()

Unnamed: 0,molecule_chembl_id,IC50,units,smiles,pIC50,ro5_fulfilled,ROMol
0,CHEMBL35820,0.006,nM,CCOc1cc2ncnc(Nc3cccc(Br)c3)c2cc1OCC,11.221849,True,<rdkit.Chem.rdchem.Mol object at 0x7fbdb826e190>
1,CHEMBL53711,0.006,nM,CN(C)c1cc2c(Nc3cccc(Br)c3)ncnc2cn1,11.221849,True,<rdkit.Chem.rdchem.Mol object at 0x7fbdb826e200>
2,CHEMBL53753,0.008,nM,CNc1cc2c(Nc3cccc(Br)c3)ncnc2cn1,11.09691,True,<rdkit.Chem.rdchem.Mol object at 0x7fbdb826e270>
3,CHEMBL66031,0.008,nM,Brc1cccc(Nc2ncnc3cc4[nH]cnc4cc23)c1,11.09691,True,<rdkit.Chem.rdchem.Mol object at 0x7fbdb826e2e0>
4,CHEMBL176582,0.01,nM,Cn1cnc2cc3ncnc(Nc4cccc(Br)c4)c3cc21,11.0,True,<rdkit.Chem.rdchem.Mol object at 0x7fbdb826e350>


In [26]:
# 提取 SMILES 和标签
smiles_list = df['smiles'].tolist()
labels = df['pIC50'].tolist()

In [27]:
# 划分数据集（80% 训练，10% 验证，10% 测试）
train_smiles, temp_smiles, train_labels, temp_labels = train_test_split(
    smiles_list, labels, test_size=0.2, random_state=42)

valid_smiles, test_smiles, valid_labels, test_labels = train_test_split(
    temp_smiles, temp_labels, test_size=0.5, random_state=42)

print(f'Compound train/valid/test split: {len(train_smiles)}/{len(valid_smiles)}/{len(test_smiles)}')


Compound train/valid/test split: 1955/244/245


In [28]:
# 更新训练数据
train_data = train_smiles
train_label = train_labels

In [29]:
# 构建词汇表
chars = set()
for string in train_data:
    chars.update(string)
all_sys = sorted(list(chars)) + ['<bos>', '<eos>', '<pad>', '<unk>']
vocab = all_sys
c2i = {c: i for i, c in enumerate(all_sys)}  # 字符到索引的映射
i2c = {i: c for i, c in enumerate(all_sys)}  # 索引到字符的映射

In [30]:
# 创建独热编码向量
vector = torch.eye(len(c2i))

In [31]:
# 定义必要的函数
def char2id(char):
    """将字符转换为索引"""
    return c2i.get(char, c2i['<unk>'])

def id2char(id):
    """将索引转换为字符"""
    return i2c.get(id, '<unk>')

def string2ids(string, add_bos=False, add_eos=False):
    """将字符串转换为索引列表"""
    ids = [char2id(c) for c in string]
    if add_bos:
        ids = [c2i['<bos>']] + ids
    if add_eos:
        ids = ids + [c2i['<eos>']]
    return ids

def ids2string(ids, rem_bos=True, rem_eos=True):
    """将索引列表转换为字符串"""
    if rem_bos and ids and ids[0] == c2i['<bos>']:
        ids = ids[1:]
    if rem_eos and ids and ids[-1] == c2i['<eos>']:
        ids = ids[:-1]
    return ''.join([id2char(id) for id in ids])

def string2tensor(string, device='cpu'):
    """将字符串转换为张量"""
    ids = string2ids(string, add_bos=True, add_eos=True)
    tensor = torch.tensor(ids, dtype=torch.long, device=device)
    return tensor

def tensor2string(tensor):
    """将张量转换为字符串"""
    ids = tensor.tolist()
    return ids2string(ids)

In [32]:
# 将训练数据转换为张量
train_tensors = [string2tensor(string, device=device) for string in train_data]


In [33]:
# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self, vocab, vector):
        super(VAE, self).__init__()
        self.vocabulary = vocab
        self.vector = vector
        n_vocab, d_emb = len(vocab), vector.size(1)

        # 嵌入层
        self.x_emb = nn.Embedding(n_vocab, d_emb, padding_idx=c2i['<pad>'])
        self.x_emb.weight.data.copy_(vector)

        # 编码器参数
        self.q_bidir = True
        self.q_d_h = 256
        self.q_n_layers = 1
        self.q_dropout = 0.5

        # 解码器参数
        self.d_n_layers = 3
        self.d_dropout = 0.0
        self.d_z = 128
        self.d_d_h = 512

        # 编码器
        self.encoder_rnn = nn.GRU(
            d_emb,
            self.q_d_h,
            num_layers=self.q_n_layers,
            batch_first=True,
            dropout=self.q_dropout if self.q_n_layers > 1 else 0,
            bidirectional=self.q_bidir
        )
        q_d_last = self.q_d_h * (2 if self.q_bidir else 1)
        self.q_mu = nn.Linear(q_d_last, self.d_z)
        self.q_logvar = nn.Linear(q_d_last, self.d_z)

        # 解码器
        self.decoder_rnn = nn.GRU(
            d_emb + self.d_z,
            self.d_d_h,
            num_layers=self.d_n_layers,
            batch_first=True,
            dropout=self.d_dropout if self.d_n_layers > 1 else 0
        )
        self.decoder_latent = nn.Linear(self.d_z, self.d_d_h)
        self.decoder_fullyc = nn.Linear(self.d_d_h, n_vocab)

    def device(self):
        """获取模型所在的设备"""
        return next(self.parameters()).device

    def forward(self, x):
        """前向传播"""
        z, kl_loss = self.forward_encoder(x)
        recon_loss = self.forward_decoder(x, z)
        return kl_loss, recon_loss

    def forward_encoder(self, x):
        """编码器前向传播"""
        x_emb = [self.x_emb(i_x) for i_x in x]
        x_packed = nn.utils.rnn.pack_sequence(x_emb, enforce_sorted=False)
        _, h = self.encoder_rnn(x_packed)
        if self.q_bidir:
            h = h.view(self.q_n_layers, 2, -1, self.q_d_h)
            h = torch.cat((h[-1, 0], h[-1, 1]), dim=-1)
        else:
            h = h[-1]
        mu = self.q_mu(h)
        logvar = self.q_logvar(h)
        eps = torch.randn_like(mu)
        z = mu + (logvar / 2).exp() * eps
        kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean()
        return z, kl_loss

    def forward_decoder(self, x, z):
        """解码器前向传播"""
        lengths = [len(i_x) for i_x in x]
        x_padded = nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=c2i['<pad>'])
        x_emb = self.x_emb(x_padded)
        z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1)
        x_input = torch.cat([x_emb, z_0], dim=-1)
        x_packed = nn.utils.rnn.pack_padded_sequence(x_input, lengths, batch_first=True, enforce_sorted=False)
        h_0 = self.decoder_latent(z)
        h_0 = h_0.unsqueeze(0).repeat(self.d_n_layers, 1, 1)
        output, _ = self.decoder_rnn(x_packed, h_0)
        output_padded, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        y = self.decoder_fullyc(output_padded)
        recon_loss = F.cross_entropy(
            y[:, :-1].contiguous().view(-1, y.size(-1)),
            x_padded[:, 1:].contiguous().view(-1),
            ignore_index=c2i['<pad>']
        )
        return recon_loss

    def sample_z_prior(self, n_batch):
        """从先验分布中采样 z"""
        return torch.randn(n_batch, self.d_z, device=self.device())

    def sample(self, n_batch, max_len=100, z=None, temp=1.0):
        """生成样本"""
        with torch.no_grad():
            if z is None:
                z = self.sample_z_prior(n_batch)
            z = z.to(self.device())
            z_0 = z.unsqueeze(1)
            h = self.decoder_latent(z)
            h = h.unsqueeze(0).repeat(self.d_n_layers, 1, 1)
            w = torch.tensor([c2i['<bos>']] * n_batch, device=self.device())
            x = torch.tensor([c2i['<pad>']] * n_batch * max_len, device=self.device()).view(n_batch, max_len)
            x[:, 0] = c2i['<bos>']
            end_pads = torch.tensor([max_len] * n_batch, device=self.device())
            eos_mask = torch.zeros(n_batch, dtype=torch.bool, device=self.device())

            for i in range(1, max_len):
                w_emb = self.x_emb(w).unsqueeze(1)
                x_input = torch.cat([w_emb, z_0], dim=-1)
                o, h = self.decoder_rnn(x_input, h)
                y = self.decoder_fullyc(o.squeeze(1))
                y = F.softmax(y / temp, dim=-1)
                w = torch.multinomial(y, 1)[:, 0]
                x[:, i] = w
                eos_mask = eos_mask | (w == c2i['<eos>'])
                if eos_mask.all():
                    break

            samples = []
            for i in range(n_batch):
                sample = x[i, :].tolist()
                sample_str = ids2string(sample)
                samples.append(sample_str)

            return samples

In [34]:
# 定义训练参数
n_last = 1000
n_batch = 512
kl_start = 0
kl_w_start = 0.0
kl_w_end = 1.0
n_epoch = 100
n_workers = 0
clip_grad = 50
lr_start = 0.003
lr_n_period = 10
lr_n_mult = 1
lr_end = 3e-4
lr_n_restarts = 6

In [35]:
# 创建损失记录 DataFrame
df_losses = pd.DataFrame(columns=['epoch', 'kl_weight', 'lr', 'kl_loss', 'recon_loss', 'loss'])


In [36]:
# 定义 KL 权重退火器
class KLAnnealer:
    def __init__(self, n_epoch):
        self.i_start = kl_start
        self.w_start = kl_w_start
        self.w_max = kl_w_end
        self.n_epoch = n_epoch
        self.inc = (self.w_max - self.w_start) / (self.n_epoch - self.i_start)

    def __call__(self, i):
        k = max(0, i - self.i_start)
        return self.w_start + k * self.inc

In [37]:
# 定义带重启的余弦退火学习率调度器
class CosineAnnealingLRWithRestart(_LRScheduler):
    def __init__(self, optimizer):
        self.n_period = lr_n_period
        self.n_mult = lr_n_mult
        self.lr_end = lr_end
        self.current_epoch = 0
        self.t_end = self.n_period
        super().__init__(optimizer)

    def get_lr(self):
        return [self.lr_end + (base_lr - self.lr_end) * 
                (1 + math.cos(math.pi * self.current_epoch / self.t_end)) / 2
                for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        self.current_epoch += 1

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

        if self.current_epoch >= self.t_end:
            self.current_epoch = 0
            self.t_end *= self.n_mult


In [38]:
# 定义循环缓冲区，用于计算平均损失
class CircularBuffer:
    def __init__(self, size):
        self.max_size = size
        self.data = np.zeros(self.max_size)
        self.size = 0
        self.pointer = -1

    def add(self, element):
        self.size = min(self.size + 1, self.max_size)
        self.pointer = (self.pointer + 1) % self.max_size
        self.data[self.pointer] = element

    def mean(self):
        return self.data.mean()

In [39]:
# 定义日志记录器
class Logger(UserList):
    def __init__(self, data=None):
        super().__init__()
        self.sdata = defaultdict(list)
        for step in (data or []):
            self.append(step)

    def __getitem__(self, key):
        if isinstance(key, int):
            return self.data[key]
        elif isinstance(key, slice):
            return Logger(self.data[key])
        else:
            return self.sdata[key]

    def append(self, step_dict):
        super().append(step_dict)
        for k, v in step_dict.items():
            self.sdata[k].append(v)

In [41]:
# 定义训练函数
# 这个_train_epoch函数可以被其它函数调用，用于训练一个 epoch
# 这个_train函数可以被其它函数调用，用于训练整个模型
# 这个fit函数可以被其它函数调用，用于训练整个模型
def _train_epoch(model, epoch, train_loader, kl_weight, optimizer=None):
    if optimizer is None:
        model.eval()
    else:
        model.train()

    kl_loss_values = CircularBuffer(n_last)
    recon_loss_values = CircularBuffer(n_last)
    loss_values = CircularBuffer(n_last)

    for i, input_batch in enumerate(train_loader):
        input_batch = [data.to(device) for data in input_batch]
        kl_loss, recon_loss = model(input_batch)
        loss = kl_weight * kl_loss + recon_loss

        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()

        kl_loss_values.add(kl_loss.item())
        recon_loss_values.add(recon_loss.item())
        loss_values.add(loss.item())
        lr = optimizer.param_groups[0]['lr'] if optimizer is not None else None

    kl_loss_value = kl_loss_values.mean()
    recon_loss_value = recon_loss_values.mean()
    loss_value = loss_values.mean()
    postfix = {
        'epoch': epoch,
        'kl_weight': kl_weight,
        'lr': lr,
        'kl_loss': kl_loss_value,
        'recon_loss': recon_loss_value,
        'loss': loss_value,
        'mode': 'Eval' if optimizer is None else 'Train'
    }
    return postfix

def _train(model, train_loader, val_loader=None, logger=None):
    optimizer = optim.Adam(model.parameters(), lr=lr_start)
    lr_annealer = CosineAnnealingLRWithRestart(optimizer)
    kl_annealer = KLAnnealer(n_epoch)
    model.zero_grad()

    for epoch in tqdm(range(n_epoch), desc='Training', unit='epoch'):
        kl_weight = kl_annealer(epoch)
        postfix = _train_epoch(model, epoch, train_loader, kl_weight, optimizer)
        df_losses.loc[len(df_losses.index)] = [
            postfix['epoch'], postfix['kl_weight'], postfix['lr'],
            postfix['kl_loss'], postfix['recon_loss'], postfix['loss']
        ]
        lr_annealer.step()

def fit(model, train_data, val_data=None):
    logger = Logger()
    train_loader = get_dataloader(model, train_data, shuffle=True)
    val_loader = None if val_data is None else get_dataloader(model, val_data, shuffle=False)
    _train(model, train_loader, val_loader, logger)
    return model

In [42]:
# 定义数据加载器相关函数
# 这get_collate_device函数可以被其它函数调用，用于获取模型所在的设备
def get_collate_device(model):
    return model.device()

# 这get_dataloader函数可以被其它函数调用，用于获取数据加载器
def get_dataloader(model, data, collate_fn=None, shuffle=True):
    if collate_fn is None:
        collate_fn = get_collate_fn(model)
    return DataLoader(data, batch_size=n_batch, shuffle=shuffle, num_workers=n_workers, collate_fn=collate_fn)

# 这get_collate_fn函数可以被其它函数调用，用于获取数据批次处理函数
def get_collate_fn(model):
    device = get_collate_device(model)
    def collate(batch):
        batch.sort(key=len, reverse=True)
        tensors = [string2tensor(string, device=device) for string in batch]
        return tensors
    return collate

In [43]:
# 这get_optim_params函数可以被其它函数调用，用于获取可训练参数
def get_optim_params(model):
    return (p for p in model.parameters() if p.requires_grad)


In [44]:
# 定义采样函数
class Sample:
    @staticmethod
    def take_samples(model, n_batch, n_samples=1000, max_len=100):
        n = n_samples
        samples = []
        with tqdm(total=n_samples, desc='Generating samples') as T:
            while n > 0:
                current_samples = model.sample(min(n, n_batch), max_len)
                samples.extend(current_samples)
                n -= len(current_samples)
                T.update(len(current_samples))
        samples = pd.DataFrame(samples, columns=['SMILES'])
        return samples

In [45]:
# 初始化模型
model = VAE(vocab, vector).to(device)

In [None]:
# 训练模型
fit(model, train_data)

Training:   0%|          | 0/100 [00:00<?, ?epoch/s]

In [None]:
# 保存模型
import os
os.makedirs('checkpoints', exist_ok=True)
torch.save(model.state_dict(), f'checkpoints/vae_model_epoch{n_epoch}.pt')

In [None]:
# 加载模型（如果需要）
# model.load_state_dict(torch.load(f'checkpoints/vae_model_epoch{n_epoch}.pt'))


In [None]:
# 生成样本
model.eval()
df_sample = Sample.take_samples(model, n_batch)
print(df_sample)

In [None]:
# 保存生成的分子
os.makedirs('generated_molecules/vae', exist_ok=True)
df_sample.to_csv(f'generated_molecules/vae/vae_epoch{n_epoch}.csv', index=False)

In [None]:
# 保存损失值
os.makedirs('checkpoints/losses', exist_ok=True)
df_losses.to_csv(f'checkpoints/losses/vae_epoch{n_epoch}.csv', index=False)

In [None]:
# 绘制损失值图
import matplotlib.pyplot as plt

def plot_losses(df_losses, title):
    plt.figure(figsize=(12, 8))
    plt.plot(df_losses['epoch'], df_losses['kl_loss'], label='KL loss')
    plt.plot(df_losses['epoch'], df_losses['recon_loss'], label='Recon loss')
    plt.plot(df_losses['epoch'], df_losses['loss'], label='Total loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(title)
    plt.legend()
    plt.show()

plot_losses(df_losses, f'VAE (epoch {n_epoch})')

In [None]:
# 绘制生成的分子
import rdkit
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole

def plot_molecules(df_sample, n_rows=10, n_cols=10):
    plt.figure(figsize=(12, 12))
    for i, smiles in enumerate(df_sample['SMILES']):
        mol = rdkit.Chem.MolFromSmiles(smiles)
        if mol is not None:
            plt.subplot(n_rows, n_cols, i+1)
            Draw.MolToImage(mol)
            plt.title(smiles)
    plt.show()

plot_molecules(df_sample, n_rows=10, n_cols=10) # 绘制前 1000 个分子

In [None]:
# 绘制训练过程中的 KL 权重
plt.figure(figsize=(12, 8))
plt.plot(df_losses['epoch'], df_losses['kl_weight'], label='KL weight')
plt.xlabel('Epoch')
plt.ylabel('KL weight')
plt.title(f'VAE (epoch {n_epoch})')
plt.legend()
plt.show()  

In [None]:
# 绘制训练过程中的学习率
plt.figure(figsize=(12, 8))
plt.plot(df_losses['epoch'], df_losses['lr'], label='Learning rate')
plt.xlabel('Epoch')
plt.ylabel('Learning rate')
plt.title(f'VAE (epoch {n_epoch})')
plt.legend()
plt.show()  

In [None]:
# 绘制训练过程中的 KL 损失
plt.figure(figsize=(12, 8))
plt.plot(df_losses['epoch'], df_losses['kl_loss'], label='KL loss')
plt.xlabel('Epoch')
plt.ylabel('KL loss')
plt.title(f'VAE (epoch {n_epoch})')
plt.legend()
plt.show()  

In [None]:
# 绘制训练过程中的重构损失
plt.figure(figsize=(12, 8))
plt.plot(df_losses['epoch'], df_losses['recon_loss'], label='Recon loss')
plt.xlabel('Epoch')
plt.ylabel('Recon loss')
plt.title(f'VAE (epoch {n_epoch})')
plt.legend()
plt.show()  

In [None]:
# 绘制训练过程中的总损失
plt.figure(figsize=(12, 8))
plt.plot(df_losses['epoch'], df_losses['loss'], label='Total loss')
plt.xlabel('Epoch')
plt.ylabel('Total loss')
plt.title(f'VAE (epoch {n_epoch})')
plt.legend()
plt.show()      

In [None]:
# 绘制训练过程中的 KL 权重与学习率
plt.figure(figsize=(12, 8))
plt.plot(df_losses['epoch'], df_losses['kl_weight'], label='KL weight')
plt.plot(df_losses['epoch'], df_losses['lr'], label='Learning rate')
plt.xlabel('Epoch')
plt.ylabel('Weight/Learning rate')  
plt.title(f'VAE (epoch {n_epoch})')
plt.legend()
plt.show()  

In [None]:
# 绘制训练过程中的 KL 损失与重构损失
plt.figure(figsize=(12, 8))
plt.plot(df_losses['epoch'], df_losses['kl_loss'], label='KL loss')
plt.plot(df_losses['epoch'], df_losses['recon_loss'], label='Recon loss')
plt.xlabel('Epoch')
plt.ylabel('KL/Recon loss')  
plt.title(f'VAE (epoch {n_epoch})')
plt.legend()
plt.show()  

In [None]:
# 绘制训练过程中的 KL 损失、重构损失与总损失
plt.figure(figsize=(12, 8))
plt.plot(df_losses['epoch'], df_losses['kl_loss'], label='KL loss')
plt.plot(df_losses['epoch'], df_losses['recon_loss'], label='Recon loss')
plt.plot(df_losses['epoch'], df_losses['loss'], label='Total loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')  
plt.title(f'VAE (epoch {n_epoch})')
plt.legend()
plt.show()  