In [24]:
import copy
import math
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

In [28]:
from transformers.utils import (
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)

In [22]:
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)

In [5]:
from decoder import STLPreTrainedModel
from utils2 import STLConfig
from encoder import STLEncoder
from decoder import STLDecoder

In [30]:
_CONFIG_FOR_DOC = "STLConfig"

In [23]:
class MarianModel(STLPreTrainedModel):
    _tied_weights_keys = ["decoder.embed_tokens.weight"]

    def __init__(self, config: STLConfig):
        super().__init__(config)

        padding_idx, vocab_size = config.pad_token_id, config.vocab_size

        # Embedding condiviso solo per il decoder
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
        if self.config.share_encoder_decoder_embeddings:
            decoder_embed_tokens = self.shared
        else:
            # Se gli embeddings non sono condivisi, facciamo una copia per il decoder
            decoder_embed_tokens = copy.deepcopy(self.shared)
            self.shared = None

        # Decoder-only: senza encoder
        self.decoder = STLDecoder(config)

        # Inizializzazione dei pesi
        self.post_init()

    def get_decoder(self):
        return self.decoder

    def resize_decoder_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
        # La dimensione del vocabolario del decoder
        old_embeddings = self.get_decoder_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.set_decoder_input_embeddings(new_embeddings)

        model_embeds = self.get_decoder_input_embeddings()

        if new_num_tokens is None:
            return model_embeds

        # Aggiorna la configurazione del modello
        self.config.decoder_vocab_size = new_num_tokens

        # Tie weights, se necessario
        self.tie_weights()

        return model_embeds

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.Tensor:
        r"""
        Metodo di forward per il modello decoder-only.

        In un modello decoder-only, gli `input_ids` vengono utilizzati solo per generare la sequenza,
        senza un encoder che fornisca rappresentazioni da un'altra sequenza (come nel caso di traduzione o altre attività seq2seq).
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Passaggio attraverso il decoder-only
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            head_mask=decoder_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs

        # Restituisci l'output come dizionario
        return decoder_outputs


In [24]:
config = STLConfig()
test = MarianModel(config)

In [28]:
# Creazione degli input per il decoder
input_ids = torch.randint(0, config.vocab_size, (1, 10))  # Batch di dimensione 1, sequenza di lunghezza 10
decoder_input_ids = torch.randint(0, config.vocab_size, (1, 10))  # Sequenza di lunghezza 10 per il decoder

# Creazione di una maschera di attenzione
attention_mask = torch.ones(1, 10)  # Tutti i token sono validi

# Chiamata al metodo forward
outputs = test(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=attention_mask,
    return_dict=True  # Restituisce un dizionario
)

# Visualizzazione dell'output
print("Output del modello ottenuto")
# print(outputs)

Output del modello ottenuto


In [54]:
from transformers.generation import GenerationMixin

class STLForCausalLM(STLPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        config = copy.deepcopy(config)
        config.is_decoder = True
        config.is_encoder_decoder = False
        
        super().__init__(config)
        self.model = STLDecoder(config)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            head_mask=head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        logits = self.lm_head(outputs[0])

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past

In [33]:
config = STLConfig()
test = MarianForCausalLM(config)

In [49]:
batch_size = 10
seq_length = 20
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length))
labels = torch.randint(0, config.vocab_size, (batch_size, seq_length))

# Run the model
outputs = test(input_ids=input_ids, labels=labels, return_dict = False)

In [51]:
from transformers import AutoConfig, AutoModel

In [53]:
AutoConfig.register("STLdec", STLConfig)

In [55]:
AutoModel.register(STLConfig, STLForCausalLM)