In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_name ="jhu-clsp/mmBERT-base"
#model_name = "deepvk/USER2-small"

In [None]:
tokenizer_orig = AutoTokenizer.from_pretrained(model_name)
model_orig = AutoModelForMaskedLM.from_pretrained(model_name)

In [None]:
from improved_collator import ImprovedUL2Collator

collator = ImprovedUL2Collator(
        tokenizer=tokenizer_orig,
        max_input_length=512,
        max_target_length=512,
        ul2_denoiser_probs=[0.5, 0.25, 0.25],
        r_denoiser_suffix_ratio=0.25,
        s_denoiser_corrupt_prob=0.15,
        x_denoiser_corrupt_prob=0.5,
        mean_span_length=3,
    )

In [None]:
# Resize model embeddings to match tokenizer
model_orig.resize_token_embeddings(len(collator.tokenizer))

In [None]:
train_dataset = "./final_pretrain_mix"

In [None]:
from datasets import load_dataset, load_from_disk
print("Loading datasets...")
train_dataset = load_from_disk(train_dataset)

In [None]:
# ВАЖНО! ПРОВЕРЬТЕ СУЩЕСТВУЮТ ЛИ У МОДЕЛЕЙ токены EOS и т.д
-tokenizer_orig

# Clean configuration
config_keys_to_update = {
    'decoder_start_token_id': tokenizer_orig.bos_token_id or tokenizer_orig.cls_token_id or 0,
    'pad_token_id': tokenizer_orig.pad_token_id or 0,
    'bos_token_id': tokenizer_orig.bos_token_id or tokenizer_orig.cls_token_id,
    'eos_token_id': tokenizer_orig.eos_token_id or tokenizer_orig.sep_token_id,
    'tie_word_embeddings': True,
    'is_encoder_decoder': True,
    'num_decoder_layers': 1,
}

In [None]:
import torch
from transformers import AutoTokenizer
from fixed_modernt5 import ModernBertModel, ModernT5Config, ModernT5ForConditionalGeneration

# Get the actual encoder model (handle different architectures)
if hasattr(model_orig, 'bert'):
    encoder_model = model_orig.bert
elif hasattr(model_orig, 'roberta'):
    encoder_model = model_orig.roberta
elif hasattr(model_orig, 'model'):
    encoder_model = model_orig.model
else:
    raise ValueError(f"Cannot find encoder in model of type {type(model_orig)}")

# Configure ModernT5
encoder_config_dict = model_orig.config.to_dict()

encoder_config_dict.update(config_keys_to_update)
config = ModernT5Config(**encoder_config_dict)

# Initialize model
model = ModernT5ForConditionalGeneration(config)
model.model.encoder = encoder_model

print("Tying word embeddings between encoder, decoder, and LM head...")
# Ensure the embeddings are shared across the new encoder, the decoder, and the LM head.
# We get the embeddings from the new encoder and set them for the whole model.
model.set_input_embeddings(model.get_encoder().get_input_embeddings())

In [None]:
#Final Steps (Verification, Saving)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model moved to {device.upper()}")

# Print parameter count
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params / 1e6:.2f} M")

# Save the composed model and the tokenizer
save_directory = "./modernt5_from_mmBERT-base_e21_d1"
print(f"Saving model and tokenizer to {save_directory}...")
model.save_pretrained(save_directory)
collator.tokenizer.save_pretrained(save_directory)
print("Save complete.")

In [None]:
# --- 5. Load and Test ---
print(f"\nLoading model from {save_directory}...")
loaded_model = ModernT5ForConditionalGeneration.from_pretrained(save_directory)
loaded_model.to(device)
print("Model loaded successfully from checkpoint.")

# Test with a dummy forward pass
print("Performing a dummy forward pass...")
dummy_input_ids = torch.randint(0, config.vocab_size, (2, 16)).to(device)
dummy_labels = torch.randint(0, config.vocab_size, (2, 10)).to(device)
dummy_output = loaded_model(input_ids=dummy_input_ids, labels=dummy_labels)
print(f"Dummy forward pass successful. Loss: {dummy_output.loss.item():.4f}")

# Test generation
print("\nPerforming a dummy generation...")
# Use the tokenizer to prepare input
prompt = "Перевод с русского на английский: как дела?"
inputs = collator.tokenizer(prompt, return_tensors="pt").to(device)

# Generate output
generated_ids = loaded_model.generate(**inputs, max_length=50)
decoded_text = collator.tokenizer.decode(generated_ids[0], skip_special_tokens=True)

print(f"Input: '{prompt}'")
print(f"Generated output: '{decoded_text}'")