In [None]:
from huggingface_hub import snapshot_download
import torch 
from torch import nn
from torch.nn import functional as F
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import os
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer


In [2]:
class VAE(nn.Module):
    def __init__(self, bert_output_dim, gpt_input_dim, hidden_dim=400, latent_dim=32):
        super(VAE, self).__init__()
        self.input_dim = bert_output_dim
        self.output_dim = gpt_input_dim
        
        # Encoder
        self.fc1 = nn.Linear(bert_output_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)  # μ
        self.fc22 = nn.Linear(hidden_dim, latent_dim)  # log(σ^2)
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, gpt_input_dim)
    
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
snapshot_dir = snapshot_download(repo_id='google/gemma-2-2b-jpn-it-pytorch')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(snapshot_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(snapshot_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

gemma_model_config = get_model_config("2b-v2")
gemma_model_config.tokenizer = tokenizer_path

# Instantiate the model and load the weights.
torch.set_default_dtype(gemma_model_config.get_dtype())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gemma_model = GemmaForCausalLM(gemma_model_config)
gemma_model.requires_grad_(False)
gemma_model.load_weights(ckpt_path)
gemma_model = gemma_model.to(device).eval()

In [None]:
def prepare_embeddings(texts):
    if isinstance(texts, str):
        texts = [texts]
    texts = ["文章: " + text for text in texts]
    return texts
# Download from the 🤗 Hub
bert_model = SentenceTransformer("cl-nagoya/ruri-large")
bert_model.requires_grad_(False)
bert_model = bert_model.to('cuda')
print(bert_model.get_sentence_embedding_dimension())

In [5]:
vae_hidden_size = 400
vae_latent_size = 32
vae_model = VAE(bert_model.get_sentence_embedding_dimension(), gemma_model_config.hidden_size,vae_hidden_size,vae_latent_size)
vae_model = vae_model.to('cuda')

In [6]:
texts = ["Hello world!", "Machine learning is fun!"]
vae_emb = torch.randn(2,gemma_model.config.hidden_size)  # hidden_size=2560の場合
loss = gemma_model.forward_teacher_forcing(vae_emb, texts)

In [None]:
from datasets import load_dataset

ds = load_dataset("AhmedSSabir/Japanese-wiki-dump-sentence-dataset",cache_dir="./.datasets")
print(ds.shape)
print(ds['train'][0]['text'])
#テネシー大学、デューク大学、フロリダ大学などからのオファーもある中、彼が選んだのはノートルダム大学であった。

In [None]:
# %% データローダーの設定
from torch.utils.data import DataLoader
from schedulefree import RAdamScheduleFree
import torch

def collate_fn(batch):
    texts = [item['text'] for item in batch]
    return texts

BATCH_SIZE = 4  # VRAMに応じて調整
train_loader = DataLoader(
    ds['train'],
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

# %% 最適化手法の設定
optimizer = RAdamScheduleFree(vae_model.parameters(), lr=1e-4)

# KL損失のアニーリング用パラメータ
beta = 0.1  # 初期値
BETA_MAX = 0.8  # 最大値
BETA_STEP = 1e-5  # ステップごとの増加量

# %% 混合精度トレーニングの設定
scaler = torch.amp.GradScaler()

# %% トレーニングループ
NUM_EPOCHS = 10
GRAD_ACCUM_STEPS = 4  # 勾配累積ステップ数

for epoch in range(NUM_EPOCHS):
    vae_model.train()
    total_loss = 0.0
    total_kl = 0.0
    total_recon = 0.0
    
    for step, batch_texts in enumerate(train_loader):
        # BERTで埋め込みを取得
        with torch.no_grad():
            prep_texts = prepare_embeddings(batch_texts)
            bert_embeddings = bert_model.encode(prep_texts, 
                                               convert_to_tensor=True,
                                               device=device)
        
        # 混合精度コンテキスト
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            # VAEのフォワード
            vae_output, mu, logvar = vae_model(bert_embeddings)
            
            # KLダイバージェンス計算
            kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            kl_div = kl_div / bert_embeddings.size(0)  # バッチ正規化
            
            # Gemmaの損失計算
            recon_loss = gemma_model.forward_teacher_forcing(
                vae_output, 
                batch_texts,
                max_seq_len=256  # データセットに応じて調整
            )
            
            # 合計損失
            loss = recon_loss + beta * kl_div
        
        # 勾配累積
        scaler.scale(loss).backward()
        
        # ロギング
        total_loss += loss.item()
        total_kl += kl_div.item()
        total_recon += recon_loss.item()
        
        if step % 100 == 0:
            avg_loss = total_loss / (step + 1)
            avg_kl = total_kl / (step + 1)
            avg_recon = total_recon / (step + 1)
            
            print(f"Epoch {epoch+1} | Step {step} | "
                  f"Loss: {avg_loss:.4f} | KL: {avg_kl:.4f} | "
                  f"Recon: {avg_recon:.4f} | Beta: {beta:.4f}")
    
    # エポック終了時の処理
    epoch_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Completed | Avg Loss: {epoch_loss:.4f}")
    
    # モデルの保存（任意）
    if (epoch + 1) % 2 == 0:
        torch.save(vae_model.state_dict(), f"vae_model_epoch_{epoch+1}.pth")