In [78]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [79]:

tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")

model

M2M100ForConditionalGeneration(
  (model): M2M100Model(
    (shared): M2M100ScaledWordEmbedding(256206, 1024, padding_idx=1)
    (encoder): M2M100Encoder(
      (embed_tokens): M2M100ScaledWordEmbedding(256206, 1024, padding_idx=1)
      (embed_positions): M2M100SinusoidalPositionalEmbedding()
      (layers): ModuleList(
        (0-11): 12 x M2M100EncoderLayer(
          (self_attn): M2M100SdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
   

In [80]:
encoder = model.model.encoder.to(device)
decoder = model.model.decoder.to(device)

## Explore the encoder

In [81]:
encoder

M2M100Encoder(
  (embed_tokens): M2M100ScaledWordEmbedding(256206, 1024, padding_idx=1)
  (embed_positions): M2M100SinusoidalPositionalEmbedding()
  (layers): ModuleList(
    (0-11): 12 x M2M100EncoderLayer(
      (self_attn): M2M100SdpaAttention(
        (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (activation_fn): ReLU()
      (fc1): Linear(in_features=1024, out_features=4096, bias=True)
      (fc2): Linear(in_features=4096, out_features=1024, bias=True)
      (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
  )
  (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)

## Explore the decoder

In [82]:
decoder

M2M100Decoder(
  (embed_tokens): M2M100ScaledWordEmbedding(256206, 1024, padding_idx=1)
  (embed_positions): M2M100SinusoidalPositionalEmbedding()
  (layers): ModuleList(
    (0-11): 12 x M2M100DecoderLayer(
      (self_attn): M2M100SdpaAttention(
        (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (activation_fn): ReLU()
      (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): M2M100SdpaAttention(
        (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=True

In [83]:
sentences = "I am playing video game now"

tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt").to(device)

tokenized_sentences

{'input_ids': tensor([[256047,    117,    259, 106186,   7826,  10095,  10643,      2]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [84]:
encoded_output = encoder(input_ids=tokenized_sentences.input_ids,
    attention_mask=tokenized_sentences.attention_mask,  # if you have an attention mask
    return_dict=True
)

encoded_output.last_hidden_state.shape

torch.Size([1, 8, 1024])

In [94]:
decoder_input_ids = torch.tensor([[2, tokenizer.convert_tokens_to_ids("zho_Hant")]], dtype=torch.long).to(device)

# 3. Autoregressive generation
max_length = 30
generated_tokens = decoder_input_ids

encoder_last_hidden_state = encoded_output.last_hidden_state[:, -1, :]
encoder_last_hidden_state = torch.unsqueeze(encoder_last_hidden_state, dim=0)
print(encoder_last_hidden_state.shape)

for _ in range(max_length - 1):
    # Get decoder outputs for current tokens
    decoded_outputs = model.model.decoder(
        input_ids=generated_tokens,
        encoder_hidden_states=encoded_output.last_hidden_state,
        encoder_attention_mask=tokenized_sentences.attention_mask,
        return_dict=True
    )
    
    # Project to vocabulary
    lm_logits = model.lm_head(decoded_outputs.last_hidden_state)
    
    # Get next token
    next_token = torch.argmax(lm_logits[:, -1, :], dim=-1, keepdim=True)
    
    # Append to generated tokens
    generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)
    
    # Break if EOS token is generated
    if next_token.item() == tokenizer.eos_token_id:
        break

generated_tokens

torch.Size([1, 1, 1024])


tensor([[     2, 256201,   4071,  11337, 117970, 252843, 250219, 250680, 251690,
         253914,      2]], device='cuda:0')

In [86]:
lm_logits = model.lm_head(decoded_output.last_hidden_state)

lm_logits.shape

torch.Size([2, 1, 256206])

In [95]:
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

'我現在正在玩電視遊戲'

In [88]:
result = model.generate(
    **tokenized_sentences, forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hant"), max_length=30
)

result

tensor([[     2, 256201,   4071,  11337, 117970, 252843, 250219, 250680, 251690,
         253914,      2]], device='cuda:0')

In [89]:
tokenizer.batch_decode(result, skip_special_tokens=True)[0]

'我現在正在玩電視遊戲'