# Initialization

In [11]:
import torch
import torch.nn as nn
import torch.distributions as D
import time
import random
import numpy as np

# seed = 1234
# torch.manual_seed(seed)
# random.seed(seed)
# np.random.seed(seed)

# Mixture Gaussians

In [12]:
def mix_gaussians(components_num, dim_num):
    # random.weights
    raw_weights = torch.rand(components_num)
    weights = raw_weights / raw_weights.sum()

    # random.means
    means = torch.randn(components_num, dim_num)

    # random.covs
    covs = []
    epsilon = 1e-3
    for _ in range(components_num):
        A = torch.randn(dim_num, dim_num)
        cov = A @ A.t() + epsilon * torch.eye(dim_num)
        covs.append(cov)

    # create components
    components = []
    for i in range(components_num):
        dist = D.MultivariateNormal(loc=means[i], covariance_matrix=covs[i])
        components.append(dist)

    return components

# Generate Samples

In [None]:
def sample_sequence_pairs(comp1, comp2, components_num, N, seq_len):
    """
    comp1, comp2: list[MultivariateNormal]，每个长度=components_num
    latent_indices: shape (N, seq_len)，表示每个样本在每个时刻要用哪个component

    返回:
      samples1: (N, seq_len, dim)
      samples2: (N, seq_len, dim)
    """
    latent_indices = torch.randint(0, components_num, (N, seq_len))
    N, seq_len = latent_indices.shape
    dim = comp1[0].event_shape[0]  # 每个component的输出维度相同

    samples1 = torch.empty(N, seq_len, dim)
    samples2 = torch.empty(N, seq_len, dim)

    # 逐个样本、逐个时间步采样
    for i in range(N):
        for t in range(seq_len):
            idx = latent_indices[i, t].item()
            samples1[i, t] = comp1[idx].sample()  # shape: (dim,)
            samples2[i, t] = comp2[idx].sample()

    return samples1, samples2

In [10]:
# hyperparameters
N = 1               # number of pairs
components_num = 10  # number of components
dim_num = 10         # dimension of each component

# generate two mixtures of Gaussians
cat1, comp1 = mix_gaussians(components_num, dim_num)
cat2, comp2 = mix_gaussians(components_num, dim_num)



samples1 shape: torch.Size([1, 10])
samples2 shape: torch.Size([1, 10])


In [None]:
d_model = 64        # model dimension
nhead = 8           # multihead attention head number
num_layers = 2      # Transformer layer number
dim_feedforward = 256  # feedforward dimension

# Transformer encoder layer
encoder_layer = nn.TransformerEncoderLayer(
    d_model=d_model,
    nhead=nhead,
    dim_feedforward=dim_feedforward,
    batch_first=True # (batch_size, seq_len, d_model)
)

# stack Encoder Layer
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

# input (batch_size, seq_len, input_dim)
# 这里我们先用一个线性层将输入映射到 d_model 维度
input_dim = 10
projection = nn.Linear(input_dim, d_model)

# 构造一个示例 batch
batch_size = 32
seq_len = 20
dummy_input = torch.randn(batch_size, seq_len, input_dim)
projected_input = projection(dummy_input)  # 映射到 (batch_size, seq_len, d_model)

# 得到 Transformer 编码器的输出
encoded_output = transformer_encoder(projected_input)  # (batch_size, seq_len, d_model)
print(encoded_output.shape)