In [1]:
from collections import OrderedDict
from transformers import AlbertConfig, AlbertModel
import torch
import yaml

In [3]:
orig_path = '../StyleTTS2_string-symbols/Utils/PLBERT/step_1000000.t7'
my_path = 'models/cz-phon-sentences/step_1000000.t7'

In [4]:
class CustomAlbert(AlbertModel):
    def forward(self, *args, **kwargs):
        # Call the original forward method
        outputs = super().forward(*args, **kwargs)

        # Only return the last_hidden_state
        return outputs.last_hidden_state

In [5]:
def load_plbert(model_file, config_file):
    plbert_config = yaml.safe_load(open(config_file))
    
    albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
    bert = CustomAlbert(albert_base_configuration)

    checkpoint = torch.load(model_file, map_location='cpu')
    state_dict = checkpoint['net']
    step = checkpoint['step']
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        if name.startswith('encoder.'):
            name = name[8:] # remove `encoder.`
            new_state_dict[name] = v
    if "embeddings.position_ids" in new_state_dict:
        del new_state_dict["embeddings.position_ids"]
    bert.load_state_dict(new_state_dict, strict=False)
    return bert, step

In [6]:
def save_model(model, path, step=None):
    state_dict = {
        'net': model.state_dict()
    }
    if step:
        state_dict['step'] = step
    torch.save(state_dict, path)

In [7]:
def get_model_size(model):
    param_size = 0
    buffer_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_in_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Velikost modelu: {size_in_mb:.2f} MB")

# Příklad použití:
# model = Váš načtený model
# get_model_size(model)

In [9]:
my_model, my_step = load_plbert(my_path, 'models/cz-phon-sentences/config.yml')
my_step

999999

In [64]:
orig_model, orig_step = load_plbert('../StyleTTS2_string-symbols/Utils/PLBERT/step_1000000.t7', '../StyleTTS2_string-symbols/Utils/PLBERT/config.yml')
orig_step

1000000

In [65]:
get_model_size(my_model)

Velikost modelu: 24.01 MB


In [66]:
get_model_size(orig_model)

Velikost modelu: 24.01 MB


In [10]:
save_model(my_model, 'models/cz-phon-sentences/step_1000000.reduced.t7', my_step+1)