In [1]:
from huggingface_hub import snapshot_download
import torch 
from torch import nn
from torch.nn import functional as F
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import os
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
from VAEs.VAE import VAE
import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
snapshot_dir = snapshot_download(repo_id='google/gemma-2-2b-jpn-it-pytorch')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(snapshot_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(snapshot_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

gemma_model_config = get_model_config("2b-v2")
gemma_model_config.tokenizer = tokenizer_path

# Instantiate the model and load the weights.
torch.set_default_dtype(gemma_model_config.get_dtype())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gemma_model = GemmaForCausalLM(gemma_model_config)
gemma_model.requires_grad_(False)
gemma_model.load_weights(ckpt_path)
gemma_model = gemma_model.to(device).eval()

Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 74898.29it/s]


In [3]:
def prepare_embeddings(texts):
    if isinstance(texts, str):
        texts = [texts]
    texts = ["文章: " + text for text in texts]
    return texts
# Download from the 🤗 Hub
bert_model = SentenceTransformer("cl-nagoya/ruri-large")
bert_model.requires_grad_(False)
bert_model = bert_model.to('cuda')
print(bert_model.get_sentence_embedding_dimension())

1024


In [4]:
from pathlib import Path

class TrainingConfig:
    def __init__(self):
        # データ設定
        self.batch_size = 4
        self.max_seq_len = 256
        self.num_workers = 4
        
        # 最適化設定
        self.lr = 1e-4
        self.num_epochs = 10
        self.grad_accum_steps = 4
        self.beta_init = 0.0
        self.beta_max = 0.4
        self.beta_step = 1e-5
        self.crop_lambda = 0.2

        # モデル設定
        self.bert_model_name = "cl-nagoya/ruri-large"
        self.gemma_model_size = "2b-v2"
        self.vae_hidden_dim = 512
        self.vae_latent_dim = 128
        
        # 生成設定
        self.sample_interval = 1000
        self.num_samples = 3
        self.max_gen_length = 100
        self.generation_temp = 0.2
        self.generation_top_p = 0.9
        self.generation_top_k = 20
        
        self.ckpt_interval = 5000
        # パス設定
        self.log_dir = f"./logs/{datetime.datetime.now()}/"
        self.checkpoint_dir = "./checkpoints"
        self.dataset_path = "AhmedSSabir/Japanese-wiki-dump-sentence-dataset"
        
        # 初期化
        self._setup_directories()
        
    def _setup_directories(self):
        Path(self.log_dir).mkdir(parents=True, exist_ok=True)
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)


In [1]:
config = TrainingConfig()
a = torch.load("../checkpoints/checkpoint_epoch_1_step_2025-01-29 22:27:17.220248_5000.pt")

config.__dict__.update(a[''])

NameError: name 'TrainingConfig' is not defined

In [8]:
vae_hidden_size = 512
vae_latent_size = 128
gemma_model.config.hidden_size = gemma_model_config.hidden_size
vae_model = VAE(bert_model.get_sentence_embedding_dimension(), gemma_model_config.hidden_size,vae_hidden_size,vae_latent_size)
vae_model = vae_model.to('cuda')

In [None]:
from datasets import load_dataset

ds = load_dataset("AhmedSSabir/Japanese-wiki-dump-sentence-dataset",cache_dir="./.datasets")
print(ds.shape)
print(ds['train'][0]['text'])
#テネシー大学、デューク大学、フロリダ大学などからのオファーもある中、彼が選んだのはノートルダム大学であった。

In [None]:
ckpt_dict = torch.load("../checkpoints/checkpoint_epoch_1_step_2025-01-25 10:36:03.710352_60000.pt",map_location='cuda')

In [None]:

vae_model.load_state_dict(ckpt_dict['model_state'])

In [None]:
# %% データローダーの設定
from torch.utils.data import DataLoader
from schedulefree import RAdamScheduleFree
import torch

def collate_fn(batch):
    texts = [item['text'] for item in batch]
    return texts

BATCH_SIZE = 4  # VRAMに応じて調整
train_loader = DataLoader(
    ds['train'],
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

# %% 最適化手法の設定
optimizer = RAdamScheduleFree(vae_model.parameters(), lr=1e-4)

# KL損失のアニーリング用パラメータ
beta = 0.1  # 初期値
BETA_MAX = 0.8  # 最大値
BETA_STEP = 1e-5  # ステップごとの増加量

# %% 混合精度トレーニングの設定
scaler = torch.amp.GradScaler()

# %% トレーニングループ
NUM_EPOCHS = 10
GRAD_ACCUM_STEPS = 4  # 勾配累積ステップ数

for epoch in range(NUM_EPOCHS):
    vae_model.train()
    total_loss = 0.0
    total_kl = 0.0
    total_recon = 0.0
    
    for step, batch_texts in enumerate(train_loader):
        # BERTで埋め込みを取得
        with torch.no_grad():
            prep_texts = prepare_embeddings(batch_texts)
            bert_embeddings = bert_model.encode(prep_texts, 
                                               convert_to_tensor=True,
                                               device=device)
        
        # 混合精度コンテキスト
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            # VAEのフォワード
            vae_output, mu, logvar = vae_model(bert_embeddings)
            
            # KLダイバージェンス計算
            kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            kl_div = kl_div / bert_embeddings.size(0)  # バッチ正規化
            
            # Gemmaの損失計算
            recon_loss = gemma_model.forward_teacher_forcing(
                vae_output, 
                batch_texts,
                max_seq_len=256  # データセットに応じて調整
            )
            
            # 合計損失
            loss = recon_loss + beta * kl_div
        
        # 勾配累積
        scaler.scale(loss).backward()
        
        # ロギング
        total_loss += loss.item()
        total_kl += kl_div.item()
        total_recon += recon_loss.item()
        
        if step % 100 == 0:
            avg_loss = total_loss / (step + 1)
            avg_kl = total_kl / (step + 1)
            avg_recon = total_recon / (step + 1)
            
            print(f"Epoch {epoch+1} | Step {step} | "
                  f"Loss: {avg_loss:.4f} | KL: {avg_kl:.4f} | "
                  f"Recon: {avg_recon:.4f} | Beta: {beta:.4f}")
    
    # エポック終了時の処理
    epoch_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Completed | Avg Loss: {epoch_loss:.4f}")
    
    # モデルの保存（任意）
    if (epoch + 1) % 2 == 0:
        torch.save(vae_model.state_dict(), f"vae_model_epoch_{epoch+1}.pth")

In [22]:
import datetime
torch.save(vae_model.state_dict(), f"vae_model_epoch_{epoch+1}_{datetime.datetime.now()}_.pth")

In [None]:
b_vec=bert_model.encode(["文章:アメンボ赤いな"],return_tensors=True,device='cuda')
b_vec=vae_model(torch.Tensor(b_vec).to('cuda'))[0]
b_vec=b_vec.unsqueeze(0)
gemma_model.generate_with_initial_embedding(b_vec,device='cuda',output_len=256,temperature=0.01)