## 1. Set Up Environment

In [9]:
!pip install transformers torch datasets evaluate


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [21]:
import torch
import torch.nn as nn
from transformers import (
    BartModel,
    BartTokenizer,
    BertModel,
    BertTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
)
from datasets import load_dataset

## 2. Load and Data preprocessing

In [22]:
import torch

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [23]:
# Step 2: Load Dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")
train_data = dataset["train"]
validation_data = dataset["validation"]

In [56]:
def preprocess_function(examples):
    bart_inputs = bart_tokenizer(examples["article"], max_length=1024, truncation=True, padding="max_length")
    bert_inputs = bert_tokenizer(examples["article"], max_length=512, truncation = True, padding="max_length")
    labels = bart_tokenizer(examples["highlights"], max_length=128, truncation=True, padding="max_length")["input_ids"]

    return {
        "bart_input_ids": bart_inputs["input_ids"],
        "bart_attention_mask": bart_inputs["attention_mask"],
        "bert_input_ids": bert_inputs["input_ids"],
        "bert_attention_mask": bert_inputs["attention_mask"],
        "labels": labels,
    }
tokenized_train_data = train_data.map(preprocess_function, batched=True)
tokenized_validation_data = validation_data.map(preprocess_function, batched=True)


Map:   0%|          | 0/287113 [00:00<?, ? examples/s]

Map:   0%|          | 0/13368 [00:00<?, ? examples/s]

In [57]:
# Step 3: Load Tokenizers
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## 3. Build the model

In [58]:
# Step 5: Define Model
class DualEncoderModel(nn.Module):
    def __init__(self):
        super(DualEncoderModel, self).__init__()
        self.bart_encoder = BartModel.from_pretrained("facebook/bart-large").encoder.to(device)
        self.bert_encoder = BertModel.from_pretrained("bert-base-uncased").to(device)
        self.fusion_layer = nn.Linear(1024 + 768, 1024).to(device)  # Combine BART and BERT dimensions
        self.bart_decoder = BartModel.from_pretrained("facebook/bart-large").decoder.to(device)

    def forward(self, bart_input_ids, bart_attetnion_mask, bert_input_ids, bert_attention_mask, labels):
        # Semantic Encoding (BART)
        semantic_embeddings = self.bart_encoder(input_ids=bart_input_ids, attention_mask=bart_attention_mask).last_hidden_state
        
        # Syntactic Encoding (BERT)
        syntactic_embeddings = self.bert_encoder(input_ids=bert_input_ids, attention_mask=bert_attention_mask).last_hidden_state
        # Align sequence lengths
        min_seq_len = min(semantic_embeddings.size(1), syntactic_embeddings.size(1))
        semantic_embeddings = semantic_embeddings[:, :min_seq_len, :]  # Truncate BART output
        syntactic_embeddings = syntactic_embeddings[:, :min_seq_len, :]  # Truncate BERT output

        # Fusion of Semantic and Syntactic Embeddings
        fused_embeddings = self.fusion_layer(torch.cat((semantic_embeddings, syntactic_embeddings), dim=-1))
        
        # Decoding with Fused Embeddings
        decoder_outputs = self.bart_decoder(encoder_outputs=(fused_embeddings,), input_ids=labels)
        return decoder_outputs

# Instantiate the model
dual_encoder_model = DualEncoderModel().to(device)

In [59]:
class CustomDataCollator(DataCollatorForSeq2Seq):
    def __call__(self, batch):
        batch = super().__call__(batch)
        # Ensure all tensors are on the correct device
        return {
            "bart_input_ids": torch.tensor([example["bart_input_ids"] for example in batch], device=device),
            "bart_attention_mask": torch.tensor([example["bart_attention_mask"] for example in batch], device=device),
            "bert_input_ids": torch.tensor([example["bert_input_ids"] for example in batch], device=device),
            "bert_attention_mask": torch.tensor([example["bert_attention_mask"] for example in batch], device=device),
            "labels": torch.tensor([example["labels"] for example in batch], device=device),
            }
data_collator = CustomDataCollator(tokenizer=bart_tokenizer, model=dual_encoder_model)

## 4. Complete Model

In [60]:
# Step 7: Training Arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=2,  # Reduced batch size for MPS compatibility
    per_device_eval_batch_size=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=3,
    logging_dir="./logs",
    predict_with_generate=True,
    no_cuda=True,
    use_cpu = True
)

In [61]:
# Step 8: Trainer
trainer = Seq2SeqTrainer(
    model=dual_encoder_model,
    args=training_args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_validation_data,
    tokenizer=bart_tokenizer,
    data_collator=data_collator,
)

# Step 9: Train the Model
trainer.train()

ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided ['bart_input_ids', 'bert_input_ids', 'bert_attention_mask']